diff --git a/.gitignore b/.gitignore index de6ab7555..25feb0a75 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,6 @@ upload.sh /tests/ .env sensitive_info_result.txt +.env.e2e +.env_e2e +.env.e2e.example diff --git a/README.md b/README.md index 7f77777db..468311753 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,36 @@ it as a link for the user to open in Feishu/Lark. ```python import lark_oapi as lark +## ClientAssertion Keyless Mode + +For self-built apps that use an external signing service, the SDK can fetch +tenant tokens with `client_assertion` instead of `app_secret`. The SDK does not +generate, parse, sign, or store JWT private keys; your provider supplies the +final assertion string. + +```python +import os + +import lark_oapi as lark +from lark_oapi.core.client_assertion import ClientAssertionToken + + +class EnvClientAssertionProvider: + def retrieve_token(self, aud: str) -> ClientAssertionToken: + return ClientAssertionToken(os.environ["LARK_CLIENT_ASSERTION"]) + + +client = lark.Client.builder() \ + .app_id(os.environ["LARK_APP_ID"]) \ + .client_assertion_provider(EnvClientAssertionProvider()) \ + .build() +``` + +If you use a custom OpenAPI domain, also configure `oauth_base_url(...)` so the +SDK can derive the OAuth audience correctly. Keyless mode is for self-built +apps only and does not support AppAccessToken-only APIs. + +## Channel Module def on_qr_code(info): print(info["url"]) diff --git a/README.zh.md b/README.zh.md index 9f658de35..284c6d653 100644 --- a/README.zh.md +++ b/README.zh.md @@ -43,6 +43,35 @@ request = CreateMessageRequest.builder() \ response = client.im.v1.message.create(request) ``` +## ClientAssertion 无密钥模式 + +自建应用如果通过外部签发服务提供 `client_assertion`,SDK 可以在不配置 +`app_secret` 的情况下换取 tenant token。SDK 不生成、不解析、不签名 JWT, +也不保存私钥;provider 只需要返回最终的 assertion 字符串。 + +```python +import os + +import lark_oapi as lark +from lark_oapi.core.client_assertion import ClientAssertionToken + + +class EnvClientAssertionProvider: + def retrieve_token(self, aud: str) -> ClientAssertionToken: + return ClientAssertionToken(os.environ["LARK_CLIENT_ASSERTION"]) + + +client = lark.Client.builder() \ + .app_id(os.environ["LARK_APP_ID"]) \ + .client_assertion_provider(EnvClientAssertionProvider()) \ + .build() +``` + +如果使用自定义 OpenAPI 域名,需要同时配置 `oauth_base_url(...)`,以便 SDK +正确生成 OAuth audience。无密钥模式仅支持自建应用,不支持只依赖 +AppAccessToken 的 API。 + +## Channel 模块 ## 一键创建应用 `lark_oapi.register_app` 基于 OAuth device flow 创建应用。SDK 会在 diff --git a/lark_oapi/client.py b/lark_oapi/client.py index 98ae7ddff..16e46385d 100644 --- a/lark_oapi/client.py +++ b/lark_oapi/client.py @@ -8,6 +8,7 @@ from .core import logger, JSON from .core.model import * from .core.token import TokenManager, verify +from .core.access_token import AccessToken from .core.http import Transport from .api.auth.service import AuthService from .api.event.service import EventService @@ -155,6 +156,7 @@ def __init__(self) -> None: self.performance: Optional[PerformanceService] = None self.security_and_compliance: Optional[SecurityAndComplianceService] = None self.speech_to_text: Optional[SpeechToTextService] = None + self.access_token: Optional[AccessToken] = None @staticmethod def builder() -> "ClientBuilder": @@ -229,6 +231,14 @@ def app_secret(self, app_secret: str) -> "ClientBuilder": self._config.app_secret = app_secret return self + def client_assertion_provider(self, provider) -> "ClientBuilder": + self._config.client_assertion_provider = provider + return self + + def oauth_base_url(self, oauth_base_url: str) -> "ClientBuilder": + self._config.oauth_base_url = oauth_base_url + return self + def domain(self, domain: str) -> "ClientBuilder": _validate_domain(domain) self._config.domain = domain @@ -331,6 +341,7 @@ def build(self) -> Client: client.performance = PerformanceService(self._config) client.security_and_compliance = SecurityAndComplianceService(self._config) client.speech_to_text = SpeechToTextService(self._config) + client.access_token = AccessToken(self._config) return client diff --git a/lark_oapi/core/__init__.py b/lark_oapi/core/__init__.py index 2da1ab6d5..8dde57e8c 100644 --- a/lark_oapi/core/__init__.py +++ b/lark_oapi/core/__init__.py @@ -1,4 +1,5 @@ from .cache import ICache +from .client_assertion import * from .const import * from .enum import * from .env_var import * diff --git a/lark_oapi/core/access_token/__init__.py b/lark_oapi/core/access_token/__init__.py new file mode 100644 index 000000000..1c249e9ff --- /dev/null +++ b/lark_oapi/core/access_token/__init__.py @@ -0,0 +1,2 @@ +from .client import * +from .model import * diff --git a/lark_oapi/core/access_token/client.py b/lark_oapi/core/access_token/client.py new file mode 100644 index 000000000..44c579dd4 --- /dev/null +++ b/lark_oapi/core/access_token/client.py @@ -0,0 +1,116 @@ +import json +from typing import Dict, Optional + +from lark_oapi.core.client_assertion import build_proxy_url, resolve_oauth_aud, resolve_oauth_base_url +from lark_oapi.core.const import ( + APPLICATION_JSON, + CLIENT_ASSERTION_TYPE_JWT_BEARER, + CONTENT_TYPE, + ERR_CODE_APP_SECRET_AND_CLIENT_ASSERTION_EMPTY, + ERR_CODE_CLIENT_ASSERTION_RETRIEVE_FAILED, + ERR_CODE_CLIENT_ASSERTION_TOKEN_EMPTY, + GRANT_TYPE_AUTHORIZATION_CODE, + GRANT_TYPE_REFRESH_TOKEN, + OAUTH_TOKEN_URI, + UTF_8, + X_TARGET_SERVICE, +) +from lark_oapi.core.enum import HttpMethod +from lark_oapi.core.exception import AccessTokenException, ClientAssertionException +from lark_oapi.core.http import Transport +from lark_oapi.core.model import BaseRequest, Config, RequestOption +from lark_oapi.core.utils import Strings +from .model import AccessTokenResponse, value_if_not_empty + + +class AccessToken(object): + def __init__(self, config: Config) -> None: + self._config = config + + def retrieve_by_authorization_code( + self, + code: str, + redirect_uri: Optional[str] = None, + code_verifier: Optional[str] = None, + scope: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> AccessTokenResponse: + return self._do_request( + { + "grant_type": GRANT_TYPE_AUTHORIZATION_CODE, + "code": code, + "redirect_uri": redirect_uri, + "code_verifier": code_verifier, + "scope": scope, + }, + headers=headers, + ) + + def refresh( + self, + refresh_token: str, + scope: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> AccessTokenResponse: + return self._do_request( + { + "grant_type": GRANT_TYPE_REFRESH_TOKEN, + "refresh_token": refresh_token, + "scope": scope, + }, + headers=headers, + ) + + def _do_request(self, body: Dict[str, object], headers: Optional[Dict[str, str]] = None) -> AccessTokenResponse: + oauth_base_url = resolve_oauth_base_url(self._config) + aud = resolve_oauth_aud(self._config) + request_url = oauth_base_url + OAUTH_TOKEN_URI + body = {k: v for k, v in body.items() if v is not None} + body["client_id"] = self._config.app_id + option = RequestOption() + if headers: + option.headers.update(headers) + + if self._config.client_assertion_provider is not None: + try: + assertion_token = self._config.client_assertion_provider.retrieve_token(aud) + except Exception as e: + raise ClientAssertionException(ERR_CODE_CLIENT_ASSERTION_RETRIEVE_FAILED, str(e)) + if assertion_token is None or Strings.is_empty(assertion_token.value): + raise ClientAssertionException(ERR_CODE_CLIENT_ASSERTION_TOKEN_EMPTY, "client assertion token is empty") + body["client_assertion_type"] = CLIENT_ASSERTION_TYPE_JWT_BEARER + body["client_assertion"] = assertion_token.value + if assertion_token.target_info is not None: + request_url = build_proxy_url(assertion_token.target_info, OAUTH_TOKEN_URI) + option.headers[X_TARGET_SERVICE] = aud + elif Strings.is_not_empty(self._config.app_secret): + body["client_secret"] = self._config.app_secret + else: + raise ClientAssertionException( + ERR_CODE_APP_SECRET_AND_CLIENT_ASSERTION_EMPTY, + "AppSecret and ClientAssertionProvider cannot both be empty for AccessToken APIs", + ) + + req = BaseRequest() + req.http_method = HttpMethod.POST + req.uri = request_url + req.headers = {CONTENT_TYPE: APPLICATION_JSON} + req.body = body + raw = Transport.execute(self._config, req, option) + resp = json.loads(str(raw.content, UTF_8)) + if raw.status_code != 200: + raise AccessTokenException( + raw.status_code, + resp.get("code") or 0, + resp.get("error") or "", + resp.get("error_description") or "", + ) + return AccessTokenResponse( + access_token=value_if_not_empty(resp.get("access_token")), + token_type=value_if_not_empty(resp.get("token_type")), + expires_in=value_if_not_empty(resp.get("expires_in")), + refresh_token=value_if_not_empty(resp.get("refresh_token")), + refresh_token_expires_in=value_if_not_empty(resp.get("refresh_token_expires_in")), + scope=value_if_not_empty(resp.get("scope")), + raw=raw, + ) diff --git a/lark_oapi/core/access_token/model.py b/lark_oapi/core/access_token/model.py new file mode 100644 index 000000000..73aac9d29 --- /dev/null +++ b/lark_oapi/core/access_token/model.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Optional + +from lark_oapi.core.model import RawResponse + + +@dataclass +class AccessTokenResponse: + access_token: Optional[str] = None + token_type: Optional[str] = None + expires_in: Optional[int] = None + refresh_token: Optional[str] = None + refresh_token_expires_in: Optional[int] = None + scope: Optional[str] = None + raw: Optional[RawResponse] = None + + +def value_if_not_empty(value): + return value if value not in ("", 0, None) else None diff --git a/lark_oapi/core/client_assertion.py b/lark_oapi/core/client_assertion.py new file mode 100644 index 000000000..9867408d5 --- /dev/null +++ b/lark_oapi/core/client_assertion.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass +from typing import Optional, Protocol +from urllib.parse import urlparse + +from lark_oapi.core.const import FEISHU_DOMAIN, FEISHU_OAUTH_DOMAIN, LARK_DOMAIN, LARK_OAUTH_DOMAIN + + +@dataclass +class TargetInfo: + target_service: str + target_prefix: str = "" + + +@dataclass +class ClientAssertionToken: + value: str + target_info: Optional[TargetInfo] = None + + +class ClientAssertionProvider(Protocol): + def retrieve_token(self, aud: str) -> ClientAssertionToken: + raise NotImplementedError + + +def _normalize_base_url(base_url: str) -> str: + if "://" not in base_url: + base_url = "https://" + base_url + return base_url.rstrip("/") + + +def extract_aud_from_url(raw_url: str) -> str: + if "://" not in raw_url: + raw_url = "https://" + raw_url + parsed = urlparse(raw_url) + if parsed.netloc: + return parsed.netloc + if parsed.path and "/" not in parsed.path: + return parsed.path + raise ValueError(f"invalid url: {raw_url}") + + +def resolve_oauth_base_url(config) -> str: + oauth_base_url = getattr(config, "oauth_base_url", None) + if oauth_base_url: + return _normalize_base_url(oauth_base_url) + + aud = extract_aud_from_url(config.domain) + if aud == extract_aud_from_url(FEISHU_DOMAIN): + return FEISHU_OAUTH_DOMAIN + if aud == extract_aud_from_url(LARK_DOMAIN): + return LARK_OAUTH_DOMAIN + raise ValueError( + "OAuthBaseUrl is not configured. When domain is set to a non-default value " + "(neither open.feishu.cn nor open.larksuite.com), configure oauth_base_url explicitly." + ) + + +def resolve_oauth_aud(config) -> str: + return extract_aud_from_url(resolve_oauth_base_url(config)) + + +def build_proxy_url(target_info: TargetInfo, api_path: str) -> str: + target_service = target_info.target_service + if "://" not in target_service: + target_service = "https://" + target_service + return target_service + target_info.target_prefix + api_path diff --git a/lark_oapi/core/const.py b/lark_oapi/core/const.py index b8a082872..5f609053b 100644 --- a/lark_oapi/core/const.py +++ b/lark_oapi/core/const.py @@ -5,10 +5,13 @@ # Domain FEISHU_DOMAIN = "https://open.feishu.cn" LARK_DOMAIN = "https://open.larksuite.com" +FEISHU_OAUTH_DOMAIN = "https://accounts.feishu.cn" +LARK_OAUTH_DOMAIN = "https://accounts.larksuite.com" # Header USER_AGENT = "User-Agent" AUTHORIZATION = "Authorization" +X_TARGET_SERVICE = "X-Target-Service" X_TT_LOGID = "X-Tt-Logid" X_REQUEST_ID = "X-Request-Id" CONTENT_TYPE = "Content-Type" @@ -24,3 +27,17 @@ URL_VERIFICATION = "url_verification" UTF_8 = "UTF-8" + +# OAuth / ClientAssertion +OAUTH_TOKEN_URI = "/oauth/v3/token" +GRANT_TYPE_AUTHORIZATION_CODE = "authorization_code" +GRANT_TYPE_REFRESH_TOKEN = "refresh_token" +GRANT_TYPE_JWT_BEARER = "urn:ietf:params:oauth:grant-type:jwt-bearer" +CLIENT_ASSERTION_TYPE_JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + +# ClientAssertion error codes, aligned with oapi-sdk-go. +ERR_CODE_CLIENT_ASSERTION_PROVIDER_NOT_CONFIGURED = 7100 +ERR_CODE_CLIENT_ASSERTION_TOKEN_EMPTY = 7101 +ERR_CODE_CLIENT_ASSERTION_RETRIEVE_FAILED = 7102 +ERR_CODE_CLIENT_ASSERTION_MODE_NOT_SUPPORTED = 7103 +ERR_CODE_APP_SECRET_AND_CLIENT_ASSERTION_EMPTY = 7104 diff --git a/lark_oapi/core/exception.py b/lark_oapi/core/exception.py index a55aaafc2..db8db19ae 100644 --- a/lark_oapi/core/exception.py +++ b/lark_oapi/core/exception.py @@ -3,6 +3,29 @@ def __init__(self, msg: str): self.msg: str = msg +class ClientAssertionException(Exception): + def __init__(self, code: int, msg: str): + super().__init__(msg) + self.code = code + self.msg = msg + + def __str__(self): + return f"{self.code}: {self.msg}" + + +class AccessTokenException(Exception): + def __init__(self, status_code: int, code: int, error: str, error_description: str): + super().__init__(error_description or error or "access token request failed") + self.status_code = status_code + self.code = code + self.error = error + self.error_description = error_description + + def __str__(self): + msg = self.error_description or self.error or "access token request failed" + return f"statusCode:{self.status_code}, code:{self.code}, msg:{msg}" + + class ObtainAccessTokenException(Exception): def __init__(self, desc: str, code: int, msg: str): self.desc = desc diff --git a/lark_oapi/core/http/transport.py b/lark_oapi/core/http/transport.py index 271c1efdd..47fd4e3d3 100644 --- a/lark_oapi/core/http/transport.py +++ b/lark_oapi/core/http/transport.py @@ -1,5 +1,6 @@ import json import urllib.parse +from typing import Optional import httpx import requests @@ -38,10 +39,12 @@ def execute(conf: Config, req: BaseRequest, option: Optional[RequestOption] = No timeout=conf.timeout, ) - logger.debug(f"{str(req.http_method.name)} {url} {response.status_code}, " - f"headers: {JSON.marshal(headers)}, " - f"params: {JSON.marshal(req.queries)}, " - f"body: {str(data, UTF_8) if isinstance(data, bytes) else data}") + logger.debug( + f"{str(req.http_method.name)} request completed with status {response.status_code}, " + f"headers_count: {len(headers)}, " + f"params_count: {len(req.queries)}, " + f"body_present: {data is not None}" + ) resp = RawResponse() resp.status_code = response.status_code @@ -84,10 +87,10 @@ async def aexecute(conf: Config, req: BaseRequest, option: Optional[RequestOptio ) logger.debug( - f"{str(req.http_method.name)} {url} {response.status_code}" - f"{f', headers: {JSON.marshal(headers)}' if headers else ''}" - f"{f', params: {JSON.marshal(req.queries)}' if req.queries else ''}" - f"{f', body: {JSON.marshal(_merge_dicts(json_, files, data))}' if json_ or files or data else ''}" + f"{str(req.http_method.name)} request completed with status {response.status_code}, " + f"headers_count: {len(headers)}, " + f"params_count: {len(req.queries)}, " + f"body_present: {json_ is not None or files is not None or data is not None}" ) resp = RawResponse() @@ -110,6 +113,8 @@ def _build_url(domain: str, uri: str, paths: Dict[str, str]) -> str: encoded = urllib.parse.quote(str(value), safe="") uri = uri.replace(":" + key, encoded) + if uri.startswith("http://") or uri.startswith("https://"): + return uri return domain + uri @@ -136,11 +141,3 @@ def _build_header(request: BaseRequest, option: RequestOption, conf: Optional[Co headers[AUTHORIZATION] = f"Bearer {option.user_access_token}" return headers - - -def _merge_dicts(*dicts): - res = {} - for d in dicts: - if d is not None: - res.update(d) - return res diff --git a/lark_oapi/core/model/config.py b/lark_oapi/core/model/config.py index db5e3b3fc..911719c55 100644 --- a/lark_oapi/core/model/config.py +++ b/lark_oapi/core/model/config.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional from lark_oapi.core import AppType, LogLevel from lark_oapi.core.cache import ICache @@ -10,8 +10,9 @@ def __init__(self) -> None: self.app_id: Optional[str] = None self.app_secret: Optional[str] = None self.domain: str = FEISHU_DOMAIN # 域名, 默认为 https://open.feishu.cn - self.timeout: Optional[ - float] = 30 # client timeout in seconds (default 30s); override via ClientBuilder.timeout() + self.oauth_base_url: Optional[str] = None + self.client_assertion_provider: Optional[Any] = None + self.timeout: Optional[float] = 30 # client timeout in seconds (default 30s); override via ClientBuilder.timeout() self.app_type: AppType = AppType.SELF # 应用类型, 默认为自建应用; 若设为 ISV 需在 request_option 中配置 tenant_key self.enable_set_token: bool = False # 是否允许手动设置 token, 默认不开启; 开启后需在 request_option 中配置 token self.cache: Optional[ICache] = None # 自定义缓存, 默认使用预置的本地缓存 diff --git a/lark_oapi/core/tests/__init__.py b/lark_oapi/core/tests/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/lark_oapi/core/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/lark_oapi/core/tests/e2e/__init__.py b/lark_oapi/core/tests/e2e/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/lark_oapi/core/tests/e2e/__init__.py @@ -0,0 +1 @@ + diff --git a/lark_oapi/core/tests/e2e/client_assertion_live_harness.py b/lark_oapi/core/tests/e2e/client_assertion_live_harness.py new file mode 100644 index 000000000..f255ef2ac --- /dev/null +++ b/lark_oapi/core/tests/e2e/client_assertion_live_harness.py @@ -0,0 +1,182 @@ +import base64 +import json +import os +import shlex +import socket +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Dict, Mapping, MutableMapping, Optional +from urllib.parse import urlencode + +from lark_oapi.core.cache import ICache +from lark_oapi.core.client_assertion import ClientAssertionToken, TargetInfo + + +DEFAULT_ENV_FILES = (".env.e2e", ".env.e2e.example") +_ORIGINAL_GETADDRINFO = None + + +@dataclass(frozen=True) +class DeployDomains: + openapi_domain: str + oauth_base_url: str + + +ONLINE_DOMAINS = DeployDomains( + openapi_domain="https://open.feishu.cn", + oauth_base_url="https://accounts.feishu.cn", +) +BOE_DOMAINS = DeployDomains( + openapi_domain="https://open.feishu-boe.cn", + oauth_base_url="https://accounts.feishu-boe.cn", +) + + +class MemoryCache(ICache): + def __init__(self) -> None: + self._data: Dict[str, str] = {} + + def get(self, key: str) -> str: + return self._data.get(key) + + def set(self, key: str, value: str, expire: int): + self._data[key] = value + + +class ModeEnvProvider: + def __init__(self, mode: str, env: Optional[Mapping[str, str]] = None) -> None: + if mode not in ("zti", "gdpr"): + raise ValueError("mode must be zti or gdpr") + self.mode = mode + self.env = env if env is not None else os.environ + self.auds = [] + + def retrieve_token(self, aud: str) -> ClientAssertionToken: + self.auds.append(aud) + if self.mode == "zti": + assertion = self.env.get("LARK_ZTI_CLIENT_ASSERTION") + if not assertion: + raise RuntimeError("LARK_ZTI_CLIENT_ASSERTION is required") + return ClientAssertionToken(assertion) + + assertion = self.env.get("LARK_GDPR_CLIENT_ASSERTION") + if not assertion: + raise RuntimeError("LARK_GDPR_CLIENT_ASSERTION is required") + target_service = self.env.get("LARK_GDPR_PROXY_SERVICE") + target_prefix = self.env.get("LARK_GDPR_PROXY_PREFIX") + if not target_service or not target_prefix: + raise RuntimeError("LARK_GDPR_PROXY_SERVICE and LARK_GDPR_PROXY_PREFIX are required") + if not target_prefix.startswith("/"): + raise RuntimeError("LARK_GDPR_PROXY_PREFIX must start with /") + return ClientAssertionToken(assertion, TargetInfo(target_service, target_prefix)) + + +def parse_bool(value: Optional[str], default: bool = False) -> bool: + if value is None: + return default + normalized = value.strip().lower() + if normalized in ("1", "true", "yes", "y", "on"): + return True + if normalized in ("0", "false", "no", "n", "off"): + return False + return default + + +def deploy_domains(deploy_env: Optional[str]) -> DeployDomains: + normalized = (deploy_env or "online").strip().lower() + if normalized in ("online", "prod", "production", "cn"): + return ONLINE_DOMAINS + if normalized == "boe": + return BOE_DOMAINS + raise ValueError("LARK_DEPLOY_ENV must be online or boe") + + +def load_env_file( + path: Path, + env: Optional[MutableMapping[str, str]] = None, + override: bool = False, +) -> bool: + env = env if env is not None else os.environ + if not path.exists(): + return False + + for raw_line in path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("export "): + line = line[len("export "):].strip() + parts = shlex.split(line, comments=True, posix=True) + if not parts or "=" not in parts[0]: + continue + key, value = parts[0].split("=", 1) + if not override and key in env: + continue + env[key] = value + return True + + +def load_e2e_env( + root: Optional[Path] = None, + env: Optional[MutableMapping[str, str]] = None, + override: bool = False, +) -> Optional[Path]: + root = root or Path.cwd() + env = env if env is not None else os.environ + for name in DEFAULT_ENV_FILES: + path = root / name + if load_env_file(path, env=env, override=override): + return path + return None + + +def build_host_override_getaddrinfo( + original_getaddrinfo: Callable, + target_host: str, + target_ip: str, +) -> Callable: + family = socket.AF_INET6 if ":" in target_ip else socket.AF_INET + + def getaddrinfo(host, port, family_arg=0, type_arg=0, proto_arg=0, flags_arg=0): + if host == target_host: + return original_getaddrinfo(target_ip, port, family, type_arg, proto_arg, flags_arg) + return original_getaddrinfo(host, port, family_arg, type_arg, proto_arg, flags_arg) + + return getaddrinfo + + +def install_host_resolver_override(target_host: str, target_ip: Optional[str]) -> bool: + if not target_ip: + return False + + global _ORIGINAL_GETADDRINFO + if _ORIGINAL_GETADDRINFO is None: + _ORIGINAL_GETADDRINFO = socket.getaddrinfo + socket.getaddrinfo = build_host_override_getaddrinfo(_ORIGINAL_GETADDRINFO, target_host, target_ip) + return True + + +def build_authorize_url( + oauth_base_url: str, + app_id: str, + redirect_uri: str, + scope: str, + state: str, +) -> str: + query = urlencode({ + "app_id": app_id, + "redirect_uri": redirect_uri, + "scope": scope, + "state": state, + }) + return oauth_base_url.rstrip("/") + "/open-apis/authen/v1/authorize?" + query + + +def decode_jwt_payload_unverified(token: str) -> Dict[str, object]: + parts = token.split(".") + if len(parts) < 2: + raise ValueError("invalid jwt") + payload = parts[1] + padding = "=" * (-len(payload) % 4) + decoded = base64.urlsafe_b64decode((payload + padding).encode("ascii")) + return json.loads(decoded.decode("utf-8")) diff --git a/lark_oapi/core/tests/e2e/test_client_assertion_keyless_live.py b/lark_oapi/core/tests/e2e/test_client_assertion_keyless_live.py new file mode 100644 index 000000000..e4a6d96e3 --- /dev/null +++ b/lark_oapi/core/tests/e2e/test_client_assertion_keyless_live.py @@ -0,0 +1,368 @@ +import asyncio +import json +import os +import queue +import secrets +import threading +import time +import uuid +import webbrowser +from contextlib import contextmanager +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Dict, Iterable +from urllib.parse import parse_qs, urlsplit + +import pytest +import websockets + +from lark_oapi import Client +from lark_oapi.api.contact.v3.model.basic_batch_user_request import BasicBatchUserRequest +from lark_oapi.api.contact.v3.model.basic_batch_user_request_body import BasicBatchUserRequestBody +from lark_oapi.api.im.v1.model.create_message_request import CreateMessageRequest +from lark_oapi.api.im.v1.model.create_message_request_body import CreateMessageRequestBody +from lark_oapi.core.client_assertion import extract_aud_from_url +from lark_oapi.core.model import RequestOption +from lark_oapi.core.token import TokenManager +from lark_oapi.core.tests.e2e.client_assertion_live_harness import ( + MemoryCache, + ModeEnvProvider, + build_authorize_url, + deploy_domains, + install_host_resolver_override, + load_e2e_env, + parse_bool, +) +from lark_oapi.ws import client as ws_client + + +load_e2e_env() +install_host_resolver_override(os.environ.get("LARK_GDPR_PROXY_SERVICE"), os.environ.get("LARK_GDPR_PROXY_RESOLVE_IP")) + +pytestmark = pytest.mark.skipif( + os.environ.get("LARK_CLIENT_ASSERTION_E2E") != "1", + reason="set LARK_CLIENT_ASSERTION_E2E=1 to run live ClientAssertion E2E", +) + + +def _require_env(names: Iterable[str]) -> Dict[str, str]: + missing = [name for name in names if not os.environ.get(name)] + if missing: + pytest.skip("missing env: " + ", ".join(missing)) + return {name: os.environ[name] for name in names} + + +def _domains(): + return deploy_domains(os.environ.get("LARK_DEPLOY_ENV")) + + +def _build_app_secret_client(): + env = _require_env(["LARK_APP_ID", "LARK_APP_SECRET"]) + domains = _domains() + return ( + Client.builder() + .app_id(env["LARK_APP_ID"]) + .app_secret(env["LARK_APP_SECRET"]) + .domain(domains.openapi_domain) + .oauth_base_url(domains.oauth_base_url) + .cache(MemoryCache()) + .build() + ) + + +def _build_client_assertion_client(mode: str, provider: ModeEnvProvider = None): + env = _require_env(["LARK_APP_ID"]) + if mode == "zti": + _require_env(["LARK_ZTI_CLIENT_ASSERTION"]) + elif mode == "gdpr": + _require_env([ + "LARK_GDPR_CLIENT_ASSERTION", + "LARK_GDPR_PROXY_SERVICE", + "LARK_GDPR_PROXY_PREFIX", + ]) + if not os.environ["LARK_GDPR_PROXY_PREFIX"].startswith("/"): + pytest.fail("LARK_GDPR_PROXY_PREFIX must start with /") + if (os.environ.get("LARK_DEPLOY_ENV") or "online").lower() not in ("online", "prod", "production", "cn"): + pytest.fail("GDPR proxy E2E must use LARK_DEPLOY_ENV=online") + else: + raise ValueError("mode must be zti or gdpr") + + provider = provider or ModeEnvProvider(mode) + domains = _domains() + client = ( + Client.builder() + .app_id(env["LARK_APP_ID"]) + .client_assertion_provider(provider) + .domain(domains.openapi_domain) + .oauth_base_url(domains.oauth_base_url) + .cache(MemoryCache()) + .build() + ) + return client, provider + + +def _build_client_with_app_secret_and_provider(mode: str): + env = _require_env(["LARK_APP_ID", "LARK_APP_SECRET"]) + provider = ModeEnvProvider(mode) + domains = _domains() + client = ( + Client.builder() + .app_id(env["LARK_APP_ID"]) + .app_secret(env["LARK_APP_SECRET"]) + .client_assertion_provider(provider) + .domain(domains.openapi_domain) + .oauth_base_url(domains.oauth_base_url) + .cache(MemoryCache()) + .build() + ) + return client, provider + + +def _response_summary(resp) -> str: + return "code={}, msg={}, log_id={}".format(resp.code, resp.msg, resp.get_log_id()) + + +def _case_id(prefix: str) -> str: + return "{}-{}".format(prefix, uuid.uuid4().hex[:8]) + + +def _send_tat_message(client: Client, case_id: str) -> str: + env = _require_env(["LARK_OPEN_ID"]) + body = ( + CreateMessageRequestBody.builder() + .receive_id(env["LARK_OPEN_ID"]) + .msg_type("text") + .content(json.dumps({"text": "ClientAssertion E2E TAT message: " + case_id})) + .uuid(str(uuid.uuid4())) + .build() + ) + req = CreateMessageRequest.builder().receive_id_type("open_id").request_body(body).build() + + resp = client.im.v1.message.create(req) + + assert resp.success(), _response_summary(resp) + assert resp.data is not None and resp.data.message_id + return resp.data.message_id + + +def _call_basic_batch_with_uat(client: Client, user_access_token: str): + env = _require_env(["LARK_OPEN_ID"]) + body = BasicBatchUserRequestBody.builder().user_ids([env["LARK_OPEN_ID"]]).build() + req = BasicBatchUserRequest.builder().user_id_type("open_id").request_body(body).build() + option = RequestOption.builder().user_access_token(user_access_token).build() + + resp = client.contact.v3.user.basic_batch(req, option) + + assert resp.success(), _response_summary(resp) + assert resp.data is not None and resp.data.users + assert resp.data.users[0].user_id + + +def _oauth_timeout_seconds() -> int: + return int(os.environ.get("LARK_OAUTH_TIMEOUT_SECONDS") or "180") + + +@contextmanager +def _oauth_callback_server(redirect_uri: str, expected_state: str): + parsed = urlsplit(redirect_uri) + if parsed.scheme != "http" or parsed.hostname not in ("127.0.0.1", "localhost"): + raise ValueError("LARK_OAUTH_REDIRECT_URI must be a local http callback") + if not parsed.port: + raise ValueError("LARK_OAUTH_REDIRECT_URI must include a port") + + result_queue = queue.Queue(maxsize=1) + + class Handler(BaseHTTPRequestHandler): + def do_GET(self): + callback = urlsplit(self.path) + if callback.path != parsed.path: + self.send_response(404) + self.end_headers() + return + + params = {key: values[0] for key, values in parse_qs(callback.query).items()} + if params.get("state") != expected_state: + self._send_text(400, "OAuth state mismatch.") + result_queue.put({"error": "state_mismatch"}) + return + + if "code" not in params: + self._send_text(400, "OAuth callback missing code.") + params.setdefault("error", "missing_code") + result_queue.put(params) + return + + self._send_text(200, "OAuth callback received. You can close this tab.") + result_queue.put(params) + + def _send_text(self, status: int, text: str): + data = text.encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "text/plain; charset=utf-8") + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + def log_message(self, fmt, *args): + return + + server = ThreadingHTTPServer((parsed.hostname, parsed.port), Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + try: + yield result_queue + finally: + server.shutdown() + server.server_close() + thread.join(timeout=5) + + +def _authorize_and_exchange_code(client: Client): + env = _require_env([ + "LARK_APP_ID", + "LARK_OAUTH_REDIRECT_URI", + "LARK_OAUTH_SCOPE", + ]) + if parse_bool(os.environ.get("LARK_OAUTH_PKCE_REQUIRED"), default=False): + pytest.fail("OAuth UAT E2E expects LARK_OAUTH_PKCE_REQUIRED=false") + + domains = _domains() + state = secrets.token_urlsafe(24) + auth_url = build_authorize_url( + oauth_base_url=domains.oauth_base_url, + app_id=env["LARK_APP_ID"], + redirect_uri=env["LARK_OAUTH_REDIRECT_URI"], + scope=env["LARK_OAUTH_SCOPE"], + state=state, + ) + + with _oauth_callback_server(env["LARK_OAUTH_REDIRECT_URI"], state) as callback_queue: + print("\nOpen this URL to authorize OAuth UAT E2E:\n{}".format(auth_url)) + webbrowser.open(auth_url) + try: + params = callback_queue.get(timeout=_oauth_timeout_seconds()) + except queue.Empty: + pytest.fail("OAuth callback timed out") + + if params.get("error"): + pytest.fail("OAuth callback failed: " + params["error"]) + return client.access_token.retrieve_by_authorization_code( + code=params["code"], + redirect_uri=env["LARK_OAUTH_REDIRECT_URI"], + scope=env["LARK_OAUTH_SCOPE"], + ) + + +async def _connect_ws_once(conn_url: str, listen_seconds: int): + kwargs = ws_client._ws_connect_kwargs() + kwargs.setdefault("open_timeout", 10) + kwargs.setdefault("close_timeout", 5) + async with websockets.connect(conn_url, **kwargs): + await asyncio.sleep(listen_seconds) + + +def _ws_listen_seconds() -> int: + seconds = int(os.environ.get("LARK_WS_LISTEN_SECONDS") or "30") + assert seconds > 0 + return seconds + + +def _assert_provider_received_aud(provider: ModeEnvProvider, expected_aud: str): + assert expected_aud in provider.auds + + +@pytest.mark.slow +def test_live_provider_takes_precedence_over_app_secret(): + client, provider = _build_client_with_app_secret_and_provider("zti") + + token = TokenManager.get_self_tenant_token(client.config) + + assert token + _assert_provider_received_aud(provider, extract_aud_from_url(_domains().oauth_base_url)) + + +@pytest.mark.slow +def test_live_app_secret_tenant_token_and_tat_message(): + client = _build_app_secret_client() + + token = TokenManager.get_self_tenant_token(client.config) + message_id = _send_tat_message(client, _case_id("SECRET-TAT")) + + assert token + assert message_id + + +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["zti", "gdpr"]) +def test_live_client_assertion_tenant_token_and_tat_message(mode): + client, provider = _build_client_assertion_client(mode) + + token = TokenManager.get_self_tenant_token(client.config) + message_id = _send_tat_message(client, _case_id(mode.upper() + "-TAT")) + + assert token + assert message_id + _assert_provider_received_aud(provider, extract_aud_from_url(_domains().oauth_base_url)) + + +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["zti", "gdpr"]) +def test_live_client_assertion_oauth_uat_authorization_code_refresh_and_basic_batch(mode): + client, provider = _build_client_assertion_client(mode) + + token = _authorize_and_exchange_code(client) + if not token.access_token: + pytest.fail("authorization code exchange did not return access_token") + _call_basic_batch_with_uat(client, token.access_token) + + if not token.refresh_token: + pytest.fail("authorization code exchange did not return refresh_token; verify offline_access is enabled") + refreshed = client.access_token.refresh( + refresh_token=token.refresh_token, + scope=os.environ.get("LARK_OAUTH_SCOPE"), + ) + if not refreshed.access_token: + pytest.fail("refresh token exchange did not return access_token") + _call_basic_batch_with_uat(client, refreshed.access_token) + _assert_provider_received_aud(provider, extract_aud_from_url(_domains().oauth_base_url)) + + +@pytest.mark.slow +def test_live_app_secret_ws_real_connect(): + if not parse_bool(os.environ.get("LARK_WS_CONNECT_E2E"), default=False): + pytest.skip("set LARK_WS_CONNECT_E2E=true to run real WS connect") + env = _require_env(["LARK_APP_ID", "LARK_APP_SECRET"]) + domains = _domains() + client = ws_client.Client( + env["LARK_APP_ID"], + env["LARK_APP_SECRET"], + domain=domains.openapi_domain, + auto_reconnect=False, + ) + + conn_url = client._get_conn_url() + asyncio.run(_connect_ws_once(conn_url, _ws_listen_seconds())) + + +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["zti", "gdpr"]) +def test_live_client_assertion_ws_real_connect(mode): + if not parse_bool(os.environ.get("LARK_WS_CONNECT_E2E"), default=False): + pytest.skip("set LARK_WS_CONNECT_E2E=true to run real WS connect") + env = _require_env(["LARK_APP_ID"]) + if mode == "zti": + _require_env(["LARK_ZTI_CLIENT_ASSERTION"]) + else: + _require_env(["LARK_GDPR_CLIENT_ASSERTION", "LARK_GDPR_PROXY_SERVICE", "LARK_GDPR_PROXY_PREFIX"]) + domains = _domains() + provider = ModeEnvProvider(mode) + client = ws_client.Client( + env["LARK_APP_ID"], + "", + domain=domains.openapi_domain, + auto_reconnect=False, + client_assertion_provider=provider, + ) + + conn_url = client._get_conn_url() + asyncio.run(_connect_ws_once(conn_url, _ws_listen_seconds())) + _assert_provider_received_aud(provider, extract_aud_from_url(domains.openapi_domain)) diff --git a/lark_oapi/core/tests/e2e/test_client_assertion_keyless_local.py b/lark_oapi/core/tests/e2e/test_client_assertion_keyless_local.py new file mode 100644 index 000000000..631a23830 --- /dev/null +++ b/lark_oapi/core/tests/e2e/test_client_assertion_keyless_local.py @@ -0,0 +1,145 @@ +import json +import threading +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + +from lark_oapi import Client +from lark_oapi.core import AccessTokenType, HttpMethod +from lark_oapi.core.cache import ICache +from lark_oapi.core.client_assertion import ClientAssertionToken +from lark_oapi.core.model import BaseRequest +from lark_oapi.ws import client as ws_client + + +class AbsoluteCache(ICache): + def __init__(self): + self.data = {} + + def get(self, key): + item = self.data.get(key) + return None if item is None else item[0] + + def set(self, key, value, expire): + self.data[key] = (value, expire) + + +class RecordingProvider: + def __init__(self): + self.auds = [] + + def retrieve_token(self, aud): + self.auds.append(aud) + return ClientAssertionToken("local-assertion") + + +class LocalState: + def __init__(self): + self.oauth_bodies = [] + self.ping_authorizations = [] + self.ws_bodies = [] + + +def _serve(state): + class Handler(BaseHTTPRequestHandler): + def do_POST(self): + length = int(self.headers.get("Content-Length") or "0") + body = json.loads(self.rfile.read(length) or b"{}") + if self.path == "/oauth/v3/token": + state.oauth_bodies.append(body) + if body["grant_type"] == "authorization_code": + self._json({ + "access_token": "user-token", + "token_type": "Bearer", + "expires_in": 7200, + "refresh_token": "refresh-token", + }) + return + if body["grant_type"] == "refresh_token": + self._json({"access_token": "refreshed-user-token", "expires_in": 7200}) + return + self._json({"access_token": "tenant-token", "expires_in": 7200}) + return + if self.path == "/callback/ws/endpoint": + state.ws_bodies.append(body) + self._json({"code": 0, "data": {"URL": "ws://example.test/callback?device_id=device&service_id=42"}}) + return + self.send_response(404) + self.end_headers() + + def do_GET(self): + if self.path == "/open-apis/mock/v1/ping": + state.ping_authorizations.append(self.headers.get("Authorization")) + self._json({"code": 0, "msg": "ok"}) + return + self.send_response(404) + self.end_headers() + + def _json(self, payload): + data = json.dumps(payload).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + def log_message(self, fmt, *args): + return + + server = ThreadingHTTPServer(("127.0.0.1", 0), Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server + + +def _ping_request(): + req = BaseRequest() + req.http_method = HttpMethod.GET + req.uri = "/open-apis/mock/v1/ping" + req.token_types = {AccessTokenType.TENANT} + return req + + +def test_local_keyless_openapi_access_token_and_ws_e2e(): + state = LocalState() + server = _serve(state) + base_url = f"http://127.0.0.1:{server.server_port}" + provider = RecordingProvider() + cache = AbsoluteCache() + try: + client = ( + Client.builder() + .app_id("cli_local") + .domain(base_url) + .oauth_base_url(base_url) + .client_assertion_provider(provider) + .cache(cache) + .build() + ) + + first = client.request(_ping_request()) + second = client.request(_ping_request()) + + auth_code = client.access_token.retrieve_by_authorization_code(code="code") + refreshed = client.access_token.refresh(refresh_token="refresh-token") + + ws = ws_client.Client("cli_local", "", domain=base_url, client_assertion_provider=provider) + conn_url = ws._get_conn_url() + + assert first.code == 0 + assert second.code == 0 + assert state.ping_authorizations == ["Bearer tenant-token", "Bearer tenant-token"] + assert auth_code.access_token == "user-token" + assert refreshed.access_token == "refreshed-user-token" + assert conn_url == "ws://example.test/callback?device_id=device&service_id=42" + assert provider.auds == [ + f"127.0.0.1:{server.server_port}", + f"127.0.0.1:{server.server_port}", + f"127.0.0.1:{server.server_port}", + f"127.0.0.1:{server.server_port}", + ] + assert len(state.oauth_bodies) == 3 + assert state.oauth_bodies[0]["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert state.oauth_bodies[0]["client_assertion"] == "local-assertion" + assert state.ws_bodies == [{"AppID": "cli_local", "AppSecret": "", "ClientAssertion": "local-assertion"}] + finally: + server.shutdown() + server.server_close() diff --git a/lark_oapi/core/tests/test_client_assertion_access_token.py b/lark_oapi/core/tests/test_client_assertion_access_token.py new file mode 100644 index 000000000..be81c1a4d --- /dev/null +++ b/lark_oapi/core/tests/test_client_assertion_access_token.py @@ -0,0 +1,179 @@ +import json +from types import SimpleNamespace + +import pytest + +from lark_oapi import Client +from lark_oapi.core.client_assertion import ClientAssertionToken, TargetInfo +from lark_oapi.core.exception import AccessTokenException, ClientAssertionException + + +class RecordingProvider: + def __init__(self, token=None): + self.token = token or ClientAssertionToken("client-assertion") + self.calls = [] + + def retrieve_token(self, aud): + self.calls.append(aud) + return self.token + + +def _response(payload, status=200): + return SimpleNamespace(status_code=status, headers={"Content-Type": "application/json"}, content=json.dumps(payload).encode()) + + +def _client(provider=None, app_secret="", oauth_base_url="https://accounts.feishu.cn"): + builder = Client.builder().app_id("cli_a").oauth_base_url(oauth_base_url) + if provider is not None: + builder.client_assertion_provider(provider) + if app_secret: + builder.app_secret(app_secret) + return builder.build() + + +def test_access_token_authorization_code_with_client_assertion(monkeypatch): + provider = RecordingProvider() + client = _client(provider) + captured = {} + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + captured["url"] = url + captured["body"] = json.loads(data.decode()) + return _response({"access_token": "user-token", "token_type": "Bearer", "expires_in": 7200}) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + resp = client.access_token.retrieve_by_authorization_code( + code="code", + redirect_uri="https://example.com/cb", + code_verifier="verifier", + ) + + assert resp.access_token == "user-token" + assert captured["url"] == "https://accounts.feishu.cn/oauth/v3/token" + assert captured["body"]["grant_type"] == "authorization_code" + assert captured["body"]["client_id"] == "cli_a" + assert captured["body"]["client_assertion"] == "client-assertion" + assert captured["body"]["client_assertion_type"] == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + assert captured["body"]["code"] == "code" + assert captured["body"]["redirect_uri"] == "https://example.com/cb" + assert captured["body"]["code_verifier"] == "verifier" + assert provider.calls == ["accounts.feishu.cn"] + + +def test_access_token_refresh_with_client_assertion(monkeypatch): + client = _client(RecordingProvider()) + captured = {} + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + captured["body"] = json.loads(data.decode()) + return _response({"access_token": "user-token", "expires_in": 7200}) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + assert client.access_token.refresh(refresh_token="refresh-token").access_token == "user-token" + assert captured["body"]["grant_type"] == "refresh_token" + assert captured["body"]["refresh_token"] == "refresh-token" + assert captured["body"]["client_assertion"] == "client-assertion" + + +def test_access_token_authorization_code_with_app_secret_fallback(monkeypatch): + client = _client(provider=None, app_secret="app-secret") + captured = {} + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + captured["body"] = json.loads(data.decode()) + return _response({"access_token": "user-token", "expires_in": 7200}) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + client.access_token.retrieve_by_authorization_code(code="code") + + assert captured["body"]["client_secret"] == "app-secret" + assert "client_assertion" not in captured["body"] + assert "client_assertion_type" not in captured["body"] + + +def test_access_token_refresh_with_app_secret_fallback(monkeypatch): + client = _client(provider=None, app_secret="app-secret") + captured = {} + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + captured["body"] = json.loads(data.decode()) + return _response({"access_token": "user-token", "expires_in": 7200}) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + client.access_token.refresh(refresh_token="refresh-token") + + assert captured["body"]["client_secret"] == "app-secret" + assert captured["body"]["refresh_token"] == "refresh-token" + + +def test_access_token_rejects_missing_credentials(): + client = _client(provider=None, app_secret="") + + with pytest.raises(ClientAssertionException) as err: + client.access_token.retrieve_by_authorization_code(code="code") + + assert err.value.code == 7104 + + +def test_access_token_returns_access_token_exception_for_non_200(monkeypatch): + client = _client(RecordingProvider()) + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + return _response({ + "code": 20001, + "error": "invalid_client", + "error_description": "client assertion invalid", + }, status=401) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + with pytest.raises(AccessTokenException) as err: + client.access_token.retrieve_by_authorization_code(code="code") + + assert err.value.status_code == 401 + assert err.value.code == 20001 + assert err.value.error == "invalid_client" + assert err.value.error_description == "client assertion invalid" + + +def test_access_token_proxy_keeps_custom_headers(monkeypatch): + provider = RecordingProvider( + ClientAssertionToken( + "client-assertion", + TargetInfo(target_service="proxy.example.com", target_prefix="/proxy"), + ) + ) + client = _client(provider) + captured = {} + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + captured["url"] = url + captured["headers"] = headers + return _response({"access_token": "user-token", "expires_in": 7200}) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + client.access_token.retrieve_by_authorization_code( + code="code", + headers={"X-Custom": "custom-value"}, + ) + + assert captured["url"] == "https://proxy.example.com/proxy/oauth/v3/token" + assert captured["headers"]["X-Custom"] == "custom-value" + assert captured["headers"]["X-Target-Service"] == "accounts.feishu.cn" diff --git a/lark_oapi/core/tests/test_client_assertion_auth.py b/lark_oapi/core/tests/test_client_assertion_auth.py new file mode 100644 index 000000000..4f3d3c4c8 --- /dev/null +++ b/lark_oapi/core/tests/test_client_assertion_auth.py @@ -0,0 +1,95 @@ +import pytest + +from lark_oapi.core import AccessTokenType, AppType +from lark_oapi.core.client_assertion import ClientAssertionToken +from lark_oapi.core.exception import ClientAssertionException, NoAuthorizationException +from lark_oapi.core.model import BaseRequest, Config, RequestOption +from lark_oapi.core.token import TokenManager, verify + + +class RecordingProvider: + def __init__(self): + self.calls = [] + + def retrieve_token(self, aud): + self.calls.append(aud) + return ClientAssertionToken("assertion") + + +def _request(*token_types): + req = BaseRequest() + req.token_types = set(token_types) + return req + + +def _config(provider=None): + config = Config() + config.app_id = "cli_a" + config.app_secret = "" + config.client_assertion_provider = provider + return config + + +def test_verify_client_assertion_prefers_tenant_over_app(monkeypatch): + provider = RecordingProvider() + config = _config(provider) + option = RequestOption() + req = _request(AccessTokenType.APP, AccessTokenType.TENANT) + + monkeypatch.setattr(TokenManager, "get_self_tenant_token", staticmethod(lambda conf: "tenant-token")) + + verify(config, req, option) + + assert req.token_types == {AccessTokenType.TENANT} + assert option.tenant_access_token == "tenant-token" + + +def test_verify_client_assertion_manual_user_token_wins(monkeypatch): + provider = RecordingProvider() + config = _config(provider) + option = RequestOption() + option.user_access_token = "user-token" + req = _request(AccessTokenType.TENANT, AccessTokenType.USER) + + def fail_if_called(conf): + raise AssertionError("tenant token should not be requested") + + monkeypatch.setattr(TokenManager, "get_self_tenant_token", staticmethod(fail_if_called)) + + verify(config, req, option) + + assert req.token_types == {AccessTokenType.USER} + assert provider.calls == [] + + +def test_verify_client_assertion_rejects_app_only(): + config = _config(RecordingProvider()) + req = _request(AccessTokenType.APP) + + with pytest.raises(ClientAssertionException) as err: + verify(config, req, RequestOption()) + + assert err.value.code == 7103 + + +def test_verify_client_assertion_rejects_isv(): + config = _config(RecordingProvider()) + config.app_type = AppType.ISV + req = _request(AccessTokenType.TENANT) + + with pytest.raises(ClientAssertionException) as err: + verify(config, req, RequestOption()) + + assert err.value.code == 7100 + + +def test_verify_app_secret_mode_still_requires_app_secret(): + config = Config() + config.app_id = "cli_a" + config.app_secret = "" + req = _request(AccessTokenType.TENANT) + + with pytest.raises(NoAuthorizationException) as err: + verify(config, req, RequestOption()) + + assert err.value.msg == "app_id or app_secret not found" diff --git a/lark_oapi/core/tests/test_client_assertion_core.py b/lark_oapi/core/tests/test_client_assertion_core.py new file mode 100644 index 000000000..7420b8614 --- /dev/null +++ b/lark_oapi/core/tests/test_client_assertion_core.py @@ -0,0 +1,74 @@ +import pytest + +from lark_oapi.core.const import ( + CLIENT_ASSERTION_TYPE_JWT_BEARER, + ERR_CODE_APP_SECRET_AND_CLIENT_ASSERTION_EMPTY, + ERR_CODE_CLIENT_ASSERTION_MODE_NOT_SUPPORTED, + ERR_CODE_CLIENT_ASSERTION_PROVIDER_NOT_CONFIGURED, + ERR_CODE_CLIENT_ASSERTION_RETRIEVE_FAILED, + ERR_CODE_CLIENT_ASSERTION_TOKEN_EMPTY, + FEISHU_OAUTH_DOMAIN, + GRANT_TYPE_JWT_BEARER, + LARK_OAUTH_DOMAIN, + OAUTH_TOKEN_URI, + X_TARGET_SERVICE, +) +from lark_oapi.core.client_assertion import ( + TargetInfo, + build_proxy_url, + resolve_oauth_aud, + resolve_oauth_base_url, +) +from lark_oapi.core.model import Config + + +def test_client_assertion_constants_match_go_sdk(): + assert FEISHU_OAUTH_DOMAIN == "https://accounts.feishu.cn" + assert LARK_OAUTH_DOMAIN == "https://accounts.larksuite.com" + assert OAUTH_TOKEN_URI == "/oauth/v3/token" + assert GRANT_TYPE_JWT_BEARER == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert CLIENT_ASSERTION_TYPE_JWT_BEARER == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + assert X_TARGET_SERVICE == "X-Target-Service" + assert ERR_CODE_CLIENT_ASSERTION_PROVIDER_NOT_CONFIGURED == 7100 + assert ERR_CODE_CLIENT_ASSERTION_TOKEN_EMPTY == 7101 + assert ERR_CODE_CLIENT_ASSERTION_RETRIEVE_FAILED == 7102 + assert ERR_CODE_CLIENT_ASSERTION_MODE_NOT_SUPPORTED == 7103 + assert ERR_CODE_APP_SECRET_AND_CLIENT_ASSERTION_EMPTY == 7104 + + +def test_resolve_oauth_base_url_default_feishu(): + config = Config() + config.domain = "https://open.feishu.cn" + + assert resolve_oauth_base_url(config) == "https://accounts.feishu.cn" + assert resolve_oauth_aud(config) == "accounts.feishu.cn" + + +def test_resolve_oauth_base_url_default_lark(): + config = Config() + config.domain = "https://open.larksuite.com" + + assert resolve_oauth_base_url(config) == "https://accounts.larksuite.com" + assert resolve_oauth_aud(config) == "accounts.larksuite.com" + + +def test_resolve_oauth_base_url_explicit_localhost(): + config = Config() + config.oauth_base_url = "http://127.0.0.1:18080/" + + assert resolve_oauth_base_url(config) == "http://127.0.0.1:18080" + assert resolve_oauth_aud(config) == "127.0.0.1:18080" + + +def test_resolve_oauth_base_url_requires_explicit_for_custom_domain(): + config = Config() + config.domain = "https://open.feishu-boe.cn" + + with pytest.raises(ValueError, match="OAuthBaseUrl is not configured"): + resolve_oauth_base_url(config) + + +def test_build_proxy_url_adds_https_when_scheme_missing(): + target_info = TargetInfo(target_service="proxy.example.com", target_prefix="/proxy") + + assert build_proxy_url(target_info, "/oauth/v3/token") == "https://proxy.example.com/proxy/oauth/v3/token" diff --git a/lark_oapi/core/tests/test_client_assertion_live_harness.py b/lark_oapi/core/tests/test_client_assertion_live_harness.py new file mode 100644 index 000000000..9777a39d8 --- /dev/null +++ b/lark_oapi/core/tests/test_client_assertion_live_harness.py @@ -0,0 +1,124 @@ +from pathlib import Path + +import pytest + +from lark_oapi.core.client_assertion import TargetInfo +from lark_oapi.core.tests.e2e.client_assertion_live_harness import ( + BOE_DOMAINS, + ONLINE_DOMAINS, + ModeEnvProvider, + build_host_override_getaddrinfo, + build_authorize_url, + decode_jwt_payload_unverified, + deploy_domains, + load_env_file, + parse_bool, +) + + +def test_parse_bool_accepts_common_values(): + assert parse_bool("true") is True + assert parse_bool("1") is True + assert parse_bool("yes") is True + assert parse_bool("false") is False + assert parse_bool("0") is False + assert parse_bool(None, default=True) is True + + +def test_deploy_domains_supports_online_and_boe(): + assert deploy_domains("online") == ONLINE_DOMAINS + assert deploy_domains("cn") == ONLINE_DOMAINS + assert deploy_domains("boe") == BOE_DOMAINS + + with pytest.raises(ValueError, match="LARK_DEPLOY_ENV"): + deploy_domains("lark") + + +def test_load_env_file_preserves_existing_env_and_handles_quotes(tmp_path: Path): + env_file = tmp_path / ".env.e2e" + env_file.write_text( + "\n".join( + [ + "# comment", + "export LARK_APP_ID=cli_file", + 'LARK_OAUTH_SCOPE="contact:user.basic_profile:readonly offline_access"', + "LARK_APP_SECRET=from_file # inline comment", + ] + ), + encoding="utf-8", + ) + env = {"LARK_APP_ID": "cli_existing"} + + load_env_file(env_file, env=env) + + assert env["LARK_APP_ID"] == "cli_existing" + assert env["LARK_OAUTH_SCOPE"] == "contact:user.basic_profile:readonly offline_access" + assert env["LARK_APP_SECRET"] == "from_file" + + +def test_mode_env_provider_returns_zti_without_target_info(): + env = {"LARK_ZTI_CLIENT_ASSERTION": "zti-token"} + provider = ModeEnvProvider("zti", env=env) + + token = provider.retrieve_token("accounts.feishu.cn") + + assert token.value == "zti-token" + assert token.target_info is None + assert provider.auds == ["accounts.feishu.cn"] + + +def test_mode_env_provider_returns_gdpr_with_target_info(): + env = { + "LARK_GDPR_CLIENT_ASSERTION": "gdpr-token", + "LARK_GDPR_PROXY_SERVICE": "gdpr-proxy.example.internal", + "LARK_GDPR_PROXY_PREFIX": "/proxy/example", + } + provider = ModeEnvProvider("gdpr", env=env) + + token = provider.retrieve_token("open.feishu.cn") + + assert token.value == "gdpr-token" + assert token.target_info == TargetInfo("gdpr-proxy.example.internal", "/proxy/example") + assert provider.auds == ["open.feishu.cn"] + + +def test_build_authorize_url_uses_local_redirect_without_pkce(): + url = build_authorize_url( + oauth_base_url="https://accounts.feishu.cn", + app_id="cli_xxx", + redirect_uri="http://127.0.0.1:8765/uat_e2e/callback", + scope="contact:user.basic_profile:readonly offline_access", + state="state-123", + ) + + assert url.startswith("https://accounts.feishu.cn/open-apis/authen/v1/authorize?") + assert "app_id=cli_xxx" in url + assert "redirect_uri=http%3A%2F%2F127.0.0.1%3A8765%2Fuat_e2e%2Fcallback" in url + assert "scope=contact%3Auser.basic_profile%3Areadonly+offline_access" in url + assert "state=state-123" in url + assert "code_challenge" not in url + + +def test_decode_jwt_payload_unverified_decodes_payload(): + token = "header.eyJleHAiOjEyMywiYXVkIjpbImFjY291bnRzLmZlaXNodS5jbiJdfQ.signature" + + payload = decode_jwt_payload_unverified(token) + + assert payload == {"exp": 123, "aud": ["accounts.feishu.cn"]} + + +def test_build_host_override_getaddrinfo_only_rewrites_target_host(): + calls = [] + + def fake_getaddrinfo(host, port, family, type_arg, proto_arg, flags_arg): + calls.append((host, port, family, type_arg, proto_arg, flags_arg)) + return [("result", host, family)] + + resolver = build_host_override_getaddrinfo( + fake_getaddrinfo, + "gdpr-proxy.example.internal", + "192.0.2.10", + ) + + assert resolver("gdpr-proxy.example.internal", 443, 0, 1, 2, 3) == [("result", "192.0.2.10", 2)] + assert resolver("open.feishu.cn", 443, 0, 1, 2, 3) == [("result", "open.feishu.cn", 0)] diff --git a/lark_oapi/core/tests/test_client_assertion_token_manager.py b/lark_oapi/core/tests/test_client_assertion_token_manager.py new file mode 100644 index 000000000..688f33d97 --- /dev/null +++ b/lark_oapi/core/tests/test_client_assertion_token_manager.py @@ -0,0 +1,174 @@ +import json +from types import SimpleNamespace + +import pytest + +from lark_oapi.core.client_assertion import ClientAssertionToken, TargetInfo +from lark_oapi.core.exception import ClientAssertionException +from lark_oapi.core.model import Config +from lark_oapi.core.token import TokenManager + + +class DictCache: + def __init__(self): + self.data = {} + self.expires = {} + + def get(self, key): + return self.data.get(key) + + def set(self, key, value, expire): + self.data[key] = value + self.expires[key] = expire + + +_DEFAULT = object() + + +class RecordingProvider: + def __init__(self, token=_DEFAULT, err=None): + self.token = ClientAssertionToken("client-assertion") if token is _DEFAULT else token + self.err = err + self.calls = [] + + def retrieve_token(self, aud): + self.calls.append(aud) + if self.err: + raise self.err + return self.token + + +def _config(provider): + config = Config() + config.app_id = "cli_a" + config.app_secret = "" + config.domain = "https://open.feishu.cn" + config.oauth_base_url = "https://accounts.feishu.cn" + config.client_assertion_provider = provider + return config + + +def _response(payload, status=200): + return SimpleNamespace(status_code=status, headers={"Content-Type": "application/json"}, content=json.dumps(payload).encode()) + + +def test_get_self_tenant_token_by_client_assertion_requests_oauth_token(monkeypatch): + provider = RecordingProvider() + config = _config(provider) + cache = DictCache() + monkeypatch.setattr(TokenManager, "cache", cache) + captured = {} + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + captured["method"] = method + captured["url"] = url + captured["headers"] = headers + captured["body"] = json.loads(data.decode()) + return _response({"access_token": "tenant-token", "expires_in": 7200}) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + token = TokenManager.get_self_tenant_token(config) + + assert token == "tenant-token" + assert captured["url"] == "https://accounts.feishu.cn/oauth/v3/token" + assert captured["body"] == { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": "client-assertion", + "client_id": "cli_a", + } + assert provider.calls == ["accounts.feishu.cn"] + assert cache.data["self_tenant_token:cli_a"] == "tenant-token" + + +def test_get_self_tenant_token_by_client_assertion_cache_hit_skips_provider(monkeypatch): + provider = RecordingProvider() + config = _config(provider) + cache = DictCache() + cache.data["self_tenant_token:cli_a"] = "cached-token" + monkeypatch.setattr(TokenManager, "cache", cache) + + assert TokenManager.get_self_tenant_token(config) == "cached-token" + assert provider.calls == [] + + +def test_get_self_tenant_token_by_client_assertion_with_proxy(monkeypatch): + provider = RecordingProvider( + ClientAssertionToken( + "client-assertion", + TargetInfo(target_service="proxy.example.com", target_prefix="/proxy"), + ) + ) + config = _config(provider) + cache = DictCache() + monkeypatch.setattr(TokenManager, "cache", cache) + captured = {} + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + captured["url"] = url + captured["headers"] = headers + return _response({"access_token": "tenant-token", "expires_in": 7200}) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + assert TokenManager.get_self_tenant_token(config) == "tenant-token" + assert captured["url"] == "https://proxy.example.com/proxy/oauth/v3/token" + assert captured["headers"]["X-Target-Service"] == "accounts.feishu.cn" + + +@pytest.mark.parametrize("token", [None, ClientAssertionToken("")]) +def test_get_self_tenant_token_by_client_assertion_empty_token(monkeypatch, token): + config = _config(RecordingProvider(token=token)) + monkeypatch.setattr(TokenManager, "cache", DictCache()) + + with pytest.raises(ClientAssertionException) as err: + TokenManager.get_self_tenant_token(config) + + assert err.value.code == 7101 + + +def test_get_self_tenant_token_by_client_assertion_provider_error(monkeypatch): + config = _config(RecordingProvider(err=RuntimeError("boom"))) + monkeypatch.setattr(TokenManager, "cache", DictCache()) + + with pytest.raises(ClientAssertionException) as err: + TokenManager.get_self_tenant_token(config) + + assert err.value.code == 7102 + assert "boom" in err.value.msg + + +def test_get_self_tenant_token_by_client_assertion_oauth_error_message_priority(monkeypatch): + config = _config(RecordingProvider()) + monkeypatch.setattr(TokenManager, "cache", DictCache()) + + def fake_request(method, url, headers=None, params=None, data=None, timeout=None): + return _response({ + "code": 20001, + "error": "invalid_client", + "error_description": "client assertion invalid", + }) + + import lark_oapi.core.http.transport as transport + + monkeypatch.setattr(transport.requests, "request", fake_request) + + with pytest.raises(ClientAssertionException) as err: + TokenManager.get_self_tenant_token(config) + + assert err.value.code == 20001 + assert err.value.msg == "client assertion invalid" + + +def test_get_self_app_token_blocked_in_client_assertion_mode(): + config = _config(RecordingProvider()) + + with pytest.raises(ClientAssertionException) as err: + TokenManager.get_self_app_token(config) + + assert err.value.code == 7100 diff --git a/lark_oapi/core/tests/test_transport_absolute_url.py b/lark_oapi/core/tests/test_transport_absolute_url.py new file mode 100644 index 000000000..fff8a03f2 --- /dev/null +++ b/lark_oapi/core/tests/test_transport_absolute_url.py @@ -0,0 +1,31 @@ +from lark_oapi.core.http.transport import _build_url + + +def test_build_url_keeps_absolute_http_url(): + url = _build_url( + "https://open.feishu.cn", + "http://127.0.0.1:18080/oauth/v3/token", + {}, + ) + + assert url == "http://127.0.0.1:18080/oauth/v3/token" + + +def test_build_url_keeps_absolute_https_url(): + url = _build_url( + "https://open.feishu.cn", + "https://accounts.feishu.cn/oauth/v3/token", + {}, + ) + + assert url == "https://accounts.feishu.cn/oauth/v3/token" + + +def test_build_url_relative_path_unchanged(): + url = _build_url( + "https://open.feishu.cn", + "/open-apis/mock/v1/ping", + {}, + ) + + assert url == "https://open.feishu.cn/open-apis/mock/v1/ping" diff --git a/lark_oapi/core/tests/test_transport_log_redaction.py b/lark_oapi/core/tests/test_transport_log_redaction.py new file mode 100644 index 000000000..0ec9cc814 --- /dev/null +++ b/lark_oapi/core/tests/test_transport_log_redaction.py @@ -0,0 +1,135 @@ +from types import SimpleNamespace + +import pytest + +from lark_oapi.core import AccessTokenType, HttpMethod +from lark_oapi.core.http import transport +from lark_oapi.core.json import JSON +from lark_oapi.core.model import BaseRequest, Config, RequestOption + + +def _request(body): + req = BaseRequest() + req.http_method = HttpMethod.POST + req.uri = "/open-apis/mock" + req.token_types = {AccessTokenType.TENANT} + req.body = body + return req + + +def test_execute_omits_sensitive_headers_queries_and_body_from_debug_log(monkeypatch): + captured = {} + debug_logs = [] + body = { + "client_assertion": "assertion-secret", + "client_secret": "client-secret", + "nested": {"refresh_token": "refresh-secret"}, + "items": [{"AppSecret": "app-secret"}, {"ClientAssertion": "ws-assertion"}], + } + + def fake_request(method, url, *, headers=None, params=None, data=None, timeout=None): + captured["headers"] = dict(headers) + captured["params"] = list(params) + captured["data"] = data + return SimpleNamespace(status_code=200, headers={}, content=b"{}") + + monkeypatch.setattr(transport.requests, "request", fake_request) + monkeypatch.setattr(transport.logger, "debug", lambda msg: debug_logs.append(msg)) + + conf = Config() + req = _request(body) + req.add_query("access_token", "query-token-secret") + option = RequestOption() + option.tenant_access_token = "tenant-token-secret" + option.headers = {"X-Api-Token": "header-token-secret"} + + transport.Transport.execute(conf, req, option) + + log_output = "\n".join(debug_logs) + for secret in ( + "assertion-secret", + "client-secret", + "refresh-secret", + "app-secret", + "ws-assertion", + "query-token-secret", + "tenant-token-secret", + "header-token-secret", + ): + assert secret not in log_output + for key in ( + "Authorization", + "X-Api-Token", + "access_token", + "client_assertion", + "client_secret", + "refresh_token", + "AppSecret", + "ClientAssertion", + ): + assert key not in log_output + assert "headers_count:" in log_output + assert "params_count: 1" in log_output + assert "body_present: True" in log_output + + assert captured["headers"]["Authorization"] == "Bearer tenant-token-secret" + assert captured["headers"]["X-Api-Token"] == "header-token-secret" + assert captured["params"] == [("access_token", "query-token-secret")] + assert JSON.unmarshal(captured["data"].decode("utf-8"), dict)["client_assertion"] == "assertion-secret" + + +@pytest.mark.asyncio +async def test_aexecute_omits_sensitive_headers_queries_and_body_from_debug_log(monkeypatch): + captured = {} + debug_logs = [] + body = {"client_assertion": "async-assertion-secret", "client_secret": "async-client-secret"} + + class FakeAsyncClient: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def request(self, method, url, *, headers=None, params=None, json=None, data=None, files=None, + timeout=None): + captured["headers"] = dict(headers) + captured["params"] = list(params) + captured["json"] = json + return SimpleNamespace(status_code=200, headers={}, content=b"{}") + + monkeypatch.setattr(transport.httpx, "AsyncClient", FakeAsyncClient) + monkeypatch.setattr(transport.logger, "debug", lambda msg: debug_logs.append(msg)) + + conf = Config() + req = _request(body) + req.token_types = {AccessTokenType.USER} + req.add_query("refresh_token", "async-query-refresh-secret") + option = RequestOption() + option.user_access_token = "user-token-secret" + option.headers = {"X-Password": "async-header-password"} + + await transport.Transport.aexecute(conf, req, option) + + log_output = "\n".join(debug_logs) + assert "async-assertion-secret" not in log_output + assert "async-client-secret" not in log_output + assert "async-query-refresh-secret" not in log_output + assert "user-token-secret" not in log_output + assert "async-header-password" not in log_output + for key in ( + "Authorization", + "X-Password", + "refresh_token", + "client_assertion", + "client_secret", + ): + assert key not in log_output + assert "headers_count:" in log_output + assert "params_count: 1" in log_output + assert "body_present: True" in log_output + + assert captured["headers"]["Authorization"] == "Bearer user-token-secret" + assert captured["headers"]["X-Password"] == "async-header-password" + assert captured["params"] == [("refresh_token", "async-query-refresh-secret")] + assert captured["json"]["client_assertion"] == "async-assertion-secret" diff --git a/lark_oapi/core/token/auth.py b/lark_oapi/core/token/auth.py index 13836c352..2898a78bd 100644 --- a/lark_oapi/core/token/auth.py +++ b/lark_oapi/core/token/auth.py @@ -1,4 +1,8 @@ -from lark_oapi.core.exception import NoAuthorizationException +from lark_oapi.core.const import ( + ERR_CODE_CLIENT_ASSERTION_MODE_NOT_SUPPORTED, + ERR_CODE_CLIENT_ASSERTION_PROVIDER_NOT_CONFIGURED, +) +from lark_oapi.core.exception import ClientAssertionException, NoAuthorizationException from lark_oapi.core.model import * from lark_oapi.core.utils import Strings from .manager import TokenManager @@ -9,6 +13,28 @@ def verify(config: Config, request: BaseRequest, option: RequestOption) -> None: if len(request.token_types) == 0: return + if config.client_assertion_provider is not None: + if Strings.is_empty(config.app_id): + raise NoAuthorizationException("app_id not found") + if AppType.ISV == config.app_type: + raise ClientAssertionException( + ERR_CODE_CLIENT_ASSERTION_PROVIDER_NOT_CONFIGURED, + "ClientAssertion mode is not supported for ISV apps", + ) + if Strings.is_not_empty(option.user_access_token) and AccessTokenType.USER in request.token_types: + request.token_types = {AccessTokenType.USER} + return + if AccessTokenType.TENANT in request.token_types: + option.tenant_access_token = TokenManager.get_self_tenant_token(config) + request.token_types = {AccessTokenType.TENANT} + return + if AccessTokenType.APP in request.token_types: + raise ClientAssertionException( + ERR_CODE_CLIENT_ASSERTION_MODE_NOT_SUPPORTED, + "AppAccessToken APIs are not available in ClientAssertion mode", + ) + return + # 如开启token配置,需手动传入token if config.enable_set_token: if Strings.is_not_empty(option.tenant_access_token) and AccessTokenType.TENANT in request.token_types: diff --git a/lark_oapi/core/token/manager.py b/lark_oapi/core/token/manager.py index f637a70c3..0d5491223 100644 --- a/lark_oapi/core/token/manager.py +++ b/lark_oapi/core/token/manager.py @@ -1,9 +1,25 @@ +import json +import time + from lark_oapi.core import JSON, Strings from lark_oapi.core.cache import * -from lark_oapi.core.const import UTF_8 -from lark_oapi.core.exception import ObtainAccessTokenException +from lark_oapi.core.client_assertion import build_proxy_url, resolve_oauth_aud, resolve_oauth_base_url +from lark_oapi.core.const import ( + APPLICATION_JSON, + CLIENT_ASSERTION_TYPE_JWT_BEARER, + CONTENT_TYPE, + ERR_CODE_CLIENT_ASSERTION_PROVIDER_NOT_CONFIGURED, + ERR_CODE_CLIENT_ASSERTION_RETRIEVE_FAILED, + ERR_CODE_CLIENT_ASSERTION_TOKEN_EMPTY, + GRANT_TYPE_JWT_BEARER, + OAUTH_TOKEN_URI, + UTF_8, + X_TARGET_SERVICE, +) +from lark_oapi.core.exception import ClientAssertionException, ObtainAccessTokenException from lark_oapi.core.http import Transport -from lark_oapi.core.model import Config, RawResponse +from lark_oapi.core.model import BaseRequest, Config, RawResponse, RequestOption +from lark_oapi.core.enum import HttpMethod from .access_token_response import AccessTokenResponse from .create_isv_app_token_request import CreateIsvAppTokenRequest from .create_isv_tenant_token_request import CreateIsvTenantTokenRequest @@ -17,6 +33,12 @@ class TokenManager(object): @staticmethod def get_self_app_token(conf: Config) -> str: + if conf.client_assertion_provider is not None: + raise ClientAssertionException( + ERR_CODE_CLIENT_ASSERTION_PROVIDER_NOT_CONFIGURED, + "AppAccessToken is not available in ClientAssertion mode", + ) + # 读缓存 cache_key = f"self_app_token:{conf.app_id}" token = TokenManager.cache.get(cache_key) @@ -45,6 +67,9 @@ def get_self_app_token(conf: Config) -> str: @staticmethod def get_self_tenant_token(config: Config) -> str: + if config.client_assertion_provider is not None: + return TokenManager._get_self_tenant_token_by_client_assertion(config) + # 读缓存 cache_key = f"self_tenant_token:{config.app_id}" token = TokenManager.cache.get(cache_key) @@ -71,6 +96,49 @@ def get_self_tenant_token(config: Config) -> str: return token + @staticmethod + def _get_self_tenant_token_by_client_assertion(config: Config) -> str: + cache_key = f"self_tenant_token:{config.app_id}" + token = TokenManager.cache.get(cache_key) + if Strings.is_not_empty(token): + return token + + oauth_base_url = resolve_oauth_base_url(config) + aud = resolve_oauth_aud(config) + try: + assertion_token = config.client_assertion_provider.retrieve_token(aud) + except Exception as e: + raise ClientAssertionException(ERR_CODE_CLIENT_ASSERTION_RETRIEVE_FAILED, str(e)) + if assertion_token is None or Strings.is_empty(assertion_token.value): + raise ClientAssertionException(ERR_CODE_CLIENT_ASSERTION_TOKEN_EMPTY, "client assertion token is empty") + + req = BaseRequest() + req.http_method = HttpMethod.POST + req.uri = oauth_base_url + OAUTH_TOKEN_URI + req.headers = {CONTENT_TYPE: APPLICATION_JSON} + req.body = { + "grant_type": GRANT_TYPE_JWT_BEARER, + "client_assertion_type": CLIENT_ASSERTION_TYPE_JWT_BEARER, + "client_assertion": assertion_token.value, + "client_id": config.app_id, + } + option = RequestOption() + if assertion_token.target_info is not None: + req.uri = build_proxy_url(assertion_token.target_info, OAUTH_TOKEN_URI) + option.headers[X_TARGET_SERVICE] = aud + + raw = Transport.execute(config, req, option) + resp = json.loads(str(raw.content, UTF_8)) + access_token = resp.get("access_token") + if Strings.is_empty(access_token): + msg = resp.get("error_description") or resp.get("error") or "oauth token response missing access token" + raise ClientAssertionException(resp.get("code") or 0, msg) + + expires_in = int(resp.get("expires_in") or 0) + expire = time.time() + max(expires_in - 180, 0) + TokenManager.cache.set(cache_key, access_token, int(expire)) + return access_token + @staticmethod def get_isv_app_token(config: Config, app_ticket: str) -> str: # 读缓存 diff --git a/lark_oapi/ws/client.py b/lark_oapi/ws/client.py index 622e842af..8ee991838 100644 --- a/lark_oapi/ws/client.py +++ b/lark_oapi/ws/client.py @@ -2,6 +2,7 @@ import base64 import http import inspect +import json import random import time from typing import Callable, Dict, Mapping, Optional @@ -12,7 +13,8 @@ from websockets.exceptions import InvalidHandshake from lark_oapi.core.cache import ExpiringCache -from lark_oapi.core.const import UTF_8, FEISHU_DOMAIN, USER_AGENT +from lark_oapi.core.client_assertion import build_proxy_url, extract_aud_from_url +from lark_oapi.core.const import UTF_8, FEISHU_DOMAIN, USER_AGENT, X_TARGET_SERVICE from lark_oapi.core.enum import LogLevel from lark_oapi.core.json import JSON from lark_oapi.core.log import logger @@ -122,13 +124,15 @@ def __init__(self, auto_reconnect: bool = True, source: Optional[str] = None, extra_ua_tags: Optional[list] = None, - headers: Optional[Mapping[str, str]] = None) -> None: + headers: Optional[Mapping[str, str]] = None, + client_assertion_provider=None) -> None: self._app_id: str = app_id self._app_secret: str = app_secret self._log_level: LogLevel = log_level self._event_handler: EventDispatcherHandler = event_handler self._auto_reconnect: bool = auto_reconnect self._domain: str = domain + self._client_assertion_provider = client_assertion_provider self._headers: Dict[str, str] = dict(headers or {}) # UA used on the endpoint-discovery POST (and any future HTTP/WS # handshakes from this client). ``extra_ua_tags`` is internal — sub- @@ -226,24 +230,47 @@ async def _receive_message_loop(self): raise e def _get_conn_url(self) -> str: - if Strings.is_empty(self._app_id) or Strings.is_empty(self._app_secret): - raise ClientException(NO_CREDENTIAL, "app_id or app_secret is null") + if Strings.is_empty(self._app_id) or ( + self._client_assertion_provider is None and Strings.is_empty(self._app_secret) + ): + raise ClientException( + NO_CREDENTIAL, + "app_id is required and either app_secret or client_assertion_provider is required", + ) headers = dict(self._headers) headers.update({ "locale": "zh", USER_AGENT: self._user_agent, }) + url = self._domain + GEN_ENDPOINT_URI + body = {"AppID": self._app_id} + if self._client_assertion_provider is not None: + aud = extract_aud_from_url(self._domain) + assertion_token = self._client_assertion_provider.retrieve_token(aud) + if assertion_token is None or Strings.is_empty(assertion_token.value): + raise ClientException(7101, "client assertion token is empty") + body["AppSecret"] = "" + body["ClientAssertion"] = assertion_token.value + if assertion_token.target_info is not None: + url = build_proxy_url(assertion_token.target_info, GEN_ENDPOINT_URI) + headers[X_TARGET_SERVICE] = aud + else: + body["AppSecret"] = self._app_secret + response = requests.post( - self._domain + GEN_ENDPOINT_URI, + url, headers=headers, - json={ - "AppID": self._app_id, - "AppSecret": self._app_secret, - }, + json=body, ) if response.status_code != http.HTTPStatus.OK: - raise ServerException(response.status_code, "system busy") + msg = "system busy" + try: + payload = json.loads(str(response.content, UTF_8)) + msg = payload.get("msg") or msg + except Exception: + pass + raise ServerException(response.status_code, msg) resp = JSON.unmarshal(str(response.content, UTF_8), EndpointResp) if resp.code == OK: diff --git a/lark_oapi/ws/tests/test_client_assertion.py b/lark_oapi/ws/tests/test_client_assertion.py new file mode 100644 index 000000000..9d412f1f3 --- /dev/null +++ b/lark_oapi/ws/tests/test_client_assertion.py @@ -0,0 +1,167 @@ +from types import SimpleNamespace + +import pytest + +from lark_oapi.core.client_assertion import ClientAssertionToken, TargetInfo +from lark_oapi.ws import client as ws_client +from lark_oapi.ws.exception import ClientException, ServerException + + +class RecordingProvider: + def __init__(self, tokens=None, err=None): + self.tokens = list(tokens or [ClientAssertionToken("assertion")]) + self.err = err + self.calls = [] + + def retrieve_token(self, aud): + self.calls.append(aud) + if self.err: + raise self.err + if len(self.tokens) == 1: + return self.tokens[0] + return self.tokens.pop(0) + + +def _ok_response(): + return SimpleNamespace( + status_code=200, + content=b'{"code":0,"data":{"URL":"ws://example.test/callback?device_id=device&service_id=42"}}', + ) + + +def test_ws_get_conn_url_with_app_secret_keeps_existing_behavior(monkeypatch): + captured = {} + + def fake_post(url, *, headers=None, json=None): + captured["url"] = url + captured["headers"] = headers + captured["json"] = json + return _ok_response() + + monkeypatch.setattr(ws_client.requests, "post", fake_post) + client = ws_client.Client("app_id", "app_secret") + + assert client._get_conn_url() == "ws://example.test/callback?device_id=device&service_id=42" + assert captured["json"] == {"AppID": "app_id", "AppSecret": "app_secret"} + + +def test_ws_get_conn_url_with_client_assertion(monkeypatch): + captured = {} + provider = RecordingProvider() + + def fake_post(url, *, headers=None, json=None): + captured["url"] = url + captured["headers"] = headers + captured["json"] = json + return _ok_response() + + monkeypatch.setattr(ws_client.requests, "post", fake_post) + client = ws_client.Client("app_id", "", client_assertion_provider=provider) + + assert client._get_conn_url() == "ws://example.test/callback?device_id=device&service_id=42" + assert captured["json"] == {"AppID": "app_id", "AppSecret": "", "ClientAssertion": "assertion"} + assert provider.calls == ["open.feishu.cn"] + + +def test_ws_get_conn_url_with_client_assertion_proxy(monkeypatch): + captured = {} + provider = RecordingProvider([ + ClientAssertionToken( + "assertion", + TargetInfo(target_service="proxy.example.com", target_prefix="/proxy"), + ) + ]) + + def fake_post(url, *, headers=None, json=None): + captured["url"] = url + captured["headers"] = headers + captured["json"] = json + return _ok_response() + + monkeypatch.setattr(ws_client.requests, "post", fake_post) + client = ws_client.Client( + "app_id", + "", + domain="https://open.feishu.cn", + headers={"X-Target-Service": "caller-value", "X-Custom": "custom-value"}, + client_assertion_provider=provider, + ) + + client._get_conn_url() + + assert captured["url"] == "https://proxy.example.com/proxy/callback/ws/endpoint" + assert captured["headers"]["X-Target-Service"] == "open.feishu.cn" + assert captured["headers"]["X-Custom"] == "custom-value" + assert captured["json"] == {"AppID": "app_id", "AppSecret": "", "ClientAssertion": "assertion"} + + +def test_ws_get_conn_url_retrieves_token_each_time(monkeypatch): + assertions = [] + provider = RecordingProvider([ + ClientAssertionToken("assertion-1"), + ClientAssertionToken("assertion-2"), + ]) + + def fake_post(url, *, headers=None, json=None): + assertions.append(json["ClientAssertion"]) + return _ok_response() + + monkeypatch.setattr(ws_client.requests, "post", fake_post) + client = ws_client.Client("app_id", "", client_assertion_provider=provider) + + client._get_conn_url() + client._get_conn_url() + + assert assertions == ["assertion-1", "assertion-2"] + assert provider.calls == ["open.feishu.cn", "open.feishu.cn"] + + +def test_ws_get_conn_url_empty_client_assertion_token(): + provider = RecordingProvider([ClientAssertionToken("")]) + client = ws_client.Client("app_id", "", client_assertion_provider=provider) + + with pytest.raises(ClientException) as err: + client._get_conn_url() + + assert err.value.code == 7101 + + +def test_ws_get_conn_url_missing_credentials_message(): + client = ws_client.Client("app_id", "") + + with pytest.raises(ClientException) as err: + client._get_conn_url() + + assert str(err.value) == ( + "1000040344: app_id is required and either app_secret or client_assertion_provider is required" + ) + + +def test_ws_provider_error_is_not_wrapped(): + err = RuntimeError("boom") + provider = RecordingProvider(err=err) + client = ws_client.Client("app_id", "", client_assertion_provider=provider) + + with pytest.raises(RuntimeError) as raised: + client._get_conn_url() + + assert raised.value is err + + +def test_ws_non_200_uses_server_msg_when_available(monkeypatch): + provider = RecordingProvider() + + def fake_post(url, *, headers=None, json=None): + return SimpleNamespace( + status_code=500, + content=b'{"code":20050,"msg":"target service unavailable"}', + ) + + monkeypatch.setattr(ws_client.requests, "post", fake_post) + client = ws_client.Client("app_id", "", client_assertion_provider=provider) + + with pytest.raises(ServerException) as err: + client._get_conn_url() + + assert err.value.code == 500 + assert str(err.value) == "500: target service unavailable" diff --git a/samples/client_assertion/access_token_authorization_code_sample.py b/samples/client_assertion/access_token_authorization_code_sample.py new file mode 100644 index 000000000..e1c5ede44 --- /dev/null +++ b/samples/client_assertion/access_token_authorization_code_sample.py @@ -0,0 +1,27 @@ +import os + +import lark_oapi as lark +from lark_oapi.core.client_assertion import ClientAssertionToken + + +class EnvClientAssertionProvider: + def retrieve_token(self, aud: str) -> ClientAssertionToken: + return ClientAssertionToken(os.environ["LARK_CLIENT_ASSERTION"]) + + +builder = ( + lark.Client.builder() + .app_id(os.environ["LARK_APP_ID"]) + .client_assertion_provider(EnvClientAssertionProvider()) +) +if os.environ.get("LARK_OAUTH_BASE_URL"): + builder.oauth_base_url(os.environ["LARK_OAUTH_BASE_URL"]) + +client = builder.build() +resp = client.access_token.retrieve_by_authorization_code( + code=os.environ["LARK_OAUTH_CODE"], + redirect_uri=os.environ.get("LARK_REDIRECT_URI"), + code_verifier=os.environ.get("LARK_CODE_VERIFIER"), +) + +print(resp.access_token) diff --git a/samples/client_assertion/client_assertion_provider_sample.py b/samples/client_assertion/client_assertion_provider_sample.py new file mode 100644 index 000000000..3d17c0aab --- /dev/null +++ b/samples/client_assertion/client_assertion_provider_sample.py @@ -0,0 +1,21 @@ +import os + +import lark_oapi as lark +from lark_oapi.core.client_assertion import ClientAssertionToken + + +class EnvClientAssertionProvider: + def retrieve_token(self, aud: str) -> ClientAssertionToken: + # The SDK does not generate or sign JWTs. Fetch the assertion from + # your signing service, KMS, Vault, or another secure source. + return ClientAssertionToken(os.environ["LARK_CLIENT_ASSERTION"]) + + +client = ( + lark.Client.builder() + .app_id(os.environ["LARK_APP_ID"]) + .client_assertion_provider(EnvClientAssertionProvider()) + .build() +) + +print(client.config.app_id) diff --git a/samples/ws/client_assertion_sample.py b/samples/ws/client_assertion_sample.py new file mode 100644 index 000000000..32acb15fa --- /dev/null +++ b/samples/ws/client_assertion_sample.py @@ -0,0 +1,18 @@ +import os + +from lark_oapi.core.client_assertion import ClientAssertionToken +from lark_oapi.ws import Client + + +class EnvClientAssertionProvider: + def retrieve_token(self, aud: str) -> ClientAssertionToken: + return ClientAssertionToken(os.environ["LARK_CLIENT_ASSERTION"]) + + +client = Client( + os.environ["LARK_APP_ID"], + "", + client_assertion_provider=EnvClientAssertionProvider(), +) + +client.start()