Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .auth_schemes import OpenIdConnectWithConfig
from .auth_tool import AuthConfig
from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger
from .oauth2_credential_util import _normalize_oauth_scopes

if TYPE_CHECKING:
from ..sessions.state import State
Expand Down Expand Up @@ -161,7 +162,7 @@ def generate_auth_uri(

if isinstance(auth_scheme, OpenIdConnectWithConfig):
authorization_endpoint = auth_scheme.authorization_endpoint
scopes = auth_scheme.scopes
scopes = _normalize_oauth_scopes(auth_scheme.scopes)
else:
authorization_endpoint = (
auth_scheme.flows.implicit
Expand All @@ -173,17 +174,20 @@ def generate_auth_uri(
or auth_scheme.flows.password
and auth_scheme.flows.password.tokenUrl
)
scopes = (
auth_scheme.flows.implicit
and auth_scheme.flows.implicit.scopes
or auth_scheme.flows.authorizationCode
and auth_scheme.flows.authorizationCode.scopes
or auth_scheme.flows.clientCredentials
and auth_scheme.flows.clientCredentials.scopes
or auth_scheme.flows.password
and auth_scheme.flows.password.scopes
)
scopes = list(scopes.keys())
if auth_scheme.flows.implicit:
scopes = _normalize_oauth_scopes(auth_scheme.flows.implicit.scopes)
elif auth_scheme.flows.authorizationCode:
scopes = _normalize_oauth_scopes(
auth_scheme.flows.authorizationCode.scopes
)
elif auth_scheme.flows.clientCredentials:
scopes = _normalize_oauth_scopes(
auth_scheme.flows.clientCredentials.scopes
)
elif auth_scheme.flows.password:
scopes = _normalize_oauth_scopes(auth_scheme.flows.password.scopes)
else:
scopes = []

client = OAuth2Session(
auth_credential.oauth2.client_id,
Expand Down
19 changes: 16 additions & 3 deletions src/google/adk/auth/oauth2_credential_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
logger = logging.getLogger("google_adk." + __name__)


def _normalize_oauth_scopes(scopes: Optional[dict[str, str] | list[str]]) -> list[str]:
"""Normalize OAuth scopes into the list shape expected by authlib."""
if not scopes:
return []
if isinstance(scopes, dict):
return list(scopes.keys())
return list(scopes)


@experimental
def create_oauth2_session(
auth_scheme: AuthScheme,
Expand All @@ -49,21 +58,25 @@ def create_oauth2_session(
logger.warning("OpenIdConnect scheme missing token_endpoint")
return None, None
token_endpoint = auth_scheme.token_endpoint
scopes = auth_scheme.scopes or []
scopes = _normalize_oauth_scopes(auth_scheme.scopes)
elif isinstance(auth_scheme, OAuth2):
# Support both authorization code and client credentials flows
if (
auth_scheme.flows.authorizationCode
and auth_scheme.flows.authorizationCode.tokenUrl
):
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
scopes = _normalize_oauth_scopes(
auth_scheme.flows.authorizationCode.scopes
)
elif (
auth_scheme.flows.clientCredentials
and auth_scheme.flows.clientCredentials.tokenUrl
):
token_endpoint = auth_scheme.flows.clientCredentials.tokenUrl
scopes = list(auth_scheme.flows.clientCredentials.scopes.keys())
scopes = _normalize_oauth_scopes(
auth_scheme.flows.clientCredentials.scopes
)
else:
logger.warning(
"OAuth2 scheme missing required flow configuration. Expected either"
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/auth/test_auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlowClientCredentials
from fastapi.openapi.models import OAuthFlows
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
Expand Down Expand Up @@ -271,6 +272,35 @@ def test_generate_auth_uri_openid(
assert "client_id=mock_client_id" in result.oauth2.auth_uri
assert result.oauth2.state == "mock_state"

@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_client_credentials_with_missing_scopes(
self, oauth2_credentials
):
"""Test client credentials flow tolerates missing scopes."""
auth_scheme = OAuth2(
flows=OAuthFlows(
clientCredentials=OAuthFlowClientCredentials(
tokenUrl="https://example.com/oauth2/token"
)
)
)
auth_scheme.flows.clientCredentials.scopes = None

config = AuthConfig(
auth_scheme=auth_scheme,
raw_auth_credential=oauth2_credentials,
exchanged_auth_credential=oauth2_credentials.model_copy(deep=True),
)

handler = AuthHandler(config)
result = handler.generate_auth_uri()

assert (
result.oauth2.auth_uri
== "https://example.com/oauth2/token?client_id=mock_client_id&scope="
)
assert result.oauth2.state == "mock_state"


class TestGenerateAuthRequest:
"""Tests for the generate_auth_request method."""
Expand Down
25 changes: 25 additions & 0 deletions tests/unittests/auth/test_oauth2_credential_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from authlib.oauth2.rfc6749 import OAuth2Token
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlowClientCredentials
from fastapi.openapi.models import OAuthFlows
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
Expand Down Expand Up @@ -207,6 +208,30 @@ def test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method(
assert token_endpoint == "https://example.com/token"
assert client.token_endpoint_auth_method == "client_secret_jwt"

def test_create_oauth2_session_client_credentials_with_missing_scopes(self):
"""Test client credentials flow tolerates missing scopes."""
flows = OAuthFlows(
clientCredentials=OAuthFlowClientCredentials(
tokenUrl="https://example.com/token"
)
)
flows.clientCredentials.scopes = None
scheme = OAuth2(type_="oauth2", flows=flows)
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uri="https://example.com/callback",
),
)

client, token_endpoint = create_oauth2_session(scheme, credential)

assert client is not None
assert token_endpoint == "https://example.com/token"
assert client.scope == ""

def test_update_credential_with_tokens(self):
"""Test update_credential_with_tokens function."""
credential = AuthCredential(
Expand Down