diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index ec7c75716c..65a0e03b83 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -23,6 +23,7 @@ from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger +from .oauth2_credential_util import _normalize_oauth_scopes if TYPE_CHECKING: from ..sessions.state import State @@ -161,7 +162,7 @@ def generate_auth_uri( if isinstance(auth_scheme, OpenIdConnectWithConfig): authorization_endpoint = auth_scheme.authorization_endpoint - scopes = auth_scheme.scopes + scopes = _normalize_oauth_scopes(auth_scheme.scopes) else: authorization_endpoint = ( auth_scheme.flows.implicit @@ -173,17 +174,20 @@ def generate_auth_uri( or auth_scheme.flows.password and auth_scheme.flows.password.tokenUrl ) - scopes = ( - auth_scheme.flows.implicit - and auth_scheme.flows.implicit.scopes - or auth_scheme.flows.authorizationCode - and auth_scheme.flows.authorizationCode.scopes - or auth_scheme.flows.clientCredentials - and auth_scheme.flows.clientCredentials.scopes - or auth_scheme.flows.password - and auth_scheme.flows.password.scopes - ) - scopes = list(scopes.keys()) + if auth_scheme.flows.implicit: + scopes = _normalize_oauth_scopes(auth_scheme.flows.implicit.scopes) + elif auth_scheme.flows.authorizationCode: + scopes = _normalize_oauth_scopes( + auth_scheme.flows.authorizationCode.scopes + ) + elif auth_scheme.flows.clientCredentials: + scopes = _normalize_oauth_scopes( + auth_scheme.flows.clientCredentials.scopes + ) + elif auth_scheme.flows.password: + scopes = _normalize_oauth_scopes(auth_scheme.flows.password.scopes) + else: + scopes = [] client = OAuth2Session( auth_credential.oauth2.client_id, diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py index df2f26c002..7a10a596d0 100644 --- a/src/google/adk/auth/oauth2_credential_util.py +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -30,6 +30,15 @@ logger = logging.getLogger("google_adk." + __name__) +def _normalize_oauth_scopes(scopes: Optional[dict[str, str] | list[str]]) -> list[str]: + """Normalize OAuth scopes into the list shape expected by authlib.""" + if not scopes: + return [] + if isinstance(scopes, dict): + return list(scopes.keys()) + return list(scopes) + + @experimental def create_oauth2_session( auth_scheme: AuthScheme, @@ -49,7 +58,7 @@ def create_oauth2_session( logger.warning("OpenIdConnect scheme missing token_endpoint") return None, None token_endpoint = auth_scheme.token_endpoint - scopes = auth_scheme.scopes or [] + scopes = _normalize_oauth_scopes(auth_scheme.scopes) elif isinstance(auth_scheme, OAuth2): # Support both authorization code and client credentials flows if ( @@ -57,13 +66,17 @@ def create_oauth2_session( and auth_scheme.flows.authorizationCode.tokenUrl ): token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl - scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) + scopes = _normalize_oauth_scopes( + auth_scheme.flows.authorizationCode.scopes + ) elif ( auth_scheme.flows.clientCredentials and auth_scheme.flows.clientCredentials.tokenUrl ): token_endpoint = auth_scheme.flows.clientCredentials.tokenUrl - scopes = list(auth_scheme.flows.clientCredentials.scopes.keys()) + scopes = _normalize_oauth_scopes( + auth_scheme.flows.clientCredentials.scopes + ) else: logger.warning( "OAuth2 scheme missing required flow configuration. Expected either" diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index 2faeeb158e..70b76258c0 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -22,6 +22,7 @@ from fastapi.openapi.models import APIKeyIn from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlowClientCredentials from fastapi.openapi.models import OAuthFlows from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes @@ -271,6 +272,35 @@ def test_generate_auth_uri_openid( assert "client_id=mock_client_id" in result.oauth2.auth_uri assert result.oauth2.state == "mock_state" + @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) + def test_generate_auth_uri_client_credentials_with_missing_scopes( + self, oauth2_credentials + ): + """Test client credentials flow tolerates missing scopes.""" + auth_scheme = OAuth2( + flows=OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/oauth2/token" + ) + ) + ) + auth_scheme.flows.clientCredentials.scopes = None + + config = AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=oauth2_credentials.model_copy(deep=True), + ) + + handler = AuthHandler(config) + result = handler.generate_auth_uri() + + assert ( + result.oauth2.auth_uri + == "https://example.com/oauth2/token?client_id=mock_client_id&scope=" + ) + assert result.oauth2.state == "mock_state" + class TestGenerateAuthRequest: """Tests for the generate_auth_request method.""" diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index b9d4da6711..e609377795 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -19,6 +19,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlowClientCredentials from fastapi.openapi.models import OAuthFlows from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes @@ -207,6 +208,30 @@ def test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method( assert token_endpoint == "https://example.com/token" assert client.token_endpoint_auth_method == "client_secret_jwt" + def test_create_oauth2_session_client_credentials_with_missing_scopes(self): + """Test client credentials flow tolerates missing scopes.""" + flows = OAuthFlows( + clientCredentials=OAuthFlowClientCredentials( + tokenUrl="https://example.com/token" + ) + ) + flows.clientCredentials.scopes = None + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.scope == "" + def test_update_credential_with_tokens(self): """Test update_credential_with_tokens function.""" credential = AuthCredential(