diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py index 39150add5d..2cd9d59815 100644 --- a/.github/actions/conformance/client.py +++ b/.github/actions/conformance/client.py @@ -1,32 +1,11 @@ -"""MCP unified conformance test client. - -This client is designed to work with the @modelcontextprotocol/conformance npm package. -It handles all conformance test scenarios via environment variables and CLI arguments. +"""MCP conformance test client for the @modelcontextprotocol/conformance harness. Contract: - - MCP_CONFORMANCE_SCENARIO env var -> scenario name - - MCP_CONFORMANCE_CONTEXT env var -> optional JSON (for client-credentials scenarios) - - MCP_CONFORMANCE_PROTOCOL_VERSION env var -> spec version the harness mock - server is speaking (e.g. "2025-11-25", "2026-07-28"). Always set; when - --spec-version is omitted the harness picks per-scenario (LATEST_SPEC_VERSION - for active scenarios, DRAFT_PROTOCOL_VERSION for draft-only ones). + - MCP_CONFORMANCE_SCENARIO env var -> scenario name (see HANDLERS; auth/* falls back to the auth code flow) + - MCP_CONFORMANCE_CONTEXT env var -> optional JSON (client-credentials scenarios) + - MCP_CONFORMANCE_PROTOCOL_VERSION env var -> spec version the harness mock server speaks - Server URL as last CLI argument (sys.argv[1]) - Must exit 0 within 30 seconds - -Scenarios: - initialize - Connect, initialize, list tools, close - tools_call - Connect, call add_numbers(a=5, b=3), close - sse-retry - Connect, call test_reconnection, close - json-schema-ref-no-deref - Connect, list tools (no $ref deref) - request-metadata - Connect with all callbacks; client stamps _meta - http-standard-headers - Connect, call a tool (Mcp-* headers checked) - http-invalid-tool-headers - List tools, call every surfaced tool (x-mcp-header filter) - elicitation-sep1034-client-defaults - Elicitation with default accept callback - sep-2322-client-request-state - Drive the MRTR auto-loop (SEP-2322) - auth/client-credentials-jwt - Client credentials with private_key_jwt - auth/client-credentials-basic - Client credentials with client_secret_basic - auth/enterprise-managed-authorization - SEP-990 ID-JAG (RFC 8693 + RFC 7523 jwt-bearer) - auth/* - Authorization code flow (default for auth scenarios) """ import asyncio @@ -64,32 +43,24 @@ ) logger = logging.getLogger(__name__) -#: Spec version the harness is running this scenario at (e.g. "2025-11-25", -#: "2026-07-28"). The harness always sets this (when --spec-version is omitted -#: it picks per-scenario: LATEST_SPEC_VERSION for active scenarios, -#: DRAFT_PROTOCOL_VERSION for draft-only ones), so None means we were invoked -#: outside the harness. +#: Spec version the harness mock server speaks (e.g. "2025-11-25"). The harness +#: always sets this, so None means we were invoked outside it. PROTOCOL_VERSION: str | None = os.environ.get("MCP_CONFORMANCE_PROTOCOL_VERSION") def client_mode() -> str: """Pick the Client(mode=) for the harness leg. - On a modern leg (2026-07-28+) -> 'auto' so Client.discover() runs and the - _meta envelope + MCP-Protocol-Version header are stamped on every request. - On a handshake-era leg -> 'legacy' so the initialize handshake runs exactly - as before (no server/discover probe is sent against a mock that would 400 it). - Outside the harness -> 'auto' (probe + fallback). + 'auto' on modern legs (2026-07-28+) and outside the harness; 'legacy' on handshake-era + legs so no server/discover probe is sent against a mock that would 400 it. """ if PROTOCOL_VERSION is None or PROTOCOL_VERSION in MODERN_PROTOCOL_VERSIONS: return "auto" return "legacy" -# Type for async scenario handler functions ScenarioHandler = Callable[[str], Coroutine[Any, None, None]] -# Registry of scenario handlers HANDLERS: dict[str, ScenarioHandler] = {} @@ -138,9 +109,7 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None class ConformanceOAuthCallbackHandler: - """OAuth callback handler that automatically fetches the authorization URL - and extracts the auth code, without requiring user interaction. - """ + """Fetches the authorization URL and extracts the auth code without user interaction.""" def __init__(self) -> None: self._auth_code: str | None = None @@ -148,7 +117,6 @@ def __init__(self) -> None: self._iss: str | None = None async def handle_redirect(self, authorization_url: str) -> None: - """Fetch the authorization URL and extract the auth code from the redirect.""" logger.debug(f"Fetching authorization URL: {authorization_url}") async with httpx.AsyncClient() as client: @@ -179,7 +147,6 @@ async def handle_redirect(self, authorization_url: str) -> None: raise RuntimeError(f"Expected redirect response, got {response.status_code} from {authorization_url}") async def handle_callback(self) -> AuthorizationCodeResult: - """Return the captured auth code, state, and iss.""" if self._auth_code is None: raise RuntimeError("No authorization code available - was handle_redirect called?") result = AuthorizationCodeResult(code=self._auth_code, state=self._state, iss=self._iss) @@ -189,9 +156,6 @@ async def handle_callback(self) -> AuthorizationCodeResult: return result -# --- Stub callbacks (declare capabilities in _meta without doing real work) --- - - async def stub_sampling_callback( context: ClientRequestContext, params: types.CreateMessageRequestParams, @@ -214,7 +178,6 @@ async def default_elicitation_callback( """Accept elicitation and apply defaults from the schema (SEP-1034).""" content: dict[str, str | int | float | bool | list[str] | None] = {} - # For form mode, extract defaults from the requested_schema if isinstance(params, types.ElicitRequestFormParams): schema = params.requested_schema logger.debug(f"Elicitation schema: {schema}") @@ -227,12 +190,8 @@ async def default_elicitation_callback( return types.ElicitResult(action="accept", content=content) -# --- Scenario Handlers --- - - @register("initialize") async def run_initialize(server_url: str) -> None: - """Connect, initialize, list tools, close.""" async with Client(server_url, mode=client_mode()) as client: logger.debug("Initialized successfully") await client.list_tools() @@ -241,12 +200,10 @@ async def run_initialize(server_url: str) -> None: @register("json-schema-ref-no-deref") async def run_json_schema_ref_no_deref(server_url: str) -> None: - """Initialize and list tools; the scenario fails only if the client fetches a network $ref. + """List tools; the scenario fails only if the client fetches a network $ref (SEP-2106). - The client never walks inputSchema or resolves $refs, so listing is enough (SEP-2106). - Pinned to mode='legacy': the harness reports PROTOCOL_VERSION=2026-07-28 for this - scenario but its mock server only speaks the handshake-era lifecycle and 400s a - modern-stamped tools/list. The check is lifecycle-agnostic so this is harmless. + Pinned to mode='legacy': the harness reports PROTOCOL_VERSION=2026-07-28 here, but its + mock only speaks the handshake-era lifecycle and 400s a modern-stamped tools/list. """ async with Client(server_url, mode="legacy") as client: await client.list_tools() @@ -254,7 +211,6 @@ async def run_json_schema_ref_no_deref(server_url: str) -> None: @register("tools_call") async def run_tools_call(server_url: str) -> None: - """Connect, list tools, call add_numbers(a=5, b=3), close.""" async with Client(server_url, mode=client_mode()) as client: await client.list_tools() result = await client.call_tool("add_numbers", {"a": 5, "b": 3}) @@ -263,7 +219,6 @@ async def run_tools_call(server_url: str) -> None: @register("sse-retry") async def run_sse_retry(server_url: str) -> None: - """Connect, list tools, call test_reconnection, close.""" async with Client(server_url, mode=client_mode()) as client: await client.list_tools() result = await client.call_tool("test_reconnection", {}) @@ -274,11 +229,9 @@ async def run_sse_retry(server_url: str) -> None: async def run_request_metadata(server_url: str) -> None: """Connect on the modern path with every client capability declared. - The scenario inspects every request's `_meta` envelope (SEP-2575) for - protocolVersion / clientInfo / clientCapabilities, and the matching - MCP-Protocol-Version header. mode='auto' makes the SDK send - server/discover (covering the unsupported-version retry check), then adopt - and stamp the envelope on the follow-up requests. + The scenario inspects each request's `_meta` envelope (SEP-2575) and MCP-Protocol-Version + header; mode='auto' sends server/discover (covering the unsupported-version retry check), + then stamps the envelope on follow-up requests. """ async with Client( server_url, @@ -320,13 +273,9 @@ def _stub_required_args(input_schema: dict[str, Any]) -> dict[str, Any]: async def run_http_invalid_tool_headers(server_url: str) -> None: """List tools, then call every tool the SDK surfaces (SEP-2243). - The harness mock advertises one valid tool plus several with malformed - x-mcp-header annotations (empty, non-primitive type, duplicate, invalid - chars). The scenario passes if valid_tool is called and the malformed - ones are not -- so a conforming client filters them out of the list_tools - result and the loop below never sees them. The scenario sets - allowClientError, so a per-call failure is logged and skipped rather - than aborting the whole run. + The mock advertises one valid tool plus several with malformed x-mcp-header annotations; + a conforming client filters those out of the list_tools result so the loop never sees + them. The scenario sets allowClientError, so per-call failures are logged, not fatal. """ async with Client(server_url, mode=client_mode()) as client: listed = await client.list_tools() @@ -340,13 +289,11 @@ async def run_http_invalid_tool_headers(server_url: str) -> None: @register("http-custom-headers") async def run_http_custom_headers(server_url: str) -> None: - """List tools, then replay the harness's `toolCalls` so x-mcp-header args mirror into headers (SEP-2243). + """Replay the harness's `toolCalls` verbatim so x-mcp-header args mirror into headers (SEP-2243). - The scenario supplies the exact arguments to send (including the null/edge-case values that - exercise omission and Base64 encoding) via the context `toolCalls`; using them verbatim is - what drives every per-parameter check. `list_tools` first so the SDK caches each tool's - annotations; a tool the SDK dropped (invalid annotations) is skipped. Per-call failures are - logged and skipped rather than aborting the run. + The context supplies the exact arguments (including null/edge-case values exercising omission + and Base64 encoding). `list_tools` first so the SDK caches each tool's annotations; tools the + SDK dropped (invalid annotations) are skipped, and per-call failures are logged, not fatal. """ tool_calls: list[dict[str, Any]] = [] if os.environ.get("MCP_CONFORMANCE_CONTEXT"): @@ -368,7 +315,6 @@ async def run_http_custom_headers(server_url: str) -> None: @register("elicitation-sep1034-client-defaults") async def run_elicitation_defaults(server_url: str) -> None: - """Connect with elicitation callback that applies schema defaults.""" async with Client(server_url, mode=client_mode(), elicitation_callback=default_elicitation_callback) as client: await client.list_tools() result = await client.call_tool("test_client_elicitation_defaults", {}) @@ -379,13 +325,10 @@ async def run_elicitation_defaults(server_url: str) -> None: async def run_mrtr_client(server_url: str) -> None: """Drive the SEP-2322 client mock through `Client.call_tool`'s auto-loop. - The mock inspects raw `tools/call` params, so registering an - `elicitation_callback` and letting the driver run is enough to satisfy - all five wire-shape checks: the driver echoes `request_state` byte-exact - and omits it when the server sent none, every retry mints a fresh - JSON-RPC id, the unrelated call between auto-loops carries no MRTR - params, and the no-`resultType` response parses as a terminal - `CallToolResult` so the driver never retries it. + The mock inspects raw `tools/call` params: the driver must echo `request_state` + byte-exact (omitting it when the server sent none), mint a fresh JSON-RPC id per + retry, keep MRTR params off unrelated calls, and treat a no-`resultType` response + as a terminal `CallToolResult`. """ async def confirm( @@ -459,9 +402,7 @@ async def run_client_credentials_basic(server_url: str) -> None: @register("auth/enterprise-managed-authorization") async def run_enterprise_managed_authorization(server_url: str) -> None: - """SEP-990 enterprise-managed authorization: RFC 8693 token-exchange at the - enterprise IdP for an ID-JAG, then RFC 7523 jwt-bearer at the MCP - authorization server.""" + """SEP-990: RFC 8693 token-exchange at the IdP for an ID-JAG, then RFC 7523 jwt-bearer at the MCP AS.""" context = get_conformance_context() client_id = context.get("client_id") client_secret = context.get("client_secret") @@ -480,11 +421,9 @@ async def run_enterprise_managed_authorization(server_url: str) -> None: if not idp_token_endpoint: raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'") - # IdentityAssertionOAuthProvider takes the AS issuer as configuration (the - # SEP-990 trust model: the resource server is never asked which AS to use). - # The harness does not put the issuer in context, so for conformance we - # learn it from the harness's PRM document (RFC 9728); production - # deployments would supply it as static configuration instead. + # SEP-990 trust model: the AS issuer is client configuration, never asked of the resource + # server. The harness omits it from context, so learn it from the PRM document (RFC 9728); + # production deployments would configure it statically. prm_url = build_protected_resource_metadata_discovery_urls(None, server_url)[0] async with httpx.AsyncClient(timeout=30.0) as http: prm = (await http.get(prm_url)).raise_for_status().json() @@ -526,7 +465,6 @@ async def run_auth_code_client(server_url: str) -> None: callback_handler = ConformanceOAuthCallbackHandler() storage = InMemoryTokenStorage() - # Check for pre-registered client credentials from context context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT") if context_json: try: @@ -573,7 +511,7 @@ async def _run_auth_session(server_url: str, oauth_auth: httpx.Auth) -> None: tools_result = await client.list_tools() logger.debug(f"Listed tools: {[t.name for t in tools_result.tools]}") - # Call the first available tool (different tests have different tools) + # Different tests expose different tools; call the first one if tools_result.tools: tool_name = tools_result.tools[0].name try: @@ -586,7 +524,6 @@ async def _run_auth_session(server_url: str, oauth_auth: httpx.Auth) -> None: def main() -> None: - """Main entry point for the conformance client.""" if len(sys.argv) < 2: print(f"Usage: {sys.argv[0]} ", file=sys.stderr) sys.exit(1) diff --git a/docs/hooks/gen_ref_pages.py b/docs/hooks/gen_ref_pages.py index 8e1afeee68..df0e3acfc3 100644 --- a/docs/hooks/gen_ref_pages.py +++ b/docs/hooks/gen_ref_pages.py @@ -9,9 +9,8 @@ root = Path(__file__).parent.parent.parent src = root / "src" -# `src/mcp-types` is a distribution directory, not an import package, so each -# package's dotted module path is taken relative to its own parent: deriving it -# from `src/` would emit the unimportable `mcp-types.mcp_types.*`. +# `src/mcp-types` is a distribution directory, not an import package, so each package's dotted module path +# is taken relative to its own parent; deriving it from `src/` would emit the unimportable `mcp-types.mcp_types.*`. for package in (src / "mcp", src / "mcp-types" / "mcp_types"): base = package.parent for path in sorted(package.rglob("*.py")): diff --git a/docs_src/__init__.py b/docs_src/__init__.py index d19acbc088..417b5ce5b6 100644 --- a/docs_src/__init__.py +++ b/docs_src/__init__.py @@ -1,7 +1,5 @@ """Complete, runnable source for every code example in `docs/`. -Each `docs/.md` includes its examples from `docs_src//tutorialNNN.py` -via `--8<--`, and `tests/docs_src/test_.py` imports the same module and -exercises it through the in-memory `mcp.Client`. The file you read in the docs is -the file CI runs. +Docs pages include these files via `--8<--`, and `tests/docs_src/` exercises the same +modules through the in-memory `mcp.Client` β€” the file you read in the docs is the file CI runs. """ diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 0d461d5d11..f1e03e3ed4 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 -"""Simple MCP client example with OAuth authentication support. - -This client connects to an MCP server using streamable HTTP transport with OAuth. - -""" +"""Simple MCP client example with OAuth authentication support.""" from __future__ import annotations as _annotations @@ -57,12 +53,10 @@ def __init__( server: socketserver.BaseServer, callback_data: dict[str, Any], ): - """Initialize with callback data storage.""" self.callback_data = callback_data super().__init__(request, client_address, server) def do_GET(self): - """Handle GET request from OAuth redirect.""" parsed = urlparse(self.path) query_params = parse_qs(parsed.query) @@ -116,7 +110,6 @@ def __init__(self, port: int = 3000): self.callback_data = {"authorization_code": None, "state": None, "iss": None, "error": None} def _create_handler_with_data(self): - """Create a handler class with access to callback data.""" callback_data = self.callback_data class DataCallbackHandler(CallbackHandler): @@ -131,7 +124,6 @@ def __init__( return DataCallbackHandler def start(self): - """Start the callback server in a background thread.""" handler_class = self._create_handler_with_data() self.server = HTTPServer(("localhost", self.port), handler_class) self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) @@ -139,7 +131,6 @@ def start(self): print(f"πŸ–₯️ Started callback server on http://localhost:{self.port}") def stop(self): - """Stop the callback server.""" if self.server: self.server.shutdown() self.server.server_close() @@ -147,7 +138,6 @@ def stop(self): self.thread.join(timeout=1) def wait_for_callback(self, timeout: int = 300): - """Wait for OAuth callback with timeout.""" start_time = time.time() while time.time() - start_time < timeout: if self.callback_data["authorization_code"]: @@ -159,12 +149,10 @@ def wait_for_callback(self, timeout: int = 300): @property def state(self): - """The received state parameter.""" return self.callback_data["state"] @property def iss(self): - """The received iss parameter.""" return self.callback_data["iss"] @@ -183,7 +171,6 @@ def __init__( self.session: ClientSession | None = None async def connect(self): - """Connect to the MCP server.""" print(f"πŸ”— Attempting to connect to {self.server_url}...") try: @@ -191,7 +178,6 @@ async def connect(self): callback_server.start() async def callback_handler() -> AuthorizationCodeResult: - """Wait for OAuth callback and return auth code, state, and iss.""" print("⏳ Waiting for authorization callback...") try: auth_code = callback_server.wait_for_callback(timeout=300) @@ -207,12 +193,10 @@ async def callback_handler() -> AuthorizationCodeResult: } async def _default_redirect_handler(authorization_url: str) -> None: - """Default redirect handler that opens the URL in a browser.""" print(f"Opening browser for authorization: {authorization_url}") webbrowser.open(authorization_url) - # Create OAuth authentication handler using the new interface - # Use client_metadata_url to enable CIMD when the server supports it + # client_metadata_url enables CIMD when the server supports it oauth_auth = OAuthClientProvider( server_url=self.server_url.replace("/mcp", ""), client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), @@ -222,7 +206,6 @@ async def _default_redirect_handler(authorization_url: str) -> None: client_metadata_url=self.client_metadata_url, ) - # Create transport with auth handler based on transport type if self.transport_type == "sse": print("πŸ“‘ Opening SSE transport connection with auth...") async with sse_client( @@ -251,7 +234,6 @@ async def _run_session( read_stream: ReadStream[SessionMessage | Exception], write_stream: WriteStream[SessionMessage], ): - """Run the MCP session with the given streams.""" print("🀝 Initializing MCP session...") async with ClientSession(read_stream, write_stream) as session: self.session = session @@ -261,11 +243,9 @@ async def _run_session( print(f"\nβœ… Connected to MCP server at {self.server_url}") - # Run interactive loop await self.interactive_loop() async def list_tools(self): - """List available tools from the server.""" if not self.session: print("❌ Not connected to server") return @@ -285,7 +265,6 @@ async def list_tools(self): print(f"❌ Failed to list tools: {e}") async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): - """Call a specific tool.""" if not self.session: print("❌ Not connected to server") return @@ -305,7 +284,6 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = Non print(f"❌ Failed to call tool '{tool_name}': {e}") async def interactive_loop(self): - """Run interactive command loop.""" print("\n🎯 Interactive MCP Client") print("Commands:") print(" list - List available tools") @@ -334,7 +312,6 @@ async def interactive_loop(self): print("❌ Please specify a tool name") continue - # Parse arguments (simple JSON-like format) arguments: dict[str, Any] = {} if len(parts) > 2: import json @@ -358,8 +335,6 @@ async def interactive_loop(self): async def main(): - """Main entry point.""" - # Default server URL - can be overridden with environment variable # Most MCP streamable HTTP servers use /mcp as the endpoint server_url = os.getenv("MCP_SERVER_PORT", 8000) transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable-http") @@ -376,7 +351,7 @@ async def main(): if client_metadata_url: print(f"Client metadata URL: {client_metadata_url}") - # Start connection flow - OAuth will be handled automatically + # OAuth is handled automatically during connect client = SimpleAuthClient(server_url, transport_type, client_metadata_url) await client.connect() diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 72b1a6f204..59f062baea 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -13,7 +13,6 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client -# Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -21,7 +20,6 @@ class Configuration: """Manages configuration and environment variables for the MCP client.""" def __init__(self) -> None: - """Initialize configuration with environment variables.""" self.load_env() self.api_key = os.getenv("LLM_API_KEY") @@ -32,30 +30,21 @@ def load_env() -> None: @staticmethod def load_config(file_path: str) -> dict[str, Any]: - """Load server configuration from JSON file. - - Args: - file_path: Path to the JSON configuration file. - - Returns: - Dict containing server configuration. + """Load server configuration from a JSON file. Raises: - FileNotFoundError: If configuration file doesn't exist. - JSONDecodeError: If configuration file is invalid JSON. + FileNotFoundError: If the file doesn't exist. + JSONDecodeError: If the file isn't valid JSON. """ with open(file_path, "r") as f: return json.load(f) @property def llm_api_key(self) -> str: - """Get the LLM API key. - - Returns: - The API key as a string. + """The LLM API key. Raises: - ValueError: If the API key is not found in environment variables. + ValueError: If LLM_API_KEY is not set in the environment. """ if not self.api_key: raise ValueError("LLM_API_KEY not found in environment variables") @@ -98,9 +87,6 @@ async def initialize(self) -> None: async def list_tools(self) -> list[Tool]: """List available tools from the server. - Returns: - A list of available tools. - Raises: RuntimeError: If the server is not initialized. """ @@ -123,19 +109,10 @@ async def execute_tool( retries: int = 2, delay: float = 1.0, ) -> Any: - """Execute a tool with retry mechanism. - - Args: - tool_name: Name of the tool to execute. - arguments: Tool arguments. - retries: Number of retry attempts. - delay: Delay between retries in seconds. - - Returns: - Tool execution result. + """Execute a tool, making up to `retries` attempts with `delay` seconds between them. Raises: - RuntimeError: If server is not initialized. + RuntimeError: If the server is not initialized. Exception: If tool execution fails after all retries. """ if not self.session: @@ -186,11 +163,7 @@ def __init__( self.input_schema: dict[str, Any] = input_schema def format_for_llm(self) -> str: - """Format tool information for LLM. - - Returns: - A formatted string describing the tool. - """ + """Format tool information for inclusion in the LLM system prompt.""" args_desc: list[str] = [] if "properties" in self.input_schema: for param_name, param_info in self.input_schema["properties"].items(): @@ -199,10 +172,8 @@ def format_for_llm(self) -> str: arg_desc += " (required)" args_desc.append(arg_desc) - # Build the formatted output with title as a separate field output = f"Tool: {self.name}\n" - # Add human-readable title if available if self.title: output += f"User-readable title: {self.title}\n" @@ -223,12 +194,6 @@ def __init__(self, api_key: str) -> None: def get_response(self, messages: list[dict[str, str]]) -> str: """Get a response from the LLM. - Args: - messages: A list of message dictionaries. - - Returns: - The LLM's response as a string. - Raises: httpx.RequestError: If the request to the LLM fails. """ @@ -283,14 +248,7 @@ async def cleanup_servers(self) -> None: logging.warning(f"Warning during final cleanup: {e}") async def process_llm_response(self, llm_response: str) -> str: - """Process the LLM response and execute tools if needed. - - Args: - llm_response: The response from the LLM. - - Returns: - The result of tool execution or the original response. - """ + """Execute the requested tool if the response is a JSON tool call, otherwise return it unchanged.""" import json def _clean_json_string(json_string: str) -> str: diff --git a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py index e91ed9d527..322782d799 100644 --- a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py +++ b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py @@ -1,17 +1,10 @@ -"""SSE Polling Demo Client +"""SSE polling demo client. -Demonstrates the client-side auto-reconnect for SSE polling pattern. +Calls process_batch on the demo server, which periodically closes the SSE stream; +the client auto-reconnects via Last-Event-ID and resumes receiving messages. -This client connects to the SSE Polling Demo server and calls process_batch, -which triggers periodic server-side stream closes. The client automatically -reconnects using Last-Event-ID and resumes receiving messages. - -Run with: - # First start the server: - uv run mcp-sse-polling-demo --port 3000 - - # Then run this client: - uv run mcp-sse-polling-client --url http://localhost:3000/mcp +Start the server (`uv run mcp-sse-polling-demo --port 3000`), then run +`uv run mcp-sse-polling-client --url http://localhost:3000/mcp`. """ import asyncio @@ -23,7 +16,6 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None: - """Run the SSE polling demo.""" print(f"\n{'=' * 60}") print("SSE Polling Demo Client") print(f"{'=' * 60}") @@ -33,16 +25,13 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None: async with streamable_http_client(url) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: - # Initialize the connection print("Initializing connection...") await session.initialize() print("Connected!\n") - # List available tools tools = await session.list_tools() print(f"Available tools: {[t.name for t in tools.tools]}\n") - # Call the process_batch tool print(f"Calling process_batch(items={items}, checkpoint_every={checkpoint_every})...\n") print("-" * 40) diff --git a/examples/mcpserver/complex_inputs.py b/examples/mcpserver/complex_inputs.py index 93a42d1c89..0e6295f194 100644 --- a/examples/mcpserver/complex_inputs.py +++ b/examples/mcpserver/complex_inputs.py @@ -1,7 +1,4 @@ -"""MCPServer Complex inputs Example - -Demonstrates validation via pydantic with complex models. -""" +"""Demonstrates validation via pydantic with complex models.""" from typing import Annotated diff --git a/examples/mcpserver/desktop.py b/examples/mcpserver/desktop.py index 804184516d..cab11b47e4 100644 --- a/examples/mcpserver/desktop.py +++ b/examples/mcpserver/desktop.py @@ -1,13 +1,9 @@ -"""MCPServer Desktop Example - -A simple example that exposes the desktop directory as a resource. -""" +"""A simple example that exposes the desktop directory as a resource.""" from pathlib import Path from mcp.server.mcpserver import MCPServer -# Create server mcp = MCPServer("Demo") diff --git a/examples/mcpserver/echo.py b/examples/mcpserver/echo.py index 501c47069b..2686862940 100644 --- a/examples/mcpserver/echo.py +++ b/examples/mcpserver/echo.py @@ -2,7 +2,6 @@ from mcp.server.mcpserver import MCPServer -# Create server mcp = MCPServer("Echo Server") diff --git a/examples/mcpserver/icons_demo.py b/examples/mcpserver/icons_demo.py index f50389f32f..02fc4f7c0f 100644 --- a/examples/mcpserver/icons_demo.py +++ b/examples/mcpserver/icons_demo.py @@ -1,7 +1,4 @@ -"""MCPServer Icons Demo Server - -Demonstrates using icons with tools, resources, prompts, and implementation. -""" +"""Demonstrates using icons with tools, resources, prompts, and the server implementation.""" import base64 from pathlib import Path @@ -15,7 +12,6 @@ icon_data = Icon(src=icon_data_uri, mime_type="image/png", sizes=["64x64"]) -# Create server with icons in implementation mcp = MCPServer( "Icons Demo Server", website_url="https://github.com/modelcontextprotocol/python-sdk", icons=[icon_data] ) @@ -52,5 +48,4 @@ def multi_icon_tool(action: str) -> str: if __name__ == "__main__": - # Run the server mcp.run() diff --git a/examples/mcpserver/logging_and_progress.py b/examples/mcpserver/logging_and_progress.py index b157f9dd05..5edcce4669 100644 --- a/examples/mcpserver/logging_and_progress.py +++ b/examples/mcpserver/logging_and_progress.py @@ -4,7 +4,6 @@ from mcp.server.mcpserver import Context, MCPServer -# Create server mcp = MCPServer("Echo Server with logging and progress updates") @@ -24,8 +23,7 @@ async def echo(text: str, ctx: Context) -> str: await ctx.info("Finished processing echo for input: " + text) await ctx.report_progress(progress=100, total=100) - # Progress notifications are process asynchronously by the client. - # A small delay here helps ensure the last notification is processed by the client. + # Clients process progress notifications asynchronously; a short delay lets the last one arrive. await asyncio.sleep(0.1) return text diff --git a/examples/mcpserver/memory.py b/examples/mcpserver/memory.py index fd0bd93627..b16db56c47 100644 --- a/examples/mcpserver/memory.py +++ b/examples/mcpserver/memory.py @@ -2,12 +2,7 @@ # dependencies = ["pydantic-ai-slim[openai]", "asyncpg", "numpy", "pgvector"] # /// -# uv pip install 'pydantic-ai-slim[openai]' asyncpg numpy pgvector - -"""Recursive memory system inspired by the human brain's clustering of memories. -Uses OpenAI's 'text-embedding-3-small' model and pgvector for efficient -similarity search. -""" +"""Recursive memory system inspired by the human brain's clustering of memories.""" import asyncio import math @@ -20,7 +15,7 @@ import asyncpg import numpy as np from openai import AsyncOpenAI -from pgvector.asyncpg import register_vector # Import register_vector +from pgvector.asyncpg import register_vector from pydantic import BaseModel, Field from pydantic_ai import Agent @@ -141,7 +136,6 @@ async def merge_with(self, other: Self, deps: Deps): self.embedding = [(a + b) / 2 for a, b in zip(self.embedding, other.embedding)] self.summary = await do_ai(self.content, "Summarize the following text concisely.", str, deps) await self.save(deps) - # Delete the merged node from the database if other.id is not None: await delete_memory(other.id, deps) diff --git a/examples/mcpserver/parameter_descriptions.py b/examples/mcpserver/parameter_descriptions.py index 59a1caf3f6..e26540f9cd 100644 --- a/examples/mcpserver/parameter_descriptions.py +++ b/examples/mcpserver/parameter_descriptions.py @@ -4,7 +4,6 @@ from mcp.server.mcpserver import MCPServer -# Create server mcp = MCPServer("Parameter Descriptions Server") diff --git a/examples/mcpserver/readme-quickstart.py b/examples/mcpserver/readme-quickstart.py index 864b774a9e..a330c5385d 100644 --- a/examples/mcpserver/readme-quickstart.py +++ b/examples/mcpserver/readme-quickstart.py @@ -1,10 +1,8 @@ from mcp.server.mcpserver import MCPServer -# Create an MCP server mcp = MCPServer("Demo") -# Add an addition tool @mcp.tool() def sum(a: int, b: int) -> int: """Add two numbers""" diff --git a/examples/mcpserver/screenshot.py b/examples/mcpserver/screenshot.py index e7b3ee6fbd..cc4dd20b73 100644 --- a/examples/mcpserver/screenshot.py +++ b/examples/mcpserver/screenshot.py @@ -1,14 +1,10 @@ -"""MCPServer Screenshot Example - -Give Claude a tool to capture and view screenshots. -""" +"""Give Claude a tool to capture and view screenshots.""" import io from mcp.server.mcpserver import MCPServer from mcp.server.mcpserver.utilities.types import Image -# Create server mcp = MCPServer("Screenshot Demo") diff --git a/examples/mcpserver/simple_echo.py b/examples/mcpserver/simple_echo.py index 3d8142a665..f26d7c862b 100644 --- a/examples/mcpserver/simple_echo.py +++ b/examples/mcpserver/simple_echo.py @@ -2,7 +2,6 @@ from mcp.server.mcpserver import MCPServer -# Create server mcp = MCPServer("Echo Server") diff --git a/examples/mcpserver/text_me.py b/examples/mcpserver/text_me.py index 7aeb543621..23f9f64120 100644 --- a/examples/mcpserver/text_me.py +++ b/examples/mcpserver/text_me.py @@ -2,19 +2,10 @@ # dependencies = [] # /// -"""MCPServer Text Me Server --------------------------------- -This defines a simple MCPServer server that sends a text message to a phone number via https://surgemsg.com/. +"""MCPServer that sends a text message to a phone number via https://surgemsg.com/. -To run this example, create a `.env` file with the following values: - -SURGE_API_KEY=... -SURGE_ACCOUNT_ID=... -SURGE_MY_PHONE_NUMBER=... -SURGE_MY_FIRST_NAME=... -SURGE_MY_LAST_NAME=... - -Visit https://surgemsg.com/ and click "Get Started" to obtain these values. +Requires a `.env` file with SURGE_API_KEY, SURGE_ACCOUNT_ID, SURGE_MY_PHONE_NUMBER, +SURGE_MY_FIRST_NAME, and SURGE_MY_LAST_NAME β€” visit https://surgemsg.com/ to obtain them. """ from typing import Annotated @@ -36,7 +27,6 @@ class SurgeSettings(BaseSettings): my_last_name: str -# Create server mcp = MCPServer("Text me") surge_settings = SurgeSettings() # type: ignore diff --git a/examples/mcpserver/unicode_example.py b/examples/mcpserver/unicode_example.py index 012633ec76..2a6684d7a7 100644 --- a/examples/mcpserver/unicode_example.py +++ b/examples/mcpserver/unicode_example.py @@ -1,6 +1,4 @@ -"""Example MCPServer server that uses Unicode characters in various places to help test -Unicode handling in tools and inspectors. -""" +"""MCPServer example using Unicode characters to exercise Unicode handling in tools and inspectors.""" from mcp.server.mcpserver import MCPServer @@ -9,11 +7,7 @@ @mcp.tool(description="🌟 A tool that uses various Unicode characters in its description: Γ‘ Γ© Γ­ Γ³ ΓΊ Γ± ζΌ’ε­— πŸŽ‰") def hello_unicode(name: str = "δΈ–η•Œ", greeting: str = "Β‘Hola") -> str: - """A simple tool that demonstrates Unicode handling in: - - Tool description (emojis, accents, CJK characters) - - Parameter defaults (CJK characters) - - Return values (Spanish punctuation, emojis) - """ + """Demonstrates Unicode in the tool description, parameter defaults, and return value.""" return f"{greeting}, {name}! πŸ‘‹" diff --git a/examples/mcpserver/weather_structured.py b/examples/mcpserver/weather_structured.py index 958c7d3197..eb48a346e1 100644 --- a/examples/mcpserver/weather_structured.py +++ b/examples/mcpserver/weather_structured.py @@ -1,8 +1,4 @@ -"""MCPServer Weather Example with Structured Output - -Demonstrates how to use structured output with tools to return -well-typed, validated data that clients can easily process. -""" +"""MCPServer structured output: tools returning Pydantic models, TypedDicts, dataclasses, dicts, and primitives.""" import asyncio import json @@ -16,13 +12,12 @@ from mcp.client import Client from mcp.server.mcpserver import MCPServer -# Create server mcp = MCPServer("Weather Service") # Example 1: Using a Pydantic model for structured output class WeatherData(BaseModel): - """Structured weather data response""" + """Structured weather data response.""" temperature: float = Field(description="Temperature in Celsius") humidity: float = Field(description="Humidity percentage (0-100)") @@ -34,15 +29,13 @@ class WeatherData(BaseModel): @mcp.tool() def get_weather(city: str) -> WeatherData: - """Get current weather for a city with full structured data""" - # In a real implementation, this would fetch from a weather API + """Get current weather for a city with full structured data.""" + # A real implementation would fetch from a weather API return WeatherData(temperature=22.5, humidity=65.0, condition="partly cloudy", wind_speed=12.3, location=city) # Example 2: Using TypedDict for a simpler structure class WeatherSummary(TypedDict): - """Simple weather summary""" - city: str temp_c: float description: str @@ -50,18 +43,14 @@ class WeatherSummary(TypedDict): @mcp.tool() def get_weather_summary(city: str) -> WeatherSummary: - """Get a brief weather summary for a city""" + """Get a brief weather summary for a city.""" return WeatherSummary(city=city, temp_c=22.5, description="Partly cloudy with light breeze") -# Example 3: Using dict[str, Any] for flexible schemas +# Example 3: Using nested dicts for flexible schemas @mcp.tool() def get_weather_metrics(cities: list[str]) -> dict[str, dict[str, float]]: - """Get weather metrics for multiple cities - - Returns a dictionary mapping city names to their metrics - """ - # Returns nested dictionaries with weather metrics + """Get weather metrics for multiple cities.""" return { city: {"temperature": 20.0 + i * 2, "humidity": 60.0 + i * 5, "pressure": 1013.0 + i * 0.5} for i, city in enumerate(cities) @@ -71,8 +60,6 @@ def get_weather_metrics(cities: list[str]) -> dict[str, dict[str, float]]: # Example 4: Using dataclass for weather alerts @dataclass class WeatherAlert: - """Weather alert information""" - severity: str # "low", "medium", "high" title: str description: str @@ -82,8 +69,7 @@ class WeatherAlert: @mcp.tool() def get_weather_alerts(region: str) -> list[WeatherAlert]: - """Get active weather alerts for a region""" - # In production, this would fetch real alerts + """Get active weather alerts for a region.""" if region.lower() == "california": return [ WeatherAlert( @@ -104,14 +90,10 @@ def get_weather_alerts(region: str) -> list[WeatherAlert]: return [] -# Example 5: Returning primitives with structured output +# Example 5: Primitive returns are wrapped in {"result": value} as structured output @mcp.tool() def get_temperature(city: str, unit: str = "celsius") -> float: - """Get just the temperature for a city - - When returning primitives as structured output, - the result is wrapped in {"result": value} - """ + """Get just the temperature for a city.""" base_temp = 22.5 if unit.lower() == "fahrenheit": return base_temp * 9 / 5 + 32 @@ -120,16 +102,12 @@ def get_temperature(city: str, unit: str = "celsius") -> float: # Example 6: Weather statistics with nested models class DailyStats(BaseModel): - """Statistics for a single day""" - high: float low: float mean: float class WeatherStats(BaseModel): - """Weather statistics over a period""" - location: str period_days: int temperature: DailyStats @@ -139,7 +117,7 @@ class WeatherStats(BaseModel): @mcp.tool() def get_weather_stats(city: str, days: int = 7) -> WeatherStats: - """Get weather statistics for the past N days""" + """Get weather statistics for the past N days.""" return WeatherStats( location=city, period_days=days, @@ -152,49 +130,42 @@ def get_weather_stats(city: str, days: int = 7) -> WeatherStats: if __name__ == "__main__": async def test() -> None: - """Test the tools by calling them through the server as a client would""" + """Call each tool through an in-memory client session.""" print("Testing Weather Service Tools (via MCP protocol)\n") print("=" * 80) async with Client(mcp) as client: - # Test get_weather result = await client.call_tool("get_weather", {"city": "London"}) print("\nWeather in London:") print(json.dumps(result.structured_content, indent=2)) - # Test get_weather_summary result = await client.call_tool("get_weather_summary", {"city": "Paris"}) print("\nWeather summary for Paris:") print(json.dumps(result.structured_content, indent=2)) - # Test get_weather_metrics result = await client.call_tool("get_weather_metrics", {"cities": ["Tokyo", "Sydney", "Mumbai"]}) print("\nWeather metrics:") print(json.dumps(result.structured_content, indent=2)) - # Test get_weather_alerts result = await client.call_tool("get_weather_alerts", {"region": "California"}) print("\nWeather alerts for California:") print(json.dumps(result.structured_content, indent=2)) - # Test get_temperature result = await client.call_tool("get_temperature", {"city": "Berlin", "unit": "fahrenheit"}) print("\nTemperature in Berlin:") print(json.dumps(result.structured_content, indent=2)) - # Test get_weather_stats result = await client.call_tool("get_weather_stats", {"city": "Seattle", "days": 30}) print("\nWeather stats for Seattle (30 days):") print(json.dumps(result.structured_content, indent=2)) - # Also show the text content for comparison + # Structured results also carry a text content block print("\nText content for last result:") for content in result.content: if content.type == "text": print(content.text) async def print_schemas() -> None: - """Print all tool schemas""" print("Tool Schemas for Weather Service\n") print("=" * 80) @@ -213,7 +184,6 @@ async def print_schemas() -> None: print("-" * 80) - # Check command line arguments if len(sys.argv) > 1 and sys.argv[1] == "--schemas": asyncio.run(print_schemas()) else: diff --git a/examples/servers/everything-server/mcp_everything_server/__main__.py b/examples/servers/everything-server/mcp_everything_server/__main__.py index 2eff688f02..f5f6e402df 100644 --- a/examples/servers/everything-server/mcp_everything_server/__main__.py +++ b/examples/servers/everything-server/mcp_everything_server/__main__.py @@ -1,5 +1,3 @@ -"""CLI entry point for the MCP Everything Server.""" - from .server import main if __name__ == "__main__": diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index f622aac7a3..540e93b937 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -1,8 +1,5 @@ #!/usr/bin/env python3 -"""MCP Everything Server - Conformance Test Server - -Server implementing all MCP features for conformance testing based on Conformance Server Specification. -""" +"""MCP Everything Server: implements all MCP features per the Conformance Server Specification.""" import asyncio import base64 @@ -52,7 +49,6 @@ logger = logging.getLogger(__name__) -# Type aliases for event store StreamId = str EventId = str @@ -65,14 +61,12 @@ def __init__(self) -> None: self._event_id_counter = 0 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: - """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) self._events.append((stream_id, event_id, message)) return event_id async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: - """Replay events after the specified ID.""" target_stream_id = None for stream_id, event_id, _ in self._events: if event_id == last_event_id: @@ -89,15 +83,13 @@ async def replay_events_after(self, last_event_id: EventId, send_callback: Event return target_stream_id -# Test data TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" TEST_AUDIO_BASE64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAAB9AAACABAAZGF0YQIAAAA=" -# Server state resource_subscriptions: set[str] = set() watched_resource_content = "Watched resource content" -# Create event store for SSE resumability (SEP-1699) +# Event store for SSE resumability (SEP-1699) event_store = InMemoryEventStore() mcp = MCPServer( @@ -105,7 +97,6 @@ async def replay_events_after(self, last_event_id: EventId, send_callback: Event ) -# Tools @mcp.tool() def test_simple_text() -> str: """Tests simple text content response""" @@ -180,7 +171,6 @@ async def test_tool_with_progress(ctx: Context) -> str: await ctx.report_progress(progress=100, total=100, message="Completed step 100 of 100") - # Return progress token as string progress_token = ( ctx.request_context.meta.get("progress_token") if ctx.request_context and ctx.request_context.meta else 0 ) @@ -191,7 +181,6 @@ async def test_tool_with_progress(ctx: Context) -> str: async def test_sampling(prompt: str, ctx: Context) -> str: """Tests server-initiated sampling (LLM completion request)""" try: - # Request sampling from client result = await ctx.session.create_message( # pyright: ignore[reportDeprecated] messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], max_tokens=100, @@ -216,7 +205,6 @@ class UserResponse(BaseModel): async def test_elicitation(message: str, ctx: Context) -> str: """Tests server-initiated elicitation (user input request)""" try: - # Request user input from client result = await ctx.elicit(message=message, schema=UserResponse) # Type-safe discriminated union narrowing using action field @@ -248,10 +236,8 @@ class SEP1034DefaultsSchema(BaseModel): async def test_elicitation_sep1034_defaults(ctx: Context) -> str: """Tests elicitation with default values for all primitive types (SEP-1034)""" try: - # Request user input with defaults for all primitive types result = await ctx.elicit(message="Please provide user information", schema=SEP1034DefaultsSchema) - # Type-safe discriminated union narrowing using action field if result.action == "accept": content = result.data.model_dump_json() else: # decline or cancel @@ -327,15 +313,11 @@ def test_error_handling() -> str: raise RuntimeError("This tool intentionally returns an error for testing") +# Tool dispatch re-raises MCPError as a protocol-level error rather than wrapping it in +# `CallToolResult.isError`, so the harness observes a JSON-RPC error with `data.requiredCapabilities`. @mcp.tool() async def test_missing_capability(ctx: Context) -> str: - """Tests that a handler-raised MISSING_REQUIRED_CLIENT_CAPABILITY surfaces as a top-level JSON-RPC error. - - Requires the client to declare the ``sampling`` capability. When absent, raises - `MCPError` (which the tool dispatch re-raises rather than wrapping in - ``CallToolResult.isError``) so the conformance harness observes a protocol-level - error response with ``data.requiredCapabilities``. - """ + """Tests that a handler-raised MISSING_REQUIRED_CLIENT_CAPABILITY surfaces as a top-level JSON-RPC error.""" client_params = ctx.session.client_params sampling_declared = client_params is not None and client_params.capabilities.sampling is not None if not sampling_declared: @@ -532,8 +514,7 @@ async def test_input_required_result_capabilities(ctx: Context) -> InputRequired return InputRequiredResult(input_requests=requests, request_state="capability-gated") -# SEP-1613 / SEP-2106 JSON Schema 2020-12 fixture: a tool whose inputSchema carries -# the full set of 2020-12 keywords the conformance scenario asserts on. +# SEP-1613 / SEP-2106 fixture: inputSchema carries the JSON Schema 2020-12 keywords the scenario asserts on. JSON_SCHEMA_2020_12_INPUT_SCHEMA: dict[str, Any] = { "$schema": "https://json-schema.org/draft/2020-12/schema", @@ -585,7 +566,6 @@ async def test_reconnection(ctx: Context) -> str: return "Reconnection test completed" -# Resources @mcp.resource("test://static-text") def static_text_resource() -> str: """A static text resource for testing""" @@ -610,7 +590,6 @@ def watched_resource() -> str: return watched_resource_content -# Prompts @mcp.prompt() def test_simple_prompt() -> list[UserMessage]: """A simple prompt without arguments""" @@ -655,24 +634,20 @@ def test_prompt_with_image() -> list[UserMessage]: ] -# Custom request handlers # TODO(felix): Add public APIs to MCPServer for subscribe_resource, unsubscribe_resource, # and set_logging_level to avoid accessing protected _lowlevel_server attribute. async def handle_set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: - """Handle logging level changes""" logger.info(f"Log level set to: {params.level}") return EmptyResult() async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: - """Handle resource subscription""" resource_subscriptions.add(str(params.uri)) logger.info(f"Subscribed to resource: {params.uri}") return EmptyResult() async def handle_unsubscribe(ctx: ServerRequestContext, params: UnsubscribeRequestParams) -> EmptyResult: - """Handle resource unsubscription""" resource_subscriptions.discard(str(params.uri)) logger.info(f"Unsubscribed from resource: {params.uri}") return EmptyResult() @@ -695,13 +670,10 @@ async def _handle_completion( argument: CompletionArgument, context: CompletionContext | None, ) -> Completion: - """Handle completion requests""" - # Basic completion support - returns empty array for conformance - # Real implementations would provide contextual suggestions + # Empty values satisfy conformance; a real server would return contextual suggestions. return Completion(values=[], total=0, has_more=False) -# CLI @click.command() @click.option("--port", default=3001, help="Port to listen on for HTTP") @click.option( diff --git a/examples/servers/simple-auth/mcp_simple_auth/__main__.py b/examples/servers/simple-auth/mcp_simple_auth/__main__.py index 2365ff5a1b..ef3f6a3f62 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/__main__.py +++ b/examples/servers/simple-auth/mcp_simple_auth/__main__.py @@ -1,5 +1,3 @@ -"""Main entry point for simple MCP server with GitHub OAuth authentication.""" - import sys from mcp_simple_auth.server import main diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 26c87c5ef2..ab96687377 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -1,11 +1,7 @@ -"""Authorization Server for MCP Split Demo. - -This server handles OAuth flows, client registration, and token issuance. -Can be replaced with enterprise authorization servers like Auth0, Entra ID, etc. - -NOTE: this is a simplified example for demonstration purposes. -This is not a production-ready implementation. +"""Authorization Server for the MCP split demo: OAuth flows, client registration, token issuance. +Simplified for demonstration β€” in production this role is filled by an enterprise +authorization server (Auth0, Entra ID, etc). """ import asyncio @@ -32,7 +28,6 @@ class AuthServerSettings(BaseModel): """Settings for the Authorization Server.""" - # Server settings host: str = "localhost" port: int = 9000 server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000") @@ -40,12 +35,7 @@ class AuthServerSettings(BaseModel): class SimpleAuthProvider(SimpleOAuthProvider): - """Authorization Server provider with simple demo authentication. - - This provider: - 1. Issues MCP tokens after simple credential authentication - 2. Stores token state for introspection by Resource Servers - """ + """Demo provider: issues MCP tokens after credential auth and stores token state for introspection.""" def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str): super().__init__(auth_settings, auth_callback_path, server_url) @@ -68,7 +58,6 @@ def create_authorization_server(server_settings: AuthServerSettings, auth_settin resource_server_url=None, ) - # Create OAuth routes routes = create_auth_routes( provider=oauth_provider, issuer_url=mcp_auth_settings.issuer_url, @@ -77,9 +66,7 @@ def create_authorization_server(server_settings: AuthServerSettings, auth_settin revocation_options=mcp_auth_settings.revocation_options, ) - # Add login page route (GET) async def login_page_handler(request: Request) -> Response: - """Show login form.""" state = request.query_params.get("state") if not state: raise HTTPException(400, "Missing state parameter") @@ -87,26 +74,18 @@ async def login_page_handler(request: Request) -> Response: routes.append(Route("/login", endpoint=login_page_handler, methods=["GET"])) - # Add login callback route (POST) async def login_callback_handler(request: Request) -> Response: - """Handle simple authentication callback.""" return await oauth_provider.handle_login_callback(request) routes.append(Route("/login/callback", endpoint=login_callback_handler, methods=["POST"])) - # Add token introspection endpoint (RFC 7662) for Resource Servers async def introspect_handler(request: Request) -> Response: - """Token introspection endpoint for Resource Servers. - - Resource Servers call this endpoint to validate tokens without - needing direct access to token storage. - """ + """RFC 7662 introspection: lets Resource Servers validate tokens without access to token storage.""" form = await request.form() token = form.get("token") if not token or not isinstance(token, str): return JSONResponse({"active": False}, status_code=400) - # Look up token in provider access_token = await oauth_provider.load_access_token(token) if not access_token: return JSONResponse({"active": False}) @@ -137,7 +116,6 @@ async def introspect_handler(request: Request) -> Response: async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings): - """Run the Authorization Server.""" auth_server = create_authorization_server(server_settings, auth_settings) config = Config( @@ -156,18 +134,11 @@ async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleA @click.command() @click.option("--port", default=9000, help="Port to listen on") def main(port: int) -> int: - """Run the MCP Authorization Server. - - This server handles OAuth flows and can be used by multiple Resource Servers. - - Uses simple hardcoded credentials for demo purposes. - """ + """Run the MCP Authorization Server (demo credentials; usable by multiple Resource Servers).""" logging.basicConfig(level=logging.INFO) - # Load simple auth settings auth_settings = SimpleAuthSettings() - # Create server settings host = "localhost" server_url = f"http://{host}:{port}" server_settings = AuthServerSettings( diff --git a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py index ab7773b5bb..1f1f1b09be 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py @@ -1,11 +1,7 @@ -"""Legacy Combined Authorization Server + Resource Server for MCP. - -This server implements the old spec where MCP servers could act as both AS and RS. -Used for backwards compatibility testing with the new split AS/RS architecture. - -NOTE: this is a simplified example for demonstration purposes. -This is not a production-ready implementation. +"""Legacy combined Authorization Server + Resource Server for MCP. +Implements the pre-split spec where one server acts as both AS and RS, +for backwards compatibility testing. Simplified demo, not production-ready. """ import datetime @@ -29,7 +25,6 @@ class ServerSettings(BaseModel): """Settings for the simple auth MCP server.""" - # Server settings host: str = "localhost" port: int = 8000 server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000") @@ -57,7 +52,7 @@ def create_simple_mcp_server(server_settings: ServerSettings, auth_settings: Sim default_scopes=[auth_settings.mcp_scope], ), required_scopes=[auth_settings.mcp_scope], - # No resource_server_url parameter in legacy mode + # Legacy combined AS/RS mode: no separate resource server URL resource_server_url=None, ) @@ -86,11 +81,7 @@ async def login_callback_handler(request: Request) -> Response: @app.tool() async def get_time() -> dict[str, Any]: - """Get the current server time. - - This tool demonstrates that system information can be protected - by OAuth authentication. User must be authenticated to access it. - """ + """Get the current server time (requires OAuth authentication).""" now = datetime.datetime.now() @@ -117,7 +108,6 @@ def main(port: int, transport: Literal["sse", "streamable-http"]) -> int: logging.basicConfig(level=logging.INFO) auth_settings = SimpleAuthSettings() - # Create server settings host = "localhost" server_url = f"http://{host}:{port}" server_settings = ServerSettings( diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 0320871b12..51a964cebe 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -1,10 +1,7 @@ -"""MCP Resource Server with Token Introspection. +"""MCP Resource Server that validates tokens via Authorization Server introspection. -This server validates tokens via Authorization Server introspection and serves MCP resources. Demonstrates RFC 9728 Protected Resource Metadata for AS/RS separation. - -NOTE: this is a simplified example for demonstration purposes. -This is not a production-ready implementation. +Simplified for demonstration; not production-ready. """ import datetime @@ -28,17 +25,14 @@ class ResourceServerSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="MCP_RESOURCE_") - # Server settings host: str = "localhost" port: int = 8001 server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8001/mcp") - # Authorization Server settings + # No user endpoint needed - user data comes from token introspection auth_server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000") auth_server_introspection_endpoint: str = "http://localhost:9000/introspect" - # No user endpoint needed - we get user data from token introspection - # MCP settings mcp_scope: str = "user" # RFC 8707 resource validation @@ -46,26 +40,17 @@ class ResourceServerSettings(BaseSettings): def create_resource_server(settings: ResourceServerSettings) -> MCPServer: - """Create MCP Resource Server with token introspection. - - This server: - 1. Provides protected resource metadata (RFC 9728) - 2. Validates tokens via Authorization Server introspection - 3. Serves MCP tools and resources - """ - # Create token verifier for introspection with RFC 8707 resource validation + """Create a Resource Server that serves RFC 9728 metadata and validates tokens via AS introspection.""" token_verifier = IntrospectionTokenVerifier( introspection_endpoint=settings.auth_server_introspection_endpoint, server_url=str(settings.server_url), - validate_resource=settings.oauth_strict, # Only validate when --oauth-strict is set + validate_resource=settings.oauth_strict, # RFC 8707 validation, only when --oauth-strict is set ) - # Create MCPServer server as a Resource Server app = MCPServer( name="MCP Resource Server", instructions="Resource Server that validates tokens via Authorization Server introspection", debug=True, - # Auth configuration for RS mode token_verifier=token_verifier, auth=AuthSettings( issuer_url=settings.auth_server_url, @@ -78,12 +63,7 @@ def create_resource_server(settings: ResourceServerSettings) -> MCPServer: @app.tool() async def get_time() -> dict[str, Any]: - """Get the current server time. - - This tool demonstrates that system information can be protected - by OAuth authentication. User must be authenticated to access it. - """ - + """Get the current server time (requires OAuth authentication).""" now = datetime.datetime.now() return { @@ -113,20 +93,13 @@ async def get_time() -> dict[str, Any]: def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http"], oauth_strict: bool) -> int: """Run the MCP Resource Server. - This server: - - Provides RFC 9728 Protected Resource Metadata - - Validates tokens via Authorization Server introspection - - Serves MCP tools requiring authentication - Must be used with a running Authorization Server. """ logging.basicConfig(level=logging.INFO) try: - # Parse auth server URL auth_server_url = AnyHttpUrl(auth_server) - # Create settings host = "localhost" server_url = f"http://{host}:{port}/mcp" settings = ResourceServerSettings( @@ -148,7 +121,6 @@ def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http logger.info(f"πŸš€ MCP Resource Server running on {settings.server_url}") logger.info(f"πŸ”‘ Using Authorization Server: {settings.auth_server_url}") - # Run the server - this should block and keep running mcp_server.run(transport=transport, host=host, port=port) logger.info("Server stopped") return 0 diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 48eb9a8414..78736c635c 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -1,11 +1,7 @@ """Simple OAuth provider for MCP servers. -This module contains a basic OAuth implementation using hardcoded user credentials -for demonstration purposes. No external authentication provider is required. - -NOTE: this is a simplified example for demonstration purposes. -This is not a production-ready implementation. - +Demo-only implementation using hardcoded user credentials and no external +authentication provider. Not production-ready. """ import secrets @@ -34,22 +30,14 @@ class SimpleAuthSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="MCP_") - # Demo user credentials demo_username: str = "demo_user" demo_password: str = "demo_password" - # MCP OAuth scope mcp_scope: str = "user" class SimpleOAuthProvider(OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]): - """Simple OAuth provider for demo purposes. - - This provider handles the OAuth flow by: - 1. Providing a simple login form for demo credentials - 2. Issuing MCP tokens after successful authentication - 3. Maintaining token state for introspection - """ + """Demo OAuth provider: serves a login form, issues MCP tokens, and keeps token state in memory.""" def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str): self.settings = settings @@ -59,24 +47,21 @@ def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_ self.auth_codes: dict[str, AuthorizationCode] = {} self.tokens: dict[str, AccessToken] = {} self.state_mapping: dict[str, dict[str, str | None]] = {} - # Store authenticated user information self.user_data: dict[str, dict[str, Any]] = {} async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: - """Get OAuth client information.""" return self.clients.get(client_id) async def register_client(self, client_info: OAuthClientInformationFull): - """Register a new OAuth client.""" if not client_info.client_id: raise ValueError("No client_id provided") self.clients[client_info.client_id] = client_info async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """Generate an authorization URL for simple login flow.""" + """Generate an authorization URL pointing at the demo login page.""" state = params.state or secrets.token_hex(16) - # Store state mapping for callback + # Stash the OAuth params so the login callback can complete the flow self.state_mapping[state] = { "redirect_uri": str(params.redirect_uri), "code_challenge": params.code_challenge, @@ -85,17 +70,14 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat "resource": params.resource, # RFC 8707 } - # Build simple login URL that points to login page auth_url = f"{self.auth_callback_url}?state={state}&client_id={client.client_id}" return auth_url async def get_login_page(self, state: str) -> HTMLResponse: - """Generate login page HTML for the given state.""" if not state: raise HTTPException(400, "Missing state parameter") - # Create simple login form HTML html_content = f""" @@ -133,7 +115,6 @@ async def get_login_page(self, state: str) -> HTMLResponse: return HTMLResponse(content=html_content) async def handle_login_callback(self, request: Request) -> Response: - """Handle login form submission callback.""" form = await request.form() username = form.get("username") password = form.get("password") @@ -150,7 +131,7 @@ async def handle_login_callback(self, request: Request) -> Response: return RedirectResponse(url=redirect_uri, status_code=302) async def handle_simple_callback(self, username: str, password: str, state: str) -> str: - """Handle simple authentication callback and return redirect URI.""" + """Validate demo credentials and redirect back to the client with an authorization code.""" state_data = self.state_mapping.get(state) if not state_data: raise HTTPException(400, "Invalid state parameter") @@ -166,11 +147,9 @@ async def handle_simple_callback(self, username: str, password: str, state: str) assert code_challenge is not None assert client_id is not None - # Validate demo credentials if username != self.settings.demo_username or password != self.settings.demo_password: raise HTTPException(401, "Invalid credentials") - # Create MCP authorization code new_code = f"mcp_{secrets.token_hex(16)}" auth_code = AuthorizationCode( code=new_code, @@ -185,7 +164,6 @@ async def handle_simple_callback(self, username: str, password: str, state: str) ) self.auth_codes[new_code] = auth_code - # Store user data self.user_data[username] = { "username": username, "user_id": f"user_{secrets.token_hex(8)}", @@ -198,22 +176,18 @@ async def handle_simple_callback(self, username: str, password: str, state: str) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str ) -> AuthorizationCode | None: - """Load an authorization code.""" return self.auth_codes.get(authorization_code) async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode ) -> OAuthToken: - """Exchange authorization code for tokens.""" if authorization_code.code not in self.auth_codes: raise ValueError("Invalid authorization code") if not client.client_id: raise ValueError("No client_id provided") - # Generate MCP access token mcp_token = f"mcp_{secrets.token_hex(32)}" - # Store MCP token self.tokens[mcp_token] = AccessToken( token=mcp_token, client_id=client.client_id, @@ -223,7 +197,6 @@ async def exchange_authorization_code( subject=authorization_code.subject, ) - # Store user data mapping for this token self.user_data[mcp_token] = { "username": self.settings.demo_username, "user_id": f"user_{secrets.token_hex(8)}", @@ -245,7 +218,6 @@ async def load_access_token(self, token: str) -> AccessToken | None: if not access_token: return None - # Check if expired if access_token.expires_at and access_token.expires_at < time.time(): del self.tokens[token] return None @@ -262,11 +234,9 @@ async def exchange_refresh_token( refresh_token: RefreshToken, scopes: list[str], ) -> OAuthToken: - """Exchange refresh token - not supported in this example.""" raise NotImplementedError("Refresh tokens not supported") # TODO(Marcelo): The type hint is wrong. We need to fix, and test to check if it works. async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: # type: ignore - """Revoke a token.""" if token in self.tokens: del self.tokens[token] diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index 641095a125..51f40fe771 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -10,14 +10,9 @@ class IntrospectionTokenVerifier(TokenVerifier): - """Example token verifier that uses OAuth 2.0 Token Introspection (RFC 7662). - - This is a simple example implementation for demonstration purposes. - Production implementations should consider: - - Connection pooling and reuse - - More sophisticated error handling - - Rate limiting and retry logic - - Comprehensive configuration options + """Example token verifier using OAuth 2.0 Token Introspection (RFC 7662). + + Demonstration only; production code needs connection pooling, retries, and richer error handling. """ def __init__( @@ -40,7 +35,6 @@ async def verify_token(self, token: str) -> AccessToken | None: logger.warning(f"Rejecting introspection endpoint with unsafe scheme: {self.introspection_endpoint}") return None - # Configure secure HTTP client timeout = httpx.Timeout(10.0, connect=5.0) limits = httpx.Limits(max_connections=10, max_keepalive_connections=5) @@ -74,7 +68,7 @@ async def verify_token(self, token: str) -> AccessToken | None: client_id=data.get("client_id", "unknown"), scopes=data.get("scope", "").split() if data.get("scope") else [], expires_at=data.get("exp"), - resource=data.get("aud"), # Include resource in token + resource=data.get("aud"), subject=data.get("sub"), # RFC 7662 subject (resource owner) claims=data, ) @@ -87,7 +81,6 @@ def _validate_resource(self, token_data: dict[str, Any]) -> bool: if not self.server_url or not self.resource_url: return False # Fail if strict validation requested but URLs missing - # Check 'aud' claim first (standard JWT audience) aud: list[str] | str | None = token_data.get("aud") if isinstance(aud, list): for audience in aud: diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index 9aca87f730..be0f44c1d9 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -1,8 +1,4 @@ -"""Simple MCP server demonstrating pagination for tools, resources, and prompts. - -This example shows how to implement pagination with the low-level server API -to handle large lists of items that need to be split across multiple pages. -""" +"""Low-level server example demonstrating cursor pagination for tools, resources, and prompts.""" from typing import TypeVar @@ -13,7 +9,6 @@ T = TypeVar("T") -# Sample data - in real scenarios, this might come from a database SAMPLE_TOOLS = [ types.Tool( name=f"tool_{i}", @@ -46,7 +41,7 @@ def _paginate(cursor: str | None, items: list[T], page_size: int) -> tuple[list[T], str | None]: - """Helper to paginate a list of items given a cursor.""" + """Slice one page from items; the cursor is a stringified start index, invalid cursors yield an empty page.""" if cursor is not None: try: start_idx = int(cursor) @@ -60,7 +55,6 @@ def _paginate(cursor: str | None, items: list[T], page_size: int) -> tuple[list[ return page, next_cursor -# Paginated list_tools - returns 5 tools per page async def handle_list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: @@ -69,7 +63,6 @@ async def handle_list_tools( return types.ListToolsResult(tools=page, next_cursor=next_cursor) -# Paginated list_resources - returns 10 resources per page async def handle_list_resources( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListResourcesResult: @@ -78,7 +71,6 @@ async def handle_list_resources( return types.ListResourcesResult(resources=page, next_cursor=next_cursor) -# Paginated list_prompts - returns 7 prompts per page async def handle_list_prompts( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListPromptsResult: @@ -88,7 +80,6 @@ async def handle_list_prompts( async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - # Find the tool in our sample data tool = next((t for t in SAMPLE_TOOLS if t.name == params.name), None) if not tool: raise ValueError(f"Unknown tool: {params.name}") diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 31e3eb7d76..2b457d202d 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -5,10 +5,8 @@ def create_messages(context: str | None = None, topic: str | None = None) -> list[types.PromptMessage]: - """Create the messages for the prompt.""" messages: list[types.PromptMessage] = [] - # Add context if provided if context: messages.append( types.PromptMessage( @@ -17,7 +15,6 @@ def create_messages(context: str | None = None, topic: str | None = None) -> lis ) ) - # Add the main prompt prompt = "Please help me with " if topic: prompt += f"the following topic: {topic}" diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py index 1664737e3a..f99ade1211 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py @@ -1,7 +1,6 @@ from .server import main if __name__ == "__main__": - # Click will handle CLI arguments import sys sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index 9df18cc6a2..3c2236b4b3 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -47,7 +47,6 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ count = arguments.get("count", 5) caller = arguments.get("caller", "unknown") - # Send the specified number of notifications with the given interval for i in range(count): await ctx.session.send_log_message( # pyright: ignore[reportDeprecated] level="info", @@ -86,7 +85,6 @@ def main( log_level: str, json_response: bool, ) -> None: - # Configure logging logging.basicConfig( level=getattr(logging, log_level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -104,8 +102,7 @@ def main( debug=True, ) - # Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header - # for browser-based clients (ensures 500 errors get proper CORS headers) + # CORS so browser clients can read Mcp-Session-Id; wrapping the ASGI app keeps headers on error responses starlette_app = CORSMiddleware( starlette_app, allow_origins=["*"], # Note: streamable_http_app() enforces localhost-only Origin by default diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py index c9369cfc2c..a7b623d225 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -1,8 +1,4 @@ -"""In-memory event store for demonstrating resumability functionality. - -This is a simple implementation intended for examples and testing, -not for production use where a persistent storage solution would be more appropriate. -""" +"""In-memory event store for demonstrating resumability; production servers should use persistent storage.""" import logging from collections import deque @@ -17,31 +13,17 @@ @dataclass class EventEntry: - """Represents an event entry in the event store.""" - event_id: EventId stream_id: StreamId message: JSONRPCMessage | None class InMemoryEventStore(EventStore): - """Simple in-memory implementation of the EventStore interface for resumability. - This is primarily intended for examples and testing, not for production use - where a persistent storage solution would be more appropriate. - - This implementation keeps only the last N events per stream for memory efficiency. - """ + """In-memory EventStore that keeps only the last N events per stream.""" def __init__(self, max_events_per_stream: int = 100): - """Initialize the event store. - - Args: - max_events_per_stream: Maximum number of events to keep per stream - """ self.max_events_per_stream = max_events_per_stream - # for maintaining last N events per stream self.streams: dict[StreamId, deque[EventEntry]] = {} - # event_id -> EventEntry for quick lookup self.event_index: dict[EventId, EventEntry] = {} async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: @@ -49,17 +31,14 @@ async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) event_id = str(uuid4()) event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) - # Get or create deque for this stream if stream_id not in self.streams: self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) - # If deque is full, the oldest event will be automatically removed - # We need to remove it from the event_index as well + # A full deque silently evicts its oldest event on append; mirror that removal in event_index if len(self.streams[stream_id]) == self.max_events_per_stream: oldest_event = self.streams[stream_id][0] self.event_index.pop(oldest_event.event_id, None) - # Add new event self.streams[stream_id].append(event_entry) self.event_index[event_id] = event_entry @@ -75,12 +54,11 @@ async def replay_events_after( logger.warning(f"Event ID {last_event_id} not found in store") return None - # Get the stream and find events after the last one last_event = self.event_index[last_event_id] stream_id = last_event.stream_id stream_events = self.streams.get(last_event.stream_id, deque()) - # Events in deque are already in chronological order + # The deque is in chronological order, so replay everything after the last-seen event found_last = False for event in stream_events: if found_last: diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index e650b35732..f5469fa910 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -9,7 +9,6 @@ from .event_store import InMemoryEventStore -# Configure logging logger = logging.getLogger(__name__) @@ -50,27 +49,21 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ count = arguments.get("count", 5) caller = arguments.get("caller", "unknown") - # Send the specified number of notifications with the given interval for i in range(count): - # Include more detailed message for resumability demonstration notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" await ctx.session.send_log_message( # pyright: ignore[reportDeprecated] level="info", data=notification_msg, logger="notification_stream", - # Associates this notification with the original request - # Ensures notifications are sent to the correct response stream - # Without this, notifications will either go to: - # - a standalone SSE stream (if GET request is supported) - # - nowhere (if GET request isn't supported) + # Routes the notification to this request's response stream; without it, + # notifications go to the standalone SSE stream (or nowhere if GET is unsupported). related_request_id=ctx.request_id, ) logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") - if i < count - 1: # Don't wait after the last notification + if i < count - 1: await anyio.sleep(interval) - # This will send a resource notification through standalone SSE - # established by GET request + # No related_request_id, so this goes out over the standalone SSE stream (GET request) await ctx.session.send_resource_updated(uri="http:///test_resource") return types.CallToolResult( content=[ @@ -100,7 +93,6 @@ def main( log_level: str, json_response: bool, ) -> int: - # Configure logging logging.basicConfig( level=getattr(logging, log_level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -112,14 +104,8 @@ def main( on_call_tool=handle_call_tool, ) - # Create event store for resumability - # The InMemoryEventStore enables resumability support for StreamableHTTP transport. - # It stores SSE events with unique IDs, allowing clients to: - # 1. Receive event IDs for each SSE message - # 2. Resume streams by sending Last-Event-ID in GET requests - # 3. Replay missed events after reconnection - # Note: This in-memory implementation is for demonstration ONLY. - # For production, use a persistent storage solution. + # Event store enables resumability: clients replay missed SSE events by sending + # Last-Event-ID on reconnect. In-memory is for demos; use persistent storage in production. event_store = InMemoryEventStore() starlette_app = app.streamable_http_app( @@ -128,12 +114,11 @@ def main( debug=True, ) - # Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header - # for browser-based clients (ensures 500 errors get proper CORS headers) + # CORS so browser clients can read Mcp-Session-Id; wrapping the ASGI app keeps headers on error responses starlette_app = CORSMiddleware( starlette_app, - allow_origins=["*"], # Note: streamable_http_app() enforces localhost-only Origin by default - allow_methods=["GET", "POST", "DELETE"], # MCP streamable HTTP methods + allow_origins=["*"], # streamable_http_app() enforces localhost-only Origin by default + allow_methods=["GET", "POST", "DELETE"], expose_headers=["Mcp-Session-Id"], ) diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py index 23cfc85e11..f5f6e402df 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py @@ -1,5 +1,3 @@ -"""Entry point for the SSE Polling Demo server.""" - from .server import main if __name__ == "__main__": diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py index e2cca4a2eb..ab172403ef 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py @@ -1,8 +1,4 @@ -"""In-memory event store for demonstrating resumability functionality. - -This is a simple implementation intended for examples and testing, -not for production use where a persistent storage solution would be more appropriate. -""" +"""In-memory event store demonstrating resumability; examples/testing only, not production.""" import logging from collections import deque @@ -17,54 +13,32 @@ @dataclass class EventEntry: - """Represents an event entry in the event store.""" - event_id: EventId stream_id: StreamId message: JSONRPCMessage | None # None for priming events class InMemoryEventStore(EventStore): - """Simple in-memory implementation of the EventStore interface for resumability. - This is primarily intended for examples and testing, not for production use - where a persistent storage solution would be more appropriate. - - This implementation keeps only the last N events per stream for memory efficiency. - """ + """In-memory EventStore keeping the last N events per stream; for examples/testing, not production.""" def __init__(self, max_events_per_stream: int = 100): - """Initialize the event store. - - Args: - max_events_per_stream: Maximum number of events to keep per stream - """ self.max_events_per_stream = max_events_per_stream - # for maintaining last N events per stream self.streams: dict[StreamId, deque[EventEntry]] = {} - # event_id -> EventEntry for quick lookup self.event_index: dict[EventId, EventEntry] = {} async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: - """Stores an event with a generated event ID. - - Args: - stream_id: ID of the stream the event belongs to - message: The message to store, or None for priming events - """ + """Store an event with a generated event ID; a None message records a priming event.""" event_id = str(uuid4()) event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) - # Get or create deque for this stream if stream_id not in self.streams: self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) - # If deque is full, the oldest event will be automatically removed - # We need to remove it from the event_index as well + # A full deque silently evicts its oldest event on append; mirror that in event_index if len(self.streams[stream_id]) == self.max_events_per_stream: oldest_event = self.streams[stream_id][0] self.event_index.pop(oldest_event.event_id, None) - # Add new event self.streams[stream_id].append(event_entry) self.event_index[event_id] = event_entry @@ -80,16 +54,14 @@ async def replay_events_after( logger.warning(f"Event ID {last_event_id} not found in store") return None - # Get the stream and find events after the last one last_event = self.event_index[last_event_id] stream_id = last_event.stream_id stream_events = self.streams.get(last_event.stream_id, deque()) - # Events in deque are already in chronological order found_last = False for event in stream_events: if found_last: - # Skip priming events (None messages) during replay + # Priming events (None message) are not replayed if event.message is not None: await send_callback(EventMessage(event.message, event.event_id)) elif event.event_id == last_event_id: diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py index 7d2c60fa32..850780c59b 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -1,15 +1,7 @@ -"""SSE Polling Demo Server +"""SSE polling demo: a long-running tool closes its SSE stream at checkpoints via `close_sse_stream`. -Demonstrates the SSE polling pattern with close_sse_stream() for long-running tasks. - -Features demonstrated: -- Priming events (automatic with EventStore) -- Server-initiated stream close via close_sse_stream callback -- Client auto-reconnect with Last-Event-ID -- Progress notifications during long-running tasks - -Run with: - uv run mcp-sse-polling-demo --port 3000 +The client auto-reconnects with Last-Event-ID and the EventStore replays missed events. +Run with: uv run mcp-sse-polling-demo --port 3000 """ import logging @@ -28,7 +20,6 @@ async def handle_list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - """List available tools.""" return types.ListToolsResult( tools=[ types.Tool( @@ -58,7 +49,6 @@ async def handle_list_tools( async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - """Handle tool calls.""" arguments = params.arguments or {} if params.name == "process_batch": @@ -85,7 +75,6 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ # Simulate work await anyio.sleep(0.5) - # Report progress await ctx.session.send_log_message( # pyright: ignore[reportDeprecated] level="info", data=f"[{i}/{items}] Processing item {i}", diff --git a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py index 393ff7a5a0..5961468d72 100644 --- a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py +++ b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 -"""Example low-level MCP server demonstrating structured output support. - -This example shows how to use the low-level server API to return -structured data from tools. -""" +"""Low-level MCP server example returning structured output from tools.""" import asyncio import json diff --git a/examples/snippets/clients/completion_client.py b/examples/snippets/clients/completion_client.py index 52957d97d8..a53685559c 100644 --- a/examples/snippets/clients/completion_client.py +++ b/examples/snippets/clients/completion_client.py @@ -10,10 +10,9 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client -# Create server parameters for stdio connection server_params = StdioServerParameters( - command="uv", # Using uv to run the server - args=["run", "server", "completion", "stdio"], # Server with completion support + command="uv", + args=["run", "server", "completion", "stdio"], env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) @@ -22,22 +21,18 @@ async def run(): """Run the completion client example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: - # Initialize the connection await session.initialize() - # List available resource templates templates = await session.list_resource_templates() print("Available resource templates:") for template in templates.resource_templates: print(f" - {template.uri_template}") - # List available prompts prompts = await session.list_prompts() print("\nAvailable prompts:") for prompt in prompts.prompts: print(f" - {prompt.name}") - # Complete resource template arguments if templates.resource_templates: template = templates.resource_templates[0] print(f"\nCompleting arguments for resource template: {template.uri_template}") @@ -57,7 +52,6 @@ async def run(): ) print(f"Completions for 'repo' with owner='modelcontextprotocol': {result.completion.values}") - # Complete prompt arguments if prompts.prompts: prompt_name = prompts.prompts[0].name print(f"\nCompleting arguments for prompt: {prompt_name}") diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index baa2765a8f..a4a52d8fe6 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -9,9 +9,8 @@ from mcp.client.stdio import stdio_client from mcp.shared.metadata_utils import get_display_name -# Create server parameters for stdio connection server_params = StdioServerParameters( - command="uv", # Using uv to run the server + command="uv", args=["run", "server", "mcpserver_quickstart", "stdio"], env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) @@ -44,10 +43,8 @@ async def display_resources(session: ClientSession): async def run(): - """Run the display utilities example.""" async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: - # Initialize the connection await session.initialize() print("=== Available Tools ===") @@ -58,7 +55,6 @@ async def run(): def main(): - """Entry point for the display utilities client.""" asyncio.run(run()) diff --git a/examples/snippets/clients/identity_assertion_client.py b/examples/snippets/clients/identity_assertion_client.py index 218df4bcfc..efabb44ed2 100644 --- a/examples/snippets/clients/identity_assertion_client.py +++ b/examples/snippets/clients/identity_assertion_client.py @@ -1,17 +1,11 @@ """Client side of SEP-990 (enterprise IdP policy controls). -`IdentityAssertionOAuthProvider` presents an Identity Assertion Authorization Grant (ID-JAG) issued -by the enterprise IdP to the MCP authorization server using the RFC 7523 jwt-bearer grant -(`grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer`, ID-JAG as `assertion`), and receives an -MCP access token. No browser redirect or dynamic client registration is involved. - -Obtaining the ID-JAG (logging into the IdP and the leg-1 exchange against it) is deployment-specific -and out of scope for the SDK; supply it through the `assertion_provider` callback. The callback -receives the authorization server's issuer (the ID-JAG `aud`) and the MCP server's resource -identifier (the ID-JAG `resource` claim). SEP-990 requires a confidential client, so a client secret -is mandatory, and `issuer` is the authorization server the credentials are provisioned for - the -provider fetches metadata from that issuer's well-known and never asks the resource server which AS -to use. +`IdentityAssertionOAuthProvider` presents an enterprise-IdP-issued Identity Assertion Authorization +Grant (ID-JAG) to the MCP authorization server via the RFC 7523 jwt-bearer grant to obtain an MCP +access token - no browser redirect or dynamic client registration. Obtaining the ID-JAG is +deployment-specific and out of SDK scope; supply it via the `assertion_provider` callback. SEP-990 +requires a confidential client (client secret mandatory), and the provider fetches AS metadata from +`issuer`, never asking the resource server which AS to use. """ import asyncio @@ -45,12 +39,10 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None async def fetch_id_jag(audience: str, resource: str) -> str: - """Return the ID-JAG to present. + """Return the ID-JAG to present (in production: exchange the user's IdP ID token at the enterprise IdP). - `audience` is the MCP authorization server's issuer (the ID-JAG `aud` claim); `resource` is the - MCP server's RFC 9728 identifier (the ID-JAG `resource` claim, which the AS audience-restricts - the issued token against). In production this exchanges the user's IdP ID token for an ID-JAG - against the enterprise identity provider. + `audience` is the authorization server's issuer (the ID-JAG `aud` claim); `resource` is the MCP + server's RFC 9728 identifier (the ID-JAG `resource` claim the issued token is audience-restricted to). """ raise NotImplementedError("Obtain the ID-JAG from your enterprise identity provider") diff --git a/examples/snippets/clients/oauth_client.py b/examples/snippets/clients/oauth_client.py index 2085b9a1db..7957245b3a 100644 --- a/examples/snippets/clients/oauth_client.py +++ b/examples/snippets/clients/oauth_client.py @@ -1,9 +1,6 @@ -"""Before running, specify running MCP RS server URL. -To spin up RS server locally, see - examples/servers/simple-auth/README.md +"""Run from the `examples/snippets` directory: uv run oauth-client -cd to the `examples/snippets` directory and run: - uv run oauth-client +Requires a running MCP resource server; see examples/servers/simple-auth/README.md. """ import asyncio @@ -19,26 +16,22 @@ class InMemoryTokenStorage(TokenStorage): - """Demo In-memory token storage implementation.""" + """Demo-only storage; production clients should persist tokens securely.""" def __init__(self): self.tokens: OAuthToken | None = None self.client_info: OAuthClientInformationFull | None = None async def get_tokens(self) -> OAuthToken | None: - """Get stored tokens.""" return self.tokens async def set_tokens(self, tokens: OAuthToken) -> None: - """Store tokens.""" self.tokens = tokens async def get_client_info(self) -> OAuthClientInformationFull | None: - """Get stored client information.""" return self.client_info async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: - """Store client information.""" self.client_info = client_info @@ -57,7 +50,6 @@ async def handle_callback() -> AuthorizationCodeResult: async def main(): - """Run the OAuth client example.""" oauth_auth = OAuthClientProvider( server_url="http://localhost:8001", client_metadata=OAuthClientMetadata( diff --git a/examples/snippets/clients/pagination_client.py b/examples/snippets/clients/pagination_client.py index 00663ef038..d15d236668 100644 --- a/examples/snippets/clients/pagination_client.py +++ b/examples/snippets/clients/pagination_client.py @@ -21,13 +21,12 @@ async def list_all_resources() -> None: cursor = None while True: - # Fetch a page of resources result = await session.list_resources(params=PaginatedRequestParams(cursor=cursor)) all_resources.extend(result.resources) print(f"Fetched {len(result.resources)} resources") - # Check if there are more pages + # A next_cursor means there are more pages to fetch if result.next_cursor: cursor = result.next_cursor else: diff --git a/examples/snippets/clients/parsing_tool_results.py b/examples/snippets/clients/parsing_tool_results.py index f9aade41e3..903d8ffe6a 100644 --- a/examples/snippets/clients/parsing_tool_results.py +++ b/examples/snippets/clients/parsing_tool_results.py @@ -25,7 +25,6 @@ async def parse_tool_results(): # Example 2: Parsing structured content from JSON tools result = await session.call_tool("get_user", {"id": "123"}) if hasattr(result, "structured_content") and result.structured_content: - # Access structured data directly user_data = result.structured_content print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}") diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index 6fff083853..247f5c2fb9 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -1,6 +1,4 @@ -"""cd to the `examples/snippets/clients` directory and run: -uv run client -""" +"""cd to the `examples/snippets/clients` directory and run: `uv run client`.""" import asyncio import os @@ -11,15 +9,14 @@ from mcp.client.context import ClientRequestContext from mcp.client.stdio import stdio_client -# Create server parameters for stdio connection server_params = StdioServerParameters( - command="uv", # Using uv to run the server + command="uv", args=["run", "server", "mcpserver_quickstart", "stdio"], # We're already in snippets dir env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) -# Optional: create a sampling callback +# Optional: sampling callback async def handle_sampling_message( context: ClientRequestContext, params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: @@ -38,33 +35,26 @@ async def handle_sampling_message( async def run(): async with stdio_client(server_params) as (read, write): async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: - # Initialize the connection await session.initialize() - # List available prompts prompts = await session.list_prompts() print(f"Available prompts: {[p.name for p in prompts.prompts]}") - # Get a prompt (greet_user prompt from mcpserver_quickstart) if prompts.prompts: prompt = await session.get_prompt("greet_user", arguments={"name": "Alice", "style": "friendly"}) print(f"Prompt result: {prompt.messages[0].content}") - # List available resources resources = await session.list_resources() print(f"Available resources: {[r.uri for r in resources.resources]}") - # List available tools tools = await session.list_tools() print(f"Available tools: {[t.name for t in tools.tools]}") - # Read a resource (greeting resource from mcpserver_quickstart) resource_content = await session.read_resource("greeting://World") content_block = resource_content.contents[0] if isinstance(content_block, types.TextResourceContents): print(f"Resource content: {content_block.text}") - # Call a tool (add tool from mcpserver_quickstart) result = await session.call_tool("add", arguments={"a": 5, "b": 3}) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): @@ -74,7 +64,6 @@ async def run(): def main(): - """Entry point for the client script.""" asyncio.run(run()) diff --git a/examples/snippets/clients/streamable_basic.py b/examples/snippets/clients/streamable_basic.py index 43bb6396c6..e2f4074ee8 100644 --- a/examples/snippets/clients/streamable_basic.py +++ b/examples/snippets/clients/streamable_basic.py @@ -9,13 +9,9 @@ async def main(): - # Connect to a streamable HTTP server async with streamable_http_client("http://localhost:8000/mcp") as (read_stream, write_stream): - # Create a session using the client streams async with ClientSession(read_stream, write_stream) as session: - # Initialize the connection await session.initialize() - # List available tools tools = await session.list_tools() print(f"Available tools: {[tool.name for tool in tools.tools]}") diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index de962eb718..2daba8246b 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -1,23 +1,7 @@ -"""URL Elicitation Client Example. - -Demonstrates how clients handle URL elicitation requests from servers. -This is the Python equivalent of TypeScript SDK's elicitationUrlExample.ts, -focused on URL elicitation patterns without OAuth complexity. - -Features demonstrated: -1. Client elicitation capability declaration -2. Handling elicitation requests from servers via callback -3. Catching UrlElicitationRequiredError from tool calls -4. Browser interaction with security warnings -5. Interactive CLI for testing - -Run with: - cd examples/snippets - uv run elicitation-client - -Requires a server with URL elicitation tools running. Start the elicitation -server first: - uv run server elicitation sse +"""Interactive client demonstrating how to handle URL elicitation requests from servers. + +Start the elicitation server first (`cd examples/snippets && uv run server elicitation sse`), +then run `uv run elicitation-client` from the same directory. """ from __future__ import annotations @@ -41,15 +25,10 @@ async def handle_elicitation( context: ClientRequestContext, params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: - """Handle elicitation requests from the server. - - This callback is invoked when the server sends an elicitation/request. - For URL mode, we prompt the user and optionally open their browser. - """ + """Elicitation callback invoked for each server elicitation request; only URL mode is supported here.""" if params.mode == "url": return await handle_url_elicitation(params) else: - # We only support URL mode in this example return types.ErrorData( code=types.INVALID_REQUEST, message=f"Unsupported elicitation mode: {params.mode}", @@ -62,15 +41,8 @@ async def handle_elicitation( async def handle_url_elicitation( params: types.ElicitRequestParams, ) -> types.ElicitResult: - """Handle URL mode elicitation - show security warning and optionally open browser. - - This function demonstrates the security-conscious approach to URL elicitation: - 1. Validate the URL scheme before prompting the user - 2. Display the full URL and domain for user inspection - 3. Show the server's reason for requesting this interaction - 4. Require explicit user consent before opening any URL - """ - # Extract URL parameters - these are available on URL mode requests + """Show a security warning and open the URL in a browser only with explicit user consent.""" + # url and elicitationId are only present on URL mode requests url = getattr(params, "url", None) elicitation_id = getattr(params, "elicitationId", None) message = params.message @@ -85,10 +57,9 @@ async def handle_url_elicitation( print(f"\nRejecting URL with disallowed scheme '{parsed.scheme}': {url}") return types.ElicitResult(action="decline") - # Extract domain for security display domain = extract_domain(url) - # Security warning - always show the user what they're being asked to do + # Always show the user what they're being asked to open print("\n" + "=" * 60) print("SECURITY WARNING: External URL Request") print("=" * 60) @@ -100,7 +71,6 @@ async def handle_url_elicitation( print(f"\n Elicitation ID: {elicitation_id}") print("\n" + "-" * 60) - # Get explicit user consent try: response = input("\nOpen this URL in your browser? (y/n): ").strip().lower() except EOFError: @@ -113,7 +83,6 @@ async def handle_url_elicitation( print("Invalid response. Cancelling.") return types.ElicitResult(action="cancel") - # Open the browser print(f"\nOpening browser to: {url}") try: webbrowser.open(url) @@ -142,18 +111,13 @@ async def call_tool_with_error_handling( ) -> types.CallToolResult | None: """Call a tool, handling UrlElicitationRequiredError if raised. - When a server tool needs URL elicitation before it can proceed, - it can either: - 1. Send an elicitation request directly (handled by elicitation_callback) - 2. Return an error with code -32042 (URL_ELICITATION_REQUIRED) - - This function demonstrates handling case 2 - catching the error - and processing the required URL elicitations. + A server tool needing URL elicitation can send an elicitation request directly + (handled by the elicitation callback) or return error -32042 + (URL_ELICITATION_REQUIRED); this demonstrates catching the error form. """ try: result = await session.call_tool(tool_name, arguments) - # Check if the tool returned an error in the result if result.is_error: print(f"Tool returned error: {result.content}") return None @@ -161,26 +125,22 @@ async def call_tool_with_error_handling( return result except MCPError as e: - # Check if this is a URL elicitation required error if e.code == URL_ELICITATION_REQUIRED: print("\n[Tool requires URL elicitation to proceed]") # Convert to typed error to access elicitations url_error = UrlElicitationRequiredError.from_error(e.error) - # Process each required elicitation for elicitation in url_error.elicitations: await handle_url_elicitation(elicitation) return None else: - # Re-raise other MCP errors print(f"MCP Error: {e.error.message} (code: {e.error.code})") return None def print_help() -> None: - """Print available commands.""" print("\nAvailable commands:") print(" list-tools - List available tools") print(" call [json-args] - Call a tool with optional JSON arguments") @@ -191,7 +151,6 @@ def print_help() -> None: def print_tool_result(result: types.CallToolResult | None) -> None: - """Print a tool call result.""" if not result: return print("\nTool result:") @@ -203,7 +162,6 @@ def print_tool_result(result: types.CallToolResult | None) -> None: async def handle_list_tools(session: ClientSession) -> None: - """Handle the list-tools command.""" tools = await session.list_tools() if tools.tools: print("\nAvailable tools:") @@ -214,7 +172,6 @@ async def handle_list_tools(session: ClientSession) -> None: async def handle_call_command(session: ClientSession, command: str) -> None: - """Handle the call command.""" parts = command.split(maxsplit=2) if len(parts) < 2: print("Usage: call [json-args]") @@ -262,7 +219,6 @@ async def process_command(session: ClientSession, command: str) -> bool: async def run_command_loop(session: ClientSession) -> None: - """Run the interactive command loop.""" while True: try: command = input("> ").strip() diff --git a/examples/snippets/servers/__init__.py b/examples/snippets/servers/__init__.py index f132f875f5..9b1cd8cabd 100644 --- a/examples/snippets/servers/__init__.py +++ b/examples/snippets/servers/__init__.py @@ -1,10 +1,6 @@ -"""MCP Snippets. +"""MCP server snippets, each demonstrating a single feature. -This package contains simple examples of MCP server features. -Each server demonstrates a single feature and can be run as a standalone server. - -To run a server, use the command: - uv run server basic_tool sse +Run one standalone: `uv run server basic_tool sse` """ import importlib @@ -13,11 +9,7 @@ def run_server(): - """Run a server by name with optional transport. - - Usage: server [transport] - Example: server basic_tool sse - """ + """Run a snippet server: `server [transport]`.""" if len(sys.argv) < 2: print("Usage: server [transport]") print("Available servers: basic_tool, basic_resource, basic_prompt, tool_progress,") diff --git a/examples/snippets/servers/direct_call_tool_result.py b/examples/snippets/servers/direct_call_tool_result.py index f3035338b3..c82896afbd 100644 --- a/examples/snippets/servers/direct_call_tool_result.py +++ b/examples/snippets/servers/direct_call_tool_result.py @@ -11,8 +11,6 @@ class ValidationModel(BaseModel): - """Model for validating structured output.""" - status: str data: dict[str, int] diff --git a/examples/snippets/servers/direct_execution.py b/examples/snippets/servers/direct_execution.py index acf7151d3b..1ab7611ef1 100644 --- a/examples/snippets/servers/direct_execution.py +++ b/examples/snippets/servers/direct_execution.py @@ -1,10 +1,6 @@ -"""Example showing direct execution of an MCP server. +"""Simplest way to run an MCP server: execute the file directly. -This is the simplest way to run an MCP server directly. -cd to the `examples/snippets` directory and run: - uv run direct-execution-server - or - python servers/direct_execution.py +From `examples/snippets`: `uv run direct-execution-server` or `python servers/direct_execution.py`. """ from mcp.server.mcpserver import MCPServer @@ -19,7 +15,6 @@ def hello(name: str = "World") -> str: def main(): - """Entry point for the direct execution server.""" mcp.run() diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 97e847b510..38cc542fc7 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,8 +1,7 @@ -"""Elicitation examples demonstrating form and URL mode elicitation. +"""Elicitation examples. -Form mode elicitation collects structured, non-sensitive data through a schema. -URL mode elicitation directs users to external URLs for sensitive operations -like OAuth flows, credential collection, or payment processing. +Form mode collects structured, non-sensitive data through a schema; URL mode +directs the user to an external URL for sensitive operations like OAuth or payments. """ import uuid @@ -28,13 +27,9 @@ class BookingPreferences(BaseModel): @mcp.tool() async def book_table(date: str, time: str, party_size: int, ctx: Context) -> str: - """Book a table with date availability check. - - This demonstrates form mode elicitation for collecting non-sensitive user input. - """ - # Check if date is available + """Book a table with date availability check (form mode elicitation).""" if date == "2024-12-25": - # Date unavailable - ask user for alternative + # Date unavailable - use form elicitation to ask for an alternative result = await ctx.elicit( message=(f"No tables available for {party_size} on {date}. Would you like to try another date?"), schema=BookingPreferences, @@ -46,17 +41,12 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context) -> str return "[CANCELLED] No booking made" return "[CANCELLED] Booking cancelled" - # Date available return f"[SUCCESS] Booked for {date} at {time}" @mcp.tool() async def secure_payment(amount: float, ctx: Context) -> str: - """Process a secure payment requiring URL confirmation. - - This demonstrates URL mode elicitation using ctx.elicit_url() for - operations that require out-of-band user interaction. - """ + """Process a secure payment requiring URL confirmation (URL mode elicitation via `ctx.elicit_url`).""" elicitation_id = str(uuid.uuid4()) result = await ctx.elicit_url( @@ -66,8 +56,7 @@ async def secure_payment(amount: float, ctx: Context) -> str: ) if result.action == "accept": - # In a real app, the payment confirmation would happen out-of-band - # and you'd verify the payment status from your backend + # In a real app, confirmation happens out-of-band; verify payment status from your backend return f"Payment of ${amount:.2f} initiated - check your browser to complete" elif result.action == "decline": return "Payment declined by user" @@ -76,16 +65,11 @@ async def secure_payment(amount: float, ctx: Context) -> str: @mcp.tool() async def connect_service(service_name: str, ctx: Context) -> str: - """Connect to a third-party service requiring OAuth authorization. - - This demonstrates the "throw error" pattern using UrlElicitationRequiredError. - Use this pattern when the tool cannot proceed without user authorization. - """ + """Connect to a third-party service requiring OAuth authorization.""" elicitation_id = str(uuid.uuid4()) - # Raise UrlElicitationRequiredError to signal that the client must complete - # a URL elicitation before this request can be processed. - # The MCP framework will convert this to a -32042 error response. + # When the tool cannot proceed without user authorization, raise UrlElicitationRequiredError: + # the framework converts it to a -32042 error telling the client to complete a URL elicitation. raise UrlElicitationRequiredError( [ ElicitRequestURLParams( diff --git a/examples/snippets/servers/identity_assertion_server.py b/examples/snippets/servers/identity_assertion_server.py index 9406111f8b..7dd9e640b8 100644 --- a/examples/snippets/servers/identity_assertion_server.py +++ b/examples/snippets/servers/identity_assertion_server.py @@ -1,24 +1,8 @@ -"""Authorization-server side of SEP-990 (enterprise IdP policy controls). +"""Authorization-server side of SEP-990 (Identity Assertion Authorization Grant). -An authorization server enables the Identity Assertion Authorization Grant by setting -`identity_assertion_enabled=True` and implementing `exchange_identity_assertion` on its provider. -The client presents the IdP-issued ID-JAG using the RFC 7523 jwt-bearer grant; the provider -validates the assertion and mints an MCP access token. - -Validating the ID-JAG is the provider's responsibility and is only stubbed here. A real -implementation MUST, per RFC 7523 Β§3 and SEP-990 Β§5.1: - -- verify the JWT signature, `iss`, and `exp`, and that `typ` is `oauth-id-jag+jwt`; -- require `aud` to identify this authorization server; -- require the ID-JAG's `client_id` claim to match the authenticated client; -- audience-restrict the issued token to the resource named in the ID-JAG's `resource` claim - (NOT the client-supplied `params.resource`); -- derive the granted scopes from the ID-JAG and policy. - -`_decode_and_validate_id_jag` below raises `NotImplementedError` so this snippet fails closed and -forces a real implementation. Wire the returned routes into a Starlette app with -`create_auth_routes(..., identity_assertion_enabled=True)`, or set -`AuthSettings(identity_assertion_enabled=True)` with `MCPServer`/`Server`. +Enable with `identity_assertion_enabled=True` (via `create_auth_routes` or `AuthSettings`) and +implement `exchange_identity_assertion` on the provider: the client presents the IdP-issued ID-JAG +using the RFC 7523 jwt-bearer grant, and the provider validates it and mints an MCP access token. """ import secrets @@ -39,8 +23,8 @@ class IdJagClaims: """The trusted claims extracted from a validated ID-JAG.""" - subject: str # the end user the ID-JAG was issued for - client_id: str # the ID-JAG `client_id` claim; Β§5.1 requires it to match the authenticated client + subject: str + client_id: str # must match the authenticated client (SEP-990 Β§5.1) resource: str # the MCP server the issued token must be audience-restricted to scopes: list[str] @@ -50,9 +34,8 @@ class IdentityAssertionProvider(OAuthAuthorizationServerProvider[AuthorizationCo def __init__(self) -> None: self.access_tokens: dict[str, AccessToken] = {} - # SEP-990 clients are pre-registered out of band (DCR refuses the grant) and must be - # confidential. `get_client` must return them, or the token endpoint 401s before the - # exchange runs. Real deployments load these from their own store. + # SEP-990 clients are pre-registered out of band (DCR refuses this grant) and must be + # confidential; `get_client` must return them or the token endpoint 401s before the exchange. self.clients: dict[str, OAuthClientInformationFull] = { "enterprise-mcp-client": OAuthClientInformationFull( client_id="enterprise-mcp-client", @@ -92,10 +75,10 @@ async def exchange_identity_assertion( def _decode_and_validate_id_jag(self, assertion: str, client: OAuthClientInformationFull) -> IdJagClaims: """Verify the ID-JAG and return its trusted claims, or reject the request. - Replace this stub with real RFC 7523 Β§3 / SEP-990 Β§5.1 validation. It fails closed - it - raises rather than trusting the assertion - so a copy of this example cannot accidentally - accept unverified tokens. RFC 7523 Β§3.1 / RFC 6749 Β§5.2 specify `invalid_grant` for a - rejected assertion. + Replace this stub per RFC 7523 Β§3 / SEP-990 Β§5.1: verify the signature, `iss`, `exp`, and + `typ` (`oauth-id-jag+jwt`); require `aud` to identify this server and `client_id` to match + the authenticated client; derive scopes from the ID-JAG and policy. Reject with + `invalid_grant` (RFC 7523 Β§3.1 / RFC 6749 Β§5.2). The stub raises so copies fail closed. """ raise NotImplementedError("Validate the ID-JAG (signature, iss/aud/exp/typ, client_id, resource)") diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index f290d31dd3..d0e92330e9 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -7,21 +7,17 @@ from mcp.server.mcpserver import Context, MCPServer -# Mock database class for example class Database: """Mock database class for example.""" @classmethod async def connect(cls) -> "Database": - """Connect to database.""" return cls() async def disconnect(self) -> None: - """Disconnect from database.""" pass def query(self) -> str: - """Execute a query.""" return "Query result" @@ -44,7 +40,6 @@ async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: await db.disconnect() -# Pass lifespan to server mcp = MCPServer("My App", lifespan=app_lifespan) diff --git a/examples/snippets/servers/lowlevel/basic.py b/examples/snippets/servers/lowlevel/basic.py index ff9b0a2c49..0affe62245 100644 --- a/examples/snippets/servers/lowlevel/basic.py +++ b/examples/snippets/servers/lowlevel/basic.py @@ -51,7 +51,6 @@ async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRe async def run(): - """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, diff --git a/examples/snippets/servers/lowlevel/direct_call_tool_result.py b/examples/snippets/servers/lowlevel/direct_call_tool_result.py index 4d6607d2ff..3840f587a6 100644 --- a/examples/snippets/servers/lowlevel/direct_call_tool_result.py +++ b/examples/snippets/servers/lowlevel/direct_call_tool_result.py @@ -50,7 +50,6 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ async def run(): - """Run the server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index 46db9ecc07..2a6001429c 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -12,23 +12,18 @@ from mcp.server import Server, ServerRequestContext -# Mock database class for example class Database: """Mock database class for example.""" @classmethod async def connect(cls) -> "Database": - """Connect to database.""" print("Database connected") return cls() async def disconnect(self) -> None: - """Disconnect from database.""" print("Database disconnected") async def query(self, query_str: str) -> list[dict[str, str]]: - """Execute a query.""" - # Simulate database query return [{"id": "1", "name": "Example", "query": query_str}] @@ -49,7 +44,6 @@ async def server_lifespan(_server: Server[AppContext]) -> AsyncIterator[AppConte async def handle_list_tools( ctx: ServerRequestContext[AppContext], params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - """List available tools.""" return types.ListToolsResult( tools=[ types.Tool( @@ -68,7 +62,6 @@ async def handle_list_tools( async def handle_call_tool( ctx: ServerRequestContext[AppContext], params: types.CallToolRequestParams ) -> types.CallToolResult: - """Handle database query tool call.""" if params.name != "query_db": raise ValueError(f"Unknown tool: {params.name}") @@ -87,7 +80,6 @@ async def handle_call_tool( async def run(): - """Run the server with lifespan management.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, diff --git a/examples/snippets/servers/lowlevel/structured_output.py b/examples/snippets/servers/lowlevel/structured_output.py index 84e411ff55..bf74e25672 100644 --- a/examples/snippets/servers/lowlevel/structured_output.py +++ b/examples/snippets/servers/lowlevel/structured_output.py @@ -68,7 +68,6 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ async def run(): - """Run the structured output server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, diff --git a/examples/snippets/servers/mcpserver_quickstart.py b/examples/snippets/servers/mcpserver_quickstart.py index 70a83a56e4..a4724f3516 100644 --- a/examples/snippets/servers/mcpserver_quickstart.py +++ b/examples/snippets/servers/mcpserver_quickstart.py @@ -6,25 +6,21 @@ from mcp.server.mcpserver import MCPServer -# Create an MCP server mcp = MCPServer("Demo") -# Add an addition tool @mcp.tool() def add(a: int, b: int) -> int: """Add two numbers""" return a + b -# Add a dynamic greeting resource @mcp.resource("greeting://{name}") def get_greeting(name: str) -> str: """Get a personalized greeting""" return f"Hello, {name}!" -# Add a prompt @mcp.prompt() def greet_user(name: str, style: str = "friendly") -> str: """Generate a greeting prompt""" @@ -37,6 +33,5 @@ def greet_user(name: str, style: str = "friendly") -> str: return f"{styles.get(style, styles['friendly'])} for someone named {name}." -# Run with streamable HTTP transport if __name__ == "__main__": mcp.run(transport="streamable-http", json_response=True) diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 05c0fbf331..d11e4e315d 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -6,13 +6,11 @@ @mcp.tool() async def process_data(data: str, ctx: Context) -> str: """Process data with logging.""" - # Different log levels await ctx.debug(f"Debug: Processing '{data}'") # pyright: ignore[reportDeprecated] await ctx.info("Info: Starting processing") # pyright: ignore[reportDeprecated] await ctx.warning("Warning: This is experimental") # pyright: ignore[reportDeprecated] await ctx.error("Error: (This is just a demo)") # pyright: ignore[reportDeprecated] - # Notify about resource changes await ctx.session.send_resource_list_changed() return f"Processed: {data}" diff --git a/examples/snippets/servers/oauth_server.py b/examples/snippets/servers/oauth_server.py index 962ef0615e..99e224bdcb 100644 --- a/examples/snippets/servers/oauth_server.py +++ b/examples/snippets/servers/oauth_server.py @@ -16,10 +16,9 @@ async def verify_token(self, token: str) -> AccessToken | None: pass # This is where you would implement actual token validation -# Create MCPServer instance as a Resource Server +# This server acts as a Resource Server: it validates tokens but does not issue them mcp = MCPServer( "Weather Service", - # Token verifier for authentication token_verifier=SimpleTokenVerifier(), # Auth settings for RFC 9728 Protected Resource Metadata auth=AuthSettings( diff --git a/examples/snippets/servers/pagination_example.py b/examples/snippets/servers/pagination_example.py index 4f7435acf6..1b4e54daa2 100644 --- a/examples/snippets/servers/pagination_example.py +++ b/examples/snippets/servers/pagination_example.py @@ -4,8 +4,7 @@ from mcp.server import Server, ServerRequestContext -# Sample data to paginate -ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items +ITEMS = [f"Item {i}" for i in range(1, 101)] async def handle_list_resources( @@ -14,20 +13,17 @@ async def handle_list_resources( """List resources with pagination support.""" page_size = 10 - # Extract cursor from request params cursor = params.cursor if params is not None else None - # Parse cursor to get offset + # The cursor is an opaque string; this server encodes the list offset in it start = 0 if cursor is None else int(cursor) end = start + page_size - # Get page of resources page_items = [ types.Resource(uri=f"resource://items/{item}", name=item, description=f"Description for {item}") for item in ITEMS[start:end] ] - # Determine next cursor next_cursor = str(end) if end < len(ITEMS) else None return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) diff --git a/examples/snippets/servers/sampling.py b/examples/snippets/servers/sampling.py index 83ec5066dd..b5a7080735 100644 --- a/examples/snippets/servers/sampling.py +++ b/examples/snippets/servers/sampling.py @@ -20,7 +20,7 @@ async def generate_poem(topic: str, ctx: Context) -> str: max_tokens=100, ) - # Since we're not passing tools param, result.content is single content + # Without the tools param, result.content is a single content block (not a list) if result.content.type == "text": return result.content.text return str(result.content) diff --git a/examples/snippets/servers/streamable_config.py b/examples/snippets/servers/streamable_config.py index 622e67063c..b99a4022b1 100644 --- a/examples/snippets/servers/streamable_config.py +++ b/examples/snippets/servers/streamable_config.py @@ -7,14 +7,12 @@ mcp = MCPServer("StatelessServer") -# Add a simple tool to demonstrate the server @mcp.tool() def greet(name: str = "World") -> str: """Greet someone by name.""" return f"Hello, {name}!" -# Run server with streamable_http transport # Transport-specific options (stateless_http, json_response) are passed to run() if __name__ == "__main__": # Stateless server with JSON responses (recommended) diff --git a/examples/snippets/servers/streamable_http_basic_mounting.py b/examples/snippets/servers/streamable_http_basic_mounting.py index 9a53034f16..886dff9693 100644 --- a/examples/snippets/servers/streamable_http_basic_mounting.py +++ b/examples/snippets/servers/streamable_http_basic_mounting.py @@ -11,7 +11,6 @@ from mcp.server.mcpserver import MCPServer -# Create MCP server mcp = MCPServer("My App") @@ -21,14 +20,13 @@ def hello() -> str: return "Hello from MCP!" -# Create a lifespan context manager to run the session manager +# The session manager must be running for the transport to handle requests @contextlib.asynccontextmanager async def lifespan(app: Starlette): async with mcp.session_manager.run(): yield -# Mount the StreamableHTTP server to the existing ASGI server # Transport-specific options are passed to streamable_http_app() app = Starlette( routes=[ diff --git a/examples/snippets/servers/streamable_http_host_mounting.py b/examples/snippets/servers/streamable_http_host_mounting.py index 2a41f74a59..dab7c3bb8c 100644 --- a/examples/snippets/servers/streamable_http_host_mounting.py +++ b/examples/snippets/servers/streamable_http_host_mounting.py @@ -11,7 +11,6 @@ from mcp.server.mcpserver import MCPServer -# Create MCP server mcp = MCPServer("MCP Host App") @@ -21,14 +20,13 @@ def domain_info() -> str: return "This is served from mcp.acme.corp" -# Create a lifespan context manager to run the session manager +# The session manager must be running for the server to handle requests @contextlib.asynccontextmanager async def lifespan(app: Starlette): async with mcp.session_manager.run(): yield -# Mount using Host-based routing # Transport-specific options are passed to streamable_http_app() app = Starlette( routes=[ diff --git a/examples/snippets/servers/streamable_http_multiple_servers.py b/examples/snippets/servers/streamable_http_multiple_servers.py index 71217bdfed..00fa258714 100644 --- a/examples/snippets/servers/streamable_http_multiple_servers.py +++ b/examples/snippets/servers/streamable_http_multiple_servers.py @@ -11,7 +11,6 @@ from mcp.server.mcpserver import MCPServer -# Create multiple MCP servers api_mcp = MCPServer("API Server") chat_mcp = MCPServer("Chat Server") @@ -28,7 +27,7 @@ def send_message(message: str) -> str: return f"Message sent: {message}" -# Create a combined lifespan to manage both session managers +# A combined lifespan must run both servers' session managers @contextlib.asynccontextmanager async def lifespan(app: Starlette): async with contextlib.AsyncExitStack() as stack: @@ -37,8 +36,7 @@ async def lifespan(app: Starlette): yield -# Mount the servers with transport-specific options passed to streamable_http_app() -# streamable_http_path="/" means endpoints will be at /api and /chat instead of /api/mcp and /chat/mcp +# streamable_http_path="/" puts endpoints at /api and /chat instead of /api/mcp and /chat/mcp app = Starlette( routes=[ Mount("/api", app=api_mcp.streamable_http_app(json_response=True, streamable_http_path="/")), diff --git a/examples/snippets/servers/streamable_http_path_config.py b/examples/snippets/servers/streamable_http_path_config.py index 4c65ffdd79..ed2790e0b7 100644 --- a/examples/snippets/servers/streamable_http_path_config.py +++ b/examples/snippets/servers/streamable_http_path_config.py @@ -9,7 +9,6 @@ from mcp.server.mcpserver import MCPServer -# Create a simple MCPServer server mcp_at_root = MCPServer("My Server") diff --git a/examples/snippets/servers/streamable_starlette_mount.py b/examples/snippets/servers/streamable_starlette_mount.py index eb6f1b8093..26bd86fc1f 100644 --- a/examples/snippets/servers/streamable_starlette_mount.py +++ b/examples/snippets/servers/streamable_starlette_mount.py @@ -9,7 +9,6 @@ from mcp.server.mcpserver import MCPServer -# Create the Echo server echo_mcp = MCPServer(name="EchoServer") @@ -19,7 +18,6 @@ def echo(message: str) -> str: return f"Echo: {message}" -# Create the Math server math_mcp = MCPServer(name="MathServer") @@ -29,7 +27,7 @@ def add_two(n: int) -> int: return n + 2 -# Create a combined lifespan to manage both session managers +# A combined lifespan must run both servers' session managers @contextlib.asynccontextmanager async def lifespan(app: Starlette): async with contextlib.AsyncExitStack() as stack: @@ -38,7 +36,6 @@ async def lifespan(app: Starlette): yield -# Create the Starlette app and mount the MCP servers app = Starlette( routes=[ Mount("/echo", echo_mcp.streamable_http_app(stateless_http=True, json_response=True)), @@ -47,7 +44,6 @@ async def lifespan(app: Starlette): lifespan=lifespan, ) -# Note: Clients connect to http://localhost:8000/echo/mcp and http://localhost:8000/math/mcp -# To mount at the root of each path (e.g., /echo instead of /echo/mcp): -# echo_mcp.streamable_http_app(streamable_http_path="/", stateless_http=True, json_response=True) -# math_mcp.streamable_http_app(streamable_http_path="/", stateless_http=True, json_response=True) +# Clients connect to http://localhost:8000/echo/mcp and http://localhost:8000/math/mcp. +# To mount at the root of each path (/echo instead of /echo/mcp), pass +# streamable_http_path="/" to streamable_http_app(). diff --git a/examples/snippets/servers/structured_output.py b/examples/snippets/servers/structured_output.py index bea7b22c16..2720b98aa3 100644 --- a/examples/snippets/servers/structured_output.py +++ b/examples/snippets/servers/structured_output.py @@ -22,7 +22,6 @@ class WeatherData(BaseModel): @mcp.tool() def get_weather(city: str) -> WeatherData: """Get weather for a city - returns structured data.""" - # Simulated weather data return WeatherData( temperature=22.5, humidity=45.0, diff --git a/examples/stories/__init__.py b/examples/stories/__init__.py index 6f4d6055a7..5e95300f28 100644 --- a/examples/stories/__init__.py +++ b/examples/stories/__init__.py @@ -1,6 +1,6 @@ """Self-verifying example suite for the MCP Python SDK. -Each story directory holds a ``server.py`` (and usually ``server_lowlevel.py``) -plus a ``client.py`` whose ``main(target, *, mode)`` runs against both. -``tests/examples/`` drives every story over an in-process matrix. +Each story directory holds a `server.py` (and usually `server_lowlevel.py`) +plus a `client.py` whose `main(target, *, mode)` runs against both. +`tests/examples/` drives every story over an in-process matrix. """ diff --git a/examples/stories/_harness.py b/examples/stories/_harness.py index c7036acd68..6a859b0931 100644 --- a/examples/stories/_harness.py +++ b/examples/stories/_harness.py @@ -1,9 +1,7 @@ """Client-side scaffold for story examples. -A story's ``client.py`` imports ``Target`` (or ``TargetFactory``) for its ``main`` -signature and calls ``run_client(main)`` from ``__main__``. The story owns the -``Client(target, mode=...)`` construction; this module only decides WHICH target -``__main__`` hands it. +A story's `client.py` calls `run_client(main)` from `__main__`; the story owns the +`Client(target, mode=...)` construction β€” this module only decides which target it gets. """ from __future__ import annotations @@ -33,17 +31,17 @@ import tomli as tomllib Target: TypeAlias = "Server[Any] | MCPServer | Transport | str" -"""Anything ``Client(...)`` accepts: an in-process server, a ``Transport``, or an HTTP URL.""" +"""Anything `Client(...)` accepts: an in-process server, a `Transport`, or an HTTP URL.""" TargetFactory = Callable[[], Target] -"""Yields a FRESH target against the same server/app on every call (``multi_connection`` stories).""" +"""Yields a FRESH target against the same server/app on every call (`multi_connection` stories).""" AuthBuilder = Callable[[httpx.AsyncClient], httpx.Auth] -"""Builds an ``httpx.Auth`` bound to the in-process HTTP client (auth-story harness seam).""" +"""Builds an `httpx.Auth` bound to the in-process HTTP client (auth-story harness seam).""" def argv_after(flag: str, *, default: str | None = None) -> str: - """Return the argv token following ``flag``, or ``default`` when the flag is absent.""" + """Return the argv token following `flag`, or `default` when the flag is absent.""" try: return sys.argv[sys.argv.index(flag) + 1] except ValueError: @@ -53,11 +51,7 @@ def argv_after(flag: str, *, default: str | None = None) -> str: def target_from_args(file: str, url: str | None) -> TargetFactory: - """Build a ``TargetFactory`` for the sibling server of the ``client.py`` at ``file``. - - ``url`` (already resolved by ``run_client``) targets that streamable-HTTP endpoint; ``None`` - spawns ``.py`` over stdio per call, ```` from ``--server`` (default ``server``). - """ + """Build a `TargetFactory`: `url` as-is, or `None` to spawn the sibling `--server` script over stdio.""" if url is not None: return lambda: url # stdio is legacy-only until serve_stdio() lands; the modern arm is --http only for now. @@ -67,7 +61,7 @@ def target_from_args(file: str, url: str | None) -> TargetFactory: def _explicit_http_url() -> str | None: - """The URL token after ``--http``, or ``None`` when the flag stands alone (self-host).""" + """The URL token after `--http`, or `None` when the flag stands alone (self-host).""" rest = sys.argv[sys.argv.index("--http") + 1 :] return rest[0] if rest and not rest[0].startswith("-") else None @@ -80,7 +74,6 @@ def _free_port() -> int: async def _accepting(port: int) -> bool: - """Whether something accepts a TCP connect on ``127.0.0.1:port`` right now.""" try: stream = await anyio.connect_tcp("127.0.0.1", port) except OSError: @@ -91,17 +84,14 @@ async def _accepting(port: int) -> bool: @asynccontextmanager async def _self_hosted(name: str, cfg: dict[str, Any]) -> AsyncIterator[str]: - """Serve the story's sibling server from a subprocess on a port this process owns; yield its URL. + """Serve the story's sibling server in a subprocess; yield its URL once it accepts TCP. - Readiness is the first accepted TCP connect (bounded by ``run_client``'s - ``anyio.fail_after``); exiting terminates the subprocess. Nothing to background or kill. - A subprocess that dies before serving, or a ``fixed_port`` someone else already holds, - is a loud ``SystemExit`` rather than a hang or a run against the wrong server. + A child that dies before serving, or a `fixed_port` someone else holds, is a loud + `SystemExit` rather than a hang; the readiness poll is bounded by `run_client`'s timeout. """ port: int = cfg["fixed_port"] or _free_port() if cfg["fixed_port"] and await _accepting(port): - # The readiness probe below can't tell our child from a server already on the - # story's pinned port, so a foreign listener would be tested in its place. + # The readiness probe can't tell our child from a foreign listener already on the pinned port. raise SystemExit( f"{name} self-hosts on :{port} but something is already serving there; " f"stop it, or connect to it with --http " @@ -122,22 +112,22 @@ async def _self_hosted(name: str, cfg: dict[str, Any]) -> AsyncIterator[str]: def _story_cfg(name: str) -> dict[str, Any]: - """The manifest entry for the story ``name`` with ``[defaults]`` applied.""" + """The manifest entry for the story `name` with `[defaults]` applied.""" manifest: dict[str, Any] = tomllib.loads((Path(__file__).parent / "manifest.toml").read_text()) return manifest["defaults"] | manifest["story"].get(name, {}) def _authed_targets(url: str, http: httpx.AsyncClient) -> TargetFactory: - """Fresh streamable-HTTP transports over an already-authed ``httpx`` client.""" + """Fresh streamable-HTTP transports over an already-authed `httpx` client.""" return lambda: streamable_http_client(url, http_client=http) def run_client(main: Callable[..., Awaitable[None]]) -> None: - """Entry point for ``if __name__ == "__main__"`` in every ``client.py``. + """Entry point for `if __name__ == "__main__"` in every `client.py`. - Resolves the argv target β€” stdio (the default), ``--http `` for a server you run, or - bare ``--http`` to self-host the sibling server in a subprocess it owns β€” and calls ``main`` - with an explicit ``mode=``. A ``build_auth`` export auths the HTTP target. ``OK``/``FAIL``, exit 0/1. + Resolves the argv target β€” stdio (default), `--http `, or bare `--http` to self-host + the sibling server β€” and calls `main` with an explicit `mode=`; a `build_auth` export + auths the HTTP target. `OK`/`FAIL`, exit 0/1. """ globals_ = getattr(main, "__globals__", {}) file = str(globals_.get("__file__", "")) @@ -153,16 +143,14 @@ def run_client(main: Callable[..., Awaitable[None]]) -> None: if cfg["needs_http"] and transport != "http": raise SystemExit(f"{name} asserts on raw HTTP responses; run it with --http") explicit_url = _explicit_http_url() if transport == "http" else None - # The era is an axis of the story matrix, so ``mode=`` is always passed explicitly - # even though it often matches the ``Client`` default of "auto". stdio is legacy-only - # until the SDK's stdio entry can negotiate the era, so only --http gets a modern arm. + # Era is an axis of the story matrix, so `mode=` is always passed explicitly even when it + # matches the `Client` default. stdio can't negotiate the era yet, so only --http gets a modern arm. era = "modern" if transport == "http" and "--legacy" not in sys.argv else "legacy" if cfg["era"] in ("legacy", "modern"): era = cfg["era"] if cfg["era"] == "dual-in-body": - # The story pins its connection modes inside ``main`` itself, so hand it "auto" - # (the ``Client`` default) and let those in-body pins decide. A hard version pin - # here would skip the discover probe and leave ``server_info`` blank. + # The story pins its connection modes inside `main`, so hand it "auto"; a hard + # version pin here would skip the discover probe and leave `server_info` blank. era = "in-body" mode = {"modern": LATEST_MODERN_VERSION, "legacy": "legacy", "in-body": "auto"}[era] @@ -176,10 +164,8 @@ async def _run() -> None: if url is None or (build_auth is None and not cfg["needs_http"]): await main(targets if cfg["multi_connection"] else targets(), mode=mode) return - # Auth and needs_http stories want the raw httpx client underneath the transport: - # build_auth threads an httpx.Auth onto it (Client(url, auth=...) doesn't exist - # yet), and needs_http stories assert on raw responses, so root the client at the - # server origin and relative paths like "/mcp" resolve. + # build_auth threads an httpx.Auth onto the raw client (Client(url, auth=...) doesn't + # exist yet); needs_http asserts on raw responses; origin base_url makes "/mcp" resolve. parts = urlsplit(url) base = f"{parts.scheme}://{parts.netloc}" http = await stack.enter_async_context(httpx.AsyncClient(base_url=base)) diff --git a/examples/stories/_hosting.py b/examples/stories/_hosting.py index 041778677d..80df056189 100644 --- a/examples/stories/_hosting.py +++ b/examples/stories/_hosting.py @@ -1,8 +1,7 @@ """Server-side hosting scaffold for story examples. -A story's ``server.py`` / ``server_lowlevel.py`` imports only from here. The -marked lines touch entry-point APIs that a later release reshapes into -free-function entries; isolating them here keeps story bodies stable. +Story `server*.py` files import only from here; the marked lines touch entry-point +APIs that a later release reshapes into free functions, keeping story bodies stable. """ from __future__ import annotations @@ -29,7 +28,7 @@ def argv_after(flag: str, *, default: str | None = None) -> str: - """Return the argv token following ``flag``, or ``default`` when the flag is absent.""" + """Return the argv token following `flag`, or `default` when the flag is absent.""" try: return sys.argv[sys.argv.index(flag) + 1] except ValueError: @@ -48,10 +47,9 @@ def asgi_from(server: AnyServer, *, path: str = "/mcp") -> Starlette: def run_server_from_args(build_server: ServerFactory) -> None: - """Entry point for ``if __name__ == "__main__"`` in every ``server*.py``. + """Entry point for `if __name__ == "__main__"` in every `server*.py`. - Bare argv serves over stdio; ``--http --port N [--path /mcp]`` serves over - uvicorn on 127.0.0.1:N. + Bare argv serves over stdio; `--http --port N [--path /mcp]` serves over uvicorn on 127.0.0.1:N. """ server = build_server() if "--http" in sys.argv: @@ -77,10 +75,9 @@ async def _serve_http(server: AnyServer, port: int, path: str) -> None: def run_app_from_args(build_app: AppFactory) -> None: - """Entry point for ``if __name__ == "__main__"`` in app-exporting ``server*.py``. + """Entry point for app-exporting `server*.py` (HTTP-only, no stdio leg). - App-exporting stories are HTTP-only; ``--port N`` serves the Starlette app over - uvicorn on 127.0.0.1:N (uvicorn drives the app's own lifespan). No stdio leg. + `--port N` serves the Starlette app over uvicorn on 127.0.0.1:N; uvicorn drives the app's lifespan. """ port = int(argv_after("--port", default="8000")) config = uvicorn.Config(build_app(), host="127.0.0.1", port=port, log_level="error") diff --git a/examples/stories/_shared/auth.py b/examples/stories/_shared/auth.py index 3bedcd3ab9..244092b1b4 100644 --- a/examples/stories/_shared/auth.py +++ b/examples/stories/_shared/auth.py @@ -1,7 +1,4 @@ -"""Minimal in-process OAuth pieces for the auth stories. - -A story-shaped subset; ``tests/interaction/auth`` keeps its own (richer) provider. -""" +"""Minimal in-process OAuth pieces for the auth stories; `tests/interaction/auth` keeps its richer provider.""" from __future__ import annotations @@ -30,7 +27,7 @@ class InMemoryTokenStorage: - """A ``TokenStorage`` that keeps tokens and DCR client info on instance attributes.""" + """A `TokenStorage` that keeps tokens and DCR client info on instance attributes.""" tokens: OAuthToken | None = None client_info: OAuthClientInformationFull | None = None @@ -49,7 +46,7 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None class HeadlessOAuth: - """Completes the authorize redirect in-process via the bound ``httpx`` client.""" + """Completes the authorize redirect in-process via the bound `httpx` client.""" def __init__(self) -> None: self.authorize_url: str | None = None @@ -62,7 +59,7 @@ def bind(self, http_client: httpx.AsyncClient) -> None: async def redirect_handler(self, authorization_url: str) -> None: assert self._http is not None self.authorize_url = authorization_url - # ``auth=None`` is load-bearing: re-entering the locked auth flow would deadlock. + # `auth=None` is load-bearing: re-entering the locked auth flow would deadlock. response = await self._http.get(authorization_url, follow_redirects=False, auth=None) assert response.status_code == 302, f"authorize returned {response.status_code}: {response.text}" params = parse_qs(urlsplit(response.headers["location"]).query) @@ -77,8 +74,8 @@ class InMemoryAuthorizationServerProvider( ): """Minimal demo AS: DCR + authorize + auth-code exchange held in instance dicts. - ``authorize`` auto-consents only when ``OAUTH_DEMO_AUTO_CONSENT=1``; otherwise it redirects - with ``error=interaction_required`` so a manual run shows where a real browser would open. + `authorize` auto-consents only when `OAUTH_DEMO_AUTO_CONSENT=1`; otherwise it redirects + with `error=interaction_required` so a manual run shows where a real browser would open. """ def __init__(self) -> None: @@ -158,9 +155,9 @@ async def revoke_token(self, token: AccessToken | RefreshToken) -> None: def auth_settings( *, required_scopes: list[str] | None = None, identity_assertion_enabled: bool = False ) -> AuthSettings: - """``AuthSettings`` for the co-hosted demo AS+RS on the loopback origin, DCR enabled. + """`AuthSettings` for the co-hosted demo AS+RS on the loopback origin, DCR enabled. - ``identity_assertion_enabled`` passes through to the SEP-990 jwt-bearer grant flag. + `identity_assertion_enabled` passes through to the SEP-990 jwt-bearer grant flag. """ scopes = required_scopes or ["mcp"] return AuthSettings( diff --git a/examples/stories/apps/client.py b/examples/stories/apps/client.py index 8a238f469e..b7e1d77ee9 100644 --- a/examples/stories/apps/client.py +++ b/examples/stories/apps/client.py @@ -8,11 +8,9 @@ async def main(target: Target, *, mode: str = "auto") -> None: - # Advertise MCP Apps support so the server returns the UI-enabled result; a - # client that omits this gets the text-only fallback (graceful degradation). + # Advertise MCP Apps support; a client that omits this gets the text-only fallback. async with Client(target, mode=mode, extensions={EXTENSION_ID: {"mimeTypes": [APP_MIME_TYPE]}}) as client: - # The extensions capability map rides `server/discover` (modern only). On a - # legacy connection (today's stdio) it is absent, so assert it only when present. + # The extensions capability map rides `server/discover` (modern only), so it's absent on legacy stdio. if client.server_capabilities.extensions is not None: assert client.server_capabilities.extensions == {EXTENSION_ID: {}}, client.server_capabilities.extensions diff --git a/examples/stories/apps/server.py b/examples/stories/apps/server.py index 74d412e02c..fd330a3c20 100644 --- a/examples/stories/apps/server.py +++ b/examples/stories/apps/server.py @@ -1,10 +1,8 @@ """MCP Apps: a tool bound to a `ui://` resource the host renders as an interactive surface. -`Apps` is an opt-in `Extension` passed to `MCPServer(extensions=[...])`. The -`@apps.tool(resource_uri=...)` decorator stamps `_meta.ui.resourceUri` onto the -tool; `add_html_resource` registers the matching `ui://` HTML resource. The tool -degrades gracefully: `client_supports_apps(ctx)` reports whether the client -negotiated Apps, so it returns text-only output otherwise. +`Apps` is an opt-in extension: `@apps.tool(resource_uri=...)` stamps `_meta.ui.resourceUri` onto the tool, +`add_html_resource` registers the matching `ui://` HTML, and `client_supports_apps(ctx)` enables a +text-only fallback for clients that didn't negotiate Apps. """ from mcp.server.apps import Apps, client_supports_apps diff --git a/examples/stories/bearer_auth/client.py b/examples/stories/bearer_auth/client.py index 5c419a0716..8fc07e103e 100644 --- a/examples/stories/bearer_auth/client.py +++ b/examples/stories/bearer_auth/client.py @@ -1,4 +1,4 @@ -"""Call the bearer-gated server through an already-authed (``build_auth``, HTTP-only) transport; assert ``whoami``.""" +"""Call the bearer-gated server through a pre-authed transport (`build_auth`, HTTP-only) and check `whoami`.""" from collections.abc import Generator @@ -11,7 +11,7 @@ class StaticBearerAuth(httpx.Auth): - """``httpx.Auth`` that attaches a fixed ``Authorization: Bearer `` to every request.""" + """Attach a fixed `Authorization: Bearer ` header to every request.""" def __init__(self, token: str) -> None: self.token = token @@ -22,10 +22,9 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re def build_auth(_http: httpx.AsyncClient) -> httpx.Auth: - """The demo bearer token as an ``httpx.Auth``. + """The demo bearer token as an `httpx.Auth`. - ``Client(url, auth=...)`` doesn't exist yet, so the harness threads this onto the underlying - ``httpx.AsyncClient`` and the target ``main`` receives is already routed through it. + `Client(url, auth=...)` doesn't exist yet, so the harness threads this onto the underlying `httpx.AsyncClient`. """ return StaticBearerAuth(DEMO_TOKEN) diff --git a/examples/stories/bearer_auth/server.py b/examples/stories/bearer_auth/server.py index 45c9872c3a..64b1d8f97c 100644 --- a/examples/stories/bearer_auth/server.py +++ b/examples/stories/bearer_auth/server.py @@ -1,4 +1,4 @@ -"""Resource-server-only bearer auth: ``TokenVerifier``/``AuthSettings`` β†’ 401/PRM/principal. Exports ``build_app()``.""" +"""Resource-server-only bearer auth: `TokenVerifier`/`AuthSettings` β†’ 401/PRM/principal. Exports `build_app()`.""" import time diff --git a/examples/stories/bearer_auth/server_lowlevel.py b/examples/stories/bearer_auth/server_lowlevel.py index f5abfc08c4..7658d8b18d 100644 --- a/examples/stories/bearer_auth/server_lowlevel.py +++ b/examples/stories/bearer_auth/server_lowlevel.py @@ -1,4 +1,4 @@ -"""Resource-server-only bearer auth (lowlevel API): same gate, hand-built ``CallToolResult``.""" +"""Resource-server-only bearer auth (lowlevel API): same gate, hand-built `CallToolResult`.""" from typing import Any diff --git a/examples/stories/custom_methods/client.py b/examples/stories/custom_methods/client.py index 4003885fa4..76f5eb12f1 100644 --- a/examples/stories/custom_methods/client.py +++ b/examples/stories/custom_methods/client.py @@ -24,12 +24,9 @@ class SearchResult(types.Result): async def main(target: Target, *, mode: str = "auto") -> None: async with Client(target, mode=mode) as client: - # `Client` only exposes spec-defined verbs, so vendor methods have to drop one - # layer to `client.session` today β€” there is no `Client`-level API for them - # yet, and whether `.session` stays public is undecided. `send_request` is - # typed against the closed `ClientRequest` union, hence the cast; at runtime - # the body only calls `.model_dump()` and the unknown method skips the - # per-spec result-validation registry. + # `Client` only exposes spec-defined verbs, so vendor methods drop to `client.session` today. + # `send_request` is typed against the closed `ClientRequest` union, hence the cast; at runtime it + # only calls `.model_dump()`, and an unknown method skips the per-spec result-validation registry. request = SearchRequest(params=SearchParams(query="mcp", limit=3)) result = await client.session.send_request(cast("types.ClientRequest", request), SearchResult) assert result.items == ["mcp-0", "mcp-1", "mcp-2"], result diff --git a/examples/stories/dual_era/client.py b/examples/stories/dual_era/client.py index ba9acf5d99..c306d634cb 100644 --- a/examples/stories/dual_era/client.py +++ b/examples/stories/dual_era/client.py @@ -8,9 +8,8 @@ async def main(targets: TargetFactory, *, mode: str = "auto") -> None: - # ── modern arm: the caller's mode (the real-user "auto" default) probes - # ``server/discover`` and adopts the result β€” no ``initialize`` handshake runs. - # The version/info/capabilities accessors are era-neutral. + # Modern arm: the caller's mode (the real-user "auto" default) probes `server/discover` + # and adopts the result β€” no `initialize` handshake. The accessors below are era-neutral. async with Client(targets(), mode=mode) as modern: assert modern.protocol_version == LATEST_MODERN_VERSION assert modern.server_info.name == "dual-era-example" @@ -24,8 +23,8 @@ async def main(targets: TargetFactory, *, mode: str = "auto") -> None: assert isinstance(first, types.TextContent) assert first.text == f"Hello, 2026 client! (served on the modern era at {LATEST_MODERN_VERSION})" - # ── legacy arm: a fresh connection to the SAME server, pinned to the handshake era. - # The same accessors are populated identically β€” here by ``initialize``. + # Legacy arm: a fresh connection to the SAME server, pinned to the handshake era. + # The same accessors are populated identically β€” here by `initialize`. async with Client(targets(), mode="legacy") as legacy: assert legacy.protocol_version == LATEST_HANDSHAKE_VERSION assert legacy.server_info.name == "dual-era-example" diff --git a/examples/stories/dual_era/server.py b/examples/stories/dual_era/server.py index 3f70ee63c9..4d8e77406b 100644 --- a/examples/stories/dual_era/server.py +++ b/examples/stories/dual_era/server.py @@ -7,8 +7,7 @@ def build_server() -> MCPServer: - # The same factory serves both eras with no configuration. Which era a request is - # on is decided by the entry point / transport, never by the server. + # One factory, both eras: which era a request is on is decided by the entry point / transport, never the server. mcp = MCPServer("dual-era-example", instructions="A small dual-era demo server.") @mcp.tool() diff --git a/examples/stories/dual_era/server_lowlevel.py b/examples/stories/dual_era/server_lowlevel.py index b209135e6d..39834c1a9f 100644 --- a/examples/stories/dual_era/server_lowlevel.py +++ b/examples/stories/dual_era/server_lowlevel.py @@ -36,8 +36,7 @@ async def call_tool(ctx: ServerRequestContext[Any], params: types.CallToolReques text = f"Hello, {params.arguments['name']}! (served on the {era} era at {ctx.protocol_version})" return types.CallToolResult(content=[types.TextContent(text=text)]) - # The same factory serves both eras with no configuration. Which era a request is - # on is decided by the entry point / transport, never by the server. + # One factory, both eras: which era a request is on is decided by the entry point / transport, never the server. return Server( "dual-era-example", instructions="A small dual-era demo server.", diff --git a/examples/stories/error_handling/client.py b/examples/stories/error_handling/client.py index 872ec7fe31..dd36344a28 100644 --- a/examples/stories/error_handling/client.py +++ b/examples/stories/error_handling/client.py @@ -19,8 +19,7 @@ async def main(target: Target, *, mode: str = "auto") -> None: failed = await client.call_tool("divide", {"a": 1, "b": 0}) assert failed.is_error is True, "execution errors ride CallToolResult, not an exception" assert isinstance(failed.content[0], TextContent) - # MCPServer prefixes "Error executing tool divide: ..."; lowlevel returns - # the message verbatim. Assert the substring both produce. + # MCPServer prefixes "Error executing tool divide: ..."; lowlevel is verbatim β€” assert the shared substring. assert "cannot divide by zero" in failed.content[0].text # Protocol error: arrives as a raised MCPError. diff --git a/examples/stories/error_handling/server.py b/examples/stories/error_handling/server.py index e4f3554433..dd4c03ac35 100644 --- a/examples/stories/error_handling/server.py +++ b/examples/stories/error_handling/server.py @@ -15,17 +15,14 @@ def build_server() -> MCPServer: def divide(a: float, b: float) -> float: """Divide a by b. Division by zero is an execution error the LLM should see.""" if b == 0: - # ToolError is caught by the tool wrapper and returned as - # CallToolResult(is_error=True) β€” the LLM reads the message and can - # self-correct. + # ToolError becomes CallToolResult(is_error=True) β€” the LLM reads the message and can self-correct. raise ToolError("cannot divide by zero") return a / b @mcp.tool() def restricted() -> str: """A tool that always rejects the caller at the protocol level.""" - # MCPError escapes the tool wrapper and becomes a JSON-RPC error - # response β€” the *host* sees code/message/data, not the LLM. + # MCPError escapes the tool wrapper as a JSON-RPC error β€” the host sees code/message/data, not the LLM. raise MCPError(code=INVALID_PARAMS, message="this tool is gated", data={"reason": "demo"}) return mcp diff --git a/examples/stories/error_handling/server_lowlevel.py b/examples/stories/error_handling/server_lowlevel.py index 9bb9aef86a..c9f8a9a9cc 100644 --- a/examples/stories/error_handling/server_lowlevel.py +++ b/examples/stories/error_handling/server_lowlevel.py @@ -33,8 +33,7 @@ async def call_tool(ctx: ServerRequestContext[Any], params: types.CallToolReques ) return types.CallToolResult(content=[types.TextContent(text=str(a / b))]) if params.name == "restricted": - # Protocol error: raise MCPError; the dispatcher serialises it as a - # JSON-RPC error response with this code/message/data. + # Protocol error: the dispatcher serialises MCPError as a JSON-RPC error with this code/message/data. raise MCPError(code=types.INVALID_PARAMS, message="this tool is gated", data={"reason": "demo"}) raise MCPError(code=types.INVALID_PARAMS, message=f"Unknown tool: {params.name}") diff --git a/examples/stories/extensions/client.py b/examples/stories/extensions/client.py index d3aacc140f..9442cd0697 100644 --- a/examples/stories/extensions/client.py +++ b/examples/stories/extensions/client.py @@ -29,8 +29,7 @@ async def main(target: Target, *, mode: str = "auto") -> None: # Declare the extension client-side so the server's `require_client_extension` # gate on `com.example/search` passes. async with Client(target, mode=mode, extensions={EXTENSION_ID: {}}) as client: - # The extensions capability map rides `server/discover` (modern only). On a - # legacy connection it is absent, so assert it only when present. + # The extensions capability map rides `server/discover` (modern only); absent on legacy connections. if client.server_capabilities.extensions is not None: assert client.server_capabilities.extensions == {EXTENSION_ID: {"suggest": True}}, ( client.server_capabilities.extensions diff --git a/examples/stories/extensions/server.py b/examples/stories/extensions/server.py index 837c668dc5..b75dfa0efd 100644 --- a/examples/stories/extensions/server.py +++ b/examples/stories/extensions/server.py @@ -1,9 +1,7 @@ """Package a vendor verb and a tool as a reusable, advertised extension (SEP-2133). -`custom_methods/` registers a verb on the lowlevel `Server` by hand; this story -bundles the same idea as an `Extension`: declared contributions, a settings entry -under `ServerCapabilities.extensions`, and a `require_client_extension` gate on -the vendor method. +Unlike the hand-registered verb in `custom_methods/`, an `Extension` bundles contributions, +settings under `ServerCapabilities.extensions`, and a `require_client_extension` gate. """ from collections.abc import Sequence diff --git a/examples/stories/identity_assertion/client.py b/examples/stories/identity_assertion/client.py index bd13909801..50a9974112 100644 --- a/examples/stories/identity_assertion/client.py +++ b/examples/stories/identity_assertion/client.py @@ -15,12 +15,12 @@ async def fetch_id_jag(audience: str, resource: str) -> str: - """Step one, the part the SDK does not do: obtain a fresh ID-JAG from the enterprise IdP. + """Obtain a fresh ID-JAG from the enterprise IdP β€” the one step the SDK does not do. - A real implementation makes an RFC 8693 token-exchange request to the IdP, presenting the - signed-in user's ID token; `audience` (the authorization server's issuer) and `resource` (the - MCP server's identifier) pass straight through into the ID-JAG's `aud` and `resource` claims. - Here the stand-in IdP signs one in-process instead. + A real implementation makes an RFC 8693 token-exchange request with the signed-in user's ID + token; `audience` (the authorization server's issuer) and `resource` (the MCP server's + identifier) become the ID-JAG's `aud` and `resource` claims. The stand-in IdP signs one + in-process instead. """ return issue_id_jag( subject=DEMO_SUBJECT, client_id=DEMO_CLIENT_ID, audience=audience, resource=resource, scope=DEMO_SCOPE @@ -30,11 +30,10 @@ async def fetch_id_jag(audience: str, resource: str) -> str: def build_auth(_http: httpx.AsyncClient) -> httpx.Auth: """An `IdentityAssertionOAuthProvider` for the pre-registered confidential client. - `issuer` is configuration, not discovery: the provider fetches metadata from this issuer's - well-known and never asks the MCP server which authorization server to use. The string must - equal the `issuer` its metadata serves byte for byte (note the trailing slash). + `issuer` is configuration, not discovery β€” metadata comes from its well-known, never from the + MCP server β€” and must match the served `issuer` byte for byte (note the trailing slash). `Client(url, auth=...)` doesn't exist yet, so the harness threads this onto the underlying - `httpx.AsyncClient` and hands `main` a target that is already routed through it. + `httpx.AsyncClient` and hands `main` a target already routed through it. """ return IdentityAssertionOAuthProvider( server_url=MCP_URL, @@ -48,10 +47,9 @@ def build_auth(_http: httpx.AsyncClient) -> httpx.Auth: async def main(target: Target, *, mode: str = "auto") -> None: - # The target is already routed through `build_auth`'s provider. The first request 401s; the - # provider fetches the authorization server's metadata from the configured issuer (never from - # the MCP server), mints a fresh ID-JAG through `fetch_id_jag`, exchanges it at `/token` under - # the jwt-bearer grant, and retries with the bearer. No `/authorize`, no `/register`, no browser. + # The first request 401s; the provider fetches the authorization server's metadata from the + # configured issuer, mints an ID-JAG via `fetch_id_jag`, exchanges it at `/token` under the + # jwt-bearer grant, and retries with the bearer. No `/authorize`, no `/register`, no browser. async with Client(target, mode=mode) as client: listed = await client.list_tools() assert [t.name for t in listed.tools] == ["whoami"] diff --git a/examples/stories/identity_assertion/idp.py b/examples/stories/identity_assertion/idp.py index 5d77c665f1..7c4d01059d 100644 --- a/examples/stories/identity_assertion/idp.py +++ b/examples/stories/identity_assertion/idp.py @@ -1,9 +1,7 @@ """A stand-in enterprise identity provider: it signs the ID-JAGs the demo authorization server trusts. -In production the IdP is a separate service (Okta, Microsoft Entra ID, ...) and the client obtains -the ID-JAG from it with an RFC 8693 token-exchange request, presenting the signed-in user's ID -token. `issue_id_jag` collapses that whole step into one in-process signing call so the story runs -unattended; the README's caveats spell out what a real deployment changes. +A real IdP (Okta, Microsoft Entra ID, ...) issues the ID-JAG via an RFC 8693 token exchange; +`issue_id_jag` collapses that step into one in-process signing call so the story runs unattended. """ import time @@ -12,17 +10,15 @@ import jwt IDP_ISSUER = "https://idp.example.com" -# Demo only: a real IdP signs with its private key and the authorization server verifies the -# signature against the IdP's published JWKS. A shared HMAC secret keeps this story self-contained. +# Demo only: a real IdP signs with its private key, verified against its published JWKS. IDP_SIGNING_KEY = "demo-idp-signing-key" def issue_id_jag(*, subject: str, client_id: str, audience: str, resource: str, scope: str) -> str: """The IdP's short-lived, signed statement that `subject`, via `client_id`, may reach `resource`. - This is where the enterprise enforces policy: an IdP that does not authorize the combination - simply never issues the ID-JAG, and there is nothing for the client to present. The `typ` - header and the claim set are fixed by the Identity Assertion JWT Authorization Grant profile. + Enterprise policy is enforced here: an unauthorized combination never gets an ID-JAG. The `typ` + header and claim set are fixed by the Identity Assertion JWT Authorization Grant profile. """ now = int(time.time()) return jwt.encode( diff --git a/examples/stories/identity_assertion/server.py b/examples/stories/identity_assertion/server.py index 8b0c8f4019..2aea643a97 100644 --- a/examples/stories/identity_assertion/server.py +++ b/examples/stories/identity_assertion/server.py @@ -1,8 +1,7 @@ """SEP-990 authorization server + bearer-gated MCP server on one app; exports `build_app()`. -`identity_assertion_enabled=True` turns the RFC 7523 jwt-bearer grant on, and the provider's -`exchange_identity_assertion` validates the IdP-signed ID-JAG and mints an access token bound to -the user and resource the assertion names. The MCP server half is ordinary bearer auth. +`identity_assertion_enabled=True` turns on the RFC 7523 jwt-bearer grant; the provider's +`exchange_identity_assertion` validates the IdP-signed ID-JAG and mints a bound access token. """ import jwt @@ -22,9 +21,8 @@ DEMO_CLIENT_ID = "finance-agent" DEMO_CLIENT_SECRET = "demo-finance-agent-secret" DEMO_SCOPE = "mcp" -# The exact `issuer` string this authorization server's metadata serves. The client must configure -# the byte-identical string: RFC 8414 issuer comparison is character for character, and the -# settings' `AnyHttpUrl` renders the path-less loopback origin with a trailing slash. +# Clients must configure this byte-identical string: RFC 8414 issuer comparison is character for +# character, and the settings' `AnyHttpUrl` renders the path-less loopback origin with a trailing slash. ISSUER = str(auth_settings().issuer_url) @@ -75,10 +73,9 @@ async def exchange_identity_assertion( if claims["jti"] in self.seen_jtis: raise TokenError("invalid_grant", "the assertion has already been used") self.seen_jtis.add(claims["jti"]) - # Everything on the issued token comes from the validated assertion, the audience - # restriction above all: it binds the token to the ID-JAG's `resource` claim, never to - # the client-controlled `params.resource`. No refresh token is returned either; the IdP - # owns session lifetime by deciding whether to issue the next ID-JAG. + # Everything on the issued token comes from the validated assertion β€” bound to the ID-JAG's + # `resource` claim, never the client-controlled `params.resource`. No refresh token either: + # the IdP owns session lifetime by deciding whether to issue the next ID-JAG. scopes = claims["scope"].split() access = self.mint_access_token( client_id=claims["client_id"], scopes=scopes, resource=claims["resource"], subject=claims["sub"] @@ -88,8 +85,7 @@ async def exchange_identity_assertion( def build_app() -> Starlette: provider = IdentityAssertionAuthorizationServer() - # `auth_server_provider=` alone is enough: MCPServer derives a token verifier from it - # (passing both trips the mutex guard). + # `auth_server_provider=` alone is enough: MCPServer derives a token verifier (passing both trips the mutex guard). mcp = MCPServer( "identity-assertion-example", auth=auth_settings(required_scopes=[DEMO_SCOPE], identity_assertion_enabled=True), diff --git a/examples/stories/json_response/client.py b/examples/stories/json_response/client.py index 08af5ef914..d321a66489 100644 --- a/examples/stories/json_response/client.py +++ b/examples/stories/json_response/client.py @@ -1,9 +1,6 @@ -"""Plain ``Client`` against a JSON-only server: mid-call progress drops. HTTP-only β€” ``main`` also takes ``http``. +"""Plain `Client` against a JSON-only server: mid-call progress drops. HTTP-only β€” `main` also takes `http`. -``RAW_ENVELOPE_BODY`` / ``MODERN_HEADERS`` are the exact wire shape a 2026-era client -sends β€” this is the only story that shows it. ``main`` posts that body by hand and -asserts the response is a single ``application/json`` body with no session id. -""" +`RAW_ENVELOPE_BODY`/`MODERN_HEADERS` are the exact wire shape a 2026-era client sends β€” the only story that shows it.""" import httpx from mcp_types import TextContent @@ -12,11 +9,9 @@ from mcp.client import Client from stories._harness import Target, run_client -# The raw 2026-07-28 POST envelope: per-request `_meta` replaces the initialize handshake. -# The key/header strings are spelled out on purpose β€” this is the raw-wire story. In code -# use the named constants instead: `mcp_types.PROTOCOL_VERSION_META_KEY` / -# `CLIENT_INFO_META_KEY` / `CLIENT_CAPABILITIES_META_KEY` and -# `mcp.shared.inbound.MCP_PROTOCOL_VERSION_HEADER` (`legacy_routing/` shows that form). +# Raw 2026-07-28 POST envelope: per-request `_meta` replaces the initialize handshake. The literal +# key/header strings are deliberate here; real code uses the `*_META_KEY` constants from `mcp_types` +# and `mcp.shared.inbound.MCP_PROTOCOL_VERSION_HEADER` (`legacy_routing/` shows that form). RAW_ENVELOPE_BODY: dict[str, object] = { "jsonrpc": "2.0", "id": 1, diff --git a/examples/stories/json_response/server.py b/examples/stories/json_response/server.py index c09aca78f3..25ebf66c3d 100644 --- a/examples/stories/json_response/server.py +++ b/examples/stories/json_response/server.py @@ -1,8 +1,7 @@ -"""Serve over Streamable HTTP with JSON responses (no SSE stream); HTTP-only, so this exports ``build_app()``. +"""Serve over Streamable HTTP with JSON responses (no SSE stream); HTTP-only, so this exports `build_app()`. -The 2026-07-28 path is stateless and JSON-only by construction today; the -``json_response=True`` flag also forces JSON for the legacy (2025-era) branch on -the same endpoint. Mid-call notifications are dropped. +The 2026-07-28 path is stateless and JSON-only by construction; `json_response=True` also forces +JSON for the legacy (2025-era) branch on the same endpoint. Mid-call notifications are dropped. """ from starlette.applications import Starlette diff --git a/examples/stories/legacy_elicitation/client.py b/examples/stories/legacy_elicitation/client.py index 52bb95e516..53f9fbe355 100644 --- a/examples/stories/legacy_elicitation/client.py +++ b/examples/stories/legacy_elicitation/client.py @@ -8,9 +8,8 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: if isinstance(params, types.ElicitRequestURLParams): - # A real client would ask consent and open params.url in a browser, returning - # `accept` right away; the server's notifications/elicitation/complete arrives - # afterward (once the out-of-band flow finishes) for the client to correlate. + # A real client would ask consent and open params.url in a browser, returning `accept` right away; + # notifications/elicitation/complete arrives once the out-of-band flow finishes, for the client to correlate. assert params.url.startswith("https://example.com/") return types.ElicitResult(action="accept") assert "username" in params.requested_schema["properties"] diff --git a/examples/stories/legacy_routing/client.py b/examples/stories/legacy_routing/client.py index b9b401a2d3..f874ec8548 100644 --- a/examples/stories/legacy_routing/client.py +++ b/examples/stories/legacy_routing/client.py @@ -20,20 +20,18 @@ def _arm(result: types.CallToolResult) -> str: async def main(targets: TargetFactory, *, mode: str = "auto") -> None: - # ── modern arm: the caller's mode (the real-user "auto" default) probes - # ``server/discover`` β†’ the stateless 2026 path. + # Modern arm: the default `auto` mode probes `server/discover` β†’ the stateless 2026 path. async with Client(targets(), mode=mode) as modern: assert modern.protocol_version == LATEST_MODERN_VERSION assert _arm(await modern.call_tool("which_arm", {})) == "modern" - # ── legacy arm: the SAME /mcp endpoint, ``initialize`` handshake β†’ sessionful 2025 path. + # Legacy arm: the SAME /mcp endpoint, `initialize` handshake β†’ sessionful 2025 path. async with Client(targets(), mode="legacy") as legacy: assert legacy.protocol_version == LATEST_HANDSHAKE_VERSION assert _arm(await legacy.call_tool("which_arm", {})) == "legacy" - # ── the exported predicate, shown directly. A 2026 _meta envelope whose - # `Mcp-Protocol-Version`/`Mcp-Method` headers mirror it is modern; a bare - # initialize body is legacy; a header that disagrees is a rejection (NOT legacy). + # The exported predicate: a 2026 _meta envelope with matching `Mcp-Protocol-Version`/`Mcp-Method` + # headers is modern; a bare initialize body is legacy; a header that disagrees is a rejection (NOT legacy). modern_body: dict[str, Any] = { "jsonrpc": "2.0", "id": 1, diff --git a/examples/stories/legacy_routing/server.py b/examples/stories/legacy_routing/server.py index 79cc2afa67..bd28731311 100644 --- a/examples/stories/legacy_routing/server.py +++ b/examples/stories/legacy_routing/server.py @@ -1,4 +1,4 @@ -"""Exported era classifier: the body-primary predicate, the built-in dual-era app, and CORS β€” exports `build_app()`.""" +"""Dual-era routing: a standalone era classifier for custom ingress, plus the built-in dual-era app with CORS.""" from collections.abc import Mapping from typing import Any, Literal @@ -23,11 +23,10 @@ def classify_era( body: Mapping[str, Any], headers: Mapping[str, str] ) -> Literal["modern", "legacy"] | InboundLadderRejection: - """Tri-state era classifier built on the exported `classify_inbound_request` predicate. + """Tri-state era classifier for ingress layers that route the two eras to different backends. - Compose this in your own ASGI/ingress layer when the two eras need different - backends. Only a rung-1 ``INVALID_PARAMS`` rejection (no envelope keys) means - "treat as legacy"; other rejections are malformed-modern and should be refused. + Only a rung-1 `INVALID_PARAMS` rejection (no envelope keys) means legacy; + other rejections are malformed-modern and should be refused. """ verdict = classify_inbound_request(body, headers=headers) if isinstance(verdict, InboundModernRoute): diff --git a/examples/stories/middleware/client.py b/examples/stories/middleware/client.py index 60ebbbc305..e25d90496f 100644 --- a/examples/stories/middleware/client.py +++ b/examples/stories/middleware/client.py @@ -13,13 +13,11 @@ async def main(target: Target, *, mode: str = "auto") -> None: assert not result.is_error assert result.structured_content is not None, result - # Era-neutral: legacy adds initialize + notifications/initialized; modern HTTP - # adds server/discover; modern in-memory adds nothing. Filter to the methods - # this client drove. + # The log also holds era-dependent bookkeeping (legacy: initialize + notifications/initialized; + # modern HTTP: server/discover). Keep only the tools/* methods this client drove. seen = [m for m in result.structured_content["result"] if m.startswith("tools/")] - # The tail ends at tools/call with no :done β€” the handler ran inside the - # middleware frame. Assert the tail (not the whole list) so a re-run against - # a long-lived server, whose log accumulates across clients, still passes. + # No :done after tools/call β€” the handler ran inside the middleware frame. Assert + # only the tail: a long-lived server's log accumulates across clients. assert seen[-3:] == ["tools/list", "tools/list:done", "tools/call"], seen diff --git a/examples/stories/middleware/server.py b/examples/stories/middleware/server.py index 076120dccd..60de3fc70f 100644 --- a/examples/stories/middleware/server.py +++ b/examples/stories/middleware/server.py @@ -1,7 +1,6 @@ -"""Dispatch-layer middleware: `Server.middleware` is the public hook. +"""Dispatch-layer middleware via the `middleware` list on lowlevel `Server`. -A lowlevel-only story: `MCPServer` has no public middleware accessor yet, so the -one supported registration point is the `middleware` list on `lowlevel.Server`. +`MCPServer` has no public middleware accessor yet, so this story is lowlevel-only. """ import json diff --git a/examples/stories/mrtr/client.py b/examples/stories/mrtr/client.py index 5b686c3c9c..dc8c875add 100644 --- a/examples/stories/mrtr/client.py +++ b/examples/stories/mrtr/client.py @@ -7,8 +7,7 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: - # The same callback serves legacy push-style elicitation/create requests AND embedded - # InputRequiredResult.input_requests entries β€” the driver dispatches both here. + # One callback serves both legacy elicitation/create requests and embedded input_requests entries. assert isinstance(params, types.ElicitRequestFormParams) assert "confirm" in params.requested_schema["properties"] return types.ElicitResult(action="accept", content={"confirm": True}) @@ -16,14 +15,13 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa async def main(target: Target, *, mode: str = "auto") -> None: async with Client(target, mode=mode, elicitation_callback=on_elicit) as client: - # ── auto-loop: Client.call_tool dispatches input_requests to on_elicit and retries - # internally; the caller just sees the final CallToolResult. + # Auto-loop: call_tool dispatches input_requests to on_elicit and retries; the caller sees the final result. deployed = await client.call_tool("deploy", {"env": "production"}) assert isinstance(deployed.content[0], types.TextContent) assert deployed.content[0].text == "deployed to production", deployed - # ── manual loop: drop to client.session for the raw InputRequiredResult so the - # request_state can be persisted between rounds (e.g. across a process restart). + # Manual loop: client.session yields the raw InputRequiredResult so request_state + # can be persisted between rounds (e.g. across a process restart). first = await client.session.call_tool("deploy", {"env": "staging"}, allow_input_required=True) assert isinstance(first, types.InputRequiredResult) assert first.input_requests is not None and "confirm" in first.input_requests diff --git a/examples/stories/mrtr/server.py b/examples/stories/mrtr/server.py index d83c2e9835..91ac79a272 100644 --- a/examples/stories/mrtr/server.py +++ b/examples/stories/mrtr/server.py @@ -19,8 +19,7 @@ def build_server() -> MCPServer: async def deploy(env: str, ctx: Context) -> str | InputRequiredResult: responses = ctx.input_responses if responses is None or "confirm" not in responses: - # First round: ask the client to elicit confirmation. request_state is opaque - # to the client; here it carries the step name so the retry can verify the echo. + # First round: request_state is opaque to the client and carries the step name for the retry to verify. ask = ElicitRequest( params=ElicitRequestFormParams(message=f"Deploy to {env}?", requested_schema=CONFIRM_SCHEMA) ) diff --git a/examples/stories/oauth/client.py b/examples/stories/oauth/client.py index c55307f633..c444432853 100644 --- a/examples/stories/oauth/client.py +++ b/examples/stories/oauth/client.py @@ -8,17 +8,14 @@ from mcp.shared.auth import OAuthClientMetadata from stories._harness import TargetFactory, run_client -# MCP_URL pins the resource to :8000. The demo AS's own metadata (issuer, PRM `resource`) -# is built from the same constant on the server side, so the whole story is bound to that -# port β€” run the server on 8000 or both halves of the discovery chain point at the wrong origin. +# The demo AS builds its issuer and PRM `resource` from the same MCP_URL constant, so the story is pinned to :8000. from stories._shared.auth import MCP_URL, REDIRECT_URI, HeadlessOAuth, InMemoryTokenStorage def build_auth(http_client: httpx.AsyncClient) -> httpx.Auth: - """An `OAuthClientProvider` over fresh storage, completing the authorize redirect headlessly. + """Build an `OAuthClientProvider` that completes the authorize redirect headlessly. - `Client(url, auth=...)` doesn't exist yet, so the harness threads this onto the underlying - `httpx.AsyncClient` and every target `main` receives is already routed through it. + `Client(url, auth=...)` doesn't exist yet; the harness threads this onto the underlying `httpx.AsyncClient`. """ headless = HeadlessOAuth() headless.bind(http_client) @@ -36,21 +33,17 @@ def build_auth(http_client: httpx.AsyncClient) -> httpx.Auth: async def main(targets: TargetFactory, *, mode: str = "auto") -> None: - # The target is already authed with build_auth's OAuthClientProvider. The first request to - # hit the wire 401s, and the provider walks PRM discovery β†’ AS metadata β†’ DCR β†’ PKCE - # authorize β†’ token exchange β†’ bearer retry before any result reaches this body. No - # UnauthorizedError ever surfaces. + # The first request 401s and the provider transparently walks PRM discovery β†’ AS metadata β†’ + # DCR β†’ PKCE authorize β†’ token exchange β†’ bearer retry; no UnauthorizedError surfaces here. async with Client(targets(), mode=mode) as client: first = await client.call_tool("whoami", {}) assert first.structured_content is not None assert "mcp" in first.structured_content["scopes"], first registered_id = first.structured_content["client_id"] - # A Client cannot be re-entered after __aexit__; reconnecting means constructing a new one. - # The provider's TokenStorage persisted both the issued tokens and the DCR registration, so - # this connection sends `Authorization: Bearer ...` on its very first request β€” no second - # /authorize, no second /register. The demo AS mints a fresh client_id per DCR call, so the - # same principal coming back IS the reuse proof. + # A Client can't be re-entered after `__aexit__`; reconnecting means a new one. TokenStorage kept + # the tokens and DCR registration, so this connection sends a bearer token on its first request β€” + # and since the demo AS mints a fresh client_id per DCR call, a matching client_id proves reuse. async with Client(targets(), mode=mode) as reconnected: again = await reconnected.call_tool("whoami", {}) assert again.structured_content is not None diff --git a/examples/stories/oauth/server.py b/examples/stories/oauth/server.py index 6d4c706b00..2387dfb014 100644 --- a/examples/stories/oauth/server.py +++ b/examples/stories/oauth/server.py @@ -15,12 +15,10 @@ class Principal(BaseModel): def build_app() -> Starlette: - # The provider is both the Authorization Server (DCR/authorize/token) and the - # token store the bearer middleware validates against β€” one in-memory dict. + # The provider is both the Authorization Server and the token store the bearer middleware validates against. provider = InMemoryAuthorizationServerProvider() - # ``auth_server_provider=`` alone is enough β€” MCPServer derives a token verifier - # from it (passing both trips the mutex guard). + # `auth_server_provider=` alone is enough β€” MCPServer derives a token verifier from it (passing both is an error). mcp = MCPServer( "oauth-example", auth=auth_settings(required_scopes=["mcp"]), diff --git a/examples/stories/oauth/server_lowlevel.py b/examples/stories/oauth/server_lowlevel.py index 0bc7799c1e..14e57636ae 100644 --- a/examples/stories/oauth/server_lowlevel.py +++ b/examples/stories/oauth/server_lowlevel.py @@ -44,8 +44,7 @@ async def call_tool(ctx: ServerRequestContext[Any], params: types.CallToolReques return types.CallToolResult(content=[types.TextContent(text=token.client_id)], structured_content=payload) server = Server("oauth-example", on_list_tools=list_tools, on_call_tool=call_tool) - # Unlike MCPServer (auth on the constructor), lowlevel.Server takes auth as - # streamable_http_app() kwargs β€” same wired routes, different entry point. + # Unlike MCPServer (auth on the constructor), lowlevel.Server takes auth as streamable_http_app() kwargs. return server.streamable_http_app( auth=auth_settings(required_scopes=["mcp"]), token_verifier=ProviderTokenVerifier(provider), diff --git a/examples/stories/oauth_client_credentials/client.py b/examples/stories/oauth_client_credentials/client.py index 318523ee70..7c04308e2b 100644 --- a/examples/stories/oauth_client_credentials/client.py +++ b/examples/stories/oauth_client_credentials/client.py @@ -1,4 +1,4 @@ -"""HTTP-only: ``build_auth`` returns a ``ClientCredentialsOAuthProvider``; ``whoami`` round-trips client_id + scopes.""" +"""HTTP-only: `build_auth` returns a `ClientCredentialsOAuthProvider`; `whoami` round-trips client_id + scopes.""" import httpx @@ -6,20 +6,17 @@ from mcp.client.auth.extensions.client_credentials import ClientCredentialsOAuthProvider from stories._harness import Target, run_client -# MCP_URL pins the resource to :8000, and the server side builds its PRM/AS metadata from -# the same constant β€” run the server on 8000 or the discovery chain points at the wrong origin. +# The server builds PRM/AS metadata from this same MCP_URL β€” run it on :8000 or discovery points at the wrong origin. from stories._shared.auth import MCP_URL, InMemoryTokenStorage from .server import DEMO_CLIENT_ID, DEMO_CLIENT_SECRET, DEMO_SCOPE def build_auth(_http: httpx.AsyncClient) -> httpx.Auth: - """The ``httpx.Auth`` for the ``client_credentials`` grant β€” five lines of provider config. + """Build the `httpx.Auth` for the `client_credentials` grant. - The SDK then handles 401 β†’ RFC 9728 PRM β†’ RFC 8414 AS-metadata discovery β†’ token POST β†’ - Bearer attachment automatically. ``Client(url)`` has no ``auth=`` passthrough yet, so the - harness threads this onto the transport's ``httpx.AsyncClient`` and hands ``main`` the - already-authed ``target``. + The SDK drives 401 β†’ RFC 9728 PRM β†’ RFC 8414 AS metadata β†’ token POST β†’ Bearer. `Client(url)` has + no `auth=` passthrough yet, so the harness threads this onto the transport's `httpx.AsyncClient`. """ return ClientCredentialsOAuthProvider( server_url=MCP_URL, diff --git a/examples/stories/oauth_client_credentials/server.py b/examples/stories/oauth_client_credentials/server.py index 7e3d910e8f..0cd993c19d 100644 --- a/examples/stories/oauth_client_credentials/server.py +++ b/examples/stories/oauth_client_credentials/server.py @@ -1,4 +1,4 @@ -"""Bearer-gated resource server + a minimal in-process ``client_credentials`` AS, one app; exports ``build_app()``.""" +"""Bearer-gated resource server + a minimal in-process `client_credentials` AS, one app; exports `build_app()`.""" import base64 import secrets diff --git a/examples/stories/oauth_client_credentials/server_lowlevel.py b/examples/stories/oauth_client_credentials/server_lowlevel.py index ba2003dedf..71a2c25325 100644 --- a/examples/stories/oauth_client_credentials/server_lowlevel.py +++ b/examples/stories/oauth_client_credentials/server_lowlevel.py @@ -1,4 +1,4 @@ -"""Bearer-gated MCP resource server (lowlevel API) + the same minimal ``client_credentials`` AS.""" +"""Bearer-gated MCP resource server (lowlevel API) + the same minimal `client_credentials` AS.""" import base64 import json diff --git a/examples/stories/parallel_calls/client.py b/examples/stories/parallel_calls/client.py index 945e5410a6..90db824fc4 100644 --- a/examples/stories/parallel_calls/client.py +++ b/examples/stories/parallel_calls/client.py @@ -16,16 +16,14 @@ async def attend(tag: str) -> None: async def on_progress(progress: float, total: float | None, message: str | None) -> None: received[tag].append(message) - # targets() yields a fresh connection target on every call; both land on the SAME - # server instance, so the two `meet` handlers can observe each other's arrival. + # Each targets() call is a fresh connection to the SAME server, so the two `meet` handlers can rendezvous. async with Client(targets(), mode=mode) as client: result = await client.call_tool("meet", {"tag": tag, "party": party}, progress_callback=on_progress) assert not result.is_error, result assert isinstance(result.content[0], TextContent) results[tag] = result.content[0].text - # Neither call can return until both handlers are running at once; a server that processed - # requests one-at-a-time would never set the second event and we'd time out here. + # Neither call returns until both handlers run concurrently β€” a serial server would never set the second event. with anyio.fail_after(5): async with anyio.create_task_group() as tg: tg.start_soon(attend, "a") diff --git a/examples/stories/parallel_calls/server.py b/examples/stories/parallel_calls/server.py index dc6d805e4a..d13f2fbb6e 100644 --- a/examples/stories/parallel_calls/server.py +++ b/examples/stories/parallel_calls/server.py @@ -10,9 +10,8 @@ def build_server() -> MCPServer: mcp = MCPServer("parallel-calls-example") - # One Event per tag, shared across every call to this server instance. A handler sets its - # own tag's event, then waits for every peer's β€” so no call can return until all named - # peers are concurrently in-flight. A sequential dispatcher would deadlock here. + # One Event per tag, shared across calls. A handler sets its own tag's event, then waits for every + # peer's β€” no call returns until all peers are concurrently in-flight; sequential dispatch would deadlock. arrivals: dict[str, anyio.Event] = defaultdict(anyio.Event) @mcp.tool() diff --git a/examples/stories/reconnect/client.py b/examples/stories/reconnect/client.py index aab2312dc9..8e7dc57a52 100644 --- a/examples/stories/reconnect/client.py +++ b/examples/stories/reconnect/client.py @@ -8,9 +8,7 @@ async def main(targets: TargetFactory, *, mode: str = "auto") -> None: - # The caller's mode (the real-user "auto" default) probes server/discover inside - # __aenter__ and caches the result; a hard version pin would skip the probe and - # never see the server's real DiscoverResult. + # mode="auto" probes server/discover inside __aenter__ and caches the result; a hard version pin skips the probe. async with Client(targets(), mode=mode) as client: discovered = client.session.discover_result assert discovered is not None, "mode='auto' against a modern server populates discover_result" @@ -26,10 +24,8 @@ async def main(targets: TargetFactory, *, mode: str = "auto") -> None: rehydrated = DiscoverResult.model_validate_json(saved) assert rehydrated == discovered - # Reconnect: a version pin plus the cached DiscoverResult adopts the prior state with - # zero round-trips on entry. A Client cannot be re-entered after exit, so targets() - # yields a fresh one. Without prior_discover= a bare pin would synthesize a blank - # server_info β€” the cache is what makes the era-neutral accessors useful here. + # Reconnect: a version pin plus prior_discover= adopts the prior state with zero round-trips; a bare pin + # would synthesize a blank server_info. A Client cannot be re-entered after exit, so targets() yields a fresh one. async with Client(targets(), mode=LATEST_MODERN_VERSION, prior_discover=rehydrated) as second: assert second.protocol_version == LATEST_MODERN_VERSION assert second.server_info.name == "reconnect-example" diff --git a/examples/stories/refund_desk/client.py b/examples/stories/refund_desk/client.py index 0ff8d28fca..dce2a3f69e 100644 --- a/examples/stories/refund_desk/client.py +++ b/examples/stories/refund_desk/client.py @@ -7,7 +7,6 @@ async def main(target: Target, *, mode: str = "auto") -> None: - # Scripted answers + per-topic counters; topics in `declines` are refused. counts = {"scope": 0, "restock": 0} answers: dict[str, dict[str, str | int | float | bool | list[str] | None]] = { "scope": {"full": True}, @@ -40,10 +39,8 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa }, receipt.structured_content assert counts == {"scope": 0, "restock": 0}, counts - # Full refund of a three-line order. The scope question fires exactly ONCE even though - # both refund_amount and ask_restock consume it β€” asked at most once per call on either - # era. ask_restock needs the scope ANSWER, so at 2026 the two questions land in - # successive rounds, never one concurrent batch: counts and order are era-independent. + # Scope fires exactly ONCE per call even though refund_amount and ask_restock both consume it. + # ask_restock needs scope's ANSWER, so at 2026 the two land in successive rounds β€” era-independent. receipt = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "arrived broken"}) assert receipt.structured_content == { "order_id": "ORD-7002", @@ -53,9 +50,8 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa }, receipt.structured_content assert counts == {"scope": 1, "restock": 1}, counts - # Declining restock still refunds: the tool keeps the ElicitationResult union for - # `restock`, sees the decline, and just skips the restock. The scope counter moves - # again β€” questions are deduped per call, not per connection. + # Declining restock still refunds: the tool takes `restock` as an ElicitationResult union and + # skips the restock on decline. Scope is asked again β€” deduped per call, not per connection. declines.add("restock") answers["scope"] = {"full": False, "sku": "canvas-tote"} receipt = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "wrong colour"}) @@ -68,17 +64,15 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa assert counts == {"scope": 2, "restock": 2}, counts declines.clear() - # An elicited SKU is human-typed: the server validates it against the order before - # any money is computed. + # An elicited SKU is human-typed, so the server validates it against the order before computing money. answers["scope"] = {"full": False, "sku": "mystery-hat"} result = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "lost parcel"}) assert result.is_error, result assert isinstance(result.content[0], types.TextContent) assert "order has no item 'mystery-hat'" in result.content[0].text, result.content[0].text - # Declining scope aborts the whole call: refund_amount and ask_restock both consume scope - # unwrapped, so whichever resolves first (`cents`, in signature order) aborts, and - # ask_restock never runs under any order. + # Declining scope aborts the whole call: both resolvers consume scope unwrapped, so whichever + # resolves first aborts and ask_restock never runs. declines.add("scope") restock_before = counts["restock"] result = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "changed mind"}) @@ -96,8 +90,7 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa assert isinstance(result.content[0], types.TextContent) assert "unknown order 'ORD-9999'" in result.content[0].text, result.content[0].text - # Full elicitation trajectory: scope fired in legs 2-5 (memoized within each call), - # restock only in the two calls that reached it. + # Final tally: scope fired in legs 2-5, restock only in the two calls that reached it. assert counts == {"scope": 4, "restock": 2}, counts diff --git a/examples/stories/refund_desk/server.py b/examples/stories/refund_desk/server.py index f29a266f0b..b0ef9ee023 100644 --- a/examples/stories/refund_desk/server.py +++ b/examples/stories/refund_desk/server.py @@ -112,9 +112,8 @@ def refund_order( cents: Annotated[int, Resolve(refund_amount)], restock: Annotated[ElicitationResult[RestockChoice], Resolve(ask_restock)], ) -> Receipt: - # `restock` keeps the full elicitation outcome: a declined restock still refunds. A plain - # (non-Elicit) resolver return arrives wrapped as an accepted outcome, so the fast path - # lands in the same `AcceptedElicitation` branch. + # `restock` keeps the full elicitation outcome: a declined restock still refunds. A non-Elicit + # resolver return arrives wrapped as accepted, so the fast path hits the same `AcceptedElicitation` branch. restocked = isinstance(restock, AcceptedElicitation) and restock.data.restock return Receipt(order_id=order_id, refunded_cents=cents, restocked=restocked, reason=reason) diff --git a/examples/stories/sampling/client.py b/examples/stories/sampling/client.py index 0ca88db996..00268ce4aa 100644 --- a/examples/stories/sampling/client.py +++ b/examples/stories/sampling/client.py @@ -7,8 +7,7 @@ async def on_sample(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: - # A real host would call its LLM provider here; the example returns a deterministic - # canned answer so the round-trip is assertable. + # A real host would call its LLM provider here; a canned answer keeps the round-trip assertable. return CreateMessageResult( role="assistant", content=TextContent(text="[canned summary]"), diff --git a/examples/stories/schema_validators/server.py b/examples/stories/schema_validators/server.py index 8648e211df..306fb3441d 100644 --- a/examples/stories/schema_validators/server.py +++ b/examples/stories/schema_validators/server.py @@ -5,8 +5,7 @@ from pydantic import BaseModel -# pydantic requires typing_extensions.TypedDict (not typing.TypedDict) on Python < 3.12 -# when a TypedDict is used as a field/parameter type. +# pydantic requires typing_extensions.TypedDict (not typing.TypedDict) for parameter types on Python < 3.12. from typing_extensions import TypedDict from mcp.server.mcpserver import MCPServer diff --git a/examples/stories/schema_validators/server_lowlevel.py b/examples/stories/schema_validators/server_lowlevel.py index 02dca8d162..93c7b3ed16 100644 --- a/examples/stories/schema_validators/server_lowlevel.py +++ b/examples/stories/schema_validators/server_lowlevel.py @@ -8,8 +8,7 @@ from mcp.server.lowlevel import Server from stories._hosting import run_server_from_args -# With lowlevel.Server there is no reflection layer: you author the JSON Schema -# yourself and validate/unpack `params.arguments` in the handler. +# lowlevel.Server has no reflection layer: author the JSON Schema and validate `params.arguments` yourself. PERSON_SCHEMA: dict[str, Any] = { "type": "object", "properties": {"name": {"type": "string"}, "title": {"type": "string"}}, diff --git a/examples/stories/serve_one/client.py b/examples/stories/serve_one/client.py index 73bd457e10..28553a17b3 100644 --- a/examples/stories/serve_one/client.py +++ b/examples/stories/serve_one/client.py @@ -9,8 +9,7 @@ async def main(target: Target, *, mode: str = "auto") -> None: - # ── direct: the namesake recipe β€” Connection.from_envelope + serve_one β†’ raw result dict. - # The entry enters lifespan once and threads it to every per-request handle_one(). + # Direct: Connection.from_envelope + serve_one β†’ raw result dict; lifespan entered once, threaded to handle_one(). server = build_server() params = { "name": "add", @@ -26,7 +25,7 @@ async def main(target: Target, *, mode: str = "auto") -> None: assert raw["structuredContent"] == {"result": 5}, raw assert raw["content"][0] == {"type": "text", "text": "5"}, raw - # ── over the wire: the loop-mode driver behind the connected client. + # Over the wire: the loop-mode driver behind the connected client. async with Client(target, mode=mode) as client: listed = await client.list_tools() assert [t.name for t in listed.tools] == ["add"] diff --git a/examples/stories/serve_one/server.py b/examples/stories/serve_one/server.py index 447e4a82b8..5b108391f3 100644 --- a/examples/stories/serve_one/server.py +++ b/examples/stories/serve_one/server.py @@ -1,11 +1,7 @@ """serve_one / serve_connection mechanics: the kernel drivers a transport entry composes. -`handle_one()` is the modern single-exchange recipe (`Connection.from_envelope` -+ `serve_one` β†’ raw result dict). `main()` is the loop recipe -(`JSONRPCDispatcher` + `Connection.for_loop` + `serve_connection`) β€” what -`Server.run()` does for stdio. Both drivers take a `lowlevel.Server`, so this is -a lowlevel-only story: `MCPServer` has no public accessor for its underlying -`Server` yet. +`handle_one()` is the single-exchange recipe; `main()` is the loop recipe (what `Server.run()` +does for stdio). Lowlevel-only: `MCPServer` has no public accessor for its underlying `Server` yet. """ from collections.abc import Mapping @@ -48,8 +44,7 @@ async def call_tool(ctx: ServerRequestContext[Any], params: types.CallToolReques class SingleExchangeContext: """Minimal `DispatchContext` for one inbound request with no back-channel. - A custom transport entry hand-builds one of these per request. The SDK - ships no public concrete class for this yet; this is the structural minimum. + A custom transport entry hand-builds one per request; the SDK ships no public concrete class yet. """ request_id: int | str | None @@ -73,10 +68,8 @@ async def handle_one( ) -> dict[str, Any]: """Serve exactly one modern-era request and return its raw result dict. - Reads the envelope from `params._meta` (the 2026 wire shape), builds a - born-ready `Connection.from_envelope`, and drives `serve_one`. The transport - entry enters `server.lifespan(server)` once and threads `lifespan_state` to - every call β€” never enter the lifespan per-request. + The envelope rides in `params._meta` (the 2026 wire shape). Enter `server.lifespan(server)` + once and thread `lifespan_state` to every call β€” never enter the lifespan per-request. """ meta = params.get("_meta", {}) connection = Connection.from_envelope( diff --git a/examples/stories/sse_polling/client.py b/examples/stories/sse_polling/client.py index d2f3918952..9efadf2d48 100644 --- a/examples/stories/sse_polling/client.py +++ b/examples/stories/sse_polling/client.py @@ -17,14 +17,12 @@ async def on_progress(progress: float, total: float | None, message: str | None) with anyio.fail_after(10): result = await client.call_tool("long_operation", {}, progress_callback=on_progress) - # The result arrived β€” the client transport survived the server-initiated close, - # reconnected with Last-Event-ID, and received the replayed response. + # The transport survived the server's close: it reconnected with Last-Event-ID and got the replayed response. assert not result.is_error, result assert isinstance(result.content[0], TextContent) assert result.content[0].text == "resumed" - # "after-close" was emitted while no SSE stream was open; receiving it proves the - # event store buffered it and the reconnect replayed it. + # "after-close" was emitted while no SSE stream was open β€” proof the event store buffered it for the replay. assert messages == ["before-close", "after-close"], messages diff --git a/examples/stories/sse_polling/event_store.py b/examples/stories/sse_polling/event_store.py index 95d2b8accf..0e36c13409 100644 --- a/examples/stories/sse_polling/event_store.py +++ b/examples/stories/sse_polling/event_store.py @@ -1,7 +1,6 @@ """Minimal in-memory `EventStore` for the SSE-resumability example. -Sequential integer IDs so the wire is readable; a production server would back -this interface with persistent storage so replay survives a process restart. +Sequential integer IDs keep the wire readable; production would use persistent storage so replay survives restarts. """ from mcp_types import JSONRPCMessage diff --git a/examples/stories/sse_polling/server.py b/examples/stories/sse_polling/server.py index 1098ca6d56..5c9a3b14d4 100644 --- a/examples/stories/sse_polling/server.py +++ b/examples/stories/sse_polling/server.py @@ -12,11 +12,7 @@ def build_app() -> Starlette: @mcp.tool() async def long_operation(ctx: Context) -> str: - """Emit progress, close this call's SSE stream, emit more progress, then return. - - Everything sent after `close_sse_stream()` lands in the event store and is - replayed when the client reconnects with `Last-Event-ID`. - """ + """Emit progress, close this call's SSE stream, emit more progress, then return.""" await ctx.report_progress(0.5, total=1.0, message="before-close") await ctx.close_sse_stream() await ctx.report_progress(1.0, total=1.0, message="after-close") diff --git a/examples/stories/sse_polling/server_lowlevel.py b/examples/stories/sse_polling/server_lowlevel.py index fcf3199861..473b1a1b0b 100644 --- a/examples/stories/sse_polling/server_lowlevel.py +++ b/examples/stories/sse_polling/server_lowlevel.py @@ -26,8 +26,7 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: assert params.name == "long_operation" await ctx.session.report_progress(0.5, total=1.0, message="before-close") - # The transport only wires this callback when an event_store is configured and the - # negotiated version is in the 2025 era; it is None otherwise. + # Only wired when an event_store is configured and the negotiated version is in the 2025 era; None otherwise. if ctx.close_sse_stream is not None: await ctx.close_sse_stream() await ctx.session.report_progress(1.0, total=1.0, message="after-close") diff --git a/examples/stories/standalone_get/client.py b/examples/stories/standalone_get/client.py index aaf870f0e7..bdec9f738b 100644 --- a/examples/stories/standalone_get/client.py +++ b/examples/stories/standalone_get/client.py @@ -8,8 +8,7 @@ async def main(target: Target, *, mode: str = "auto") -> None: - # `message_handler` is constructor-only on `Client`, so the event it sets - # has to exist before the connection does. + # `message_handler` is constructor-only on `Client`, so the event it sets must exist before the connection does. received: list[types.ResourceListChangedNotification] = [] seen = anyio.Event() @@ -25,8 +24,7 @@ async def on_message(message: object) -> None: result = await client.call_tool("add_note", {"content": "hello"}) assert not result.is_error, result - # The notification rides the standalone GET stream, not the call's POST stream β€” - # delivery order vs the tool result is not guaranteed, so wait. + # The notification arrives on the standalone GET stream, so its order vs the tool result is not guaranteed. with anyio.fail_after(5): await seen.wait() assert len(received) == 1, received diff --git a/examples/stories/standalone_get/server.py b/examples/stories/standalone_get/server.py index 4b0c956841..b85912b0da 100644 --- a/examples/stories/standalone_get/server.py +++ b/examples/stories/standalone_get/server.py @@ -18,8 +18,7 @@ async def add_note(content: str, ctx: Context) -> str: """Register a new resource and announce it via `notifications/resources/list_changed`.""" name = f"note-{next(counter)}" mcp.add_resource(TextResource(uri=f"note://{name}", name=name, text=content)) - # MCPServer does not auto-emit on add_resource; send explicitly. With no - # related_request_id this routes to the standalone GET stream. + # Not auto-emitted on add_resource; with no related_request_id this routes to the standalone GET stream. await ctx.session.send_resource_list_changed() return f"registered {name}" diff --git a/examples/stories/starlette_mount/server.py b/examples/stories/starlette_mount/server.py index 858abc9203..bca49d445d 100644 --- a/examples/stories/starlette_mount/server.py +++ b/examples/stories/starlette_mount/server.py @@ -20,10 +20,8 @@ def greet(name: str) -> str: """Return a greeting.""" return f"Hello, {name}! (served from a Starlette sub-mount)" - # streamable_http_path="/" so Mount("/api", ...) serves the MCP endpoint at - # /api itself, not /api/mcp. The returned sub-app has its own lifespan, but - # Starlette does not run nested lifespans under Mount β€” the parent app below - # must enter mcp.session_manager.run() itself. + # streamable_http_path="/" puts the MCP endpoint at /api itself, not /api/mcp. Starlette does not run + # nested lifespans under Mount, so the parent app's lifespan below must enter mcp.session_manager.run(). mcp_app = mcp.streamable_http_app(streamable_http_path="/", transport_security=NO_DNS_REBIND) async def health(_request: Request) -> JSONResponse: diff --git a/examples/stories/stateless_legacy/client.py b/examples/stories/stateless_legacy/client.py index d21ff850cf..6cfc000c88 100644 --- a/examples/stories/stateless_legacy/client.py +++ b/examples/stories/stateless_legacy/client.py @@ -8,8 +8,7 @@ async def main(targets: TargetFactory, *, mode: str = "auto") -> None: - # ── modern era: the caller's mode (the real-user "auto" default) routes this connection - # through the 2026 envelope path. No initialize handshake, no session id. + # Modern era: the caller's mode (default "auto") takes the 2026 envelope path β€” no handshake, no session id. async with Client(targets(), mode=mode) as client: assert client.protocol_version == LATEST_MODERN_VERSION @@ -21,9 +20,8 @@ async def main(targets: TargetFactory, *, mode: str = "auto") -> None: assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Hello, world!", result - # ── legacy era: a fresh mode="legacy" client runs the initialize handshake against the - # SAME stateless app. It is answered statelessly (no Mcp-Session-Id) and the same tool - # gives the same answer β€” the era is invisible to the server body. + # Legacy era: a fresh mode="legacy" client runs the initialize handshake against the SAME app, + # answered statelessly (no Mcp-Session-Id); the era is invisible to the server body. async with Client(targets(), mode="legacy") as legacy: assert legacy.protocol_version == LATEST_HANDSHAKE_VERSION diff --git a/examples/stories/stickynotes/client.py b/examples/stories/stickynotes/client.py index 56ca10f551..3918e63286 100644 --- a/examples/stories/stickynotes/client.py +++ b/examples/stories/stickynotes/client.py @@ -25,7 +25,6 @@ async def on_message(message: object) -> None: async with Client(target, mode=mode, elicitation_callback=on_elicit, message_handler=on_message) as client: legacy = client.protocol_version in HANDSHAKE_PROTOCOL_VERSIONS - # Add two notes. first = await client.call_tool("add_note", {"text": "Buy milk"}) assert first.structured_content is not None first_id, first_uri = first.structured_content["id"], first.structured_content["uri"] @@ -35,7 +34,6 @@ async def on_message(message: object) -> None: second_id, second_uri = second.structured_content["id"], second.structured_content["uri"] assert first_id != second_id - # List + read β€” both notes appear as resources; first reads back its text. listed = await client.list_resources() uris = {str(r.uri) for r in listed.resources} assert first_uri in uris and second_uri in uris, uris @@ -48,7 +46,6 @@ async def on_message(message: object) -> None: with anyio.fail_after(5): await list_changed.wait() - # Remove one. removed = await client.call_tool("remove_note", {"note_id": first_id}) assert removed.structured_content == {"result": True} after = await client.list_resources() diff --git a/examples/stories/stickynotes/server.py b/examples/stories/stickynotes/server.py index 4c6c9d0a7e..41aa6b1f37 100644 --- a/examples/stories/stickynotes/server.py +++ b/examples/stories/stickynotes/server.py @@ -44,11 +44,9 @@ async def lifespan(_: MCPServer) -> AsyncIterator[Board]: mcp = MCPServer("stickynotes-example", lifespan=lifespan) def unregister_note(note_id: str) -> None: - # DO NOT copy this line into your own server. `MCPServer` has no public - # `remove_resource()` yet (only `add_resource`), so unregistering a runtime-added - # resource has to reach a private attribute. `server_lowlevel.py` shows the clean - # shape: `on_list_resources` rebuilds the list from the board on every call, so - # removal never touches a registry at all. + # Don't copy this: `MCPServer` has no public `remove_resource()` yet, so this reaches a + # private attribute. `server_lowlevel.py` shows the clean shape β€” `on_list_resources` + # rebuilds the list from the board on every call, so removal never touches a registry. mcp._resource_manager._resources.pop(f"note:///{note_id}", None) # pyright: ignore[reportPrivateUsage] @mcp.tool() diff --git a/examples/stories/streaming/client.py b/examples/stories/streaming/client.py index e584b4c1ef..41c92e22e2 100644 --- a/examples/stories/streaming/client.py +++ b/examples/stories/streaming/client.py @@ -8,15 +8,13 @@ async def main(target: Target, *, mode: str = "auto") -> None: - # `logging_callback` is constructor-only on `Client`, so the list it fills - # has to exist before the connection does. + # `logging_callback` is constructor-only on `Client`, so the list it fills must exist first. logs: list[LoggingMessageNotificationParams] = [] async def on_log(params: LoggingMessageNotificationParams) -> None: logs.append(params) async with Client(target, mode=mode, logging_callback=on_log) as client: - # ── progress + logging: a short countdown delivers exactly `steps` of each, in order ── updates: list[tuple[float, float | None, str | None]] = [] async def collect(progress: float, total: float | None, message: str | None) -> None: @@ -31,7 +29,7 @@ async def collect(progress: float, total: float | None, message: str | None) -> ("info", "countdown", "step 3/3"), ] - # ── cancellation: abandon the awaiting scope once the call is provably in flight ── + # Cancel mid-flight by abandoning the awaiting scope once the call is provably started. in_flight = anyio.Event() with anyio.fail_after(5): with anyio.CancelScope() as scope: diff --git a/examples/stories/streaming/server.py b/examples/stories/streaming/server.py index ced59878d7..1fb88ea74b 100644 --- a/examples/stories/streaming/server.py +++ b/examples/stories/streaming/server.py @@ -16,9 +16,8 @@ async def countdown(steps: int, ctx: Context) -> dict[str, int]: try: for i in range(1, steps + 1): await ctx.report_progress(float(i), float(steps), f"step {i}/{steps}") - # No non-deprecated logging helper on Context yet, so send the raw - # notification. `related_request_id` keeps it on this request's response - # stream (matters over streamable HTTP). + # No non-deprecated logging helper on Context yet, so send the raw notification. + # `related_request_id` keeps it on this request's response stream over streamable HTTP. await ctx.request_context.session.send_notification( types.LoggingMessageNotification( params=types.LoggingMessageNotificationParams( @@ -28,8 +27,7 @@ async def countdown(steps: int, ctx: Context) -> dict[str, int]: related_request_id=ctx.request_context.request_id, ) except anyio.get_cancelled_exc_class(): - # The client abandoned the call. Release resources here, then re-raise so - # the dispatcher unwinds the request β€” never swallow cancellation. + # The client abandoned the call: clean up here, then re-raise β€” never swallow cancellation. raise return {"completed": steps, "total": steps} diff --git a/scripts/gen_surface_types.py b/scripts/gen_surface_types.py index f338629095..1d60d7d66f 100644 --- a/scripts/gen_surface_types.py +++ b/scripts/gen_surface_types.py @@ -1,11 +1,8 @@ """Regenerate the per-version wire-shape surface packages from vendored schemas. -Runs `datamodel-code-generator` over each `schema/PINNED.json` entry and -writes the result to `src/mcp-types/mcp_types/v/__init__.py` with only the -fixes the raw output needs: a small JSON pre-patch for the known -`number`-as-`integer` schema.json defect, a header, full URLs for the spec's -site-absolute doc links, and per-version epilogue aliases. Run with -`uv run --frozen --group codegen python scripts/gen_surface_types.py [--check]`. +Runs `datamodel-code-generator` over each `schema/PINNED.json` entry, applies the minimal +fixes the raw output needs, and writes `src/mcp-types/mcp_types/v/__init__.py`. +Run with `uv run --frozen --group codegen python scripts/gen_surface_types.py [--check]`. """ from __future__ import annotations @@ -25,11 +22,10 @@ SCHEMA_DIR = REPO_ROOT / "schema" TYPES_DIR = REPO_ROOT / "src" / "mcp-types" / "mcp_types" -# schema.ts -> schema.json renders TypeScript `number` as JSON Schema -# `integer` at these sites; patch the JSON before codegen so floats validate. -# Patched to `["integer", "number"]` (not bare `"number"`) so codegen emits -# `int | float` and pydantic's smart-union preserves ints on round-trip. -# TODO: drop once modelcontextprotocol/modelcontextprotocol fixes the schema.ts -> schema.json number rendering. +# schema.ts -> schema.json renders TypeScript `number` as `integer` at these sites; patch to +# `["integer", "number"]` (not bare `"number"`) so codegen emits `int | float` and pydantic's +# smart-union preserves ints on round-trip. +# TODO: drop once modelcontextprotocol/modelcontextprotocol fixes the number rendering. SCHEMA_PATCHES: dict[str, list[tuple[str, Any, Any]]] = { "2025-11-25": [ ("$defs/NumberSchema/properties/default/type", "integer", ["integer", "number"]), @@ -41,10 +37,9 @@ ["string", "integer", "boolean"], ["string", "integer", "number", "boolean", "null"], ), - # Older python-sdk releases emit `anyOf` for Optional fields; the callback's - # own schema validation is the real gate, so accept any property shape inbound. - # PrimitiveSchemaDefinition becomes an orphan $def after this patch but - # datamodel-codegen still emits it; elicitation.py imports it as the gate type. + # Older python-sdk releases emit `anyOf` for Optional fields; the callback's own schema + # validation is the real gate, so accept any property shape inbound. PrimitiveSchemaDefinition + # becomes an orphan $def but codegen still emits it; elicitation.py imports it as the gate type. ( "$defs/ElicitRequestFormParams/properties/requestedSchema/properties/properties/additionalProperties", {"$ref": "#/$defs/PrimitiveSchemaDefinition"}, @@ -67,8 +62,7 @@ ["string", "integer", "boolean"], ["string", "integer", "number", "boolean", "null"], ), - # Older python-sdk releases emit `anyOf` for Optional fields; the callback's - # own schema validation is the real gate, so accept any property shape inbound. + # Same rationale as the 2025-11-25 ElicitRequestFormParams patch above. ( "$defs/ElicitRequestFormParams/properties/requestedSchema/properties/properties/additionalProperties", {"$ref": "#/$defs/PrimitiveSchemaDefinition"}, @@ -77,11 +71,9 @@ ], } -# Classes the spec defines as open key-value bags: `_meta` content, the -# JSON-Schema-document fields on `Tool`, and the schemas with explicit -# `additionalProperties: {}`. These keep `extra="allow"` so the sieve preserves -# arbitrary keys; every other class ignores extras. Per-version because codegen -# reuses class names across versions for unrelated schemas (e.g. `Data`). +# Classes the spec defines as open key-value bags (`_meta` content, Tool's JSON-Schema-document +# fields, explicit `additionalProperties: {}`): these keep `extra="allow"`; every other class +# ignores extras. Per-version because codegen reuses class names across versions (e.g. `Data`). OPEN_CLASSES: dict[str, frozenset[str]] = { "2025-11-25": frozenset({"Meta", "InputSchema", "OutputSchema", "Result", "GetTaskPayloadResult", "Data"}), "2026-07-28": frozenset( @@ -169,11 +161,7 @@ def run_codegen(schema_path: Path, output_path: Path) -> None: def allow_open_class_extras(source: str, open_classes: frozenset[str]) -> str: - """Restore `extra="allow"` on `open_classes` only. - - Every other class uses `extra="ignore"` so the surface acts as a sieve; - `open_classes` are the places the spec defines as open key-value bags. - """ + """Restore `extra="allow"` on `open_classes` only; every other class keeps `extra="ignore"` as a sieve.""" def patch(match: re.Match[str]) -> str: if match.group(1) not in open_classes: @@ -209,9 +197,8 @@ def build(entry: dict[str, str]) -> str: # Codegen appends `| None` to forward refs of nullable models, which is a # runtime TypeError on a string ref and redundant since `JSONValue` includes None. source = source.replace('"JSONValue" | None', '"JSONValue"') - # Schema descriptions link to spec-site pages with site-absolute paths; expand - # them to full URLs so they resolve from the rendered API docs and pass the - # strict mkdocs link validation. + # Expand the spec's site-absolute doc links to full URLs so they resolve from the + # rendered API docs and pass strict mkdocs link validation. source = source.replace("](/", "](https://modelcontextprotocol.io/") source = allow_open_class_extras(source, OPEN_CLASSES[version]) if epilogue := EPILOGUES.get(version, ""): diff --git a/scripts/update_readme_snippets.py b/scripts/update_readme_snippets.py index 99e9237a4c..74e09ef4a1 100755 --- a/scripts/update_readme_snippets.py +++ b/scripts/update_readme_snippets.py @@ -1,12 +1,7 @@ #!/usr/bin/env python3 -"""Update README.md with live code snippets from example files. +"""Update README.md code blocks marked with snippet-source comments from the referenced files. -This script finds specially marked code blocks in README.md and updates them -with the actual code from the referenced files. - -Usage: - python scripts/update_readme_snippets.py - python scripts/update_readme_snippets.py --check # Check mode for CI +Usage: python scripts/update_readme_snippets.py [--check] """ import argparse @@ -16,37 +11,21 @@ def get_github_url(file_path: str) -> str: - """Generate a GitHub URL for the file. - - Args: - file_path: Path to the file relative to repo root - - Returns: - GitHub URL - """ + """Return the GitHub URL for a repo-relative file path.""" base_url = "https://github.com/modelcontextprotocol/python-sdk/blob/main" return f"{base_url}/{file_path}" def process_snippet_block(match: re.Match[str], check_mode: bool = False) -> str: - """Process a single snippet-source block. - - Args: - match: The regex match object - check_mode: If True, return original if no changes needed - - Returns: - The updated block content - """ + """Return the regenerated block, or the original in check mode when the code is unchanged.""" full_match = match.group(0) indent = match.group(1) file_path = match.group(2) try: - # Read the entire file. A missing source file must be fatal: a "Warning" - # that returns the stale block lets --check pass with exit 0, so a - # renamed or deleted snippet is invisible to CI. SystemExit deliberately - # escapes the `except Exception` below. + # A missing source file must be fatal: returning the stale block would let --check + # exit 0, hiding renamed/deleted snippets from CI. SystemExit escapes the + # `except Exception` below. file = Path(file_path) if not file.exists(): sys.exit(f"Error: snippet-source file not found: {file_path}") @@ -54,7 +33,6 @@ def process_snippet_block(match: re.Match[str], check_mode: bool = False) -> str code = file.read_text().rstrip() github_url = get_github_url(file_path) - # Build the replacement block indented_code = code.replace("\n", f"\n{indent}") replacement = f"""{indent} {indent}```python @@ -64,13 +42,10 @@ def process_snippet_block(match: re.Match[str], check_mode: bool = False) -> str {indent}_Full example: [{file_path}]({github_url})_ {indent}""" - # In check mode, only check if code has changed if check_mode: - # Extract existing code from the match existing_content = match.group(3) if existing_content is not None: existing_lines = existing_content.strip().split("\n") - # Find code between ```python and ``` code_lines: list[str] = [] in_code = False for line in existing_lines: @@ -81,7 +56,6 @@ def process_snippet_block(match: re.Match[str], check_mode: bool = False) -> str elif in_code: code_lines.append(line) existing_code = "\n".join(code_lines).strip() - # Compare with the indented version we would generate expected_code = code.replace("\n", f"\n{indent}").strip() if existing_code == expected_code: return full_match @@ -94,13 +68,9 @@ def process_snippet_block(match: re.Match[str], check_mode: bool = False) -> str def update_readme_snippets(check_mode: bool = False) -> bool: - """Update code snippets in README.md with live code from source files. - - Args: - check_mode: If True, only check if updates are needed without modifying + """Update README.md snippet blocks from their source files. - Returns: - True if file is up to date or was updated, False if check failed + Returns False when README.md is missing or check_mode finds stale snippets. """ readme_path = Path("README.md") if not readme_path.exists(): @@ -110,13 +80,9 @@ def update_readme_snippets(check_mode: bool = False) -> bool: content = readme_path.read_text() original_content = content - # Pattern to match snippet-source blocks - # Matches: - # ... any content ... - # + # Matches `` ... `` blocks pattern = r"^(\s*)\n" r"(.*?)" r"^\1" - # Process all snippet-source blocks updated_content = re.sub( pattern, lambda m: process_snippet_block(m, check_mode), content, flags=re.MULTILINE | re.DOTALL ) @@ -141,7 +107,6 @@ def update_readme_snippets(check_mode: bool = False) -> bool: def main(): - """Main entry point.""" parser = argparse.ArgumentParser(description="Update README code snippets from source files") parser.add_argument( "--check", action="store_true", help="Check mode - verify snippets are up to date without modifying" diff --git a/src/mcp-types/mcp_types/__init__.py b/src/mcp-types/mcp_types/__init__.py index 2ed97cba33..6b7d1ea213 100644 --- a/src/mcp-types/mcp_types/__init__.py +++ b/src/mcp-types/mcp_types/__init__.py @@ -1,10 +1,8 @@ -"""This module defines the types for the MCP protocol. +"""Types for the MCP protocol. -Check the latest schema at: -https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.json +Schema: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.json """ -# Re-export everything from _types for backward compatibility from mcp_types._types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -195,8 +193,6 @@ server_request_adapter, server_result_adapter, ) - -# Re-export JSONRPC types from mcp_types.jsonrpc import ( CONNECTION_CLOSED, HEADER_MISMATCH, @@ -222,15 +218,12 @@ from mcp_types.version import LATEST_PROTOCOL_VERSION __all__ = [ - # Protocol version constants "LATEST_PROTOCOL_VERSION", "DEFAULT_NEGOTIATED_VERSION", - # Reserved request _meta keys "PROTOCOL_VERSION_META_KEY", "CLIENT_INFO_META_KEY", "CLIENT_CAPABILITIES_META_KEY", "LOG_LEVEL_META_KEY", - # Type aliases and variables "ContentBlock", "ElicitRequestedSchema", "ElicitRequestParams", @@ -247,7 +240,6 @@ "SamplingMessageContentBlock", "StopReason", "TaskStatus", - # Base classes "BaseMetadata", "Request", "Notification", @@ -261,7 +253,6 @@ "PaginatedResult", "CacheableResult", "EmptyResult", - # Capabilities "ClientCapabilities", "ClientTasksCapability", "ClientTasksRequestsCapability", @@ -288,7 +279,6 @@ "TasksToolsCapability", "ToolsCapability", "UrlElicitationCapability", - # Content types "Annotations", "AudioContent", "BlobResourceContents", @@ -302,7 +292,6 @@ "TextResourceContents", "ToolResultContent", "ToolUseContent", - # Entity types "Completion", "CompletionArgument", "CompletionContext", @@ -326,7 +315,6 @@ "ToolAnnotations", "ToolChoice", "ToolExecution", - # Requests "CallToolRequest", "CallToolRequestParams", "CompleteRequest", @@ -364,7 +352,6 @@ "SubscriptionsListenRequestParams", "UnsubscribeRequest", "UnsubscribeRequestParams", - # Results "CallToolResult", "CancelTaskResult", "CompleteResult", @@ -387,10 +374,8 @@ "ListToolsResult", "ReadResourceResult", "SubscriptionsListenResult", - # Error data payloads "MissingRequiredClientCapabilityErrorData", "UnsupportedProtocolVersionErrorData", - # Notifications "CancelledNotification", "CancelledNotificationParams", "ElicitCompleteNotification", @@ -410,21 +395,18 @@ "TaskStatusNotification", "TaskStatusNotificationParams", "ToolListChangedNotification", - # Union types for request/response routing "ClientNotification", "ClientRequest", "ClientResult", "ServerNotification", "ServerRequest", "ServerResult", - # Type adapters "client_notification_adapter", "client_request_adapter", "client_result_adapter", "server_notification_adapter", "server_request_adapter", "server_result_adapter", - # JSON-RPC types "CONNECTION_CLOSED", "HEADER_MISMATCH", "INTERNAL_ERROR", diff --git a/src/mcp-types/mcp_types/_types.py b/src/mcp-types/mcp_types/_types.py index 34dc10083b..2257288f35 100644 --- a/src/mcp-types/mcp_types/_types.py +++ b/src/mcp-types/mcp_types/_types.py @@ -24,13 +24,9 @@ from mcp_types.jsonrpc import RequestId DEFAULT_NEGOTIATED_VERSION: Final[str] = "2025-03-26" -"""The default negotiated version of the Model Context Protocol when no version is specified. +"""Protocol version the server must assume when the client specifies none, per the MCP spec. -We need this to satisfy the MCP specification, which requires the server to assume a specific version if none is -provided by the client. - -See the "Protocol Version Header" at -https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#protocol-version-header. +See https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#protocol-version-header. """ ProgressToken = str | int @@ -78,12 +74,9 @@ class RequestParamsMeta(TypedDict, extra_items=Any): """ progress_token: NotRequired[ProgressToken] - """ - If specified, the caller requests out-of-band progress notifications for - this request (as represented by notifications/progress). The value of this - parameter is an opaque token that will be attached to any subsequent - notifications. The receiver is not obligated to provide these notifications. - """ + """An opaque token requesting out-of-band progress notifications (notifications/progress) + for this request; it is attached to any subsequent notifications. The receiver is not + obligated to provide them.""" class RequestParams(MCPModel): @@ -99,18 +92,12 @@ class RequestParams(MCPModel): class PaginatedRequestParams(RequestParams): cursor: str | None = None - """An opaque token representing the current pagination position. - - If provided, the server should return results starting after this cursor. - """ + """Opaque pagination position; if provided, the server should return results starting after it.""" class NotificationParams(MCPModel): meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) @@ -130,7 +117,7 @@ class Request(MCPModel, Generic[RequestParamsT, MethodT]): class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): - """Base class for paginated requests, matching the schema's PaginatedRequest interface.""" + """Base class for paginated requests.""" params: PaginatedRequestParams | None = None """Pagination params. Required on the 2026-07-28+ wire (because `_meta` is); @@ -163,18 +150,12 @@ class Result(MCPModel): """ meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" class PaginatedResult(Result): next_cursor: str | None = None - """ - An opaque token representing the pagination position after the last returned result. - If present, there may be more results available. - """ + """Opaque pagination position after the last returned result; if present, more may be available.""" class CacheableResult(Result): @@ -217,13 +198,10 @@ class BaseMetadata(MCPModel): specs or fallback (if title isn't present).""" title: str | None = None - """ - Intended for UI and end-user contexts β€” optimized to be human-readable and easily understood, - even by those unfamiliar with domain-specific terminology. + """Human-readable display name for UI and end-user contexts. - If not provided, the name should be used for display (except for Tool, - where `annotations.title` should be given precedence over using `name`, - if present). + If not provided, `name` should be used for display (except for Tool, where + `annotations.title` takes precedence over `name`). """ @@ -274,19 +252,14 @@ class RootsCapability(MCPModel): class SamplingContextCapability(MCPModel): - """Capability for context inclusion during sampling. + """Support for non-'none' `includeContext` values during sampling. - Indicates support for non-'none' values in the includeContext parameter. - SOFT-DEPRECATED: New implementations should use tools parameter instead. + SOFT-DEPRECATED: new implementations should use the tools parameter instead. """ class SamplingToolsCapability(MCPModel): - """Capability indicating support for tool calling during sampling. - - When present in ClientCapabilities.sampling, indicates that the client - supports the tools and toolChoice parameters in sampling requests. - """ + """The client supports the tools and toolChoice parameters in sampling requests.""" class FormElicitationCapability(MCPModel): @@ -314,15 +287,10 @@ class SamplingCapability(MCPModel): """Sampling capability structure. Deprecated in 2026-07-28 (SEP-2577); shape unchanged.""" context: SamplingContextCapability | None = None - """ - Present if the client supports non-'none' values for includeContext parameter. - SOFT-DEPRECATED: New implementations should use tools parameter instead. - """ + """Present if the client supports non-'none' `includeContext` values. + SOFT-DEPRECATED: new implementations should use the tools parameter instead.""" tools: SamplingToolsCapability | None = None - """ - Present if the client supports tools and toolChoice parameters in sampling requests. - Presence indicates full tool calling support during sampling. - """ + """Present if the client supports the tools and toolChoice sampling parameters.""" class TasksListCapability(MCPModel): @@ -378,10 +346,7 @@ class ClientCapabilities(MCPModel): experimental: dict[str, dict[str, Any]] | None = None """Experimental, non-standard capabilities that the client supports.""" sampling: SamplingCapability | None = None - """ - Present if the client supports sampling from an LLM. - Can contain fine-grained capabilities like context and tools support. - """ + """Present if the client supports sampling from an LLM.""" elicitation: ElicitationCapability | None = None """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None @@ -507,8 +472,7 @@ class InitializeRequestParams(RequestParams): class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]): - """This request is sent from the client to the server when it first connects, asking it - to begin initialization. + """Sent from the client when it first connects, asking the server to begin initialization. Removed in protocol 2026-07-28; sent/received on sessions negotiating <= 2025-11-25. On 2026-07-28 the handshake is `server/discover` plus per-request `_meta`. @@ -530,16 +494,12 @@ class InitializeResult(Result): capabilities: ServerCapabilities server_info: Implementation instructions: str | None = None - """Instructions describing how to use the server and its features. - - Clients may use this to improve an LLM's understanding of available tools, - resources, etc., for example by adding it to the system prompt. - """ + """Instructions describing how to use the server, e.g. added to a system prompt + to improve the LLM's understanding of available tools and resources.""" class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]): - """This notification is sent from the client to the server after initialization has - finished. + """Sent from the client after initialization has finished. Removed in protocol 2026-07-28; sent/received on sessions negotiating <= 2025-11-25. """ @@ -591,9 +551,8 @@ class DiscoverResult(CacheableResult): ignored by older peers, and defaulted on inbound bodies that omit it.""" -# Tasks: introduced in 2025-11-25, removed from the core spec in 2026-07-28 -# (continuing as an extension). Defined here types-only; their methods are not -# in the request/notification unions below, so they are never dispatched. +# Tasks: introduced in 2025-11-25, removed from the core spec in 2026-07-28 (now an +# extension). Types-only: their methods are not in the unions below, so never dispatched. class ToolExecution(MCPModel): @@ -729,22 +688,13 @@ class ProgressNotificationParams(NotificationParams): """Parameters for progress notifications.""" progress_token: ProgressToken - """ - The progress token which was given in the initial request, used to associate this - notification with the request that is proceeding. - """ + """The token from the original request, associating this notification with it.""" progress: float - """ - The progress thus far. This should increase every time progress is made, even if the - total is unknown. - """ + """Progress thus far; should increase on every update, even if the total is unknown.""" total: float | None = None """Total number of items to process (or total progress required), if known.""" message: str | None = None - """Message related to progress. - - This should provide relevant human-readable progress information. - """ + """Optional human-readable progress information.""" class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]): @@ -787,10 +737,8 @@ class Resource(BaseMetadata): """The MIME type of this resource, if known.""" size: int | None = None - """The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known. - - This can be used by Hosts to display file sizes and estimate context window usage. - """ + """Raw content size in bytes (before base64 encoding or tokenization), if known. + Lets hosts display file sizes and estimate context window usage.""" icons: list[Icon] | None = None """Optional set of sized icons that the client can display in a user interface.""" @@ -812,10 +760,7 @@ class ResourceTemplate(BaseMetadata): """A description of what this template is for.""" mime_type: str | None = None - """The MIME type for all resources that match this template. - - This should only be included if all resources matching this template have the same type. - """ + """The MIME type of all matching resources; include only if they all share the same type.""" icons: list[Icon] | None = None """An optional set of sized icons that the client can display in a user interface.""" @@ -824,10 +769,7 @@ class ResourceTemplate(BaseMetadata): """Optional annotations for the client.""" meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" class ListResourcesResult(PaginatedResult, CacheableResult): @@ -868,10 +810,7 @@ class InputResponseRequestParams(RequestParams): class ReadResourceRequestParams(InputResponseRequestParams): uri: str - """ - The URI of the resource. The URI can use any protocol; it is up to the server - how to interpret it. - """ + """The URI of the resource; any protocol, interpreted by the server.""" class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]): @@ -889,20 +828,14 @@ class ResourceContents(MCPModel): mime_type: str | None = None """The MIME type of this resource, if known.""" meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" class TextResourceContents(ResourceContents): """Text contents of a resource.""" text: str - """ - The text of the item. This must only be set if the item can actually be represented - as text (not binary data). - """ + """The text of the item; only for items representable as text (not binary data).""" class BlobResourceContents(ResourceContents): @@ -925,8 +858,7 @@ class ReadResourceResult(CacheableResult): class ResourceListChangedNotification( Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]] ): - """An optional notification from the server to the client, informing it that the list - of resources it can read from has changed. + """Optional server notification that the list of readable resources has changed. May be sent spontaneously through 2025-11-25; on 2026-07-28 sessions the client must opt in via `subscriptions/listen`. @@ -943,15 +875,11 @@ class SubscribeRequestParams(RequestParams): """ uri: str - """ - The URI of the resource to subscribe to. The URI can use any protocol; it is up to - the server how to interpret it. - """ + """The URI of the resource to subscribe to; any protocol, interpreted by the server.""" class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]): - """Sent from the client to request resources/updated notifications from the server - whenever a particular resource changes. + """Requests resources/updated notifications whenever a particular resource changes. Removed in protocol 2026-07-28; sent/received on sessions negotiating <= 2025-11-25. On 2026-07-28 use `subscriptions/listen` instead. @@ -972,8 +900,7 @@ class UnsubscribeRequestParams(RequestParams): class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]): - """Sent from the client to request cancellation of resources/updated notifications - from the server. This should follow a previous resources/subscribe request. + """Cancels resources/updated notifications from a previous resources/subscribe request. Removed in protocol 2026-07-28; sent/received on sessions negotiating <= 2025-11-25. On 2026-07-28 use `subscriptions/listen` instead. @@ -985,17 +912,13 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un class ResourceUpdatedNotificationParams(NotificationParams): uri: str - """ - The URI of the resource that has been updated. This might be a sub-resource of the - one that the client actually subscribed to. - """ + """The URI of the updated resource; may be a sub-resource of the one subscribed to.""" class ResourceUpdatedNotification( Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]] ): - """A notification from the server to the client, informing it that a resource has - changed and may need to be read again. + """Server notification that a resource has changed and may need to be read again. Only sent if the client subscribed: via `resources/subscribe` through 2025-11-25, or `subscriptions/listen` on 2026-07-28. @@ -1103,10 +1026,7 @@ class Prompt(BaseMetadata): icons: list[Icon] | None = None """An optional list of icons for this prompt.""" meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" class ListPromptsResult(PaginatedResult, CacheableResult): @@ -1140,10 +1060,7 @@ class TextContent(MCPModel): annotations: Annotations | None = None """Optional annotations for the client.""" meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" class ImageContent(MCPModel): @@ -1153,14 +1070,11 @@ class ImageContent(MCPModel): data: str """The base64-encoded image data.""" mime_type: str - """ - The MIME type of the image. Different providers may support different - image types. - """ + """The MIME type of the image. Different providers may support different image types.""" annotations: Annotations | None = None """Optional annotations for the client.""" meta: Meta | None = Field(alias="_meta", default=None) - """See the MCP specification's "General fields: _meta" section for notes on _meta usage.""" + """See the MCP specification for notes on `_meta` usage.""" class AudioContent(MCPModel): @@ -1170,17 +1084,11 @@ class AudioContent(MCPModel): data: str """The base64-encoded audio data.""" mime_type: str - """ - The MIME type of the audio. Different providers may support different - audio types. - """ + """The MIME type of the audio. Different providers may support different audio types.""" annotations: Annotations | None = None """Optional annotations for the client.""" meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" class ToolUseContent(MCPModel): @@ -1256,20 +1164,13 @@ class SamplingMessage(MCPModel): role: Role content: SamplingMessageContentBlock | list[SamplingMessageContentBlock] - """ - Message content. Can be a single content block or an array of content blocks - for multi-modal messages and tool interactions. - """ + """A single content block or an array, for multi-modal messages and tool interactions.""" meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" @property def content_as_list(self) -> list[SamplingMessageContentBlock]: - """Returns the content as a list of content blocks, regardless of whether - it was originally a single block or a list.""" + """The content as a list, whether it was originally a single block or a list.""" return self.content if isinstance(self.content, list) else [self.content] @@ -1285,10 +1186,7 @@ class EmbeddedResource(MCPModel): annotations: Annotations | None = None """Optional annotations for the client.""" meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" class ResourceLink(Resource): @@ -1329,8 +1227,7 @@ class GetPromptResult(Result): class PromptListChangedNotification( Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]] ): - """An optional notification from the server to the client, informing it that the list - of prompts it offers has changed. + """Optional server notification that the list of offered prompts has changed. May be sent spontaneously through 2025-11-25; on 2026-07-28 sessions the client must opt in via `subscriptions/listen`. @@ -1349,47 +1246,28 @@ class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]): class ToolAnnotations(MCPModel): """Additional properties describing a Tool to clients. - NOTE: all properties in ToolAnnotations are **hints**. - They are not guaranteed to provide a faithful description of - tool behavior (including descriptive properties like `title`). - - Clients should never make tool use decisions based on ToolAnnotations - received from untrusted servers. + All properties are hints, not guaranteed to faithfully describe tool + behavior. Clients should never make tool-use decisions based on + annotations received from untrusted servers. """ title: str | None = None """A human-readable title for the tool.""" read_only_hint: bool | None = None - """ - If true, the tool does not modify its environment. - Default: false - """ + """If true, the tool does not modify its environment. Default: false.""" destructive_hint: bool | None = None - """ - If true, the tool may perform destructive updates to its environment. - If false, the tool performs only additive updates. - (This property is meaningful only when `read_only_hint == false`) - Default: true - """ + """If true, the tool may perform destructive (rather than only additive) updates. + Meaningful only when `read_only_hint` is false. Default: true.""" idempotent_hint: bool | None = None - """ - If true, calling the tool repeatedly with the same arguments - will have no additional effect on its environment. - (This property is meaningful only when `read_only_hint == false`) - Default: false - """ + """If true, repeated calls with the same arguments have no additional effect. + Meaningful only when `read_only_hint` is false. Default: false.""" open_world_hint: bool | None = None - """ - If true, this tool may interact with an "open world" of external - entities. If false, the tool's domain of interaction is closed. - For example, the world of a web search tool is open, whereas that - of a memory tool is not. - Default: true - """ + """If true, the tool may interact with an "open world" of external entities + (e.g. web search); if false, its domain is closed (e.g. memory). Default: true.""" class Tool(BaseMetadata): @@ -1469,8 +1347,7 @@ class CallToolResult(Result): class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]): - """An optional notification from the server to the client, informing it that the list - of tools it offers has changed. + """Optional server notification that the list of offered tools has changed. May be sent spontaneously through 2025-11-25; on 2026-07-28 sessions the client must opt in via `subscriptions/listen`. @@ -1516,10 +1393,7 @@ class LoggingMessageNotificationParams(NotificationParams): logger: str | None = None """An optional name of the logger issuing this message.""" data: Any - """ - The data to be logged, such as a string message or an object. Any JSON serializable - type is allowed here. - """ + """The data to log: any JSON-serializable value, such as a string message or an object.""" class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]): @@ -1549,62 +1423,33 @@ class ModelHint(MCPModel): """ name: str | None = None - """A hint for a model name. - - The client SHOULD treat this as a substring (e.g. `sonnet` matches - `claude-3-5-sonnet-20241022`) and MAY map it to another provider's model - that fills a similar niche. - """ + """A model-name hint. The client SHOULD treat it as a substring (e.g. `sonnet` + matches `claude-3-5-sonnet-20241022`) and MAY map it to an equivalent model + from another provider.""" class ModelPreferences(MCPModel): - """The server's preferences for model selection, requested of the client during - sampling. + """The server's preferences for model selection, requested of the client during sampling. - Because LLMs can vary along multiple dimensions, choosing the "best" model is - rarely straightforward. Different models excel in different areasβ€”some are - faster but less capable, others are more capable but more expensive, and so - on. This interface allows servers to express their priorities across multiple - dimensions to help clients make an appropriate selection for their use case. - - These preferences are always advisory. The client MAY ignore them. It is also - up to the client to decide how to interpret these preferences and how to - balance them against other considerations. + Expresses the server's priorities across cost, speed, and intelligence. + Always advisory: the client MAY ignore them, and decides how to interpret + and balance them against other considerations. Deprecated in 2026-07-28 (SEP-2577) with the rest of sampling. """ hints: list[ModelHint] | None = None - """ - Optional hints to use for model selection. - - If multiple hints are specified, the client MUST evaluate them in order - (such that the first match is taken). - - The client SHOULD prioritize these hints over the numeric priorities, but - MAY still use the priorities to select from ambiguous matches. - """ + """Hints the client MUST evaluate in order (first match is taken). The client SHOULD + prioritize these over the numeric priorities, but MAY use those for ambiguous matches.""" cost_priority: float | None = None - """ - How much to prioritize cost when selecting a model. A value of 0 means cost - is not important, while a value of 1 means cost is the most important - factor. - """ + """How much to prioritize cost: 0 means not important, 1 means most important.""" speed_priority: float | None = None - """ - How much to prioritize sampling speed (latency) when selecting a model. A - value of 0 means speed is not important, while a value of 1 means speed is - the most important factor. - """ + """How much to prioritize sampling speed (latency): 0 means not important, 1 means most important.""" intelligence_priority: float | None = None - """ - How much to prioritize intelligence and capabilities when selecting a - model. A value of 0 means intelligence is not important, while a value of 1 - means intelligence is the most important factor. - """ + """How much to prioritize intelligence and capabilities: 0 means not important, 1 means most important.""" class ToolChoice(MCPModel): @@ -1615,30 +1460,21 @@ class ToolChoice(MCPModel): """ mode: Literal["auto", "required", "none"] | None = None - """ - Controls the tool use ability of the model: - - "auto": Model decides whether to use tools (default) - - "required": Model MUST use at least one tool before completing - - "none": Model MUST NOT use any tools - """ + """Tool-use mode: "auto" = model decides (default); "required" = MUST use at + least one tool before completing; "none" = MUST NOT use any tools.""" class CreateMessageRequestParams(RequestParams): messages: list[SamplingMessage] """The conversation to sample from.""" model_preferences: ModelPreferences | None = None - """ - The server's preferences for which model to select. The client MAY ignore - these preferences. - """ + """The server's model selection preferences; the client MAY ignore them.""" system_prompt: str | None = None """An optional system prompt the server wants to use for sampling.""" include_context: IncludeContext | None = None - """ - A request to include context from one or more MCP servers (including the - caller), to be attached to the prompt. The client MAY ignore this request. - Default is "none". "thisServer" and "allServers" are deprecated (SEP-2596). - """ + """Request to attach context from MCP servers (including the caller) to the prompt; + the client MAY ignore it. Default "none". "thisServer" and "allServers" are + deprecated (SEP-2596).""" temperature: float | None = None max_tokens: int """The maximum number of tokens to sample, as requested by the server.""" @@ -1704,22 +1540,15 @@ class CreateMessageResultWithTools(Result): role: Role """The role of the message sender (typically 'assistant' for LLM responses).""" content: SamplingMessageContentBlock | list[SamplingMessageContentBlock] - """ - Response content. May be a single content block or an array. - May include ToolUseContent if stop_reason is 'toolUse'. - """ + """A single content block or an array; may include ToolUseContent when stop_reason is 'toolUse'.""" model: str """The name of the model that generated the message.""" stop_reason: StopReason | None = None - """ - The reason why sampling stopped, if known. - 'toolUse' indicates the model wants to use a tool. - """ + """Why sampling stopped, if known; 'toolUse' means the model wants to use a tool.""" @property def content_as_list(self) -> list[SamplingMessageContentBlock]: - """Returns the content as a list of content blocks, regardless of whether - it was originally a single block or a list.""" + """The content as a list, whether it was originally a single block or a list.""" return self.content if isinstance(self.content, list) else [self.content] @@ -1779,15 +1608,9 @@ class Completion(MCPModel): values: list[str] """An array of completion values. Must not exceed 100 items.""" total: int | None = None - """ - The total number of completion options available. This can exceed the number of - values actually sent in the response. - """ + """Total completion options available; can exceed the number of values returned.""" has_more: bool | None = None - """ - Indicates whether there are additional completion options beyond those provided in - the current response, even if the exact total is unknown. - """ + """Whether more completion options exist beyond those returned, even if the total is unknown.""" class CompleteResult(Result): @@ -1801,13 +1624,7 @@ class CompleteResult(Result): class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): - """Sent from the server to request a list of root URIs from the client. Roots allow - servers to ask for specific directories or files to operate on. A common example - for roots is providing a set of repositories or directories a server should operate - on. - - This request is typically used when the server needs to understand the file system - structure or access specific locations that the client has permission to read from. + """Requests the list of root URIs (directories/files the server may operate on) from the client. A standalone JSON-RPC request through 2025-11-25; on 2026-07-28 it is embedded in `InputRequiredResult.input_requests`. Deprecated in 2026-07-28 (SEP-2577). @@ -1826,30 +1643,17 @@ class Root(MCPModel): """ uri: FileUrl - """ - The URI identifying the root. This *must* start with file:// for now. - This restriction may be relaxed in future versions of the protocol to allow - other URI schemes. - """ + """The URI identifying the root; must start with `file://` for now (future + protocol versions may allow other schemes).""" name: str | None = None - """ - An optional name for the root. This can be used to provide a human-readable - identifier for the root, which may be useful for display purposes or for - referencing the root in other parts of the application. - """ + """Optional human-readable identifier for the root, for display or referencing.""" meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ + """See the MCP specification for notes on `_meta` usage.""" class ListRootsResult(Result): """The client's response to a roots/list request from the server. - This result contains an array of Root objects, each representing a root - directory or file that the server can operate on. - On 2026-07-28 this is carried as an `InputResponses` entry, not a JSON-RPC result. Deprecated in 2026-07-28 (SEP-2577). """ @@ -1860,12 +1664,8 @@ class ListRootsResult(Result): class RootsListChangedNotification( Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]] ): - """A notification from the client to the server, informing it that the list of - roots has changed. - - This notification should be sent whenever the client adds, removes, or - modifies any root. The server should then request an updated list of roots - using the ListRootsRequest. + """Client notification that the roots list has changed (any root added, removed, + or modified); the server should re-request the list via ListRootsRequest. Removed in protocol 2026-07-28; sent/received on sessions negotiating <= 2025-11-25. """ @@ -1876,25 +1676,22 @@ class RootsListChangedNotification( class CancelledNotificationParams(NotificationParams): request_id: RequestId | None = None - """ - The ID of the request to cancel. + """The ID of the request to cancel; MUST match a request previously issued in the same direction. - This MUST correspond to the ID of a request previously issued in the same direction. Required on the wire through 2025-06-18; optional at 2025-11-25; required again from - 2026-07-28, where it must name a request the client previously issued (servers send - this notification only to terminate a `subscriptions/listen` stream). + 2026-07-28, where it must name a request the client issued (servers send this only + to terminate a `subscriptions/listen` stream). """ reason: str | None = None """An optional string describing the reason for the cancellation.""" class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]): - """This notification can be sent by either side to indicate that it is canceling a - previously-issued request. + """Sent by either side to cancel a previously-issued request. - The request SHOULD still be in-flight, but due to communication latency, it - is always possible that this notification MAY arrive after the request has - already finished. A client MUST NOT attempt to cancel its `initialize` request. + The request SHOULD still be in-flight, but due to communication latency it MAY + arrive after the request has already finished. A client MUST NOT attempt to + cancel its `initialize` request. """ method: Literal["notifications/cancelled"] = "notifications/cancelled" @@ -1911,15 +1708,11 @@ class ElicitCompleteNotificationParams(NotificationParams): class ElicitCompleteNotification( Notification[ElicitCompleteNotificationParams, Literal["notifications/elicitation/complete"]] ): - """A notification from the server to the client, informing it that a URL mode - elicitation has been completed. + """Server notification that a URL mode elicitation has been completed. - Clients MAY use the notification to automatically retry requests that received a - URLElicitationRequiredError, update the user interface, or otherwise continue - an interaction. However, because delivery of the notification is not guaranteed, - clients must not wait indefinitely for a notification from the server. - - New in protocol 2025-11-25 with URL mode itself. + Clients MAY use it to retry requests that received a URLElicitationRequiredError + or update the UI; delivery is not guaranteed, so clients must not wait + indefinitely. New in protocol 2025-11-25 with URL mode itself. """ method: Literal["notifications/elicitation/complete"] = "notifications/elicitation/complete" @@ -1945,10 +1738,7 @@ class ElicitRequestFormParams(RequestParams): """The message to present to the user describing what information is being requested.""" requested_schema: ElicitRequestedSchema - """ - A restricted subset of JSON Schema defining the structure of the expected response. - Only top-level properties are allowed, without nesting. - """ + """Restricted JSON Schema subset for the expected response: top-level properties only, no nesting.""" task: TaskMetadata | None = None """If specified, the caller requests task-augmented execution (2025-11-25 only).""" @@ -1971,17 +1761,13 @@ class ElicitRequestURLParams(RequestParams): """The URL that the user should navigate to.""" elicitation_id: str | None = None - """The ID of the elicitation, which must be unique within the context of the server. - - The client MUST treat this ID as an opaque value. Required on the wire at - 2025-11-25; removed at 2026-07-28. - """ + """Server-unique elicitation ID; the client MUST treat it as opaque. + Required on the wire at 2025-11-25; removed at 2026-07-28.""" task: TaskMetadata | None = None """If specified, the caller requests task-augmented execution (2025-11-25 only).""" -# Union type for elicitation request parameters ElicitRequestParams: TypeAlias = ElicitRequestURLParams | ElicitRequestFormParams """Parameters for elicitation requests - either form or URL mode.""" @@ -1997,27 +1783,18 @@ class ElicitResult(Result): """The client's response to an elicitation request.""" action: Literal["accept", "decline", "cancel"] - """ - The user action in response to the elicitation. - - "accept": User submitted the form/confirmed the action (or consented to URL navigation) - - "decline": User explicitly declined the action - - "cancel": User dismissed without making an explicit choice - """ + """The user action: "accept" = submitted the form / confirmed (or consented to URL + navigation); "decline" = explicitly declined; "cancel" = dismissed without an + explicit choice.""" content: dict[str, str | int | float | bool | list[str] | None] | None = None - """ - The submitted form data, only present when action is "accept" in form mode. - Contains values matching the requested schema. Values can be strings, integers, floats, - booleans, arrays of strings, or null. - For URL mode, this field is omitted. - """ + """Submitted form data matching the requested schema; present only when action + is "accept" in form mode (omitted for URL mode).""" class ElicitationRequiredErrorData(MCPModel): - """Error data for the -32042 URL-elicitation-required error. - - Servers return this when a request cannot be processed until one or more - URL mode elicitations are completed. + """Error data for the -32042 URL-elicitation-required error: the request cannot + proceed until the listed URL mode elicitations are completed. Removed in protocol 2026-07-28; sent/received on sessions negotiating 2025-11-25. """ diff --git a/src/mcp-types/mcp_types/jsonrpc.py b/src/mcp-types/mcp_types/jsonrpc.py index fcc3317d86..2f974f93c6 100644 --- a/src/mcp-types/mcp_types/jsonrpc.py +++ b/src/mcp-types/mcp_types/jsonrpc.py @@ -41,11 +41,9 @@ class JSONRPCResponse(BaseModel): result: dict[str, Any] -# MCP error codes occupy the JSON-RPC server-error range -32000..-32099. -# Per the 2026-07-28 spec's allocation policy: -# -32000..-32019 implementation-defined -# -32020..-32099 reserved for spec-defined codes, allocated sequentially from -32020 -# -32002, -32042 reserved-never-reused (retired by earlier protocol versions) +# MCP error codes occupy the JSON-RPC server-error range -32000..-32099. 2026-07-28 allocation policy: +# -32000..-32019 implementation-defined; -32020..-32099 spec-defined, allocated sequentially from -32020; +# -32002 and -32042 reserved-never-reused (retired by earlier protocol versions). HEADER_MISMATCH = -32020 """HTTP headers do not match the request body, or required headers are missing/malformed (protocol 2026-07-28).""" @@ -59,15 +57,13 @@ class JSONRPCResponse(BaseModel): URL_ELICITATION_REQUIRED = -32042 """A URL-mode elicitation is required before the request can be processed (protocol 2025-11-25 only).""" -# SDK error codes: SDK-internal allocations in the implementation-defined band -# -32000..-32019; not defined by the MCP schema. +# SDK-internal allocations in the implementation-defined band; not defined by the MCP schema. CONNECTION_CLOSED = -32000 """SDK-only: the connection closed before a response arrived; never emitted on the wire.""" REQUEST_TIMEOUT = -32001 """SDK-only: a request timed out waiting for its response.""" -# Standard JSON-RPC error codes PARSE_ERROR = -32700 """Standard JSON-RPC: invalid JSON was received.""" @@ -94,16 +90,10 @@ class ErrorData(BaseModel): """The error type that occurred.""" message: str - """A short description of the error. - - The message SHOULD be limited to a concise single sentence. - """ + """A short description of the error; SHOULD be limited to a concise single sentence.""" data: Any = None - """Additional information about the error. - - The value of this member is defined by the sender (e.g. detailed error information, nested errors, etc.). - """ + """Additional information about the error, defined by the sender (e.g. detailed or nested errors).""" class JSONRPCError(BaseModel): @@ -111,10 +101,7 @@ class JSONRPCError(BaseModel): jsonrpc: Literal["2.0"] id: RequestId | None - """The id of the request this error responds to. - - Required but nullable per JSON-RPC 2.0: `None` encodes `"id": null` (the id could not be determined). - """ + """Required but nullable per JSON-RPC 2.0: `None` encodes `"id": null` (the id could not be determined).""" error: ErrorData diff --git a/src/mcp-types/mcp_types/methods.py b/src/mcp-types/mcp_types/methods.py index 824dcfdfe6..20c7302670 100644 --- a/src/mcp-types/mcp_types/methods.py +++ b/src/mcp-types/mcp_types/methods.py @@ -48,8 +48,6 @@ ] -# --- Surface maps: client-to-server --- - CLIENT_REQUESTS: Final[Mapping[tuple[str, str], type[BaseModel]]] = MappingProxyType( { # 2024-11-05 @@ -150,8 +148,6 @@ ) -# --- Surface maps: server-to-client --- - SERVER_REQUESTS: Final[Mapping[tuple[str, str], type[BaseModel]]] = MappingProxyType( { # 2024-11-05 @@ -224,8 +220,6 @@ ) -# --- Surface maps: results --- - SERVER_RESULTS: Final[Mapping[tuple[str, str], type[BaseModel] | UnionType]] = MappingProxyType( { # 2024-11-05 @@ -325,8 +319,6 @@ """Results clients send, keyed by the originating server request's (method, version).""" -# --- Direction-specific method sets --- - SPEC_CLIENT_METHODS: Final[frozenset[str]] = frozenset(m for m, _ in CLIENT_REQUESTS) """Spec request methods a client may send (any version); the server-side spec-method discriminator.""" @@ -334,8 +326,6 @@ """Spec notification methods a client may send (any version); the server-side spec-method discriminator.""" -# --- Monolith maps --- - MONOLITH_REQUESTS: Final[Mapping[str, type[types.Request[Any, Any]]]] = MappingProxyType( { "completion/complete": types.CompleteRequest, @@ -404,8 +394,6 @@ """Monolith result model (or two-arm union) per request method.""" -# --- Parse functions --- - # Envelope stubs merged into bodies for surface validation (surface classes are full frames). _REQUEST_STUB: Final[Mapping[str, Any]] = MappingProxyType({"jsonrpc": "2.0", "id": 0}) _NOTIFICATION_STUB: Final[Mapping[str, Any]] = MappingProxyType({"jsonrpc": "2.0"}) @@ -418,7 +406,6 @@ def _check_known_version(version: str) -> None: def _body(method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - """Build a JSON-RPC body, omitting `params` when None.""" body: dict[str, Any] = {"method": method} if params is not None: body["params"] = params @@ -473,12 +460,7 @@ def parse_client_request( ) -> types.Request[Any, Any]: """Validate a client request against `surface`, then parse and return its `monolith` model. - Args: - surface: `(method, version)` to wire-type map; the version-gate lookup - and (per-schema-era) shape check run against this. Pass an extended - map to admit custom methods. - monolith: `method` to version-free model map; the returned instance is - parsed from this row. Must cover every method `surface` admits. + Pass extended maps to admit custom methods; `monolith` must cover every method `surface` admits. Raises: ValueError: `version` is not a known protocol version. @@ -500,12 +482,7 @@ def parse_server_request( ) -> types.Request[Any, Any]: """Validate a server request against `surface`, then parse and return its `monolith` model. - Args: - surface: `(method, version)` to wire-type map; the version-gate lookup - and (per-schema-era) shape check run against this. Pass an extended - map to admit custom methods. - monolith: `method` to version-free model map; the returned instance is - parsed from this row. Must cover every method `surface` admits. + Pass extended maps to admit custom methods; `monolith` must cover every method `surface` admits. Raises: ValueError: `version` is not a known protocol version. @@ -547,12 +524,7 @@ def parse_client_notification( ) -> types.Notification[Any, Any]: """Validate a client notification against `surface`, then parse and return its `monolith` model. - Args: - surface: `(method, version)` to wire-type map; the version-gate lookup - and (per-schema-era) shape check run against this. Pass an extended - map to admit custom methods. - monolith: `method` to version-free model map; the returned instance is - parsed from this row. Must cover every method `surface` admits. + Pass extended maps to admit custom methods; `monolith` must cover every method `surface` admits. Raises: ValueError: `version` is not a known protocol version. @@ -574,12 +546,7 @@ def parse_server_notification( ) -> types.Notification[Any, Any]: """Validate a server notification against `surface`, then parse and return its `monolith` model. - Args: - surface: `(method, version)` to wire-type map; the version-gate lookup - and (per-schema-era) shape check run against this. Pass an extended - map to admit custom methods. - monolith: `method` to version-free model map; the returned instance is - parsed from this row. Must cover every method `surface` admits. + Pass extended maps to admit custom methods; `monolith` must cover every method `surface` admits. Raises: ValueError: `version` is not a known protocol version. @@ -644,12 +611,7 @@ def parse_server_result( ) -> types.Result: """Validate a server result against `surface`, then parse and return its `monolith` model. - Args: - surface: `(method, version)` to wire-type map; the version-gate lookup - and (per-schema-era) shape check run against this. Pass an extended - map to admit custom methods. - monolith: `method` to version-free model map; the returned instance is - parsed from this row. Must cover every method `surface` admits. + Pass extended maps to admit custom methods; `monolith` must cover every method `surface` admits. Raises: ValueError: `version` is not a known protocol version. @@ -690,12 +652,7 @@ def parse_client_result( ) -> types.Result: """Validate a client result against `surface`, then parse and return its `monolith` model. - Args: - surface: `(method, version)` to wire-type map; the version-gate lookup - and (per-schema-era) shape check run against this. Pass an extended - map to admit custom methods. - monolith: `method` to version-free model map; the returned instance is - parsed from this row. Must cover every method `surface` admits. + Pass extended maps to admit custom methods; `monolith` must cover every method `surface` admits. Raises: ValueError: `version` is not a known protocol version. diff --git a/src/mcp-types/mcp_types/version.py b/src/mcp-types/mcp_types/version.py index c5c2233274..30a6222f2d 100644 --- a/src/mcp-types/mcp_types/version.py +++ b/src/mcp-types/mcp_types/version.py @@ -1,10 +1,9 @@ """Protocol-version registry and comparison helpers. -Date-string protocol revisions happen to sort lexicographically, but versions -are an enumerated set, not an ordered scalar: future identifiers are not -guaranteed to be date-shaped, and unrecognized peer strings must compare -conservatively instead of accidentally (e.g. "zzz" > "2025-11-25"). All -ordering questions go through KNOWN_PROTOCOL_VERSIONS. +Versions are an enumerated set, not an ordered scalar: future identifiers may +not be date-shaped, and unrecognized peer strings must compare conservatively +(lexicographic comparison would put "zzz" above "2025-11-25"). All ordering +goes through KNOWN_PROTOCOL_VERSIONS. """ from typing import Final @@ -30,30 +29,26 @@ """Protocol revisions that use the stateless per-request envelope.""" SUPPORTED_PROTOCOL_VERSIONS: tuple[str, ...] = (*HANDSHAKE_PROTOCOL_VERSIONS, *MODERN_PROTOCOL_VERSIONS) -"""Deprecated: prefer HANDSHAKE_PROTOCOL_VERSIONS or MODERN_PROTOCOL_VERSIONS. - -Kept as the union for v1.x compatibility. -""" +"""Deprecated: use HANDSHAKE_PROTOCOL_VERSIONS or MODERN_PROTOCOL_VERSIONS; kept as their union for v1.x compat.""" LATEST_PROTOCOL_VERSION: Final[str] = KNOWN_PROTOCOL_VERSIONS[-1] """Newest protocol revision this SDK speaks (any era).""" LATEST_HANDSHAKE_VERSION: Final[str] = HANDSHAKE_PROTOCOL_VERSIONS[-1] -"""Newest revision reachable via the ``initialize`` handshake; the client's offer and server's counter-offer default.""" +"""Newest revision reachable via the `initialize` handshake; the client's offer and server's counter-offer default.""" LATEST_MODERN_VERSION: Final[str] = MODERN_PROTOCOL_VERSIONS[-1] -"""Newest per-request-envelope revision; the ``server/discover`` probe default.""" +"""Newest per-request-envelope revision; the `server/discover` probe default.""" OLDEST_SUPPORTED_VERSION: Final[str] = HANDSHAKE_PROTOCOL_VERSIONS[0] -"""Oldest revision this SDK still negotiates via the ``initialize`` handshake.""" +"""Oldest revision this SDK still negotiates via the `initialize` handshake.""" def is_version_at_least(version: str, minimum: str) -> bool: """Return True if `version` is a known revision at least as new as `minimum`. - Unknown `version` strings return False (treat unrecognized peers - conservatively). `minimum` must be a member of KNOWN_PROTOCOL_VERSIONS; - passing anything else is programmer error and raises ValueError. + Unknown `version` strings return False (unrecognized peers compare conservatively). + `minimum` must be in KNOWN_PROTOCOL_VERSIONS; anything else raises ValueError. """ if minimum not in KNOWN_PROTOCOL_VERSIONS: raise ValueError(f"minimum must be a known protocol version, got {minimum!r}") diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index e65379682a..9c4179fbf3 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -16,12 +16,9 @@ def mcp_requirement(package: str = "mcp") -> str: """Requirement string pinning spawned environments to the running SDK version. - `uv run --with mcp` resolves the requirement in a fresh environment, where - an unpinned `mcp` means the latest stable release β€” not necessarily the - version the user installed (pre-releases in particular are never selected - without an explicit pin). Source builds carry dev/local version segments - that are not published to PyPI, so they fall back to the unpinned form, - as does a missing distribution (no metadata to pin from). + An unpinned `mcp` in a fresh `uv run --with mcp` environment resolves to the latest + stable release, not the installed one (pre-releases are never selected without a pin). + Dev/local builds and missing distributions have no published version, so they stay unpinned. """ try: version = importlib.metadata.version("mcp") @@ -55,7 +52,7 @@ def get_uv_path() -> str: logger.error( "uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH" ) - return "uv" # Fall back to just "uv" if not found + return "uv" return uv_path @@ -70,16 +67,11 @@ def update_claude_config( """Add or update an MCP server in Claude's configuration. Args: - file_spec: Path to the server file, optionally with :object suffix - server_name: Name for the server in Claude's config - with_editable: Optional directory to install in editable mode - with_packages: Optional list of additional packages to install - env_vars: Optional dictionary of environment variables. These are merged with - any existing variables, with new values taking precedence. + file_spec: Path to the server file, optionally with `:object` suffix. + env_vars: Merged with any existing variables, with new values taking precedence. Raises: - RuntimeError: If Claude Desktop's config directory is not found, indicating - Claude Desktop may not be installed or properly set up. + RuntimeError: If Claude Desktop's config directory is not found. """ config_dir = get_claude_config_path() uv_path = get_uv_path() @@ -107,48 +99,39 @@ def update_claude_config( if "mcpServers" not in config: config["mcpServers"] = {} - # Always preserve existing env vars and merge with new ones if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: existing_env = config["mcpServers"][server_name]["env"] if env_vars: - # New vars take precedence over existing ones env_vars = {**existing_env, **env_vars} else: env_vars = existing_env - # Build uv run command args = ["run", "--frozen"] - # Collect all packages in a set to deduplicate packages = {mcp_requirement("mcp[cli]")} if with_packages: packages.update(pkg for pkg in with_packages if pkg) - # Add all packages with --with for pkg in sorted(packages): args.extend(["--with", pkg]) if with_editable: args.extend(["--with-editable", str(with_editable)]) - # Convert file path to absolute before adding to command - # Split off any :object suffix first - # First check if we have a Windows path (e.g., C:\...) + # Resolve to an absolute path, splitting any :object suffix on the last colon + # without mistaking a Windows drive letter (C:\...) for one. has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":" - # Split on the last colon, but only if it's not part of the Windows drive letter if ":" in (file_spec[2:] if has_windows_drive else file_spec): file_path, server_object = file_spec.rsplit(":", 1) file_spec = f"{Path(file_path).resolve()}:{server_object}" else: file_spec = str(Path(file_spec).resolve()) - # Add mcp run command args.extend(["mcp", "run", file_spec]) server_config: dict[str, Any] = {"command": uv_path, "args": args} - # Add environment variables if specified if env_vars: server_config["env"] = env_vars diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index eb06bf087a..70bed9c6cc 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -35,14 +35,13 @@ name="mcp", help="MCP development tools", add_completion=False, - no_args_is_help=True, # Show help if no args provided + no_args_is_help=True, ) def _get_npx_command(): """Get the correct npx command for the current platform.""" if sys.platform == "win32": - # Try both npx.cmd and npx.exe on Windows for cmd in ["npx.cmd", "npx.exe", "npx"]: try: subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True) @@ -50,7 +49,7 @@ def _get_npx_command(): except subprocess.CalledProcessError: continue return None - return "npx" # On Unix-like systems, just use npx + return "npx" def _parse_env_var(env_var: str) -> tuple[str, str]: # pragma: no cover @@ -80,31 +79,20 @@ def _build_uv_command( if pkg: # pragma: no branch cmd.extend(["--with", pkg]) - # Add mcp run command cmd.extend(["mcp", "run", file_spec]) return cmd def _parse_file_path(file_spec: str) -> tuple[Path, str | None]: - """Parse a file path that may include a server object specification. - - Args: - file_spec: Path to file, optionally with :object suffix - - Returns: - Tuple of (file_path, server_object) - """ - # First check if we have a Windows path (e.g., C:\...) + """Parse a `path` or `path:object` spec into a resolved file path and optional server object name.""" + # Split on the last colon, ignoring a Windows drive-letter colon (e.g. C:\...) has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":" - # Split on the last colon, but only if it's not part of the Windows drive letter - # and there's actually another colon in the string after the drive letter if ":" in (file_spec[2:] if has_windows_drive else file_spec): file_str, server_object = file_spec.rsplit(":", 1) else: file_str, server_object = file_spec, None - # Resolve the file path file_path = Path(file_str).expanduser().resolve() if not file_path.exists(): logger.error(f"File not found: {file_path}") @@ -117,21 +105,12 @@ def _parse_file_path(file_spec: str) -> tuple[Path, str | None]: def _import_server(file: Path, server_object: str | None = None): # pragma: no cover - """Import an MCP server from a file. - - Args: - file: Path to the file - server_object: Optional object name in format "module:object" or just "object" - - Returns: - The server object - """ + """Import an MCP server from a file; server_object may be `object` or `module:object`.""" # Add parent directory to Python path so imports can be resolved file_dir = str(file.parent) if file_dir not in sys.path: sys.path.insert(0, file_dir) - # Import the module spec = importlib.util.spec_from_file_location("server_module", file) if not spec or not spec.loader: logger.error("Could not load module", extra={"file": str(file)}) @@ -141,14 +120,7 @@ def _import_server(file: Path, server_object: str | None = None): # pragma: no spec.loader.exec_module(module) def _check_server_object(server_object: Any, object_name: str): - """Helper function to check that the server object is supported - - Args: - server_object: The server object to check. - - Returns: - True if it's supported. - """ + """Check that the object is a supported MCPServer instance.""" if not isinstance(server_object, MCPServer): logger.error(f"The server object {object_name} is of type {type(server_object)} (expecting {MCPServer}).") if isinstance(server_object, LowLevelServer): @@ -156,9 +128,7 @@ def _check_server_object(server_object: Any, object_name: str): return False return True - # If no object specified, try common server names if not server_object: - # Look for the most common server object names for name in ["mcp", "server", "app"]: if hasattr(module, name): if not _check_server_object(getattr(module, name), f"{file}:{name}"): @@ -177,7 +147,6 @@ def _check_server_object(server_object: Any, object_name: str): ) sys.exit(1) - # Handle module:object syntax if ":" in server_object: module_name, object_name = server_object.split(":", 1) try: @@ -190,7 +159,6 @@ def _check_server_object(server_object: Any, object_name: str): ) sys.exit(1) else: - # Just object name server = getattr(module, server_object, None) if server is None: @@ -256,14 +224,12 @@ def dev( ) try: - # Import server to get dependencies server = _import_server(file, server_object) if hasattr(server, "dependencies"): with_packages = list(set(with_packages + server.dependencies)) uv_cmd = _build_uv_command(file_spec, with_editable, with_packages) - # Get the correct npx command npx_cmd = _get_npx_command() if not npx_cmd: logger.error( @@ -271,13 +237,12 @@ def dev( ) sys.exit(1) - # Run the MCP Inspector command with shell=True on Windows shell = sys.platform == "win32" process = subprocess.run( [npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd, check=True, shell=shell, - env=dict(os.environ.items()), # Copy the environment for subprocess launch + env=dict(os.environ.items()), ) sys.exit(process.returncode) except subprocess.CalledProcessError as e: @@ -337,10 +302,8 @@ def run( ) try: - # Import and get server object server = _import_server(file, server_object) - # Run the server kwargs = {} if transport: kwargs["transport"] = transport @@ -432,8 +395,6 @@ def install( logger.error("Claude app not found") sys.exit(1) - # Try to import server to get its name, but fall back to file name if dependencies - # missing name = server_name server = None if not name: @@ -447,16 +408,13 @@ def install( ) name = file.stem - # Get server dependencies if available server_dependencies = getattr(server, "dependencies", []) if server else [] if server_dependencies: with_packages = list(set(with_packages + server_dependencies)) - # Process environment variables if provided env_dict: dict[str, str] | None = None if env_file or env_vars: env_dict = {} - # Load from .env file if specified if env_file: if dotenv: try: @@ -468,7 +426,6 @@ def install( logger.error("python-dotenv is not installed. Cannot load .env file.") sys.exit(1) - # Add command line environment variables for env_var in env_vars: key, value = _parse_env_var(env_var) env_dict[key] = value diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 5fa3ce109b..24141c873c 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -52,11 +52,9 @@ async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]) env_dict = dict(env) if urlparse(command_or_url).scheme in ("http", "https"): - # Use SSE client for HTTP(S) URLs async with sse_client(command_or_url) as streams: await run_session(*streams) else: - # Use stdio client for commands server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict) async with stdio_client(server_parameters) as streams: await run_session(*streams) diff --git a/src/mcp/client/_input_required.py b/src/mcp/client/_input_required.py index fe3f59e175..35f648f581 100644 --- a/src/mcp/client/_input_required.py +++ b/src/mcp/client/_input_required.py @@ -1,12 +1,8 @@ """SEP-2322 client-side multi-round-trip driver. -When a server returns `InputRequiredResult` instead of the normal result of a -`tools/call` / `prompts/get` / `resources/read`, the client fulfils the -embedded `input_requests` (sampling, elicitation, roots) and retries the -original request carrying the responses and the echoed opaque `request_state`. -This module implements that retry loop as a pure function so it can drive any -of the three methods identically; `Client` builds the `dispatch` and `retry` -closures, `ClientSession` stays mechanics-only. +Fulfils the `input_requests` embedded in an `InputRequiredResult` (sampling, +elicitation, roots) and retries the original `tools/call` / `prompts/get` / +`resources/read` with the responses and the echoed opaque `request_state`. """ from __future__ import annotations @@ -21,11 +17,7 @@ from mcp.shared.exceptions import MCPError DEFAULT_INPUT_REQUIRED_MAX_ROUNDS = 10 -"""Default cap on `InputRequiredResult` retry rounds before the driver gives up. - -Matches the typescript-sdk default; csharp-sdk and go-sdk use the same value -as a hard constant. -""" +"""Default cap on retry rounds; matches the typescript-sdk default (csharp-sdk and go-sdk hardcode the same).""" _STATE_ONLY_BACKOFF_INITIAL_SECONDS = 0.05 """First sleep when an `InputRequiredResult` carries only `request_state` (no input requests).""" @@ -58,21 +50,15 @@ async def run_input_required_driver( ) -> ResultT: """Resolve an `InputRequiredResult` to its terminal result. - Loops until `retry` returns a non-`InputRequiredResult`, or `max_rounds` is - exhausted. Each round either dispatches all `input_requests` concurrently - and retries with the collected responses, or β€” when the server sent only - `request_state` β€” sleeps with exponential backoff (50ms doubling to a 250ms - cap, reset by any leg that carries input requests) and retries empty. - `request_state` is passed through byte-exact and never inspected. + Each round dispatches all `input_requests` concurrently and retries with the + responses; a state-only leg instead sleeps with exponential backoff (reset by + any leg carrying requests) and retries empty. `request_state` is echoed + byte-exact, never inspected. Args: - first: The `InputRequiredResult` the original call returned. - dispatch: Runs one embedded `InputRequest` through the client's - sampling / elicitation / roots callbacks. Called concurrently per - request key. An `ErrorData` return aborts the loop as an `MCPError`. - retry: Re-issues the original request with the collected responses and - the latest `request_state`. Each call mints a fresh JSON-RPC id. - max_rounds: Cap on retry rounds. + dispatch: Fulfils one `InputRequest` via the client's sampling/elicitation/ + roots callbacks; an `ErrorData` return aborts the loop as `MCPError`. + retry: Re-issues the original request; each call mints a fresh JSON-RPC id. Raises: InputRequiredRoundsExceededError: `max_rounds` exhausted. @@ -102,10 +88,8 @@ async def _dispatch_all( ) -> InputResponses: """Run `dispatch` concurrently for every key, raising `MCPError` on the first `ErrorData`. - The first task to return `ErrorData` cancels its siblings via the task - group's cancel scope, so a refused input does not wait on a slow peer. - A callback that *raises* propagates as an `ExceptionGroup` like any other - task-group failure. + The first `ErrorData` cancels its sibling tasks so a refused input does not wait + on a slow peer; a callback that raises propagates as an `ExceptionGroup`. """ responses: InputResponses = {} refused: ErrorData | None = None diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index 187131e380..d4195965d5 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -19,28 +19,15 @@ class InMemoryTransport: - """In-memory transport for testing MCP servers without network overhead. - - This transport starts the server in a background task and provides - streams for client-side communication. The server is automatically - stopped when the context manager exits. - """ + """In-memory transport that runs the server in a background task and stops it on context exit.""" def __init__(self, server: Server[Any] | MCPServer, *, raise_exceptions: bool = False) -> None: - """Initialize the in-memory transport. - - Args: - server: The MCP server to connect to (Server or MCPServer instance) - raise_exceptions: Whether to raise exceptions from the server - """ self._server = server self._raise_exceptions = raise_exceptions self._cm: AbstractAsyncContextManager[TransportStreams] | None = None @asynccontextmanager async def _connect(self) -> AsyncIterator[TransportStreams]: - """Connect to the server and yield streams for communication.""" - # Unwrap MCPServer to get underlying Server if isinstance(self._server, MCPServer): # TODO(Marcelo): Make `lowlevel_server` public. actual_server: Server[Any] = self._server._lowlevel_server # type: ignore[reportPrivateUsage] @@ -70,40 +57,26 @@ async def _run_server() -> None: try: yield client_read, client_write finally: - # EOF the server (and our own read side) instead of - # cancelling outright. The dispatcher's run() cancels its - # own in-flight handlers on read-stream EOF, so for a - # well-behaved server the task exits naturally and the - # task-group join below is immediate. Cancelling here - # unconditionally would `coro.throw()` into this task, - # which on CPython 3.11 (gh-106749) drops `'call'` trace - # events for the outer await chain and desyncs coverage's - # CTracer past the test frame. + # EOF the server instead of cancelling: the dispatcher's run() exits on + # read-stream EOF, while cancelling would `coro.throw()` into this task β€” on + # CPython 3.11 (gh-106749) that drops `'call'` trace events and desyncs coverage's CTracer. await client_write.aclose() await server_write.aclose() - # Backstop: the dispatcher exits on EOF, but the server's - # own teardown (lifespan __aexit__, connection.exit_stack - # callbacks) runs after that and is user code. If it never - # completes the join would hang forever, so bound the wait - # and fall back to cancelling. The healthy path returns - # from wait() without the timeout firing, so the cancel is - # never reached and gh-106749 stays avoided. If the cancel - # does fire, the checkpoint at the end of - # `create_client_server_memory_streams` resyncs the tracer. + # Backstop: server teardown (lifespan __aexit__, exit_stack callbacks) is user code + # and may never finish, so bound the wait before falling back to cancelling. If the + # cancel fires, the checkpoint ending `create_client_server_memory_streams` resyncs the tracer. with anyio.move_on_after(SERVER_SHUTDOWN_GRACE): await server_done.wait() if not server_done.is_set(): tg.cancel_scope.cancel() async def __aenter__(self) -> TransportStreams: - """Connect to the server and return streams for communication.""" self._cm = self._connect() return await self._cm.__aenter__() async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: - """Close the transport and stop the server.""" if self._cm is not None: # pragma: no branch await self._cm.__aexit__(exc_type, exc_val, exc_tb) self._cm = None diff --git a/src/mcp/client/_probe.py b/src/mcp/client/_probe.py index 39a5c52964..6aa1b1ed73 100644 --- a/src/mcp/client/_probe.py +++ b/src/mcp/client/_probe.py @@ -1,16 +1,9 @@ -"""Connect-time era negotiation for ``mode='auto'``. +"""Connect-time era negotiation for `mode='auto'`. -The ``server/discover`` probe is sent at the newest modern version. Anything -that is not positive evidence the peer is a modern MCP server falls back to -the legacy ``initialize`` handshake β€” a *denylist* (only the disjoint-modern -case raises) rather than an allowlist of fallback codes. - -Every ``MCPError`` falls back except ``-32022`` with a disjoint modern-only -``supported`` list. The streamable-HTTP transport already maps HTTP-layer -4xx rejections (no JSON-RPC body) into ``MCPError`` codes, so those reach -the same path. Any non-``MCPError`` exception (network/connection errors, -anyio cancellation, the ``RuntimeError`` from ``adopt()`` on no-mutual) -propagates to the caller; an outage or in-process bug is never an era verdict. +Fallback to legacy `initialize` is a denylist: every `MCPError` falls back +except `-32022` with a disjoint modern-only `supported` list. Streamable HTTP +maps HTTP-layer 4xx rejections into `MCPError` codes, so they take the same +path. Non-`MCPError` exceptions propagate β€” an outage is never an era verdict. """ from __future__ import annotations @@ -31,7 +24,7 @@ def _parse_supported(data: Any) -> list[str] | None: - """Pull ``data.supported`` off a -32022 error, or ``None`` if not actionable.""" + """Pull `data.supported` off a -32022 error, or `None` if not actionable.""" try: return types.UnsupportedProtocolVersionErrorData.model_validate(data).supported except ValidationError: @@ -39,18 +32,15 @@ def _parse_supported(data: Any) -> list[str] | None: async def negotiate_auto(session: ClientSession) -> None: - """Drive the ``mode='auto'`` connect-time policy on ``session``. + """Drive the `mode='auto'` connect-time policy on `session`. - Probes ``server/discover`` once (twice if the server names a mutual - modern version via -32022), then either ``adopt()``s the result or falls - back to ``initialize()``. Idempotent only in the sense that one of - ``session.discover_result`` / ``session.initialize_result`` is set on - return. + Probes `server/discover` (retrying once at a mutual modern version on + -32022), then `adopt()`s the result or falls back to `initialize()`; one of + `session.discover_result`/`session.initialize_result` is set on return. Raises: - MCPError: The server is modern-only and shares no version with this - client (-32022 with a disjoint ``supported`` list). - Exception: Any transport/network error from the probe propagates as-is. + MCPError: Server is modern-only with a disjoint `supported` list (-32022). + Exception: Transport/network errors from the probe propagate as-is. """ version = LATEST_MODERN_VERSION for attempt in range(2): @@ -65,10 +55,8 @@ async def negotiate_auto(session: ClientSession) -> None: continue if supported is not None and not any(v in HANDSHAKE_PROTOCOL_VERSIONS for v in supported): raise # server is modern-only and disjoint β€” real incompatibility - await session.initialize() # every other rpc-error β†’ legacy (the denylist) + await session.initialize() # any other MCPError β†’ legacy (the denylist) return - # any other exception (httpx.TransportError, ConnectionError, anyio errors, - # RuntimeError from adopt) β†’ propagate try: result = types.DiscoverResult.model_validate(raw) except ValidationError: diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index 0163fef950..c4755fcef5 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -14,8 +14,4 @@ class Transport(AbstractAsyncContextManager[TransportStreams], Protocol): - """Protocol for MCP transports. - - A transport is an async context manager that yields read and write streams - for bidirectional communication with an MCP server. - """ + """An async context manager yielding read/write streams for bidirectional communication with an MCP server.""" diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py index 9d00fc700f..ab35b2b348 100644 --- a/src/mcp/client/auth/__init__.py +++ b/src/mcp/client/auth/__init__.py @@ -1,7 +1,4 @@ -"""OAuth2 Authentication implementation for HTTPX. - -Implements authorization code flow with PKCE and automatic token refresh. -""" +"""OAuth2 authentication for HTTPX: authorization code flow with PKCE and automatic token refresh.""" from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError from mcp.client.auth.oauth2 import ( diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index 1daf55c1c5..b161b16231 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -1,10 +1,7 @@ -"""OAuth client credential extensions for MCP. +"""OAuth client credential extensions for machine-to-machine authentication flows. -Provides OAuth providers for machine-to-machine authentication flows: -- ClientCredentialsOAuthProvider: For client_credentials with client_id + client_secret -- PrivateKeyJWTOAuthProvider: For client_credentials with private_key_jwt authentication - (typically using a pre-built JWT from workload identity federation) -- RFC7523OAuthClientProvider: For jwt-bearer grant (RFC 7523 Section 2.1) +Provides client_credentials providers (client secret and private_key_jwt) and the +deprecated RFC 7523 jwt-bearer grant provider. """ import time @@ -25,18 +22,8 @@ class ClientCredentialsOAuthProvider(OAuthClientProvider): """OAuth provider for client_credentials grant with client_id + client_secret. - This provider sets client_info directly, bypassing dynamic client registration. - Use this when you already have client credentials (client_id and client_secret). - - Example: - ```python - provider = ClientCredentialsOAuthProvider( - server_url="https://api.example.com", - storage=my_token_storage, - client_id="my-client-id", - client_secret="my-client-secret", - ) - ``` + Sets client_info directly, bypassing dynamic client registration; use when you + already have client credentials. """ def __init__( @@ -48,18 +35,11 @@ def __init__( token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post"] = "client_secret_basic", scopes: str | None = None, ) -> None: - """Initialize client_credentials OAuth provider. + """Initialize the provider. Args: - server_url: The MCP server URL. - storage: Token storage implementation. - client_id: The OAuth client ID. - client_secret: The OAuth client secret. - token_endpoint_auth_method: Authentication method for token endpoint. - Either "client_secret_basic" (default) or "client_secret_post". - scopes: Optional space-separated list of scopes to request. + scopes: Optional space-separated scopes to request. """ - # Build minimal client_metadata for the base class client_metadata = OAuthClientMetadata( redirect_uris=None, grant_types=["client_credentials"], @@ -67,7 +47,7 @@ def __init__( scope=scopes, ) super().__init__(server_url, client_metadata, storage, None, None, 300.0) - # Store client_info to be set during _initialize - no dynamic registration needed + # Applied in _initialize instead of dynamic client registration self._fixed_client_info = OAuthClientInformationFull( redirect_uris=None, client_id=client_id, @@ -84,18 +64,15 @@ async def _initialize(self) -> None: self._initialized = True async def _perform_authorization(self) -> httpx.Request: - """Perform client_credentials authorization.""" return await self._exchange_token_client_credentials() async def _exchange_token_client_credentials(self) -> httpx.Request: - """Build token exchange request for client_credentials grant.""" token_data: dict[str, Any] = { "grant_type": "client_credentials", } headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} - # Use standard auth methods (client_secret_basic, client_secret_post, none) token_data, headers = self.context.prepare_token_auth(token_data, headers) if self.context.should_include_resource_param(self.context.protocol_version): @@ -109,26 +86,10 @@ async def _exchange_token_client_credentials(self) -> httpx.Request: def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]: - """Create an assertion provider that returns a static JWT token. - - Use this when you have a pre-built JWT (e.g., from workload identity federation) - that doesn't need the audience parameter. - - Example: - ```python - provider = PrivateKeyJWTOAuthProvider( - server_url="https://api.example.com", - storage=my_token_storage, - client_id="my-client-id", - assertion_provider=static_assertion_provider(my_prebuilt_jwt), - ) - ``` + """Create an assertion provider that returns `token` unchanged, ignoring the audience. - Args: - token: The pre-built JWT assertion string. - - Returns: - An async callback suitable for use as an assertion_provider. + Use for a pre-built JWT (e.g. from workload identity federation) that doesn't + depend on the audience parameter. """ async def provider(audience: str) -> str: @@ -140,23 +101,7 @@ async def provider(audience: str) -> str: class SignedJWTParameters(BaseModel): """Parameters for creating SDK-signed JWT assertions. - Use `create_assertion_provider()` to create an assertion provider callback - for use with `PrivateKeyJWTOAuthProvider`. - - Example: - ```python - jwt_params = SignedJWTParameters( - issuer="my-client-id", - subject="my-client-id", - signing_key=private_key_pem, - ) - provider = PrivateKeyJWTOAuthProvider( - server_url="https://api.example.com", - storage=my_token_storage, - client_id="my-client-id", - assertion_provider=jwt_params.create_assertion_provider(), - ) - ``` + `create_assertion_provider()` yields a callback for `PrivateKeyJWTOAuthProvider`. """ issuer: str = Field(description="Issuer for JWT assertions (typically client_id).") @@ -167,11 +112,10 @@ class SignedJWTParameters(BaseModel): additional_claims: dict[str, Any] | None = Field(default=None, description="Additional claims.") def create_assertion_provider(self) -> Callable[[str], Awaitable[str]]: - """Create an assertion provider callback for use with PrivateKeyJWTOAuthProvider. + """Create an assertion provider for `PrivateKeyJWTOAuthProvider`. - Returns: - An async callback that takes the audience (authorization server issuer URL) - and returns a signed JWT assertion. + The returned callback signs a fresh JWT whose audience is the authorization + server issuer URL it receives. """ async def provider(audience: str) -> str: @@ -195,61 +139,14 @@ async def provider(audience: str) -> str: class PrivateKeyJWTOAuthProvider(OAuthClientProvider): """OAuth provider for client_credentials grant with private_key_jwt authentication. - Uses RFC 7523 Section 2.2 for client authentication via JWT assertion. - - The JWT assertion's audience MUST be the authorization server's issuer identifier - (per RFC 7523bis security updates). The `assertion_provider` callback receives - this audience value and must return a JWT with that audience. - - **Option 1: Pre-built JWT via Workload Identity Federation** - - In production scenarios, the JWT assertion is typically obtained from a workload - identity provider (e.g., GCP, AWS IAM, Azure AD): - - ```python - async def get_workload_identity_token(audience: str) -> str: - # Fetch JWT from your identity provider - # The JWT's audience must match the provided audience parameter - return await fetch_token_from_identity_provider(audience=audience) - - provider = PrivateKeyJWTOAuthProvider( - server_url="https://api.example.com", - storage=my_token_storage, - client_id="my-client-id", - assertion_provider=get_workload_identity_token, - ) - ``` - - **Option 2: Static pre-built JWT** - - If you have a static JWT that doesn't need the audience parameter: - - ```python - provider = PrivateKeyJWTOAuthProvider( - server_url="https://api.example.com", - storage=my_token_storage, - client_id="my-client-id", - assertion_provider=static_assertion_provider(my_prebuilt_jwt), - ) - ``` - - **Option 3: SDK-signed JWT (for testing/simple setups)** - - For testing or simple deployments, use `SignedJWTParameters.create_assertion_provider()`: - - ```python - jwt_params = SignedJWTParameters( - issuer="my-client-id", - subject="my-client-id", - signing_key=private_key_pem, - ) - provider = PrivateKeyJWTOAuthProvider( - server_url="https://api.example.com", - storage=my_token_storage, - client_id="my-client-id", - assertion_provider=jwt_params.create_assertion_provider(), - ) - ``` + Client authentication uses a JWT assertion (RFC 7523 Section 2.2). The assertion's + audience MUST be the authorization server's issuer identifier (per RFC 7523bis); + the `assertion_provider` callback receives this audience value and must return a + JWT with that audience. Supply the callback from a workload identity provider + (e.g. GCP, AWS IAM, Azure AD β€” fetch a JWT for the given audience), from + `static_assertion_provider()` for a static pre-built JWT, or from + `SignedJWTParameters.create_assertion_provider()` for SDK-signed JWTs + (testing/simple setups). """ def __init__( @@ -260,20 +157,13 @@ def __init__( assertion_provider: Callable[[str], Awaitable[str]], scopes: str | None = None, ) -> None: - """Initialize private_key_jwt OAuth provider. + """Initialize the provider. Args: - server_url: The MCP server URL. - storage: Token storage implementation. - client_id: The OAuth client ID. assertion_provider: Async callback that takes the audience (authorization - server's issuer identifier) and returns a JWT assertion. Use - `SignedJWTParameters.create_assertion_provider()` for SDK-signed JWTs, - `static_assertion_provider()` for pre-built JWTs, or provide your own - callback for workload identity federation. - scopes: Optional space-separated list of scopes to request. + server's issuer identifier) and returns a JWT assertion. + scopes: Optional space-separated scopes to request. """ - # Build minimal client_metadata for the base class client_metadata = OAuthClientMetadata( redirect_uris=None, grant_types=["client_credentials"], @@ -282,7 +172,7 @@ def __init__( ) super().__init__(server_url, client_metadata, storage, None, None, 300.0) self._assertion_provider = assertion_provider - # Store client_info to be set during _initialize - no dynamic registration needed + # Applied in _initialize instead of dynamic client registration self._fixed_client_info = OAuthClientInformationFull( redirect_uris=None, client_id=client_id, @@ -298,11 +188,9 @@ async def _initialize(self) -> None: self._initialized = True async def _perform_authorization(self) -> httpx.Request: - """Perform client_credentials authorization with private_key_jwt.""" return await self._exchange_token_client_credentials() async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None: - """Add JWT assertion for client authentication to token endpoint parameters.""" if not self.context.oauth_metadata: raise OAuthFlowError("Missing OAuth metadata for private_key_jwt flow") # pragma: no cover @@ -316,14 +204,12 @@ async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" async def _exchange_token_client_credentials(self) -> httpx.Request: - """Build token exchange request for client_credentials grant with private_key_jwt.""" token_data: dict[str, Any] = { "grant_type": "client_credentials", } headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} - # Add JWT client authentication (RFC 7523 Section 2.2) await self._add_client_authentication_jwt(token_data=token_data) if self.context.should_include_resource_param(self.context.protocol_version): @@ -389,15 +275,11 @@ def to_assertion(self, with_audience_fallback: str | None = None) -> str: class RFC7523OAuthClientProvider(OAuthClientProvider): - """OAuth client provider for RFC 7523 jwt-bearer grant. - - .. deprecated:: - Use :class:`ClientCredentialsOAuthProvider` for client_credentials with - client_id + client_secret, or :class:`PrivateKeyJWTOAuthProvider` for - client_credentials with private_key_jwt authentication instead. + """OAuth client provider for the jwt-bearer authorization grant (RFC 7523 Section 2.1). - This provider supports the jwt-bearer authorization grant (RFC 7523 Section 2.1) - where the JWT itself is the authorization grant. + The JWT itself is the authorization grant. Deprecated: use + `ClientCredentialsOAuthProvider` (client_id + client_secret) or + `PrivateKeyJWTOAuthProvider` (private_key_jwt) instead. """ def __init__( @@ -422,14 +304,12 @@ def __init__( async def _exchange_token_authorization_code( self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None ) -> httpx.Request: # pragma: no cover - """Build token exchange request for authorization_code flow.""" token_data = token_data or {} if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt": self._add_client_authentication_jwt(token_data=token_data) return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data) async def _perform_authorization(self) -> httpx.Request: # pragma: no cover - """Perform the authorization flow.""" if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types: token_request = await self._exchange_token_jwt_bearer() return token_request @@ -437,26 +317,23 @@ async def _perform_authorization(self) -> httpx.Request: # pragma: no cover return await super()._perform_authorization() def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover - """Add JWT assertion for client authentication to token endpoint parameters.""" if not self.jwt_parameters: raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow") if not self.context.oauth_metadata: raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow") - # We need to set the audience to the issuer identifier of the authorization server + # JWT audience MUST be the issuer identifier of the authorization server # https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523 issuer = str(self.context.oauth_metadata.issuer) assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer) - # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2 + # RFC 7523 Section 2.2: client authentication via JWT token_data["client_assertion"] = assertion token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" - # We need to set the audience to the resource server, the audience is different from the one in claims - # it represents the resource server that will validate the token + # Unlike the JWT's aud claim, this audience is the resource server that will validate the token token_data["audience"] = self.context.get_resource_url() async def _exchange_token_jwt_bearer(self) -> httpx.Request: - """Build token exchange request for JWT bearer grant.""" if not self.context.client_info: raise OAuthFlowError("Missing client info") # pragma: no cover if not self.jwt_parameters: @@ -464,7 +341,7 @@ async def _exchange_token_jwt_bearer(self) -> httpx.Request: if not self.context.oauth_metadata: raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover - # We need to set the audience to the issuer identifier of the authorization server + # JWT audience MUST be the issuer identifier of the authorization server # https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523 issuer = str(self.context.oauth_metadata.issuer) assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer) diff --git a/src/mcp/client/auth/extensions/identity_assertion.py b/src/mcp/client/auth/extensions/identity_assertion.py index 2d97e5eff2..26669146a2 100644 --- a/src/mcp/client/auth/extensions/identity_assertion.py +++ b/src/mcp/client/auth/extensions/identity_assertion.py @@ -1,22 +1,7 @@ -"""SEP-990 Identity Assertion Authorization Grant (RFC 7523 jwt-bearer) client provider. - -`IdentityAssertionOAuthProvider` is the client side of SEP-990 leg 2: it presents an Identity -Assertion Authorization Grant (ID-JAG) - a signed JWT issued by the enterprise identity provider - -to the MCP authorization server's token endpoint using the RFC 7523 jwt-bearer grant -(`grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer`, ID-JAG as `assertion`), and receives an -MCP access token. - -The authorization server is configuration, not discovery. SEP-990's trust model is the inverse of -the default OAuth client's: the AS issuer is supplied at construction, authorization-server metadata -is fetched from that issuer's own RFC 8414 well-known, and the resource server is never asked which -AS to use - so it cannot redirect the ID-JAG or client secret elsewhere. There is no protected -resource metadata fetch, no dynamic client registration, and no server-driven scope selection. - -Obtaining the ID-JAG (logging into the IdP and the leg-1 token exchange against it) is -deployment-specific and out of scope for the SDK. The caller supplies it through the -`assertion_provider` callback, which receives the configured issuer (the `aud` the ID-JAG must -carry) and the MCP server's resource identifier (the `resource` claim it must carry, per ext-auth -section 4.3), and returns the ID-JAG. +"""SEP-990 Identity Assertion Authorization Grant (ID-JAG) client provider. + +The client side of SEP-990 leg 2: exchange an enterprise-IdP-issued ID-JAG for an MCP access token +via the RFC 7523 jwt-bearer grant at a statically configured authorization server. """ import base64 @@ -46,11 +31,7 @@ def _origin(url: str) -> tuple[str, str, int | None]: - """Return the (scheme, host, port) origin of a URL for same-origin comparison. - - The port is normalized to the scheme's default so an explicit `:443`/`:80` compares equal to the - same origin written without a port. - """ + """Return a URL's (scheme, host, port) origin for comparison; a missing port becomes the scheme's default.""" parsed = urlsplit(url) port = parsed.port if parsed.port is not None else _DEFAULT_PORTS.get(parsed.scheme) return (parsed.scheme, parsed.hostname or "", port) @@ -59,17 +40,16 @@ def _origin(url: str) -> tuple[str, str, int | None]: class IdentityAssertionOAuthProvider(httpx.Auth): """`httpx.Auth` for the SEP-990 ID-JAG flow (RFC 7523 jwt-bearer grant) against a configured AS. - The authorization server `issuer` is fixed at construction; metadata is fetched from its - RFC 8414 well-known and the ID-JAG and client secret are sent only to that issuer's token - endpoint. The resource server is never consulted for AS selection. The ID-JAG is fetched lazily - from `assertion_provider` so a fresh assertion is used on each exchange. + The AS `issuer` is fixed at construction; metadata comes from its RFC 8414 well-known, and the + ID-JAG and client secret are sent only to its token endpoint. The resource server is never asked + which AS to use, so it cannot redirect them elsewhere - there is no protected-resource metadata + fetch, dynamic client registration, or server-driven scope selection. The ID-JAG is fetched + lazily from `assertion_provider` so each exchange uses a fresh assertion. Example: ```python async def fetch_id_jag(audience: str, resource: str) -> str: - # `audience` is the configured issuer (the ID-JAG `aud`); `resource` is the MCP - # server's identifier (the ID-JAG `resource` claim). Obtaining the ID-JAG from the - # enterprise IdP is deployment-specific and not handled by the SDK. + # Obtaining the ID-JAG from the enterprise IdP is deployment-specific; the SDK does not handle it. return await my_idp.issue_id_jag(audience=audience, resource=resource) @@ -100,18 +80,10 @@ def __init__( """Initialize the identity-assertion OAuth provider. Args: - server_url: The MCP server URL. - storage: Token storage implementation. - client_id: The OAuth client ID registered with the MCP authorization server. - client_secret: The client secret. SEP-990 section 5.1 requires a confidential client. - issuer: The issuer identifier of the MCP authorization server this client is provisioned - for. Authorization-server metadata is fetched from this issuer's well-known and the - ID-JAG and secret are sent only to its token endpoint. - assertion_provider: Async callback taking `(audience, resource)` - the configured issuer - and the MCP server's resource identifier - and returning the ID-JAG. - scope: Optional space-separated list of scopes to request. - token_endpoint_auth_method: Confidential-client auth method, either `client_secret_post` - (default) or `client_secret_basic`. + client_secret: Required; SEP-990 section 5.1 mandates a confidential client. + assertion_provider: Async callback `(audience, resource) -> ID-JAG`: `audience` is the + configured issuer (the ID-JAG `aud`), `resource` the MCP server's identifier (its + `resource` claim, per ext-auth section 4.3). """ if not client_secret: raise ValueError("client_secret is required: SEP-990 mandates a confidential client") diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 711848d724..66a040feb4 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -1,7 +1,4 @@ -"""OAuth2 Authentication implementation for HTTPX. - -Implements authorization code flow with PKCE and automatic token refresh. -""" +"""OAuth2 authentication for HTTPX: authorization code flow with PKCE and automatic token refresh.""" import base64 import hashlib @@ -106,20 +103,16 @@ class OAuthContext: timeout: float = 300.0 client_metadata_url: str | None = None - # Discovered metadata protected_resource_metadata: ProtectedResourceMetadata | None = None oauth_metadata: OAuthMetadata | None = None auth_server_url: str | None = None protocol_version: str | None = None - # Client registration client_info: OAuthClientInformationFull | None = None - # Token management current_tokens: OAuthToken | None = None token_expiry_time: float | None = None - # State lock: anyio.Lock = field(default_factory=anyio.Lock) def get_authorization_base_url(self, server_url: str) -> str: @@ -128,11 +121,9 @@ def get_authorization_base_url(self, server_url: str) -> str: return f"{parsed.scheme}://{parsed.netloc}" def update_token_expiry(self, token: OAuthToken) -> None: - """Update token expiry time using shared util function.""" self.token_expiry_time = calculate_token_expiry(token.expires_in) def is_token_valid(self) -> bool: - """Check if current token is valid.""" return bool( self.current_tokens and self.current_tokens.access_token @@ -140,22 +131,16 @@ def is_token_valid(self) -> bool: ) def can_refresh_token(self) -> bool: - """Check if token can be refreshed.""" return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) def clear_tokens(self) -> None: - """Clear current tokens.""" self.current_tokens = None self.token_expiry_time = None def get_resource_url(self) -> str: - """Get resource URL for RFC 8707. - - Uses PRM resource if it's a valid parent, otherwise uses canonical server URL. - """ + """Get the RFC 8707 resource URL, preferring the PRM resource when it's a valid parent.""" resource = resource_url_from_server_url(self.server_url) - # If PRM provides a resource that's a valid parent, use it if self.protected_resource_metadata and self.protected_resource_metadata.resource: prm_resource = str(self.protected_resource_metadata.resource) if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): @@ -164,17 +149,10 @@ def get_resource_url(self) -> str: return resource def should_include_resource_param(self, protocol_version: str | None = None) -> bool: - """Determine if the resource parameter should be included in OAuth requests. - - Returns True if: - - Protected resource metadata is available, OR - - MCP-Protocol-Version header is 2025-06-18 or later - """ - # If we have protected resource metadata, include the resource param + """True when PRM is available or the protocol version is 2025-06-18 or later (RFC 8707).""" if self.protected_resource_metadata is not None: return True - # If no protocol version provided, don't include resource param if not protocol_version: return False @@ -183,15 +161,7 @@ def should_include_resource_param(self, protocol_version: str | None = None) -> def prepare_token_auth( self, data: dict[str, str], headers: dict[str, str] | None = None ) -> tuple[dict[str, str], dict[str, str]]: - """Prepare authentication for token requests. - - Args: - data: The form data to send - headers: Optional headers dict to update - - Returns: - Tuple of (updated_data, updated_headers) - """ + """Apply the client's token endpoint auth method; returns (updated_data, updated_headers).""" if headers is None: headers = {} # pragma: no cover @@ -210,19 +180,14 @@ def prepare_token_auth( # Don't include client_secret in body for basic auth data = {k: v for k, v in data.items() if k != "client_secret"} elif auth_method == "client_secret_post" and self.client_info.client_id and self.client_info.client_secret: - # Include client_id and client_secret in request body (RFC 6749 Β§2.3.1) data["client_id"] = self.client_info.client_id data["client_secret"] = self.client_info.client_secret - # For auth_method == "none", don't add any client_secret return data, headers class OAuthClientProvider(httpx.Auth): - """OAuth2 authentication for httpx. - - Handles OAuth flow with automatic client registration and token storage. - """ + """OAuth2 authentication for httpx with automatic client registration and token storage.""" requires_response_body = True @@ -240,26 +205,17 @@ def __init__( """Initialize OAuth2 authentication. Args: - server_url: The MCP server URL. - client_metadata: OAuth client metadata for registration. - storage: Token storage implementation. - redirect_handler: Handler for authorization redirects. - callback_handler: Handler for authorization callbacks. - timeout: Timeout for the OAuth flow. - client_metadata_url: URL-based client ID. When provided and the server - advertises client_id_metadata_document_supported=True, this URL will be - used as the client_id instead of performing dynamic client registration. - Must be a valid HTTPS URL with a non-root pathname. - validate_resource_url: Optional callback to override resource URL validation. - Called with (server_url, prm_resource) where prm_resource is the resource - from Protected Resource Metadata (or None if not present). If not provided, - default validation rejects mismatched resources per RFC 8707. + client_metadata_url: URL-based client ID (CIMD). When the server advertises + client_id_metadata_document_supported=True, this URL is used as the client_id + instead of dynamic client registration. Must be a valid HTTPS URL with a + non-root pathname. + validate_resource_url: Optional override for resource URL validation, called with + (server_url, prm_resource) where prm_resource may be None. Default validation + rejects mismatched resources per RFC 8707. Raises: - ValueError: If client_metadata_url is provided but not a valid HTTPS URL - with a non-root pathname. + ValueError: If client_metadata_url is not a valid HTTPS URL with a non-root pathname. """ - # Validate client_metadata_url if provided if client_metadata_url is not None and not is_valid_client_metadata_url(client_metadata_url): raise ValueError( f"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {client_metadata_url}" @@ -278,13 +234,7 @@ def __init__( self._initialized = False async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: - """Handle protected resource metadata discovery response. - - Per SEP-985, supports fallback when discovery fails at one URL. - - Returns: - True if metadata was successfully discovered, False if we should try next URL - """ + """Handle PRM discovery response; False means try the next URL (SEP-985 fallback).""" if response.status_code == 200: try: content = await response.aread() @@ -295,21 +245,17 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> return True except ValidationError: # pragma: no cover - # Invalid metadata - try next URL logger.warning(f"Invalid protected resource metadata at {response.request.url}") return False elif response.status_code == 404: # pragma: no cover - # Not found - try next URL in fallback chain logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL") return False else: - # Other error - fail immediately raise OAuthFlowError( f"Protected Resource Metadata request failed: {response.status_code}" ) # pragma: no cover async def _perform_authorization(self) -> httpx.Request: - """Perform the authorization flow.""" auth_code, code_verifier = await self._perform_authorization_code_grant() token_request = await self._exchange_token_authorization_code(auth_code, code_verifier) return token_request @@ -332,7 +278,6 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if not self.context.client_info: raise OAuthFlowError("No client info available for authorization") # pragma: no cover - # Generate PKCE parameters pkce_params = PKCEParameters.generate() state = secrets.token_urlsafe(32) @@ -345,7 +290,6 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: "code_challenge_method": "S256", } - # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): auth_params["resource"] = self.context.get_resource_url() # RFC 8707 @@ -360,7 +304,6 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" await self.context.redirect_handler(authorization_url) - # Wait for callback result = await self.context.callback_handler() if result.state is None or not secrets.compare_digest(result.state, state): @@ -372,7 +315,6 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if not result.code: raise OAuthFlowError("No authorization code received") - # Return auth code and code verifier for token exchange return result.code, pkce_params.code_verifier def _get_token_endpoint(self) -> str: @@ -404,34 +346,28 @@ async def _exchange_token_authorization_code( } ) - # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): token_data["resource"] = self.context.get_resource_url() # RFC 8707 - # Prepare authentication based on preferred method headers = {"Content-Type": "application/x-www-form-urlencoded"} token_data, headers = self.context.prepare_token_auth(token_data, headers) return httpx.Request("POST", token_url, data=token_data, headers=headers) async def _handle_token_response(self, response: httpx.Response) -> None: - """Handle token exchange response.""" if response.status_code not in {200, 201}: body = await response.aread() body_text = body.decode("utf-8") raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") - # Parse and validate response with scope validation token_response = await handle_token_response_scopes(response) - # RFC 6749 Β§5.1: an omitted scope means the granted scope equals the requested - # scope. Record it explicitly so the persisted token is self-describing β€” the - # SEP-2350 step-up union reads it after a restart, when client_metadata.scope - # has reverted to its constructor value. + # RFC 6749 Β§5.1: omitted scope means granted == requested. Record it so the persisted + # token stays self-describing for the SEP-2350 step-up union after a restart, when + # client_metadata.scope has reverted to its constructor value. if token_response.scope is None: token_response.scope = self.context.client_metadata.scope - # Store tokens in context self.context.current_tokens = token_response self.context.update_token_expiry(token_response) await self.context.storage.set_tokens(token_response) @@ -456,11 +392,9 @@ async def _refresh_token(self) -> httpx.Request: "client_id": self.context.client_info.client_id, } - # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 - # Prepare authentication based on preferred method headers = {"Content-Type": "application/x-www-form-urlencoded"} refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers) @@ -477,10 +411,9 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: content = await response.aread() token_response = OAuthToken.model_validate_json(content) - # RFC 6749 Β§6: a refresh response may omit scope (unchanged) and refresh_token - # (the AS does not rotate). Carry both forward so the persisted token stays - # self-describing for the SEP-2350 step-up union and the next expiry can - # still refresh instead of forcing a full re-authorization. + # RFC 6749 Β§6: a refresh response may omit scope (unchanged) and refresh_token (not + # rotated). Carry both forward so the persisted token stays self-describing for the + # SEP-2350 step-up union and the next expiry can refresh instead of fully re-authorizing. prior = self.context.current_tokens if token_response.scope is None and prior is not None: token_response.scope = prior.scope @@ -504,7 +437,6 @@ async def _initialize(self) -> None: self._initialized = True def _add_auth_header(self, request: httpx.Request) -> None: - """Add authorization header to request if we have valid tokens.""" if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" @@ -533,16 +465,14 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if not self._initialized: await self._initialize() - # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) if not self.context.is_token_valid() and self.context.can_refresh_token(): - # Try to refresh token refresh_request = await self._refresh_token() refresh_response = yield refresh_request if not await self._handle_refresh_response(refresh_response): - # Refresh failed, need full re-authentication + # Refresh failed: force a full re-authentication self._initialized = False if self.context.is_token_valid(): @@ -551,12 +481,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. response = yield request if response.status_code == 401: - # Perform full OAuth flow + # Full OAuth flow, written inline due to generator constraints try: - # OAuth flow must be inline due to generator constraints www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) - # Step 1: Discover protected resource metadata (SEP-985 with fallback support) + # Step 1: discover protected resource metadata (SEP-985 fallback chain) prm_discovery_urls = build_protected_resource_metadata_discovery_urls( www_auth_resource_metadata_url, self.context.server_url ) @@ -564,11 +493,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. for url in prm_discovery_urls: # pragma: no branch discovery_request = create_oauth_metadata_request(url) - discovery_response = yield discovery_request # sending request + discovery_response = yield discovery_request prm = await handle_protected_resource_response(discovery_response) if prm: - # Validate PRM resource matches server URL (RFC 8707) await self._validate_resource_match(prm) self.context.protected_resource_metadata = prm @@ -582,9 +510,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. else: logger.debug(f"Protected resource metadata discovery failed: {url}") - # SEP-2352: stored credentials are bound to the issuer that registered them. - # If the authorization server changed, drop them (and the old tokens) so the - # flow re-registers instead of presenting another server's credentials. + # SEP-2352: stored credentials are bound to their registering issuer. If the + # authorization server changed, drop them (and the old tokens) so the flow re-registers. if ( self.context.client_info is not None and self.context.auth_server_url is not None @@ -603,7 +530,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.auth_server_url, self.context.server_url ) - # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) + # Step 2: discover Authorization Server Metadata (with fallback for legacy servers) for url in asm_discovery_urls: # pragma: no branch oauth_metadata_request = create_oauth_metadata_request(url) oauth_metadata_response = yield oauth_metadata_request @@ -620,9 +547,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. else: logger.debug(f"OAuth metadata discovery failed: {url}") - # SEP-2352: on the legacy no-PRM path the issuer is only known after ASM - # discovery, so re-evaluate the binding here using the discovered metadata - # issuer (mirroring the bound_issuer fallback in Step 4). + # SEP-2352: on the legacy no-PRM path the issuer is only known after ASM discovery, + # so re-evaluate the binding using the discovered metadata issuer (mirrors Step 4). if ( self.context.client_info is not None and self.context.auth_server_url is None @@ -637,7 +563,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_info = None self.context.clear_tokens() - # Step 3: Apply scope selection strategy + # Step 3: apply scope selection strategy self.context.client_metadata.scope = get_client_metadata_scopes( extract_scope_from_www_auth(response), self.context.protected_resource_metadata, @@ -645,7 +571,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_metadata.grant_types, ) - # Step 4: Register client or use URL-based client ID (CIMD) + # Step 4: register client or use URL-based client ID (CIMD) if not self.context.client_info: # SEP-2352: the issuer to bind these credentials to, when known. discovered_issuer: str | None = None @@ -655,8 +581,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if should_use_client_metadata_url( self.context.oauth_metadata, self.context.client_metadata_url ): - # Use URL-based client ID (CIMD). CIMD records are portable across - # authorization servers, so the issuer stamp is informational. + # CIMD records are portable across authorization servers, so the + # issuer stamp is informational. logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}") client_information = create_client_info_from_metadata_url( self.context.client_metadata_url, # type: ignore[arg-type] @@ -666,19 +592,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_info = client_information await self.context.storage.set_client_info(client_information) else: - # Fallback to Dynamic Client Registration fallback_base = self.context.get_authorization_base_url(self.context.server_url) registration_request = create_client_registration_request( self.context.oauth_metadata, self.context.client_metadata, fallback_base ) registration_response = yield registration_request client_information = await handle_registration_response(registration_response) - # Only record the issuer when the registration above actually targeted - # the discovered AS β€” either via its published registration_endpoint, - # or because the resource-origin /register fallback is on the issuer's - # own host (legacy same-origin embedded AS). Otherwise the fallback hit - # a different server and recording a binding to the PRM-advertised AS - # would persist a binding that was never established. + # Stamp the issuer only when registration actually targeted the discovered + # AS β€” via its published registration_endpoint, or a same-origin /register + # fallback. Otherwise the binding was never established with that issuer. if ( self.context.oauth_metadata is not None and discovered_issuer is not None @@ -691,7 +613,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_info = client_information await self.context.storage.set_client_info(client_information) - # Step 5: Perform authorization and complete token exchange + # Step 5: perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) except Exception: @@ -702,16 +624,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self._add_auth_header(request) yield request elif response.status_code == 403: - # Step 1: Extract error field from WWW-Authenticate header error = extract_field_from_www_auth(response, "error") - # Step 2: Check if we need to step-up authorization if error == "insufficient_scope": # pragma: no branch try: - # Step 2a: Union previously requested scopes with the newly challenged - # scopes (SEP-2350) so escalating one operation keeps the others' grants. - # Fold in the stored token's scope too: on a restart the token is reloaded - # but client_metadata.scope is not, so it would otherwise be the only basis. + # SEP-2350 step-up: union previously requested scopes with the newly challenged + # ones so escalating one operation keeps the others' grants. Fold in the stored + # token's scope too: on a restart the token is reloaded but client_metadata.scope is not. challenged_scope = get_client_metadata_scopes( extract_scope_from_www_auth(response), self.context.protected_resource_metadata, @@ -722,7 +641,6 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. prior_scope = union_scopes(self.context.client_metadata.scope, granted_scope) self.context.client_metadata.scope = union_scopes(prior_scope, challenged_scope) - # Step 2b: Perform (re-)authorization and token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) except Exception: # pragma: no cover diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index d6b05e0667..4b656399ab 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -17,41 +17,28 @@ def extract_field_from_www_auth(response: Response, field_name: str) -> str | None: - """Extract field from WWW-Authenticate header. - - Returns: - Field value if found in WWW-Authenticate header, None otherwise - """ + """Extract a field value from the WWW-Authenticate header, or None if absent.""" www_auth_header = response.headers.get("WWW-Authenticate") if not www_auth_header: return None - # Pattern matches: field_name="value" or field_name=value (unquoted) + # Matches field_name="value" or field_name=value (unquoted) pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' match = re.search(pattern, www_auth_header) if match: - # Return quoted value if present, otherwise unquoted value return match.group(1) or match.group(2) return None def extract_scope_from_www_auth(response: Response) -> str | None: - """Extract scope parameter from WWW-Authenticate header as per RFC 6750. - - Returns: - Scope string if found in WWW-Authenticate header, None otherwise - """ + """Extract the scope parameter from the WWW-Authenticate header (RFC 6750).""" return extract_field_from_www_auth(response, "scope") def extract_resource_metadata_from_www_auth(response: Response) -> str | None: - """Extract protected resource metadata URL from WWW-Authenticate header as per RFC 9728. - - Returns: - Resource metadata URL if found in WWW-Authenticate header, None otherwise - """ + """Extract the protected resource metadata URL from the WWW-Authenticate header (RFC 9728).""" if not response or response.status_code != 401: return None # pragma: no cover @@ -59,36 +46,23 @@ def extract_resource_metadata_from_www_auth(response: Response) -> str | None: def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]: - """Build ordered list of URLs to try for protected resource metadata discovery. - - Per SEP-985, the client MUST: - 1. Try resource_metadata from WWW-Authenticate header (if present) - 2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path} - 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource + """Build the ordered list of URLs to try for protected resource metadata discovery. - Args: - www_auth_url: Optional resource_metadata URL extracted from the WWW-Authenticate header - server_url: Server URL - - Returns: - Ordered list of URLs to try for discovery + Per SEP-985: the WWW-Authenticate `resource_metadata` URL first (if present), then the + path-based well-known URI, then the root-based well-known URI (RFC 9728). """ urls: list[str] = [] - # Priority 1: WWW-Authenticate header with resource_metadata parameter if www_auth_url: urls.append(www_auth_url) - # Priority 2-3: Well-known URIs (RFC 9728) parsed = urlparse(server_url) base_url = f"{parsed.scheme}://{parsed.netloc}" - # Priority 2: Path-based well-known URI (if server has a path component) if parsed.path and parsed.path != "/": path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}") urls.append(path_based_url) - # Priority 3: Root-based well-known URI root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource") urls.append(root_based_url) @@ -104,11 +78,7 @@ def get_client_metadata_scopes( """Select effective scopes and augment for refresh token support.""" selected_scope: str | None = None - # MCP spec scope selection priority: - # 1. WWW-Authenticate header scope - # 2. PRM scopes_supported - # 3. AS scopes_supported (SDK fallback) - # 4. Omit scope parameter + # MCP spec scope priority: WWW-Authenticate scope > PRM scopes_supported > AS scopes_supported > omit if www_authenticate_scope is not None: selected_scope = www_authenticate_scope elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: @@ -135,9 +105,8 @@ def union_scopes(previous_scope: str | None, new_scope: str | None) -> str | Non """Merge two space-delimited scope strings, preserving order and dropping duplicates. SEP-2350: on step-up re-authorization the client requests the union of previously requested - scopes and the newly challenged scopes, so escalating one operation does not drop the - permissions granted for another. Previously requested scopes come first; new scopes are - appended in order. + and newly challenged scopes, so escalating one operation does not drop permissions granted + for another. """ if not previous_scope: return new_scope @@ -154,16 +123,10 @@ def union_scopes(previous_scope: str | None, new_scope: str | None) -> str | Non def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: - """Generate an ordered list of URLs for authorization server metadata discovery. - - Args: - auth_server_url: OAuth Authorization Server Metadata URL if found, otherwise None - server_url: URL for the MCP server, used as a fallback if auth_server_url is None - """ + """Generate an ordered list of URLs for authorization server metadata discovery.""" if not auth_server_url: - # Legacy path using the 2025-03-26 spec: - # link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization + # Legacy 2025-03-26 spec path: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization parsed = urlparse(server_url) return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"] @@ -171,26 +134,23 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st parsed = urlparse(auth_server_url) base_url = f"{parsed.scheme}://{parsed.netloc}" - # RFC 8414: Path-aware OAuth discovery + # RFC 8414: path-aware OAuth discovery if parsed.path and parsed.path != "/": oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" urls.append(urljoin(base_url, oauth_path)) - # RFC 8414 section 5: Path-aware OIDC discovery - # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + # RFC 8414 section 5: path-aware OIDC discovery oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" urls.append(urljoin(base_url, oidc_path)) - # https://openid.net/specs/openid-connect-discovery-1_0.html + # OIDC discovery 1.0: well-known suffix appended after the path oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration" urls.append(urljoin(base_url, oidc_path)) return urls - # OAuth root urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) - # OIDC 1.0 fallback (appends to full URL per OIDC spec) - # https://openid.net/specs/openid-connect-discovery-1_0.html + # OIDC 1.0 fallback (https://openid.net/specs/openid-connect-discovery-1_0.html) urls.append(urljoin(base_url, "/.well-known/openid-configuration")) return urls @@ -199,12 +159,9 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st async def handle_protected_resource_response( response: Response, ) -> ProtectedResourceMetadata | None: - """Handle protected resource metadata discovery response. + """Parse a protected resource metadata discovery response. - Per SEP-985, supports fallback when discovery fails at one URL. - - Returns: - ProtectedResourceMetadata if successfully discovered, None if we should try next URL + Returns None when discovery failed at this URL and the next one should be tried (SEP-985). """ if response.status_code == 200: try: @@ -213,10 +170,8 @@ async def handle_protected_resource_response( return metadata except ValidationError: # pragma: no cover - # Invalid metadata - try next URL return None else: - # Not found - try next URL in fallback chain return None @@ -236,15 +191,12 @@ async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuth def validate_authorization_response_iss(iss: str | None, oauth_metadata: OAuthMetadata | None) -> None: """Validate the RFC 9207 `iss` authorization-response parameter. - Per RFC 9207 section 2.4, the client compares `iss` against the issuer of the - authorization server the request was sent to, using simple string comparison - (RFC 3986 section 6.2.1, i.e. without URL normalization), and rejects on mismatch. - A response that omits `iss` is rejected only when the server advertised support via - `authorization_response_iss_parameter_supported`. + Per RFC 9207 section 2.4, `iss` is compared to the issuer of the authorization server the + request was sent to by simple string comparison (RFC 3986 section 6.2.1); a missing `iss` is + rejected only when the server advertised `authorization_response_iss_parameter_supported`. Raises: - OAuthFlowError: If `iss` is present and does not match, or is absent when the - authorization server advertised support. + OAuthFlowError: On mismatch, or when `iss` is absent but the server advertised support. """ expected = str(oauth_metadata.issuer) if oauth_metadata else None @@ -260,8 +212,8 @@ def validate_authorization_response_iss(iss: str | None, oauth_metadata: OAuthMe def validate_metadata_issuer(oauth_metadata: OAuthMetadata, expected_issuer: str) -> None: """Validate that authorization server metadata `issuer` matches the discovery issuer. - Per RFC 8414 section 3.3 / SEP-2468, the `issuer` in the metadata must match the issuer - used to construct the well-known URL, compared as a simple string (RFC 3986 section 6.2.1). + RFC 8414 section 3.3 / SEP-2468: compared as a simple string (RFC 3986 section 6.2.1) against + the issuer used to construct the well-known URL. Raises: OAuthFlowError: If the metadata issuer does not match `expected_issuer`. @@ -279,8 +231,6 @@ def create_oauth_metadata_request(url: str) -> Request: def create_client_registration_request( auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str ) -> Request: - """Build a client registration request.""" - if auth_server_metadata and auth_server_metadata.registration_endpoint: registration_url = str(auth_server_metadata.registration_endpoint) else: @@ -292,7 +242,6 @@ def create_client_registration_request( async def handle_registration_response(response: Response) -> OAuthClientInformationFull: - """Handle registration response.""" if response.status_code not in (200, 201): await response.aread() raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") @@ -306,16 +255,7 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma def is_valid_client_metadata_url(url: str | None) -> bool: - """Validate that a URL is suitable for use as a client_id (CIMD). - - The URL must be HTTPS with a non-root pathname. - - Args: - url: The URL to validate - - Returns: - True if the URL is a valid HTTPS URL with a non-root pathname - """ + """Whether `url` is usable as a URL-based client ID (CIMD): HTTPS with a non-root path.""" if not url: return False try: @@ -330,13 +270,11 @@ def credentials_match_issuer( ) -> bool: """Whether stored client credentials may be reused against `issuer` (SEP-2352). - A URL-based client ID (CIMD) is portable across authorization servers β€” the same self-hosted - document is resolved by whichever server is in use β€” so it always matches; CIMD is identified - by the client ID being the configured `client_metadata_url`, not by URL shape (a registration - server may also issue URL-shaped IDs that are bound to it). Credentials with a recorded issuer - match only when it equals `issuer` (simple string comparison). Credentials with no recorded - issuer (pre-registered, or stored before issuer binding existed) carry no binding to enforce - and are left as-is. + A CIMD client ID is portable across authorization servers, so it always matches; CIMD is + identified by the client ID equalling the configured `client_metadata_url`, not by URL shape + (registration servers may also issue URL-shaped IDs bound to them). A recorded issuer must + equal `issuer` (simple string comparison); credentials with no recorded issuer (pre-registered, + or stored before issuer binding existed) carry no binding to enforce. """ if client_metadata_url is not None and client_info.client_id == client_metadata_url: return True @@ -349,19 +287,7 @@ def should_use_client_metadata_url( oauth_metadata: OAuthMetadata | None, client_metadata_url: str | None, ) -> bool: - """Determine if URL-based client ID (CIMD) should be used instead of DCR. - - URL-based client IDs should be used when: - 1. The server advertises client_id_metadata_document_supported=True - 2. The client has a valid client_metadata_url configured - - Args: - oauth_metadata: OAuth authorization server metadata - client_metadata_url: URL-based client ID (already validated) - - Returns: - True if CIMD should be used, False if DCR should be used - """ + """Whether to use a URL-based client ID (CIMD) instead of dynamic client registration.""" if not client_metadata_url: return False @@ -376,16 +302,8 @@ def create_client_info_from_metadata_url( ) -> OAuthClientInformationFull: """Create client information using a URL-based client ID (CIMD). - When using URL-based client IDs, the URL itself becomes the client_id - and no client_secret is used (token_endpoint_auth_method="none"). - - Args: - client_metadata_url: The URL to use as the client_id - redirect_uris: The redirect URIs from the client metadata (passed through for - compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata) - - Returns: - OAuthClientInformationFull with the URL as client_id + The URL itself becomes the client_id and no client_secret is used + (`token_endpoint_auth_method="none"`). """ return OAuthClientInformationFull( client_id=client_metadata_url, @@ -397,18 +315,10 @@ def create_client_info_from_metadata_url( async def handle_token_response_scopes( response: Response, ) -> OAuthToken: - """Parse and validate a token response. - - Parses token response JSON. Callers should check response.status_code before calling. - - Args: - response: HTTP response from token endpoint (status already checked by caller) - - Returns: - Validated OAuthToken model + """Parse and validate a token response; callers must check `response.status_code` first. Raises: - OAuthTokenError: If response JSON is invalid + OAuthTokenError: If the response JSON is invalid. """ try: content = await response.aread() diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index d3290f3080..d59a6f454e 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -58,22 +58,20 @@ from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher ConnectMode = Literal["legacy", "auto"] | str -"""``mode=`` value: ``"legacy"`` (initialize handshake), ``"auto"`` (discover, fall back to -initialize), or a modern protocol-version string (adopt directly). The ``str`` arm is for -forward-compat; ``Client.__post_init__`` rejects anything outside that set at construction.""" +"""`mode=` value: `"legacy"` (initialize handshake), `"auto"` (discover, fall back to initialize), or a +modern protocol-version string (adopt directly); the `str` arm is forward-compat, validated in `__post_init__`.""" _T = TypeVar("_T") _ResultT = TypeVar("_ResultT") _Connector = Callable[[AsyncExitStack, ConnectMode, bool], Awaitable["Dispatcher[Any]"]] -"""Resolved at ``__post_init__`` from the shape of ``server`` alone: enter whatever resources -are needed onto the exit stack and hand back the ``Dispatcher`` ``ClientSession`` will drive. -``mode`` and ``raise_exceptions`` are passed at call time so they're read at the same moment -``__aenter__`` reads them for the handshake step.""" +"""Resolved at `__post_init__` from the shape of `server` alone: enters resources onto the exit +stack and returns the `Dispatcher` that `ClientSession` will drive. `mode` and `raise_exceptions` +are passed at call time so they're read at the same moment `__aenter__` reads them.""" def _connect_transport(transport: Transport) -> _Connector: - """Connector for the stream-backed paths (URL, user-supplied ``Transport``).""" + """Connector for the stream-backed paths (URL, user-supplied `Transport`).""" async def connect(exit_stack: AsyncExitStack, _mode: ConnectMode, _raise_exceptions: bool) -> Dispatcher[Any]: read_stream, write_stream = await exit_stack.enter_async_context(transport) @@ -83,9 +81,8 @@ async def connect(exit_stack: AsyncExitStack, _mode: ConnectMode, _raise_excepti def _connect_inproc(server: Server[Any]) -> _Connector: - """Connector for an in-process ``Server``: legacy mode drives the stream loop via - ``InMemoryTransport``; any other mode drives the modern per-request path through a - ``DirectDispatcher`` peer pair (no streams, no JSON-RPC framing, no initialize handshake).""" + """Connector for an in-process `Server`: legacy mode drives the stream loop via `InMemoryTransport`; + any other mode uses a `DirectDispatcher` peer pair (no streams, no framing, no initialize handshake).""" async def connect(exit_stack: AsyncExitStack, mode: ConnectMode, raise_exceptions: bool) -> Dispatcher[Any]: if mode == "legacy": @@ -104,11 +101,10 @@ async def connect(exit_stack: AsyncExitStack, mode: ConnectMode, raise_exception def _connected(value: _T | None) -> _T: - """Narrow a post-handshake session attribute from ``T | None`` to ``T``. + """Narrow a post-handshake session attribute from `T | None` to `T`. - ``Client.__aenter__`` only assigns ``_session`` after the handshake succeeds, so inside - ``async with Client(...)`` these attributes are always populated; the ``.session`` gate - raises before this is reached otherwise. The guard exists for pyright, not runtime. + `Client.__aenter__` assigns `_session` only after the handshake succeeds, so inside + `async with` these attributes are always populated; the guard exists for pyright, not runtime. """ if value is None: # pragma: no cover raise RuntimeError("Client must be used within an async context manager") @@ -127,13 +123,10 @@ def _synthesize_discover(protocol_version: str) -> types.DiscoverResult: async def _no_inbound_client_notifications(_dctx: Any, _method: str, _params: Mapping[str, Any] | None) -> None: - """Server-side inbound ``OnNotify`` for the modern in-process path β€” receives nothing. + """Server-side inbound `OnNotify` for the modern in-process path β€” receives nothing. - At 2026-07-28 the spec defines no clientβ†’server notifications: ``initialized`` and - ``roots/list_changed`` are removed, and cancellation is structural (anyio scope cancel - through the direct await, not a notify). Serverβ†’client notifications (progress, log - messages) flow the other way via the per-request ``DispatchContext`` into the client's - callbacks, and are not seen here. + At 2026-07-28 the spec defines no client-to-server notifications: `initialized` and + `roots/list_changed` are removed, and cancellation is structural (anyio scope cancel). """ @@ -164,12 +157,8 @@ async def main(): """ server: Server[Any] | MCPServer | Transport | str - """The MCP server to connect to. - - If the server is a `Server` or `MCPServer` instance, it will be connected in-process. - If the server is a URL string, it will be used as the URL for a `streamable_http_client` transport. - If the server is a `Transport` instance, it will be used directly. - """ + """The server to connect to: a `Server`/`MCPServer` runs in-process, a URL string becomes a + `streamable_http_client` transport, and a `Transport` is used directly.""" _: KW_ONLY @@ -199,15 +188,14 @@ async def main(): mode: ConnectMode = "auto" """How to negotiate the protocol version. - 'auto' (the default) probes `server/discover` and falls back to the initialize handshake on legacy servers; - for an in-process `Server`/`MCPServer` it dispatches directly without JSON-RPC framing. 'legacy' forces the - initialize handshake (byte-identical pre-2026 behavior). A modern protocol-version string (e.g. '2026-07-28') - adopts that version directly without a probe β€” supply `prior_discover` to reuse a known DiscoverResult, or - omit it to synthesize a minimal one.""" + 'auto' (default) probes `server/discover`, falling back to the initialize handshake on legacy + servers (an in-process `Server`/`MCPServer` dispatches directly, no framing). 'legacy' forces the + handshake (byte-identical pre-2026 behavior). A modern protocol-version string (e.g. '2026-07-28') + adopts that version directly without a probe.""" prior_discover: types.DiscoverResult | None = None - """A previously-obtained DiscoverResult to install via .adopt() when mode is a version pin. - Ignored when mode='legacy'.""" + """A previously-obtained DiscoverResult to install via `.adopt()` when mode is a version pin + (when omitted, a minimal one is synthesized). Ignored when mode='legacy'.""" elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" @@ -263,7 +251,6 @@ async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: ) async def __aenter__(self) -> Client: - """Enter the async context manager.""" if self._entered: raise RuntimeError("Client is already entered; cannot reenter") self._entered = True @@ -279,24 +266,20 @@ async def __aenter__(self) -> Client: else: session.adopt(self.prior_discover or _synthesize_discover(self.mode)) - # Only publish the session after the handshake succeeds, so `_session is not None` - # implies the protocol_version/server_info/server_capabilities are populated. If the - # handshake raised above, the local exit_stack unwinds the transport for us. + # Publish only after the handshake succeeds: `_session is not None` implies the + # negotiated attributes are populated; if it raised, the local exit_stack unwinds. self._session = session self._exit_stack = exit_stack.pop_all() return self async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: - """Exit the async context manager.""" if self._exit_stack: # pragma: no branch await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) self._session = None @property def session(self) -> ClientSession: - """Get the underlying ClientSession. - - This provides access to the full ClientSession API for advanced use cases. + """The underlying ClientSession, for advanced use cases. Raises: RuntimeError: If accessed before entering the context manager. @@ -305,23 +288,22 @@ def session(self) -> ClientSession: raise RuntimeError("Client must be used within an async context manager") return self._session - # TODO(maxisbey): the by-construction shape is for __aenter__ to return a connected-view - # type whose protocol_version/server_info/server_capabilities are non-Optional fields, - # eliminating these guards (and the one in .session). Same family as resolving the - # transport/connector at __post_init__ so the Optional internal fields disappear. + # TODO(maxisbey): the by-construction shape is for __aenter__ to return a connected-view type + # with non-Optional protocol_version/server_info/server_capabilities, eliminating these guards + # and the .session gate β€” same family as resolving the connector at __post_init__. @property def protocol_version(self) -> str: - """Negotiated protocol version (set by initialize/discover/adopt during ``__aenter__``).""" + """Negotiated protocol version (set by initialize/discover/adopt during `__aenter__`).""" return _connected(self.session.protocol_version) @property def server_info(self) -> Implementation: - """Server name/version (set by initialize/discover/adopt during ``__aenter__``).""" + """Server name/version (set by initialize/discover/adopt during `__aenter__`).""" return _connected(self.session.server_info) @property def server_capabilities(self) -> ServerCapabilities: - """Server capabilities (set by initialize/discover/adopt during ``__aenter__``).""" + """Server capabilities (set by initialize/discover/adopt during `__aenter__`).""" return _connected(self.session.server_capabilities) @property @@ -389,20 +371,10 @@ async def read_resource( ) -> ReadResourceResult: """Read a resource from the server. - If the server returns an `InputRequiredResult`, the embedded input - requests are dispatched to this client's sampling / elicitation / roots - callbacks and the read is retried automatically (up to - `input_required_max_rounds`). - - Args: - uri: The URI of the resource to read. - input_responses: Responses to seed the first call with (e.g. when - resuming from a persisted `InputRequiredResult`). - request_state: Opaque state to seed the first call with. - meta: Additional metadata for the request. - - Returns: - The resource content. + If the server returns an `InputRequiredResult`, the embedded input requests are dispatched + to this client's sampling / elicitation / roots callbacks and the read is retried + automatically (up to `input_required_max_rounds`). Pass `input_responses` / `request_state` + to seed the first call, e.g. when resuming from a persisted `InputRequiredResult`. Raises: InputRequiredRoundsExceededError: `input_required_max_rounds` exhausted. @@ -437,26 +409,14 @@ async def call_tool( ) -> CallToolResult: """Call a tool on the server. - If the server returns an `InputRequiredResult`, the embedded input - requests are dispatched to this client's sampling / elicitation / roots - callbacks and the call is retried automatically (up to - `input_required_max_rounds`). To drive the loop yourself β€” e.g. to - persist `request_state` across process restarts β€” use + If the server returns an `InputRequiredResult`, the embedded input requests are dispatched + to this client's sampling / elicitation / roots callbacks and the call is retried + automatically (up to `input_required_max_rounds`); `read_timeout_seconds` bounds each + underlying `tools/call` round. Pass `input_responses` / `request_state` to seed the first + call, e.g. when resuming from a persisted `InputRequiredResult`. To drive the loop yourself + (e.g. persisting `request_state` across process restarts), use `client.session.call_tool(..., allow_input_required=True)`. - Args: - name: The name of the tool to call. - arguments: Arguments to pass to the tool. - read_timeout_seconds: Timeout for each underlying `tools/call` round. - progress_callback: Callback for progress updates. - input_responses: Responses to seed the first call with (e.g. when - resuming from a persisted `InputRequiredResult`). - request_state: Opaque state to seed the first call with. - meta: Additional metadata for the request. - - Returns: - The tool result. - Raises: InputRequiredRoundsExceededError: `input_required_max_rounds` exhausted. MCPError: A callback returned `ErrorData` for an embedded input request. @@ -496,21 +456,10 @@ async def get_prompt( ) -> GetPromptResult: """Get a prompt from the server. - If the server returns an `InputRequiredResult`, the embedded input - requests are dispatched to this client's sampling / elicitation / roots - callbacks and the get is retried automatically (up to - `input_required_max_rounds`). - - Args: - name: The name of the prompt. - arguments: Arguments to pass to the prompt. - input_responses: Responses to seed the first call with (e.g. when - resuming from a persisted `InputRequiredResult`). - request_state: Opaque state to seed the first call with. - meta: Additional metadata for the request. - - Returns: - The prompt content. + If the server returns an `InputRequiredResult`, the embedded input requests are dispatched + to this client's sampling / elicitation / roots callbacks and the get is retried + automatically (up to `input_required_max_rounds`). Pass `input_responses` / `request_state` + to seed the first call, e.g. when resuming from a persisted `InputRequiredResult`. Raises: InputRequiredRoundsExceededError: `input_required_max_rounds` exhausted. @@ -531,9 +480,8 @@ async def _drive_input_required( ) -> _ResultT: """Hand an `InputRequiredResult` to the SEP-2322 driver, or pass a terminal result through. - `dispatch` routes each embedded request through the same callback table - that serves legacy serverβ†’client RPCs, so the two paths stay - behaviourally identical by construction. + `dispatch` routes each embedded request through the same callback table that serves legacy + server-to-client RPCs, so the two paths stay behaviourally identical by construction. """ if not isinstance(first, InputRequiredResult): return first @@ -553,16 +501,7 @@ async def complete( argument: dict[str, str], context_arguments: dict[str, str] | None = None, ) -> CompleteResult: - """Get completions for a prompt or resource template argument. - - Args: - ref: Reference to the prompt or resource template - argument: The argument to complete - context_arguments: Additional context arguments - - Returns: - Completion suggestions. - """ + """Get completions for a prompt or resource template argument.""" return await self.session.complete(ref=ref, argument=argument, context_arguments=context_arguments) async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta | None = None) -> ListToolsResult: diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3cebb569ec..4644e01f99 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -203,12 +203,11 @@ def _input_required_unexpected(method: str) -> RuntimeError: class ClientSession: """Client half of an MCP connection, running on a `Dispatcher`. - Construct it over a transport's stream pair (or pass a pre-built - `dispatcher=`), enter as an async context manager, then call - `initialize()`. The dispatcher owns the receive loop and request - correlation; this class owns the typed MCP layer and the constructor - callbacks. Transport `Exception` items reach `message_handler` only when - the session builds its own dispatcher from a stream pair. + Construct it over a transport's stream pair (or pass a pre-built `dispatcher=`), + enter as an async context manager, then call `initialize()`. The dispatcher owns + the receive loop and request correlation; this class owns the typed MCP layer and + the constructor callbacks. Transport `Exception` items reach `message_handler` + only when the session builds its own dispatcher from a stream pair. """ def __init__( @@ -248,12 +247,10 @@ def __init__( raise ValueError("pass read_stream/write_stream or dispatcher, not both") self._dispatcher: Dispatcher[Any] = dispatcher if isinstance(dispatcher, JSONRPCDispatcher) and dispatcher.on_stream_exception is None: - # Route transport-level Exception items into message_handler β€” only - # stream-backed dispatchers carry these; DirectDispatcher has none. - # Don't clobber a caller-supplied hook. - # TODO(L78): this leaves a bound-method ref on the dispatcher after the - # session exits (memory pin) and a second wrap of the same dispatcher would - # skip install. The Transport-as-Dispatcher rework (L77) removes this seam. + # Route transport Exception items into message_handler; only stream-backed + # dispatchers carry these. Don't clobber a caller-supplied hook. + # TODO(L78): leaves a bound-method ref after session exit and double-wrap + # skips install; the Transport-as-Dispatcher rework (L77) removes this seam. dispatcher.on_stream_exception = self._on_stream_exception else: if read_stream is None or write_stream is None: @@ -269,9 +266,8 @@ async def __aenter__(self) -> Self: try: await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) except BaseException: - # Unwind the entered task group before propagating: a cancellation - # landing here (e.g. `move_on_after` around connect) would abandon - # it and anyio would later raise "exited non-innermost cancel scope". + # Unwind the entered task group before propagating: abandoning it (e.g. a + # `move_on_after` cancel landing here) makes anyio raise "exited non-innermost cancel scope". task_group = self._task_group self._task_group = None task_group.cancel_scope.cancel() @@ -344,8 +340,7 @@ async def send_request( async def send_notification(self, notification: types.ClientNotification) -> None: """Send a one-way notification. Usable before entering the context manager. - Fire-and-forget: after the connection has closed, the notification is - dropped with a debug log instead of raising. + Fire-and-forget: after the connection closes, it is dropped with a debug log instead of raising. """ data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) opts: CallOptions = {} @@ -364,9 +359,7 @@ def _build_capabilities(self) -> types.ClientCapabilities: else None ) roots = ( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? + # TODO: base this on whether we _will_ send notifications, or only whether they're supported? types.RootsCapability(list_changed=True) if self._list_roots_callback is not _default_list_roots_callback else None @@ -401,8 +394,7 @@ async def initialize(self) -> types.InitializeResult: def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: """Install negotiated state from a result the caller already holds (no wire traffic). - Clears the opposite slot, so at most one of `initialize_result` / - `discover_result` is ever non-None. + Clears the opposite slot, so at most one of `initialize_result` / `discover_result` is ever non-None. Raises: RuntimeError: `result` is a `DiscoverResult` whose `supported_versions` @@ -429,17 +421,15 @@ def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: self._negotiated_version = result.protocol_version async def send_discover(self, version: str) -> dict[str, Any]: - """Send a single ``server/discover`` at ``version`` and return the raw result dict. + """Send a single `server/discover` at `version` and return the raw result dict. - No retry, no ``adopt()``. The ``_meta`` envelope and the - ``Mcp-Protocol-Version`` header are stamped at ``version`` so the - server-side era router sees a coherent probe. Used by ``discover()`` and - the connect-time auto-negotiation policy. + No retry, no `adopt()`; used by `discover()` and connect-time auto-negotiation. + The `_meta` envelope and `Mcp-Protocol-Version` header are stamped at `version` + so the server-side era router sees a coherent probe. Raises: - MCPError: The server returned a JSON-RPC error, or the transport - bounced the request at its own layer (a bare HTTP 4xx is - synthesized into a JSON-RPC error by the transport). + MCPError: JSON-RPC error response, or a transport-layer bounce (a bare + HTTP 4xx is synthesized into a JSON-RPC error by the transport). """ client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) @@ -463,18 +453,15 @@ async def send_discover(self, version: str) -> dict[str, Any]: async def discover(self) -> types.DiscoverResult: """Probe `server/discover` and adopt the result. - Sends a single `server/discover` proposing the newest modern protocol - version. On `UNSUPPORTED_PROTOCOL_VERSION` (-32022) the server's - `supported` list is intersected with `MODERN_PROTOCOL_VERSIONS` and the - probe is retried once at the highest mutual version. Any other error β€” - including `METHOD_NOT_FOUND` (-32601) and `REQUEST_TIMEOUT` (-32001) β€” - propagates; the legacy `initialize()` fallback is the caller's policy. + Proposes the newest modern protocol version; on `UNSUPPORTED_PROTOCOL_VERSION` + (-32022) retries once at the highest mutual version. Any other error β€” + including `METHOD_NOT_FOUND` and `REQUEST_TIMEOUT` β€” propagates; the legacy + `initialize()` fallback is the caller's policy. Raises: - MCPError: The server rejected `server/discover`, the probe timed - out, or the -32022 retry found no mutual version / failed again. - RuntimeError: `adopt()` found no mutual version in the returned - `supported_versions`. + MCPError: The probe was rejected or timed out, or the -32022 retry + found no mutual version / failed again. + RuntimeError: `adopt()` found no mutual version in `supported_versions`. """ if self._discover_result is not None: return self._discover_result @@ -507,8 +494,7 @@ def initialize_result(self) -> types.InitializeResult | None: def discover_result(self) -> types.DiscoverResult | None: """The server's DiscoverResult. None unless `discover()` ran (or was adopted). - Retained intact (supported_versions, ttl_ms, cache_scope) so callers - can round-trip it as ``prior_discover=``. + Retained intact (supported_versions, ttl_ms, cache_scope) so callers can round-trip it as `prior_discover=`. """ return self._discover_result @@ -588,21 +574,13 @@ async def set_logging_level( ) async def list_resources(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListResourcesResult: - """Send a resources/list request. - - Args: - params: Full pagination parameters including cursor and any future fields - """ + """Send a resources/list request.""" return await self.send_request(types.ListResourcesRequest(params=params), types.ListResourcesResult) async def list_resource_templates( self, *, params: types.PaginatedRequestParams | None = None ) -> types.ListResourceTemplatesResult: - """Send a resources/templates/list request. - - Args: - params: Full pagination parameters including cursor and any future fields - """ + """Send a resources/templates/list request.""" return await self.send_request( types.ListResourceTemplatesRequest(params=params), types.ListResourceTemplatesResult, @@ -644,13 +622,11 @@ async def read_resource( Args: input_responses: Responses to a prior `InputRequiredResult.input_requests`. request_state: Opaque state echoed from a prior `InputRequiredResult`. - allow_input_required: When `False` (default), an `InputRequiredResult` - from the server raises `RuntimeError`; when `True`, it is returned + allow_input_required: Return an `InputRequiredResult` instead of raising, so the caller can resolve the requests and retry. Raises: - RuntimeError: If the server returns an `InputRequiredResult` and - `allow_input_required` is `False`. + RuntimeError: The server returned an `InputRequiredResult` with `allow_input_required` False. """ result = await self.send_request( types.ReadResourceRequest( @@ -723,21 +699,18 @@ async def call_tool( ) -> types.CallToolResult | types.InputRequiredResult: """Send a tools/call request with optional progress callback support. - On a modern (2026-07-28) connection, arguments annotated with `x-mcp-header` - in the tool's input schema are mirrored into `Mcp-Param-*` request headers. - The annotations are read from the tool's last `list_tools` entry, so list - the tool before calling it to enable header emission. + On a modern (2026-07-28) connection, arguments annotated with `x-mcp-header` in + the tool's input schema are mirrored into `Mcp-Param-*` request headers. The + annotations come from the tool's last `list_tools` entry, so list before calling. Args: input_responses: Responses to a prior `InputRequiredResult.input_requests`. request_state: Opaque state echoed from a prior `InputRequiredResult`. - allow_input_required: When ``False`` (default), an `InputRequiredResult` - from the server raises `RuntimeError`; when ``True``, it is returned + allow_input_required: Return an `InputRequiredResult` instead of raising, so the caller can resolve the requests and retry. Raises: - RuntimeError: If the server returns an `InputRequiredResult` and - ``allow_input_required`` is ``False``. + RuntimeError: The server returned an `InputRequiredResult` with `allow_input_required` False. """ result = await self.send_request( types.CallToolRequest( @@ -793,11 +766,7 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover async def list_prompts(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListPromptsResult: - """Send a prompts/list request. - - Args: - params: Full pagination parameters including cursor and any future fields - """ + """Send a prompts/list request.""" return await self.send_request(types.ListPromptsRequest(params=params), types.ListPromptsResult) @overload @@ -839,13 +808,11 @@ async def get_prompt( Args: input_responses: Responses to a prior `InputRequiredResult.input_requests`. request_state: Opaque state echoed from a prior `InputRequiredResult`. - allow_input_required: When `False` (default), an `InputRequiredResult` - from the server raises `RuntimeError`; when `True`, it is returned + allow_input_required: Return an `InputRequiredResult` instead of raising, so the caller can resolve the requests and retry. Raises: - RuntimeError: If the server returns an `InputRequiredResult` and - `allow_input_required` is `False`. + RuntimeError: The server returned an `InputRequiredResult` with `allow_input_required` False. """ result = await self.send_request( types.GetPromptRequest( @@ -886,11 +853,7 @@ async def complete( ) async def list_tools(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListToolsResult: - """Send a tools/list request. - - Args: - params: Full pagination parameters including cursor and any future fields - """ + """Send a tools/list request.""" result = await self.send_request( types.ListToolsRequest(params=params), types.ListToolsResult, @@ -902,17 +865,14 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None for tool in result.tools: if (reason := find_invalid_x_mcp_header(tool.input_schema)) is not None: logger.warning("dropping tool %r: invalid x-mcp-header (%s)", tool.name, reason) - # Evict any map cached from a prior valid listing so a stale entry can't - # mirror headers for a tool this listing dropped. + # Evict any map cached from a prior valid listing so it can't mirror headers for a dropped tool. self._x_mcp_header_maps.pop(tool.name, None) continue - # Cache the argβ†’header map so a later tools/call mirrors it into Mcp-Param-* headers. self._x_mcp_header_maps[tool.name] = x_mcp_header_map(tool.input_schema) kept.append(tool) result.tools = kept - # Cache tool output schemas for future validation - # Note: don't clear the cache, as we may be using a cursor + # Don't clear the output-schema cache: this may be one cursor page of a longer listing. for tool in result.tools: self._tool_output_schemas[tool.name] = tool.output_schema @@ -927,9 +887,8 @@ async def _on_request( self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: """Answer a server-initiated request via the registered callbacks.""" - # Literal, not LATEST_PROTOCOL_VERSION: the fallback covers the initialize - # handshake (which only exists at <=2025) and stateless until the header - # is plumbed; its meaning is fixed regardless of LATEST bumps. + # Literal, not LATEST_PROTOCOL_VERSION: the fallback covers the initialize handshake + # (<=2025 only) and stateless until the header is plumbed; fixed regardless of LATEST bumps. version = self._negotiated_version or "2025-11-25" try: request = cast(types.ServerRequest, _methods.parse_server_request(method, version, params)) @@ -962,9 +921,8 @@ async def _dispatch_input_request( ) -> types.InputResponse | types.ErrorData: """Route a server-initiated input request to the matching constructor callback. - Shared by the legacy serverβ†’client RPC path (`_on_request`) and the - 2026-07-28 multi-round-trip driver, which dispatches the embedded - `InputRequiredResult.input_requests` through the same callbacks. + Shared by the legacy serverβ†’client RPC path (`_on_request`) and the 2026-07-28 + multi-round-trip driver, which feeds `InputRequiredResult.input_requests` here too. """ match req: case types.CreateMessageRequest(params=p): @@ -996,17 +954,15 @@ async def _on_notify( await self._logging_callback(notification.params) await self._message_handler(notification) except Exception: - # Contain here, not in the dispatcher: DirectDispatcher awaits this - # handler inline in the peer's notify() call, so a raising callback - # would otherwise fail the peer's send. A raising logging_callback - # skips the message_handler tee for that notification (v1 parity). + # Contain here, not in the dispatcher: DirectDispatcher awaits this handler + # inline in the peer's notify(), so raising would fail the peer's send. A raising + # logging_callback skips the message_handler tee for that notification (v1 parity). logger.exception("notification callback for %r raised", method) async def _on_stream_exception(self, exc: Exception) -> None: """Deliver a transport-level fault to message_handler via a spawned task. - Running the handler inline would park the dispatcher's read loop and - deadlock handlers that await session I/O. + Running it inline would park the dispatcher's read loop and deadlock handlers that await session I/O. """ assert self._task_group is not None self._task_group.start_soon(self._deliver_stream_exception, exc) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 40f0232594..1941d91dc6 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -1,10 +1,4 @@ -"""SessionGroup concurrently manages multiple MCP session connections. - -Tools, resources, and prompts are aggregated across servers. Servers may -be connected to or disconnected from at any point after initialization. - -This abstraction can handle naming collisions using a custom user-provided hook. -""" +"""Manage concurrent sessions to multiple MCP servers, aggregating their tools, resources, and prompts.""" import contextlib import logging @@ -32,43 +26,29 @@ class SseServerParameters(BaseModel): """Parameters for initializing an sse_client.""" - # The endpoint URL. url: str - - # Optional headers to include in requests. headers: dict[str, Any] | None = None - - # HTTP timeout for regular operations (in seconds). + # Timeouts in seconds: `timeout` for regular HTTP operations, `sse_read_timeout` for SSE reads. timeout: float = 5.0 - - # Timeout for SSE read operations (in seconds). sse_read_timeout: float = 300.0 class StreamableHttpParameters(BaseModel): """Parameters for initializing a streamable_http_client.""" - # The endpoint URL. url: str - - # Optional headers to include in requests. headers: dict[str, Any] | None = None - - # HTTP timeout for regular operations (in seconds). + # Timeouts in seconds: `timeout` for regular HTTP operations, `sse_read_timeout` for SSE reads. timeout: float = 30.0 - - # Timeout for SSE read operations (in seconds). sse_read_timeout: float = 300.0 - - # Close the client session when the transport closes. + # Terminate the server session when the transport closes. terminate_on_close: bool = True ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters -# Use dataclass instead of Pydantic BaseModel -# because Pydantic BaseModel cannot handle Protocol fields. +# Dataclass rather than pydantic BaseModel: pydantic cannot handle Protocol-typed fields. @dataclass class ClientSessionParameters: """Parameters for establishing a client session to an MCP server.""" @@ -83,13 +63,10 @@ class ClientSessionParameters: class ClientSessionGroup: - """Client for managing connections to multiple MCP servers. + """Manages connections to multiple MCP servers, aggregating their tools, resources, and prompts. - This class is responsible for encapsulating management of server connections. - It aggregates tools, resources, and prompts from all connected servers. - - For auxiliary handlers, such as resource subscription, this is delegated to - the client and can be accessed via the session. + Auxiliary operations such as resource subscription are performed through the + individual sessions. Example: ```python @@ -102,26 +79,22 @@ class ClientSessionGroup: """ class _ComponentNames(BaseModel): - """Used for reverse index to find components.""" + """Names of the components owned by a single session.""" prompts: set[str] = Field(default_factory=set) resources: set[str] = Field(default_factory=set) tools: set[str] = Field(default_factory=set) - # Standard MCP components. _prompts: dict[str, types.Prompt] _resources: dict[str, types.Resource] _tools: dict[str, types.Tool] - # Client-server connection management. _sessions: dict[mcp.ClientSession, _ComponentNames] _tool_to_session: dict[str, mcp.ClientSession] _exit_stack: contextlib.AsyncExitStack _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] - # Optional fn consuming (component_name, server_info) for custom names. - # This is to provide a means to mitigate naming conflicts across servers. - # Example: (tool_name, server_info) => "{result.server_info.name}.{tool_name}" + # Optional hook mapping (component_name, server_info) to a custom name, to avoid collisions across servers. _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] _component_name_hook: _ComponentNameHook | None @@ -130,8 +103,6 @@ def __init__( exit_stack: contextlib.AsyncExitStack | None = None, component_name_hook: _ComponentNameHook | None = None, ) -> None: - """Initializes the MCP client.""" - self._tools = {} self._resources = {} self._prompts = {} @@ -148,7 +119,6 @@ def __init__( self._component_name_hook = component_name_hook async def __aenter__(self) -> Self: # pragma: no cover - # Enter the exit stack only if we created it ourselves if self._owns_exit_stack: await self._exit_stack.__aenter__() return self @@ -159,35 +129,31 @@ async def __aexit__( _exc_val: BaseException | None, _exc_tb: TracebackType | None, ) -> bool | None: # pragma: no cover - """Closes session exit stacks and main exit stack upon completion.""" - - # Only close the main exit stack if we created it if self._owns_exit_stack: await self._exit_stack.aclose() - # Concurrently close session stacks. async with anyio.create_task_group() as tg: for exit_stack in self._session_exit_stacks.values(): tg.start_soon(exit_stack.aclose) @property def sessions(self) -> list[mcp.ClientSession]: - """Returns the list of sessions being managed.""" + """The list of managed sessions.""" return list(self._sessions.keys()) # pragma: no cover @property def prompts(self) -> dict[str, types.Prompt]: - """Returns the prompts as a dictionary of names to prompts.""" + """Prompts aggregated from all servers, keyed by name.""" return self._prompts @property def resources(self) -> dict[str, types.Resource]: - """Returns the resources as a dictionary of names to resources.""" + """Resources aggregated from all servers, keyed by name.""" return self._resources @property def tools(self) -> dict[str, types.Tool]: - """Returns the tools as a dictionary of names to tools.""" + """Tools aggregated from all servers, keyed by name.""" return self._tools @overload @@ -233,8 +199,7 @@ async def call_tool( """Executes a tool given its name and arguments. Raises: - RuntimeError: If the server returns an `InputRequiredResult` and - ``allow_input_required`` is ``False``. + RuntimeError: If the server returns an `InputRequiredResult` and `allow_input_required` is `False`. """ session = self._tool_to_session[name] session_tool_name = self.tools[name].name @@ -262,24 +227,20 @@ async def disconnect_from_server(self, session: mcp.ClientSession) -> None: ) if session_known_for_components: # pragma: no branch - component_names = self._sessions.pop(session) # Pop from _sessions tracking + component_names = self._sessions.pop(session) - # Remove prompts associated with the session. for name in component_names.prompts: if name in self._prompts: # pragma: no branch del self._prompts[name] - # Remove resources associated with the session. for name in component_names.resources: if name in self._resources: # pragma: no branch del self._resources[name] - # Remove tools associated with the session. for name in component_names.tools: if name in self._tools: # pragma: no branch del self._tools[name] if name in self._tool_to_session: # pragma: no branch del self._tool_to_session[name] - # Clean up the session's resources via its dedicated exit stack if session_known_for_stack: session_stack_to_close = self._session_exit_stacks.pop(session) # pragma: no cover await session_stack_to_close.aclose() # pragma: no cover @@ -287,7 +248,7 @@ async def disconnect_from_server(self, session: mcp.ClientSession) -> None: async def connect_with_session( self, server_info: types.Implementation, session: mcp.ClientSession ) -> mcp.ClientSession: - """Connects to a single MCP server.""" + """Adds an already-established session to the group and aggregates its components.""" await self._aggregate_components(server_info, session) return session @@ -296,7 +257,7 @@ async def connect_to_server( server_params: ServerParameters, session_params: ClientSessionParameters | None = None, ) -> mcp.ClientSession: - """Connects to a single MCP server.""" + """Connects to a single MCP server and aggregates its components.""" server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters()) return await self.connect_with_session(server_info, session) @@ -309,7 +270,6 @@ async def _establish_session( session_stack = contextlib.AsyncExitStack() try: - # Create read and write streams that facilitate io with the server. if isinstance(server_params, StdioServerParameters): client = mcp.stdio_client(server_params) read, write = await session_stack.enter_async_context(client) @@ -354,36 +314,27 @@ async def _establish_session( result = await session.initialize() - # Session successfully initialized. - # Store its stack and register the stack with the main group stack. + # The session stack itself becomes a resource managed by the group's exit stack. self._session_exit_stacks[session] = session_stack - # session_stack itself becomes a resource managed by the - # main _exit_stack. await self._exit_stack.enter_async_context(session_stack) return result.server_info, session except Exception: # pragma: no cover - # If anything during this setup fails, ensure the session-specific - # stack is closed. await session_stack.aclose() raise async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: """Aggregates prompts, resources, and tools from a given session.""" - # Create a reverse index so we can find all prompts, resources, and - # tools belonging to this session. Used for removing components from - # the session group via self.disconnect_from_server. + # Reverse index used by disconnect_from_server to remove this session's components. component_names = self._ComponentNames() - # Temporary components dicts. We do not want to modify the aggregate - # lists in case of an intermediate failure. + # Stage into temporary dicts so an intermediate failure leaves the group state untouched. prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} - # Query the server for its prompts and aggregate to list. try: prompts = (await session.list_prompts()).prompts for prompt in prompts: @@ -393,7 +344,6 @@ async def _aggregate_components(self, server_info: types.Implementation, session except MCPError as err: # pragma: no cover logging.warning(f"Could not fetch prompts: {err}") - # Query the server for its resources and aggregate to list. try: resources = (await session.list_resources()).resources for resource in resources: @@ -403,7 +353,6 @@ async def _aggregate_components(self, server_info: types.Implementation, session except MCPError as err: # pragma: no cover logging.warning(f"Could not fetch resources: {err}") - # Query the server for its tools and aggregate to list. try: tools = (await session.list_tools()).tools for tool in tools: @@ -414,12 +363,9 @@ async def _aggregate_components(self, server_info: types.Implementation, session except MCPError as err: # pragma: no cover logging.warning(f"Could not fetch tools: {err}") - # Clean up exit stack for session if we couldn't retrieve anything - # from the server. if not any((prompts_temp, resources_temp, tools_temp)): del self._session_exit_stacks[session] # pragma: no cover - # Check for duplicates. matching_prompts = prompts_temp.keys() & self._prompts.keys() if matching_prompts: raise MCPError( # pragma: no cover @@ -436,7 +382,6 @@ async def _aggregate_components(self, server_info: types.Implementation, session if matching_tools: raise MCPError(code=types.INVALID_PARAMS, message=f"{matching_tools} already exist in group tools.") - # Aggregate components. self._sessions[session] = component_names self._prompts.update(prompts_temp) self._resources.update(resources_temp) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 8b482932aa..04e4d21ec1 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -39,16 +39,10 @@ async def sse_client( ): """Client transport for SSE. - `sse_read_timeout` determines how long (in seconds) the client will wait for a new - event before disconnecting. All other HTTP operations are controlled by `timeout`. + `sse_read_timeout` is how long (in seconds) to wait for a new SSE event before + disconnecting; all other HTTP operations are governed by `timeout`. Args: - url: The SSE endpoint URL. - headers: Optional headers to include in requests. - timeout: HTTP timeout for regular operations (in seconds). - sse_read_timeout: Timeout for SSE read operations (in seconds). - httpx_client_factory: Factory function for creating the HTTPX client. - auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. """ logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") @@ -142,9 +136,8 @@ async def _send_message(session_message: SessionMessage) -> None: except Exception: # pragma: lax no cover logger.exception("Error in post_writer") - # On Python 3.14, coverage.py reports a phantom branch arc on this - # line (->yield) when nested two async-with levels deep. The branch - # is the unreachable "did __aexit__ suppress?" arm for memory streams. + # coverage.py on 3.14 reports a phantom ->yield branch arc here (the unreachable + # "did __aexit__ suppress?" arm) when nested two async-with levels deep. async with ( # pragma: no branch read_stream_writer, read_stream, diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 3e03eef9ef..3b8693d299 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -1,11 +1,7 @@ -"""stdio client transport. - -Runs an MCP server as a subprocess and exchanges newline-delimited JSON-RPC -messages with it over stdin/stdout. Two pipe tasks bridge the server's pipes -to the session's in-memory streams; shutdown follows the MCP spec sequence -(close stdin, wait, then kill the process tree) inside a cancellation shield -with every wait bounded, so a cancelled caller can neither leak a live server -process nor hang on one. +"""stdio client transport: runs an MCP server as a subprocess, speaking newline-delimited JSON-RPC. + +Shutdown (close stdin, wait, then kill the process tree) runs inside a cancellation shield with +every wait bounded, so a cancelled caller can neither leak a live server process nor hang on one. """ import logging @@ -36,7 +32,6 @@ logger = logging.getLogger(__name__) -# Environment variables to inherit by default DEFAULT_INHERITED_ENV_VARS = ( [ "APPDATA", @@ -130,8 +125,7 @@ async def stdio_client( cwd=server.cwd, ) - # The spawn succeeded; no awaits until the task group is entered, or a - # cancellation delivered in the gap would leak the live process. + # No awaits between spawn and task-group entry: a cancellation in the gap would leak the live process. read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) @@ -182,11 +176,9 @@ async def stdin_writer() -> None: writer_done.set() async def shutdown() -> None: - """Winds the transport down: stop traffic, flush, stop the server, release the streams.""" # Unblock the reader into its drain: a server stuck writing stdout cannot # read its stdin, so draining is what lets the flush below complete. read_stream.close() - # Bounded window for the writer to flush already-accepted messages. write_stream.close() with anyio.move_on_after(_WRITER_FLUSH_TIMEOUT) as flush_scope: await writer_done.wait() @@ -226,11 +218,10 @@ def _parse_line(line: str) -> SessionMessage | Exception: async def _drain_stdout(process: ServerProcess) -> None: - """Consumes and discards the server's remaining stdout. + """Discards the server's remaining stdout. - Keeps a server flushing buffered output from blocking on a full pipe and - missing its chance to exit; shielded, raw bytes, ends when shutdown closes - the pipe. + A server flushing buffered output would otherwise block on a full pipe and miss its + chance to exit; runs shielded until shutdown closes the pipe. """ assert process.stdout with anyio.CancelScope(shield=True): @@ -268,7 +259,6 @@ async def _stop_server_process(process: ServerProcess) -> None: async def _close_pipe(stream: AsyncResource) -> None: - """Closes a pipe stream, tolerating one already closed, broken, or contended.""" with suppress(OSError, anyio.BrokenResourceError, anyio.ClosedResourceError): await stream.aclose() @@ -276,8 +266,8 @@ async def _close_pipe(stream: AsyncResource) -> None: async def _wait_for_process_exit(process: ServerProcess, timeout: float) -> bool: """Returns whether the process died within the timeout, by polling returncode. - Not process.wait(): on asyncio 3.11+ it also waits for pipe EOF, and a - child that inherited the pipes makes an exited server look hung. + Not process.wait(): on asyncio 3.11+ it also waits for pipe EOF, and a child that + inherited the pipes makes an exited server look hung. """ deadline = anyio.current_time() + timeout while process.returncode is None: @@ -288,11 +278,7 @@ async def _wait_for_process_exit(process: ServerProcess, timeout: float) -> bool async def _terminate_process_tree(process: ServerProcess) -> None: - """Kills the process and all its descendants. - - POSIX: SIGTERM to the process group, SIGKILL after FORCE_KILL_TIMEOUT. - Windows: immediate Job Object termination (already a hard kill). - """ + """Kills the process and all its descendants.""" if sys.platform == "win32": # pragma: no cover await terminate_windows_process_tree(process) else: # pragma: lax no cover @@ -304,9 +290,9 @@ async def _terminate_process_tree(process: ServerProcess) -> None: def _close_subprocess_transport(process: ServerProcess) -> None: """Closes the asyncio subprocess transport, if there is one. - The transport otherwise stays open (and warns at GC) while a surviving - descendant holds a pipe end; nothing public exposes it, hence the attribute - walk. No-op on trio and the Windows fallback. + The transport otherwise stays open (and warns at GC) while a surviving descendant + holds a pipe end; nothing public exposes it, hence the attribute walk. No-op on trio + and the Windows fallback. """ transport = getattr(getattr(process, "_process", None), "_transport", None) # Duck-typed: uvloop's UVProcessTransport is not an asyncio.SubprocessTransport. @@ -318,7 +304,6 @@ def _close_subprocess_transport(process: ServerProcess) -> None: def _get_executable_command(command: str) -> str: - """Normalizes the command for the current platform.""" if sys.platform == "win32": # pragma: no cover return get_windows_executable_command(command) else: # pragma: lax no cover @@ -332,10 +317,7 @@ async def _create_platform_compatible_process( errlog: TextIO = sys.stderr, cwd: Path | str | None = None, ) -> ServerProcess: - """Spawns the server in its own kill scope. - - A new session/process group on POSIX, a Job Object on Windows. - """ + """Spawns the server in its own kill scope: a new session on POSIX, a Job Object on Windows.""" if sys.platform == "win32": # pragma: no cover return await create_windows_process(command, args, env, errlog, cwd) else: # pragma: lax no cover @@ -349,6 +331,5 @@ async def _create_platform_compatible_process( async def _aclose_all(*streams: AsyncResource) -> None: - """Closes every given stream.""" for stream in streams: await stream.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index f28eb7c7ab..93ff213edf 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -46,9 +46,8 @@ MCP_SESSION_ID = "mcp-session-id" LAST_EVENT_ID = "last-event-id" -# Reconnection defaults -DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry -MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up +DEFAULT_RECONNECTION_DELAY_MS = 1000 # fallback when the server's SSE retry field is absent +MAX_RECONNECTION_ATTEMPTS = 2 class StreamableHTTPError(Exception): @@ -74,29 +73,15 @@ class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" def __init__(self, url: str) -> None: - """Initialize the StreamableHTTP transport. - - Args: - url: The endpoint URL. - """ self.url = url self.session_id: str | None = None - # Captured from each stamped POST's metadata. Reused on outbound HTTP that carries - # no per-message header (transport-internal GET/DELETE, and dispatcher-written - # response/error/cancel POSTs that bypass the session's stamp). Cleared when an - # `initialize` POST goes out so a probe-stamped value cannot leak onto the handshake. + # Captured from each stamped POST's metadata and reused on outbound HTTP that carries no + # per-message header (transport-internal GET/DELETE, dispatcher-written response/error/cancel + # POSTs). Cleared on `initialize` POSTs so a probe-stamped value can't leak onto the handshake. self._protocol_version_header: str | None = None def _prepare_headers(self) -> dict[str, str]: - """Build MCP-specific request headers for any outbound HTTP request. - - These are merged with the ``httpx.AsyncClient`` defaults (these take - precedence). The cached ``MCP-Protocol-Version`` is included whenever - present so messages that don't pass through the session's stamp β€” - response/error/cancel POSTs, transport-internal GET/DELETE β€” still - carry the negotiated version. Per-message headers are layered on top - by the caller. - """ + """Build MCP headers, overriding `httpx.AsyncClient` defaults; callers layer per-message headers on top.""" headers: dict[str, str] = { "accept": "application/json, text/event-stream", "content-type": "application/json", @@ -108,15 +93,12 @@ def _prepare_headers(self) -> dict[str, str]: return headers def _is_initialization_request(self, message: JSONRPCMessage) -> bool: - """Check if the message is an initialization request.""" return isinstance(message, JSONRPCRequest) and message.method == "initialize" def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: - """Check if the message is an initialized notification.""" return isinstance(message, JSONRPCNotification) and message.method == "notifications/initialized" def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None: - """Extract and store session ID from response headers.""" new_session_id = response.headers.get(MCP_SESSION_ID) if new_session_id: self.session_id = new_session_id @@ -131,9 +113,8 @@ async def _handle_sse_event( ) -> bool: """Handle an SSE event, returning True if the response is complete.""" if sse.event == "message": - # Handle priming events (empty data with ID) for resumability + # Priming event (empty data with ID) for resumability if not sse.data: - # Call resumption callback for priming events that have an ID if sse.id and resumption_callback: await resumption_callback(sse.id) return False @@ -141,24 +122,19 @@ async def _handle_sse_event( message = jsonrpc_message_adapter.validate_json(sse.data, by_name=False) logger.debug(f"SSE message: {message}") - # If this is a response and we have original_request_id, replace it if original_request_id is not None and isinstance(message, JSONRPCResponse | JSONRPCError): message.id = original_request_id session_message = SessionMessage(message) await read_stream_writer.send(session_message) - # Call resumption token callback if we have an ID if sse.id and resumption_callback: await resumption_callback(sse.id) - # If this is a response or error return True indicating completion - # Otherwise, return False to continue listening return isinstance(message, JSONRPCResponse | JSONRPCError) # Forwarding to a closed read stream lands here when the caller cancels mid-SSE - # (BrokenResourceError, not a parse failure); coverage is timing-dependent in the - # streaming story's modern HTTP cancellation leg. + # (BrokenResourceError, not a parse failure); coverage is timing-dependent. except Exception as exc: # pragma: lax no cover logger.exception("Error parsing SSE message") if original_request_id is not None: @@ -192,16 +168,14 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: logger.debug("GET SSE connection established") async for sse in event_source.aiter_sse(): - # Track last event ID for reconnection if sse.id: last_event_id = sse.id - # Track retry interval from server if sse.retry is not None: retry_interval_ms = sse.retry await self._handle_sse_event(sse, read_stream_writer) - # Stream ended normally (server closed) - reset attempt counter + # Stream ended normally (server closed) β€” reset attempt counter attempt = 0 except Exception: @@ -212,7 +186,6 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: logger.debug(f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") return - # Wait before reconnecting delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS logger.info(f"GET stream disconnected, reconnecting in {delay_ms}ms...") await anyio.sleep(delay_ms / 1000.0) @@ -225,7 +198,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: else: raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover - # Extract original request ID to map responses original_request_id = None if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch original_request_id = ctx.session_message.message.id @@ -246,12 +218,11 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: break async def _handle_post_request(self, ctx: RequestContext) -> None: - """Handle a POST request with response processing.""" message = ctx.session_message.message is_initialization = self._is_initialization_request(message) if is_initialization: # `initialize` is the negotiation, not a "subsequent request" β€” discard any - # probe-stamped value so the discoverβ†’fallback path can't leak it onto the handshake. + # probe-stamped value so it can't leak onto the handshake. self._protocol_version_header = None headers = self._prepare_headers() if ctx.metadata is not None and ctx.metadata.headers is not None: @@ -271,10 +242,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: if response.status_code >= 400: if isinstance(message, JSONRPCRequest): - # A spec-correct server may return the JSON-RPC error in the - # body at a non-2xx status (e.g. 400 for INVALID_PARAMS, 404 - # for METHOD_NOT_FOUND). Surface that error rather than the - # status-derived stand-in below. + # A spec-correct server may return the JSON-RPC error in the body at a non-2xx + # status (e.g. 400 for INVALID_PARAMS, 404 for METHOD_NOT_FOUND); surface it + # rather than the status-derived stand-in below. if response.headers.get("content-type", "").lower().startswith("application/json"): try: body = await response.aread() @@ -290,9 +260,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.debug("Non-2xx body was not a JSON-RPC error; using fallback") if response.status_code == 404: if self.session_id is None: - # No session yet β†’ 404 is the HTTP-level spelling of - # METHOD_NOT_FOUND (gateway / legacy server doesn't know - # this method); "Session terminated" would be a lie here. + # No session yet β†’ 404 is the HTTP-level spelling of METHOD_NOT_FOUND + # (gateway/legacy server); "Session terminated" would be a lie here. error_data = ErrorData(code=METHOD_NOT_FOUND, message="Not Found") else: error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") @@ -326,7 +295,6 @@ async def _handle_json_response( *, request_id: RequestId, ) -> None: - """Handle JSON response from the server.""" try: content = await response.aread() message = jsonrpc_message_adapter.validate_json(content, by_name=False) @@ -343,23 +311,19 @@ async def _handle_sse_response( response: httpx.Response, ctx: RequestContext, ) -> None: - """Handle SSE response from the server.""" last_event_id: str | None = None retry_interval_ms: int | None = None - # The caller (_handle_post_request) only reaches here inside - # isinstance(message, JSONRPCRequest), so this is always a JSONRPCRequest. + # _handle_post_request only calls this for JSONRPCRequest messages. assert isinstance(ctx.session_message.message, JSONRPCRequest) original_request_id = ctx.session_message.message.id try: event_source = EventSource(response) async for sse in event_source.aiter_sse(): # pragma: no branch - # Track last event ID for potential reconnection if sse.id: last_event_id = sse.id - # Track retry interval from server if sse.retry is not None: retry_interval_ms = sse.retry @@ -369,15 +333,13 @@ async def _handle_sse_response( original_request_id=original_request_id, resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), ) - # If the SSE event indicates completion, like returning response/error - # break the loop if is_complete: await response.aclose() return # Normal completion, no reconnect needed except Exception: logger.debug("SSE stream ended", exc_info=True) # pragma: lax no cover - # Stream ended without response - reconnect if we received an event with ID + # Stream ended without a response β€” reconnect if we saw an event with an ID if last_event_id is not None: # pragma: no branch logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) @@ -390,19 +352,16 @@ async def _handle_reconnection( attempt: int = 0, ) -> None: """Reconnect with Last-Event-ID to resume stream after server disconnect.""" - # Bail if max retries exceeded if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") return - # Always wait - use server value or default delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS await anyio.sleep(delay_ms / 1000.0) headers = self._prepare_headers() headers[LAST_EVENT_ID] = last_event_id - # Extract original request ID to map responses original_request_id = None if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch original_request_id = ctx.session_message.message.id @@ -412,7 +371,6 @@ async def _handle_reconnection( event_source.response.raise_for_status() logger.info("Reconnected to SSE stream") - # Track for potential further reconnection reconnect_last_event_id: str = last_event_id reconnect_retry_ms = retry_interval_ms @@ -432,12 +390,11 @@ async def _handle_reconnection( await event_source.response.aclose() return - # Stream ended again without response - reconnect again (reset attempt counter) + # Stream ended again without a response β€” reconnect with a fresh attempt counter logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) except Exception as e: # pragma: no cover logger.debug(f"Reconnection failed: {e}") - # Try to reconnect again if we still have an event ID await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) async def post_writer( @@ -461,12 +418,10 @@ async def _handle_message(session_message: SessionMessage) -> None: else None ) - # Check if this is a resumption request is_resumption = bool(metadata and metadata.resumption_token) logger.debug(f"Sending client message: {message}") - # Handle initialized notification if self._is_initialized_notification(message): start_get_stream() @@ -484,7 +439,6 @@ async def handle_request_async(): else: await self._handle_post_request(ctx) - # If this is a request, start a new task to handle it if isinstance(message, JSONRPCRequest): tg.start_soon(handle_request_async) else: @@ -533,29 +487,17 @@ async def streamable_http_client( http_client: httpx.AsyncClient | None = None, terminate_on_close: bool = True, ) -> AsyncGenerator[TransportStreams, None]: - """Client transport for StreamableHTTP. + """Client transport for StreamableHTTP, yielding (read_stream, write_stream). Args: - url: The MCP server endpoint URL. - http_client: Optional pre-configured httpx.AsyncClient. If None, a default - client with recommended MCP timeouts will be created. To configure headers, - authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here. - terminate_on_close: If True, send a DELETE request to terminate the session when the context exits. - - Yields: - Tuple containing: - - read_stream: Stream for reading messages from the server - - write_stream: Stream for sending messages to the server - - Example: - See examples/snippets/clients/ for usage patterns. + http_client: Pre-configured `httpx.AsyncClient`; defaults to one with recommended + MCP timeouts. Pass your own to configure headers, auth, or other HTTP settings. + terminate_on_close: Send a session-terminating DELETE request when the context exits. """ - # Determine if we need to create and manage the client client_provided = http_client is not None client = http_client if client is None: - # Create default client with recommended MCP timeouts client = create_mcp_http_client() transport = StreamableHTTPTransport(url) @@ -563,7 +505,7 @@ async def streamable_http_client( logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") async with contextlib.AsyncExitStack() as stack: - # Only manage client lifecycle if we created it + # Only manage the client's lifecycle if we created it if not client_provided: await stack.enter_async_context(client) diff --git a/src/mcp/os/posix/utilities.py b/src/mcp/os/posix/utilities.py index d15be17194..723d9f5848 100644 --- a/src/mcp/os/posix/utilities.py +++ b/src/mcp/os/posix/utilities.py @@ -10,19 +10,17 @@ logger = logging.getLogger(__name__) -# How often to probe for surviving group members between SIGTERM and SIGKILL. _GROUP_POLL_INTERVAL = 0.01 async def terminate_posix_process_tree(process: Process, timeout_seconds: float = 2.0) -> None: - """Terminates a process and all its descendants on POSIX. - - SIGTERMs the process group, waits up to timeout_seconds for it to - disappear, then SIGKILLs whatever remains. killpg reaches every descendant - atomically, even ones whose parent already exited; daemonizers that left - the group escape by design. A group only disappears once every member is - dead and reaped, so a client running as PID 1 should reap orphans (e.g. - docker run --init) or the wait below runs its full timeout. + """SIGTERM the process group, wait up to timeout_seconds, then SIGKILL whatever remains. + + killpg reaches every descendant atomically, even ones whose parent already + exited; daemonizers that left the group escape by design. A group only + disappears once every member is dead and reaped, so a client running as + PID 1 should reap orphans (e.g. docker run --init) or the wait below runs + its full timeout. """ # The leader's pid is the pgid (start_new_session). Never use getpgid(): # it fails once the leader is reaped, even with live members left. diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index 1cc867d4fa..6431b7727d 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -15,35 +15,26 @@ logger = logging.getLogger(__name__) -# Windows-specific imports for Job Objects if sys.platform == "win32": import pywintypes import win32api import win32con import win32job else: - # Type stubs for non-Windows platforms win32api = None win32con = None win32job = None pywintypes = None -# How often FallbackProcess polls the underlying Popen for exit. _EXIT_POLL_INTERVAL = 0.01 -# Job Object handle per spawned process, for tree termination at shutdown. -# Values stay pywin32 PyHANDLEs: if no pop site ever runs, the dying weak entry -# drops the last reference and the PyHANDLE destructor closes the handle, which -# is what makes KILL_ON_JOB_CLOSE reap an abandoned tree. +# Job Object handle per spawned process, for tree termination at shutdown. Values stay pywin32 PyHANDLEs: if +# no pop site ever runs, the dying weak entry drops the last reference and KILL_ON_JOB_CLOSE reaps the tree. _process_jobs: "weakref.WeakKeyDictionary[Process | FallbackProcess, object]" = weakref.WeakKeyDictionary() def get_windows_executable_command(command: str) -> str: - """Resolves the command to a Windows executable path. - - Tries the bare name first, then the common script extensions (.cmd, .bat, - .exe, .ps1). - """ + """Resolves the command to a Windows executable path, trying .cmd/.bat/.exe/.ps1 after the bare name.""" try: if command_path := shutil.which(command): return command_path @@ -59,11 +50,7 @@ def get_windows_executable_command(command: str) -> str: class FallbackProcess: - """Async wrapper around subprocess.Popen for SelectorEventLoop. - - Windows event loops without async subprocess support get this Popen-backed - fallback, with anyio file streams wrapping the pipes. - """ + """Async Popen wrapper for Windows event loops without async subprocess support (SelectorEventLoop).""" def __init__(self, popen_obj: subprocess.Popen[bytes]) -> None: self.popen: subprocess.Popen[bytes] = popen_obj @@ -74,17 +61,12 @@ def __init__(self, popen_obj: subprocess.Popen[bytes]) -> None: self.stdout = FileReadStream(cast(BinaryIO, stdout)) if stdout else None async def wait(self) -> int: - """Waits for exit by polling the Popen. - - A thread blocked in Popen.wait() cannot be cancelled by anyio, which - would defeat every timeout placed around this call. - """ + """Polls for exit; a thread blocked in Popen.wait() can't be cancelled, defeating anyio timeouts.""" while (returncode := self.popen.poll()) is None: await anyio.sleep(_EXIT_POLL_INTERVAL) return returncode def terminate(self) -> None: - """Terminates the subprocess.""" self.popen.terminate() def kill(self) -> None: @@ -93,20 +75,15 @@ def kill(self) -> None: @property def pid(self) -> int: - """Returns the process ID.""" return self.popen.pid @property def returncode(self) -> int | None: - """The exit code, or None while the process is still running. - - Polls the Popen so death is observable without anyone calling wait(). - """ + """Exit code, or None while running; polls the Popen so death is observable without wait().""" return self.popen.poll() -# The process handle stdio_client drives: anyio's Process, or the Popen-backed -# fallback used on Windows event loops without async subprocess support. +# The process handle stdio_client drives: anyio's Process or the Popen-backed Windows fallback. ServerProcess: TypeAlias = Process | FallbackProcess @@ -117,34 +94,24 @@ async def create_windows_process( errlog: TextIO | None = sys.stderr, cwd: Path | str | None = None, ) -> Process | FallbackProcess: - """Creates a subprocess with Job Object support for tree termination. - - Spawns via anyio's open_process; event loops without async subprocess - support (notably the SelectorEventLoop) raise NotImplementedError, in which - case the spawn falls back to a Popen-backed FallbackProcess. Either way the - process is then assigned to a Job Object so its children can be terminated - with it; children spawned before the assignment completes are not captured - (see the inline note below). + """Creates a subprocess assigned to a Job Object so its children can be terminated with it. - Returns: - Process | FallbackProcess: The spawned process with async stdin/stdout streams. + Event loops without async subprocess support (SelectorEventLoop) raise + NotImplementedError; the spawn then falls back to a Popen-backed FallbackProcess. """ try: process = await anyio.open_process( [command, *args], env=env, - # Ensure we don't create console windows for each process creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), stderr=errlog, cwd=cwd, ) except NotImplementedError: - # Windows event loops without async subprocess support (SelectorEventLoop) process = await _create_windows_fallback_process(command, args, env, errlog, cwd) - # Children spawned before the assignment completes land outside the job - # (membership is inherited at CreateProcess, never acquired retroactively); - # if that ever bites, the fix is a CREATE_SUSPENDED spawn -> assign -> resume. + # Children spawned before the assignment completes land outside the job (membership is inherited + # at CreateProcess, never acquired retroactively); the fix would be CREATE_SUSPENDED spawn -> assign -> resume. job = _create_job_object() _maybe_assign_process_to_job(process, job) return process @@ -165,7 +132,7 @@ async def _create_windows_fallback_process( stderr=errlog, env=env, cwd=cwd, - bufsize=0, # Unbuffered output + bufsize=0, creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), ) return FallbackProcess(popen_obj) @@ -193,10 +160,7 @@ def _create_job_object() -> object | None: def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: object | None) -> None: - """Assigns the process to the job and records it for tree termination. - - On any failure the job handle is closed instead. - """ + """Assigns the process to the job and records it for tree termination; on failure the job handle is closed.""" if job is None: return @@ -225,9 +189,8 @@ def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: object def close_process_job(process: Process | FallbackProcess) -> None: """Closes the process's Job Object handle, if it still has one. - KILL_ON_JOB_CLOSE makes the close also kill any members still alive, - deterministically rather than at GC time; a deliberate divergence from - POSIX, where a graceful server's children are left alive. + KILL_ON_JOB_CLOSE makes the close also kill surviving members β€” deterministically, not at GC + time, and deliberately diverging from POSIX, where a graceful server's children are left alive. """ if sys.platform != "win32": return @@ -238,11 +201,9 @@ def close_process_job(process: Process | FallbackProcess) -> None: async def terminate_windows_process_tree(process: Process | FallbackProcess) -> None: - """Terminates the process's job, or just the process if it has no job. + """Hard-kills the process's job and every member, or just the process if it has no job. - Job termination is an immediate hard kill of every member. Windows has no - tree-wide SIGTERM; the stdin-close grace period is the server's chance to - exit cleanly. + Windows has no tree-wide SIGTERM; the stdin-close grace period is the server's chance to exit cleanly. """ if sys.platform != "win32": return diff --git a/src/mcp/server/_otel.py b/src/mcp/server/_otel.py index ff722eb903..485d165998 100644 --- a/src/mcp/server/_otel.py +++ b/src/mcp/server/_otel.py @@ -59,10 +59,9 @@ async def __call__(self, ctx: ServerRequestContext[Any, Any], call_next: CallNex span.set_status(StatusCode.ERROR, str(e)) raise if ctx.method == "tools/call": - # Tool errors are detected pre-serialization, so only shapes that reach the wire as an error - # count: the model, or the camelCase alias (`is_error` is dropped by the alias-only wire - # validation). A raw-dict `isError` is matched as a literal bool only - non-bool coercible - # values (1, "true") would serialize to an error but are rare enough to leave undetected. + # Detection runs pre-serialization, so only shapes that reach the wire as an error count: the + # model, or the camelCase alias (alias-only wire validation drops `is_error`). A raw-dict + # `isError` is matched as a literal bool; coercible non-bools (1, "true") are rare enough to ignore. match result: case CallToolResult(is_error=True) | {"isError": True}: span.set_attribute("error.type", "tool_error") diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index e36ac7dd4e..51cd07a27d 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -1,19 +1,9 @@ """Single-exchange HTTP serving for protocol version 2026-07-28. -Private module β€” entry is via `StreamableHTTPSessionManager.handle_request`. -The legacy streamable-HTTP transport is untouched and remains the supported -path for earlier protocol revisions. - -A 2026-07-28 request is a self-contained POST: no `initialize` handshake, no -`Mcp-Session-Id`, one JSON-RPC request in, one JSON-RPC response out. JSON -mode handles the request directly in the ASGI task. SSE mode runs the handler -as a sibling task and defers committing to `text/event-stream` until the -handler emits a notification or `_SSE_PING_INTERVAL` elapses, whichever -comes first: a handler that completes (or raises) within that window without -emitting still gets a JSON response with the table-mapped HTTP status, so -the spec's `404`/`400` MUSTs hold for kernel-dispatch errors; a handler that -runs silent past the window commits SSE so the keepalive ping can keep the -connection open behind a proxy idle-read timeout. +Private module β€” entry is via `StreamableHTTPSessionManager.handle_request`; +the legacy streamable-HTTP transport remains the path for earlier revisions. +A 2026-07-28 request is a self-contained POST with no `initialize` handshake +and no `Mcp-Session-Id`: one JSON-RPC request in, one response out. """ from __future__ import annotations @@ -72,12 +62,10 @@ @dataclass class _SingleExchangeDispatchContext: - """`DispatchContext` for one inbound HTTP request. + """Structural `mcp.shared.dispatcher.DispatchContext` for one inbound HTTP request. - Structurally satisfies `mcp.shared.dispatcher.DispatchContext`. The - back-channel is closed by construction: a 2026-07-28 server cannot send - requests to the client. The SSE sink, when present, carries request-scoped - notifications onto this request's response stream. + Back-channel is closed by construction β€” a 2026-07-28 server cannot send requests + to the client. The optional sink carries notifications onto this request's SSE stream. """ transport: TransportContext @@ -119,9 +107,7 @@ async def progress(self, progress: float, total: float | None = None, message: s def _typed(model: type[_ModelT], raw: Any) -> _ModelT | None: """Validate the classifier's raw envelope value into a typed model. - Rung 1 guarantees the envelope key was present; a ``null`` or mis-shaped - value falls through to ``ValidationError`` and is treated as not supplied - so the request still routes. + Rung 1 guaranteed key presence; a `null` or mis-shaped value is treated as not supplied so the request routes. """ try: return model.model_validate(raw, by_name=False) @@ -132,12 +118,10 @@ def _typed(model: type[_ModelT], raw: Any) -> _ModelT | None: async def _to_jsonrpc_response( request_id: RequestId, coro: Awaitable[dict[str, Any]] ) -> JSONRPCResponse | JSONRPCError: - """Await ``coro`` and wrap its outcome as the JSON-RPC reply for ``request_id``. + """Await `coro` and wrap its outcome as the JSON-RPC reply for `request_id`. - The exception-to-wire boundary for the modern HTTP entry, composed around - `serve_one`. `MCPError` and `ValidationError` map via the shared - `handler_exception_to_error_data` ladder; any other exception is logged and - surfaced as `INTERNAL_ERROR` so handler internals never reach the wire. + `MCPError`/`ValidationError` map via the `handler_exception_to_error_data` ladder; + anything else is logged and surfaced as `INTERNAL_ERROR` so handler internals never reach the wire. """ try: result = await coro @@ -151,7 +135,7 @@ async def _to_jsonrpc_response( _SSE_PING_INTERVAL: float = 15.0 -"""Seconds between SSE comment-line keepalives once `text/event-stream` has committed.""" +"""Seconds between SSE keepalive pings, and the deferral window before committing to `text/event-stream`.""" _SSE_HEADERS: Final[list[tuple[bytes, bytes]]] = [ (b"content-type", b"text/event-stream"), @@ -164,8 +148,8 @@ async def _to_jsonrpc_response( def _sse_event(msg: JSONRPCResponse | JSONRPCError | JSONRPCNotification) -> bytes: """Serialise a JSON-RPC message as one SSE `event: message` frame. - SSE mode begins after the handler has emitted, so a `JSONRPCError` here - always carries the request's id; the `id: null` case lives in `_write`. + A `JSONRPCError` here always carries the request's id (unparseable-id + rejections never reach SSE mode), so `exclude_none` cannot drop `id: null`. """ body = msg.model_dump(mode="json", by_alias=True, exclude_none=True) data = json.dumps(body, separators=(",", ":")) @@ -182,8 +166,7 @@ async def _write( status = ERROR_CODE_HTTP_STATUS.get(msg.error.code, _OK_STATUS) if isinstance(msg, JSONRPCError) else _OK_STATUS body = msg.model_dump(mode="json", by_alias=True, exclude_none=True) if isinstance(msg, JSONRPCError) and msg.id is None: - # JSON-RPC requires `id: null` to appear on the wire when the request - # id couldn't be parsed; `exclude_none` would otherwise drop it. + # JSON-RPC requires `id: null` on the wire for unparseable request ids; `exclude_none` drops it. body["id"] = None await Response( json.dumps(body, separators=(",", ":")), @@ -203,10 +186,8 @@ async def handle_modern_request( ) -> None: """ASGI handler for a single stateless-era POST. - Called from `StreamableHTTPSessionManager.handle_request` when the - `MCP-Protocol-Version` header names a modern revision; the manager enters - `app.lifespan` once at startup and passes the state in. Never sets - `Mcp-Session-Id`. + Routed here when `MCP-Protocol-Version` names a modern revision; the session manager + enters `app.lifespan` once at startup and passes the state in. Never sets `Mcp-Session-Id`. """ request = Request(scope, receive) @@ -217,8 +198,7 @@ async def handle_modern_request( return if request.method != "POST": - # HTTP-layer rejection (Allow accompanies 405 per RFC 9110) β€” happens - # before JSON-RPC parsing, so it doesn't go through `_write`. + # HTTP-layer rejection, before JSON-RPC parsing; Allow accompanies 405 per RFC 9110. await Response(status_code=405, headers={"Allow": "POST"})(scope, receive, send) return @@ -237,13 +217,9 @@ async def handle_modern_request( try: req = JSONRPCRequest.model_validate(decoded) except ValidationError: - # Well-formed JSON that isn't a single request object. The transport - # spec permits notification POSTs and gives the server two responses - # (202 accept / 4xx cannot-accept; streamable-http Β§Sending Messages - # item 5). The core protocol defines no clientβ†’server notifications - # over HTTP at 2026-07-28 (cancellation is SSE-stream close), so this - # entry takes the cannot-accept branch. TODO(L57): S4 owns the - # strict-vs-lenient choice. + # Well-formed JSON but not a single request object. The spec permits notification POSTs + # (202 accept / 4xx cannot-accept; streamable-http Β§Sending Messages item 5), but 2026-07-28 has + # no clientβ†’server HTTP notifications (cancellation is SSE close) β€” reject. TODO(L57): strict-vs-lenient. rej = JSONRPCError( jsonrpc="2.0", id=None, @@ -310,15 +286,12 @@ async def watch_disconnect(cancel_scope: anyio.CancelScope) -> None: done = True if done: - # Handler completed within the deferral window without emitting: - # `application/json` with the table-mapped status. Kernel-dispatch - # errors (METHOD_NOT_FOUND, missing-capability, INVALID_PARAMS) - # resolve here in practice. + # Completed within the deferral window without emitting: plain JSON with the + # table-mapped status, so the spec's 404/400 MUSTs hold for kernel-dispatch errors. await _write(result[0], scope, receive, send) else: - # First notification arrived, or the deferral window elapsed: commit - # `text/event-stream` and start pinging so a proxy idle-read timeout - # cannot close the stream (which on this path cancels the handler). + # First notification arrived or the window elapsed: commit `text/event-stream` and + # ping so a proxy idle-read timeout can't close the stream (which would cancel the handler). await send({"type": "http.response.start", "status": _OK_STATUS, "headers": _SSE_HEADERS}) while not done: await send({"type": "http.response.body", "body": event or b": ping\r\n\r\n", "more_body": True}) diff --git a/src/mcp/server/apps.py b/src/mcp/server/apps.py index d5b9d9ed85..ed703c2c06 100644 --- a/src/mcp/server/apps.py +++ b/src/mcp/server/apps.py @@ -1,30 +1,12 @@ """MCP Apps extension (`io.modelcontextprotocol/ui`). -MCP Apps lets a tool carry a reference to an interactive UI: the tool's -`_meta.ui.resourceUri` points at a `ui://` resource (an HTML document served -with the `text/html;profile=mcp-app` MIME type) that the host renders in a -sandboxed iframe. See https://modelcontextprotocol.io/specification/draft/extensions/apps -and the ext-apps spec for the wire format, and SEP-2133 for the extension framework. - -This is a self-contained, additive `Extension`: it contributes tools and -resources and advertises the capability, but does not intercept any core method. -A server opts in by passing an `Apps` instance to `MCPServer(extensions=[...])`. - - apps = Apps() - - @apps.tool(resource_uri="ui://clock/app.html", description="Current time") - def get_time(ctx: Context) -> str: - return datetime.now(timezone.utc).isoformat() - - apps.add_html_resource("ui://clock/app.html", CLOCK_HTML) - - mcp = MCPServer("clock", extensions=[apps]) - -Per SEP-2133, an extension MUST degrade gracefully: a UI-enabled tool should -still return meaningful text for clients that did not negotiate Apps. Use -`client_supports_apps(ctx)` to branch on the client's advertised support. (The SDK -keeps Apps in-core under `mcp.server.apps` rather than a separate package; the -TypeScript and C# SDKs ship it as a standalone package.) +A tool's `_meta.ui.resourceUri` points at a `ui://` resource (HTML served as +`text/html;profile=mcp-app`) that the host renders in a sandboxed iframe; a +server opts in via `MCPServer(extensions=[Apps()])`. Per SEP-2133 a UI-enabled +tool must degrade gracefully for clients that did not negotiate Apps β€” branch on +`client_supports_apps(ctx)`. Ships in-core (the TypeScript and C# SDKs package it +separately). Wire format: +https://modelcontextprotocol.io/specification/draft/extensions/apps """ from __future__ import annotations @@ -41,7 +23,7 @@ def get_time(ctx: Context) -> str: from mcp.server.mcpserver.resources import Resource, TextResource EXTENSION_ID = "io.modelcontextprotocol/ui" -"""The MCP Apps extension identifier (the shipped TS/C# constant).""" +"""The MCP Apps extension identifier.""" APP_MIME_TYPE = "text/html;profile=mcp-app" """MIME type for a `ui://` app resource.""" @@ -77,9 +59,8 @@ class ResourceCsp(BaseModel): class Apps(Extension): """The MCP Apps extension: bind tools to `ui://` UI resources. - Register UI-bound tools with `@apps.tool(resource_uri=...)` and their HTML - with `add_html_resource(...)`, then pass the instance to - `MCPServer(extensions=[apps])`. + Register tools with `@apps.tool(resource_uri=...)`, their HTML with + `add_html_resource(...)`, then pass the instance to `MCPServer(extensions=[apps])`. """ identifier = EXTENSION_ID @@ -99,14 +80,8 @@ def tool( """Decorator registering a tool bound to a `ui://` resource. Stamps `_meta.ui.resourceUri` (and `_meta.ui.visibility` when given) on the - tool. `tool_kwargs` are forwarded to `MCPServer.add_tool` (name, title, - description, annotations, ...); pass `meta=` to merge extra `_meta` keys - alongside the `ui` entry. - - Args: - resource_uri: The `ui://` URI of the UI resource this tool renders. - visibility: Where the tool is surfaced (`["model", "app"]`). - meta: Additional `_meta` keys to merge with the `ui` entry. + tool; `tool_kwargs` are forwarded to `MCPServer.add_tool` and `meta` keys + merge alongside the `ui` entry. Raises: ValueError: If `resource_uri` does not use the `ui://` scheme, or @@ -144,10 +119,6 @@ def add_html_resource( `csp`, `permissions`, `domain`, and `prefers_border` populate the resource's `_meta.ui` per the ext-apps spec. - Args: - uri: The `ui://` URI; a tool references it via `resource_uri`. - html: The HTML document the host renders. - Raises: ValueError: If `uri` does not use the `ui://` scheme. """ @@ -173,17 +144,15 @@ def add_html_resource( ) def add_resource(self, resource: Resource) -> None: - """Register a pre-built `ui://` resource. + """Register a pre-built `ui://` resource (e.g. a `FileResource` serving HTML from disk). - The escape hatch for resources `add_html_resource` cannot express (e.g. a - `FileResource` serving HTML from disk). A resource without an explicit - `mime_type` is served as `text/html;profile=mcp-app` β€” hosts will not - render a `ui://` resource under any other MIME type, so an explicit - mismatch is rejected. + Without an explicit `mime_type` the resource is served as + `text/html;profile=mcp-app`; hosts render `ui://` resources only under + that MIME type, so an explicit mismatch is rejected. Raises: - ValueError: If the resource URI does not use the `ui://` scheme, or - its explicit `mime_type` is not `text/html;profile=mcp-app`. + ValueError: If the URI does not use the `ui://` scheme, or an + explicit `mime_type` is not `text/html;profile=mcp-app`. """ _require_ui_scheme(resource.uri) if "mime_type" not in resource.model_fields_set: @@ -196,10 +165,9 @@ def tools(self) -> Sequence[ToolBinding]: """The bound tools. Raises: - ValueError: If a tool's `resource_uri` has no matching resource - registered on this instance β€” a tool advertising a - `_meta.ui.resourceUri` that 404s on `resources/read` is a - misconfiguration, caught when the server consumes the extension. + ValueError: If a tool's `resource_uri` has no resource registered on + this instance β€” a `_meta.ui.resourceUri` that 404s on + `resources/read` is a misconfiguration. """ registered = {binding.resource.uri for binding in self._resources} for tool, uri in self._tools: @@ -217,9 +185,9 @@ def resources(self) -> Sequence[ResourceBinding]: def client_supports_apps(ctx: Context[Any] | ServerRequestContext[Any, Any]) -> bool: """Whether the connected client negotiated MCP Apps support. - Returns `True` only when the client advertised the extension AND listed the - `text/html;profile=mcp-app` MIME type in its settings, so a UI-enabled tool - can fall back to text-only output otherwise. + True only when the client advertised the extension AND listed + `text/html;profile=mcp-app` in its settings; UI-enabled tools should fall + back to text-only output otherwise. """ capabilities = _client_capabilities(ctx) extensions = capabilities.extensions if capabilities else None diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 5cf93cf8c2..8008be42c9 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -27,7 +27,6 @@ class AuthorizationRequest(BaseModel): client_id: str = Field(..., description="The client ID") redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization") - # see OAuthClientMetadata; we only support `code` response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow") code_challenge: str = Field(..., description="PKCE code challenge") code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256") @@ -68,8 +67,7 @@ class AuthorizationHandler: provider: OAuthAuthorizationServerProvider[Any, Any, Any] async def handle(self, request: Request) -> Response: - # implements authorization requests for grant_type=code; - # see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 + # authorization endpoint for the code grant; see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 state = None redirect_uri = None @@ -81,28 +79,14 @@ async def error_response( error_description: str | None, attempt_load_client: bool = True, ): - # Error responses take two different formats: - # 1. The request has a valid client ID & redirect_uri: we issue a redirect - # back to the redirect_uri with the error response fields as query - # parameters. This allows the client to be notified of the error. - # 2. Otherwise, we return an error response directly to the end user; - # we choose to do so in JSON, but this is left undefined in the - # specification. - # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 - # - # This logic is a bit awkward to handle, because the error might be thrown - # very early in request validation, before we've done the usual Pydantic - # validation, loaded the client, etc. To handle this, error_response() - # contains fallback logic which attempts to load the parameters directly - # from the request. - + # Per https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1: with a valid client and + # redirect_uri we redirect back with the error in query params, else respond directly in JSON (a + # case the spec leaves undefined). Errors may predate validation, so fall back to the raw params. nonlocal client, redirect_uri, state if client is None and attempt_load_client: - # make last-ditch attempt to load the client client_id = best_effort_extract_string("client_id", params) client = await self.provider.get_client(client_id) if client_id else None if redirect_uri is None and client: - # make last-ditch effort to load the redirect uri try: if params is not None and "redirect_uri" not in params: raw_redirect_uri = None @@ -112,13 +96,11 @@ async def error_response( ).root redirect_uri = client.validate_redirect_uri(raw_redirect_uri) except (ValidationError, InvalidRedirectUriError): - # if the redirect URI is invalid, ignore it & just return the - # initial error + # invalid redirect_uri: ignore it and return the original error directly pass # the error response MUST contain the state specified by the client, if any if state is None: - # make last-ditch effort to load state state = best_effort_extract_string("state", params) error_resp = AuthorizationErrorResponse( @@ -141,20 +123,17 @@ async def error_response( ) try: - # Parse request parameters if request.method == "GET": - # Convert query_params to dict for pydantic validation params = request.query_params else: - # Parse form data for POST requests params = await request.form() - # Save state if it exists, even before validation + # capture state before validation so even early error responses can echo it state = best_effort_extract_string("state", params) try: auth_request = AuthorizationRequest.model_validate(params) - state = auth_request.state # Update with validated state + state = auth_request.state except ValidationError as validation_error: error: AuthorizationErrorCode = "invalid_request" for e in validation_error.errors(): @@ -163,39 +142,33 @@ async def error_response( break return await error_response(error, stringify_pydantic_error(validation_error)) - # Get client information client = await self.provider.get_client( auth_request.client_id, ) if not client: - # For client_id validation errors, return direct error (no redirect) return await error_response( error="invalid_request", error_description=f"Client ID '{auth_request.client_id}' not found", attempt_load_client=False, ) - # Validate redirect_uri against client's registered URIs try: redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri) except InvalidRedirectUriError as validation_error: - # For redirect_uri validation errors, return direct error (no redirect) return await error_response( error="invalid_request", error_description=validation_error.message, ) - # Validate scope - for scope errors, we can redirect + # unlike client_id/redirect_uri errors above, scope errors may redirect back to the client try: scopes = client.validate_scope(auth_request.scope) except InvalidScopeError as validation_error: - # For scope errors, redirect with error parameters return await error_response( error="invalid_scope", error_description=validation_error.message, ) - # Setup authorization parameters auth_params = AuthorizationParams( state=state, scopes=scopes, @@ -206,7 +179,6 @@ async def error_response( ) try: - # Let the provider pick the next URI to redirect to return RedirectResponse( url=await self.provider.authorize( client, @@ -216,10 +188,8 @@ async def error_response( headers={"Cache-Control": "no-store"}, ) except AuthorizeError as e: - # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 return await error_response(error=e.error, error_description=e.error_description) except Exception as validation_error: # pragma: no cover - # Catch-all for unexpected errors logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) return await error_response(error="server_error", error_description="An unexpected error occurred") diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index f126442150..4606a47b47 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -14,7 +14,7 @@ class MetadataHandler: async def handle(self, request: Request) -> Response: return PydanticJSONResponse( content=self.metadata, - headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour + headers={"Cache-Control": "public, max-age=3600"}, ) @@ -25,5 +25,5 @@ class ProtectedResourceMetadataHandler: async def handle(self, request: Request) -> Response: return PydanticJSONResponse( content=self.metadata, - headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour + headers={"Cache-Control": "public, max-age=3600"}, ) diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index e565b27383..9a36105123 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -14,8 +14,7 @@ from mcp.server.auth.settings import ClientRegistrationOptions from mcp.shared.auth import JWT_BEARER_GRANT_TYPE, OAuthClientInformationFull, OAuthClientMetadata -# this alias is a no-op; it's just to separate out the types exposed to the -# provider from what we use in the HTTP handler +# No-op alias separating the provider-facing type from the HTTP handler's RegistrationRequest = OAuthClientMetadata @@ -30,12 +29,10 @@ class RegistrationHandler: options: ClientRegistrationOptions async def handle(self, request: Request) -> Response: - # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 + # Dynamic client registration per https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 try: body = await request.body() client_metadata = OAuthClientMetadata.model_validate_json(body) - - # Scope validation is handled below except ValidationError as validation_error: return PydanticJSONResponse( content=RegistrationErrorResponse( @@ -47,13 +44,11 @@ async def handle(self, request: Request) -> Response: client_id = str(uuid4()) - # If auth method is None, default to client_secret_post if client_metadata.token_endpoint_auth_method is None: client_metadata.token_endpoint_auth_method = "client_secret_post" client_secret = None if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch - # cryptographically secure random 32-byte hex string client_secret = secrets.token_hex(32) if client_metadata.scope is None and self.options.default_scopes is not None: @@ -79,9 +74,8 @@ async def handle(self, request: Request) -> Response: status_code=400, ) - # SEP-990 Β§5.1 / draft-ietf-oauth-identity-assertion-authz-grant Β§8.1: the ID-JAG flow is - # for confidential clients provisioned out of band. Refuse to grant it through DCR so a - # self-registered client cannot reach the identity-assertion provider hook. + # SEP-990 Β§5.1 / draft-ietf-oauth-identity-assertion-authz-grant Β§8.1: the ID-JAG flow is for + # out-of-band-provisioned confidential clients, so refuse to grant it to self-registered ones. if JWT_BEARER_GRANT_TYPE in client_metadata.grant_types: return PydanticJSONResponse( content=RegistrationErrorResponse( @@ -94,8 +88,7 @@ async def handle(self, request: Request) -> Response: status_code=400, ) - # The MCP spec requires servers to use the authorization `code` flow - # with PKCE + # The MCP spec requires the authorization code flow with PKCE if "code" not in client_metadata.response_types: return PydanticJSONResponse( content=RegistrationErrorResponse( @@ -117,7 +110,6 @@ async def handle(self, request: Request) -> Response: client_id_issued_at=client_id_issued_at, client_secret=client_secret, client_secret_expires_at=client_secret_expires_at, - # passthrough information from the client request redirect_uris=client_metadata.redirect_uris, token_endpoint_auth_method=client_metadata.token_endpoint_auth_method, grant_types=client_metadata.grant_types, @@ -135,13 +127,10 @@ async def handle(self, request: Request) -> Response: software_version=client_metadata.software_version, ) try: - # Register client await self.provider.register_client(client_info) - - # Return client information return PydanticJSONResponse(content=client_info, status_code=201) except RegistrationError as e: - # Handle registration errors as defined in RFC 7591 Section 3.2.2 + # Error response per RFC 7591 Β§3.2.2 return PydanticJSONResponse( content=RegistrationErrorResponse(error=e.error, error_description=e.error_description), status_code=400, diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 4efd154001..ddc21b8350 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -71,13 +71,11 @@ async def handle(self, request: Request) -> Response: if token is not None: break - # if token is not found, just return HTTP 200 per the RFC + # Unknown or mismatched tokens still get HTTP 200 per RFC 7009 section 2.2 if token and token.client_id == client.client_id: - # Revoke token; provider is not meant to be able to do validation - # at this point that would result in an error + # The provider is not expected to raise here; validation happened when loading the token await self.provider.revoke_token(token) - # Return successful empty response return Response( status_code=200, headers={ diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 0e644c378a..6208e702e0 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -39,23 +39,17 @@ class RefreshTokenRequest(BaseModel): refresh_token: str = Field(..., description="The refresh token") scope: str | None = Field(None, description="Optional scope parameter") client_id: str - # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None - # RFC 8707 resource indicator resource: str | None = Field(None, description="Resource indicator for the token") class JwtBearerRequest(BaseModel): - # RFC 7523 Β§2.1 JWT bearer authorization grant. SEP-990 leg 2: the client presents the - # enterprise IdP-issued ID-JAG to the MCP authorization server as the `assertion`. + # RFC 7523 Β§2.1 JWT bearer grant. SEP-990 leg 2: client presents the enterprise IdP-issued ID-JAG as `assertion`. grant_type: Literal["urn:ietf:params:oauth:grant-type:jwt-bearer"] - # See https://datatracker.ietf.org/doc/html/rfc7523#section-2.1 assertion: str = Field(..., description="The ID-JAG (a signed JWT) being presented as the grant") scope: str | None = Field(None, description="Optional scope parameter") client_id: str - # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None - # RFC 8707 resource indicator resource: str | None = Field(None, description="Resource indicator for the token") @@ -74,9 +68,7 @@ class TokenErrorResponse(BaseModel): error_uri: AnyHttpUrl | None = None -# this is just an alias over OAuthToken; the only reason we do this -# is to have some separation between the HTTP response type, and the -# type returned by the provider +# alias to separate the HTTP response type from the provider's return type TokenSuccessResponse = OAuthToken @@ -104,7 +96,6 @@ async def handle(self, request: Request): try: client_info = await self.client_authenticator.authenticate_request(request) except AuthenticationError as e: - # Authentication failures should return 401 return PydanticJSONResponse( content=TokenErrorResponse( error="invalid_client", @@ -151,8 +142,7 @@ async def handle(self, request: Request): ) ) - # make auth codes expire after a deadline - # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 + # enforce expiry per https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 if auth_code.expires_at < time.time(): return self.response( TokenErrorResponse( @@ -161,14 +151,13 @@ async def handle(self, request: Request): ) ) - # verify redirect_uri doesn't change between /authorize and /tokens - # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 + # redirect_uri must match /authorize, see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 if auth_code.redirect_uri_provided_explicitly: authorize_request_redirect_uri = auth_code.redirect_uri else: # pragma: no cover authorize_request_redirect_uri = None - # Convert both sides to strings for comparison to handle AnyUrl vs string issues + # compare as strings to handle AnyUrl vs str mismatches token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None auth_redirect_str = ( str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None @@ -182,7 +171,6 @@ async def handle(self, request: Request): ) ) - # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") @@ -196,7 +184,6 @@ async def handle(self, request: Request): ) try: - # Exchange authorization code for tokens tokens = await self.provider.exchange_authorization_code(client_info, auth_code) except TokenError as e: return self.response(TokenErrorResponse(error=e.error, error_description=e.error_description)) @@ -213,7 +200,6 @@ async def handle(self, request: Request): ) if refresh_token.expires_at and refresh_token.expires_at < time.time(): - # if the refresh token has expired, pretend it doesn't exist return self.response( TokenErrorResponse( error="invalid_grant", @@ -221,7 +207,6 @@ async def handle(self, request: Request): ) ) - # Parse scopes if provided scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes for scope in scopes: @@ -234,7 +219,6 @@ async def handle(self, request: Request): ) try: - # Exchange refresh token for new tokens tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) except TokenError as e: return self.response(TokenErrorResponse(error=e.error, error_description=e.error_description)) @@ -248,13 +232,11 @@ async def handle(self, request: Request): ) ) - # SEP-990 Β§5.1: only confidential clients may present an ID-JAG. ClientAuthenticator - # already rejects a secret-based method with no stored secret; this additionally - # rejects the public `none` method so an unauthenticated client never reaches the - # provider hook. + # SEP-990 Β§5.1: only confidential clients may present an ID-JAG. ClientAuthenticator already + # rejects secret-based methods lacking a stored secret; this blocks `none` before the provider hook. if not client_info.client_secret: - # RFC 6749 Β§5.2: the client authenticated but is not permitted this grant, so - # unauthorized_client (not invalid_client, which is for failed authentication). + # RFC 6749 Β§5.2: authenticated but not permitted this grant, so unauthorized_client + # rather than invalid_client (which signals failed authentication). return self.response( TokenErrorResponse( error="unauthorized_client", diff --git a/src/mcp/server/auth/json_response.py b/src/mcp/server/auth/json_response.py index bd95bd693b..b18cc37639 100644 --- a/src/mcp/server/auth/json_response.py +++ b/src/mcp/server/auth/json_response.py @@ -4,7 +4,6 @@ class PydanticJSONResponse(JSONResponse): - # use pydantic json serialization instead of the stock `json.dumps`, - # so that we can handle serializing pydantic models like AnyHttpUrl + # Pydantic serialization instead of stock json.dumps, so models with fields like AnyHttpUrl serialize. def render(self, content: Any) -> bytes: return content.model_dump_json(exclude_none=True).encode("utf-8") diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1d34a5546b..4817fbb028 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -5,28 +5,19 @@ from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser from mcp.server.auth.provider import AccessToken -# Create a contextvar to store the authenticated user -# The default is None, indicating no authenticated user is present auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None) def get_access_token() -> AccessToken | None: - """Get the access token from the current context. - - Returns: - The access token if an authenticated user is available, None otherwise. - """ + """Get the authenticated user's access token from the current context, or None.""" auth_user = auth_context_var.get() return auth_user.access_token if auth_user else None class AuthContextMiddleware: - """Middleware that extracts the authenticated user from the request - and sets it in a contextvar for easy access throughout the request lifecycle. + """Stores the authenticated user in a contextvar for the duration of the request. - This middleware should be added after the AuthenticationMiddleware in the - middleware stack to ensure that the user is properly authenticated before - being stored in the context. + Must be added after AuthenticationMiddleware so `scope["user"]` is populated. """ def __init__(self, app: ASGIApp): @@ -35,12 +26,10 @@ def __init__(self, app: ASGIApp): async def __call__(self, scope: Scope, receive: Receive, send: Send): user = scope.get("user") if isinstance(user, AuthenticatedUser): - # Set the authenticated user in the contextvar token = auth_context_var.set(user) try: await self.app(scope, receive, send) finally: auth_context_var.reset(token) else: - # No authenticated user, just process the request await self.app(scope, receive, send) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index ba66e94226..4f2a5b3dae 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -26,14 +26,11 @@ class AuthorizationContext(TypedDict): def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: - """Identify the principal `user` represents, for transports to compare - against the principal that created a session. Components the token - verifier does not supply are `None`, so the comparison degrades to the - remaining components. - - See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for - a verifier that populates `subject` and `claims` from an introspection - response.""" + """Identify the principal `user` represents, for transports to compare against a session's creator. + + Components the token verifier does not supply are `None`, so the comparison degrades to the rest. + See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for a populating verifier. + """ token = user.access_token issuer = (token.claims or {}).get("iss") return AuthorizationContext( @@ -59,7 +56,6 @@ async def authenticate(self, conn: HTTPConnection): token = auth_header[7:] # Remove "Bearer " prefix - # Validate the token with the verifier auth_info = await self.token_verifier.verify_token(token) if not auth_info: @@ -72,10 +68,9 @@ async def authenticate(self, conn: HTTPConnection): class RequireAuthMiddleware: - """Middleware that requires a valid Bearer token in the Authorization header. + """Middleware that rejects requests lacking a valid Bearer token with the required scopes. - This will validate the token with the auth provider and store the resulting - auth info in the request state. + Error responses carry a WWW-Authenticate header, advertising `resource_metadata_url` when set. """ def __init__( @@ -84,13 +79,6 @@ def __init__( required_scopes: list[str], resource_metadata_url: AnyHttpUrl | None = None, ): - """Initialize the middleware. - - Args: - app: ASGI application - required_scopes: List of scopes that the token must have - resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header - """ self.app = app self.required_scopes = required_scopes self.resource_metadata_url = resource_metadata_url @@ -116,15 +104,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None: - """Send an authentication error response with WWW-Authenticate header.""" - # Build WWW-Authenticate header value www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] if self.resource_metadata_url: www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') www_authenticate = f"Bearer {', '.join(www_auth_parts)}" - # Send response body = {"error": error, "error_description": description} body_bytes = json.dumps(body).encode() diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 3d5067d611..9a82af8540 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -17,39 +17,20 @@ def __init__(self, message: str): class ClientAuthenticator: - """ClientAuthenticator is a callable which validates requests from a client - application, used to verify /token calls. + """Validates client credentials on /token calls. - If, during registration, the client requested to be issued a secret, the - authenticator asserts that /token calls must be authenticated with - that same secret. - - NOTE: clients can opt for no authentication during registration, in which case this - logic is skipped. + A client that was issued a secret at registration must present that secret; + clients that registered with no authentication skip this check. """ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): - """Initialize the authenticator. - - Args: - provider: Provider to look up client information - """ self.provider = provider async def authenticate_request(self, request: Request) -> OAuthClientInformationFull: - """Authenticate a client from an HTTP request. - - Extracts client credentials from the appropriate location based on the - client's registered authentication method and validates them. - - Args: - request: The HTTP request containing client credentials - - Returns: - The authenticated client information + """Validate the request's client credentials per the client's registered auth method. Raises: - AuthenticationError: If authentication fails + AuthenticationError: If authentication fails. """ form_data = await request.form() client_id = form_data.get("client_id") @@ -85,7 +66,7 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation elif client.token_endpoint_auth_method == "client_secret_post": raw_form_data = form_data.get("client_secret") - # form_data.get() can return an UploadFile or None, so we need to check if it's a string + # form_data.get() can return an UploadFile, not just str/None if isinstance(raw_form_data, str): request_client_secret = str(raw_form_data) @@ -96,20 +77,15 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation f"Unsupported auth method: {client.token_endpoint_auth_method}" ) - # A client registered for a secret-based auth method but with no stored secret is - # misconfigured: nothing was actually verified above, so it must not pass authentication. + # Secret-based auth method with no stored secret: nothing was verified above, so reject. if client.token_endpoint_auth_method != "none" and not client.client_secret: raise AuthenticationError("Client is registered for secret-based authentication but has no stored secret") - # If client from the store expects a secret, validate that the request provides - # that secret if client.client_secret: if not request_client_secret: raise AuthenticationError("Client secret is required") - # hmac.compare_digest requires that both arguments are either bytes or a `str` containing - # only ASCII characters. Since we do not control `request_client_secret`, we encode both - # arguments to bytes. + # Encode to bytes: compare_digest requires bytes or ASCII-only str, and we don't control the input if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()): raise AuthenticationError("Invalid client_secret") diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index eeb371f1c2..40f7881f8e 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -19,9 +19,8 @@ class AuthorizationParams(BaseModel): class IdentityAssertionParams(BaseModel): """Validated parameters of a SEP-990 identity-assertion (RFC 7523 jwt-bearer) request. - Passed to ``OAuthAuthorizationServerProvider.exchange_identity_assertion``. ``assertion`` is the - ID-JAG (a signed JWT) the enterprise identity provider issued; the provider validates it per - RFC 7523 Β§3 and the SEP-990 Β§5.1 processing rules before issuing an access token. + Passed to `OAuthAuthorizationServerProvider.exchange_identity_assertion`; `assertion` is the + ID-JAG (a signed JWT) issued by the enterprise identity provider. """ assertion: str # RFC 7523 Β§2.1: the JWT (ID-JAG) presented as the authorization grant @@ -98,7 +97,7 @@ class AuthorizeError(Exception): "unauthorized_client", "unsupported_grant_type", "invalid_scope", - # RFC 8707 Β§2: the requested resource (RFC 8707 indicator) is unknown or unsupported. + # RFC 8707 Β§2: the requested resource indicator is unknown or unsupported. "invalid_target", ] @@ -116,8 +115,7 @@ async def verify_token(self, token: str) -> AccessToken | None: """Verify a bearer token and return access info if valid.""" -# NOTE: MCPServer doesn't render any of these types in the user response, so it's -# OK to add fields to subclasses which should not be exposed externally. +# MCPServer never renders these types in user responses, so subclasses may add fields not meant for external exposure. AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) @@ -125,68 +123,32 @@ async def verify_token(self, token: str) -> AccessToken | None: class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]): async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: - """Retrieves client information by client ID. + """Retrieve client information by client ID, or None if the client does not exist. Implementors MAY raise NotImplementedError if dynamic client registration is disabled in ClientRegistrationOptions. - - Args: - client_id: The ID of the client to retrieve. - - Returns: - The client information, or None if the client does not exist. """ async def register_client(self, client_info: OAuthClientInformationFull) -> None: - """Saves client information as part of registering it. + """Save client information as part of registering it. Implementors MAY raise NotImplementedError if dynamic client registration is disabled in ClientRegistrationOptions. - Args: - client_info: The client metadata to register. - Raises: RegistrationError: If the client metadata is invalid. """ async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """Handle the /authorize endpoint and return a URL that the client - will be redirected to. - - Many MCP implementations will redirect to a third-party provider to perform - a second OAuth exchange with that provider. In this sort of setup, the client - has an OAuth connection with the MCP server, and the MCP server has an OAuth - connection with the 3rd-party provider. At the end of this flow, the client - should be redirected to the redirect_uri from params.redirect_uri. - - +--------+ +------------+ +-------------------+ - | | | | | | - | Client | --> | MCP Server | --> | 3rd Party OAuth | - | | | | | Server | - +--------+ +------------+ +-------------------+ - | ^ | - +------------+ | | | - | | | | Redirect | - |redirect_uri|<-----+ +------------------+ - | | - +------------+ - - Implementations will need to define another handler on the MCP server's return - flow to perform the second redirect, and generate and store an authorization - code as part of completing the OAuth authorization step. - - Implementations SHOULD generate an authorization code with at least 160 bits of - entropy, - and MUST generate an authorization code with at least 128 bits of entropy. - See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. - - Args: - client: The client requesting authorization. - params: The parameters of the authorization request. + """Handle the /authorize endpoint and return the URL to redirect the client to. - Returns: - A URL to redirect the client to for authorization. + Many MCP implementations redirect here to a third-party OAuth provider for a second + exchange (client <-> MCP server <-> third-party server). Such setups need another + handler on the return flow that generates and stores an authorization code and + finally redirects the client to `params.redirect_uri`. + + Authorization codes MUST have at least 128 bits of entropy and SHOULD have at least + 160 (https://datatracker.ietf.org/doc/html/rfc6749#section-10.10). Raises: AuthorizeError: If the authorization request is invalid. @@ -196,28 +158,13 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str ) -> AuthorizationCodeT | None: - """Loads an AuthorizationCode by its code. - - Args: - client: The client that requested the authorization code. - authorization_code: The authorization code to get the challenge for. - - Returns: - The AuthorizationCode, or None if not found. - """ + """Load an AuthorizationCode by its code string, or None if not found.""" ... async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT ) -> OAuthToken: - """Exchanges an authorization code for an access token and refresh token. - - Args: - client: The client exchanging the authorization code. - authorization_code: The authorization code to exchange. - - Returns: - The OAuth token, containing access and refresh tokens. + """Exchange an authorization code for an access token and refresh token. Raises: TokenError: If the request is invalid. @@ -225,15 +172,7 @@ async def exchange_authorization_code( ... async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None: - """Loads a RefreshToken by its token string. - - Args: - client: The client that is requesting to load the refresh token. - refresh_token: The refresh token string to load. - - Returns: - The RefreshToken object if found, or None if not found. - """ + """Load a RefreshToken by its token string, or None if not found.""" ... async def exchange_refresh_token( @@ -242,47 +181,26 @@ async def exchange_refresh_token( refresh_token: RefreshTokenT, scopes: list[str], ) -> OAuthToken: - """Exchanges a refresh token for an access token and refresh token. + """Exchange a refresh token for an access token and refresh token. Implementations SHOULD rotate both the access token and refresh token. - Args: - client: The client exchanging the refresh token. - refresh_token: The refresh token to exchange. - scopes: Optional scopes to request with the new access token. - - Returns: - The OAuth token, containing access and refresh tokens. - Raises: TokenError: If the request is invalid. """ ... async def load_access_token(self, token: str) -> AccessTokenT | None: - """Loads an access token by its token string. - - Args: - token: The access token to verify. - - Returns: - The access token, or None if the token is invalid. - """ + """Load an access token by its token string, or None if the token is invalid.""" async def revoke_token( self, token: AccessTokenT | RefreshTokenT, ) -> None: - """Revokes an access or refresh token. + """Revoke an access or refresh token; do nothing if it is invalid or already revoked. - If the given token is invalid or already revoked, this method should do nothing. - - Implementations SHOULD revoke both the access token and its corresponding - refresh token, regardless of which of the access token or refresh token is - provided. - - Args: - token: The token to revoke. + Implementations SHOULD revoke both the access token and its corresponding refresh + token, regardless of which one is provided. """ async def exchange_identity_assertion( @@ -290,42 +208,38 @@ async def exchange_identity_assertion( client: OAuthClientInformationFull, params: IdentityAssertionParams, ) -> OAuthToken: - """Exchanges an Identity Assertion Authorization Grant (ID-JAG) for an access token. - - This is leg 2 of SEP-990: the client presents an ID-JAG - issued by the enterprise - identity provider - using the RFC 7523 ``urn:ietf:params:oauth:grant-type:jwt-bearer`` - grant, and receives an access token for this MCP server. The default implementation - rejects every request as an unsupported grant type; override it to enable the grant. - - The implementation is responsible for validating ``params.assertion`` per RFC 7523 Β§3 - and the SEP-990 Β§5.1 processing rules, in particular: - - - verify the JWT signature, ``iss``, and ``exp``, and that ``typ`` is ``oauth-id-jag+jwt``; - - require ``aud`` to identify this authorization server (its own issuer); - - require a ``sub`` (RFC 7523 Β§3 makes it mandatory) identifying the end user; - - reject replays - enforce ``exp``, and track ``jti`` for the assertion's lifetime; - - require the ID-JAG's ``client_id`` claim to match the authenticated ``client`` - do - NOT derive authorization from ``client.client_id`` alone, which for a confidential - client is authenticated but for any client is ultimately self-asserted in the request; - - audience-restrict the issued access token to the resource named in the ID-JAG's - ``resource`` claim, not merely ``params.resource`` (which the client controls); + """Exchange an Identity Assertion Authorization Grant (ID-JAG) for an access token. + + Leg 2 of SEP-990: the client presents an IdP-issued ID-JAG via the RFC 7523 + `urn:ietf:params:oauth:grant-type:jwt-bearer` grant and receives an access token for + this MCP server. The default implementation rejects every request as an unsupported + grant type; override it to enable the grant. + + The implementation must validate `params.assertion` per RFC 7523 Β§3 and the + SEP-990 Β§5.1 processing rules, in particular: + + - verify the JWT signature, `iss`, and `exp`, and that `typ` is `oauth-id-jag+jwt`; + - require `aud` to identify this authorization server (its own issuer); + - require a `sub` (mandatory per RFC 7523 Β§3) identifying the end user; + - reject replays: enforce `exp` and track `jti` for the assertion's lifetime; + - require the ID-JAG's `client_id` claim to match the authenticated `client`; do NOT + derive authorization from `client.client_id` alone, which is ultimately + self-asserted in the request even for an authenticated confidential client; + - audience-restrict the issued access token to the ID-JAG's `resource` claim, not + merely `params.resource` (which the client controls); - derive the granted scopes from the ID-JAG and policy rather than granting - ``params.scopes`` verbatim. + `params.scopes` verbatim. - The handler guarantees ``client`` is confidential (it rejects clients without a stored + The handler guarantees `client` is confidential (it rejects clients without a stored secret before calling this hook), but the ID-JAG remains the authoritative grant. - Args: - client: The authenticated client presenting the assertion. - params: The validated jwt-bearer request parameters (the ID-JAG and indicators). - Returns: - The OAuth token, containing the issued access token. A refresh token SHOULD NOT be - issued: SEP-990 relies on the IdP to control session lifetime via re-issued ID-JAGs. + The OAuth token. A refresh token SHOULD NOT be issued: SEP-990 relies on the IdP + to control session lifetime via re-issued ID-JAGs. Raises: - TokenError: If the assertion or request is invalid. Use ``invalid_grant`` for a - rejected assertion and ``invalid_target`` for an unknown ``resource``. + TokenError: If the assertion or request is invalid. Use `invalid_grant` for a + rejected assertion and `invalid_target` for an unknown `resource`. """ raise TokenError( error="unsupported_grant_type", @@ -345,11 +259,11 @@ def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: class ProviderTokenVerifier(TokenVerifier): - """Token verifier that uses an OAuthAuthorizationServerProvider. + """Token verifier backed by an OAuthAuthorizationServerProvider. - This is provided for backwards compatibility with existing auth_server_provider - configurations. For new implementations using AS/RS separation, consider using - the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier. + Provided for backwards compatibility with existing auth_server_provider configurations; + new AS/RS-separated implementations should prefer a dedicated TokenVerifier such as + IntrospectionTokenVerifier. """ def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"): diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index fa88dddcf4..039ea4e840 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -24,9 +24,6 @@ def validate_issuer_url(url: AnyHttpUrl): """Validate that the issuer URL meets OAuth 2.0 requirements. - Args: - url: The issuer URL to validate. - Raises: ValueError: If the issuer URL is invalid. """ @@ -35,7 +32,6 @@ def validate_issuer_url(url: AnyHttpUrl): if url.scheme != "https" and url.host not in ("localhost", "127.0.0.1", "[::1]"): raise ValueError("Issuer URL must be HTTPS") - # No fragments or query parameters allowed if url.fragment: raise ValueError("Issuer URL must not have a fragment") if url.query: @@ -85,10 +81,7 @@ def create_auth_routes( ) client_authenticator = ClientAuthenticator(provider) - # Create routes - # Allow CORS requests for endpoints meant to be hit by the OAuth client - # (with the client secret). This is intended to support things like MCP Inspector, - # where the client runs in a web browser. + # Allow CORS on endpoints the OAuth client calls directly, so browser-based clients (e.g. MCP Inspector) work. routes = [ Route( "/.well-known/oauth-authorization-server", @@ -100,8 +93,7 @@ def create_auth_routes( ), Route( AUTHORIZATION_PATH, - # do not allow CORS for authorization endpoint; - # clients should just redirect to this + # No CORS: clients redirect users to the authorization endpoint rather than calling it endpoint=AuthorizationHandler(provider).handle, methods=["GET", "POST"], ), @@ -167,7 +159,6 @@ def build_metadata( grant_types_supported.append(JWT_BEARER_GRANT_TYPE) authorization_grant_profiles_supported = [ID_JAG_GRANT_PROFILE] - # Create metadata metadata = OAuthMetadata( issuer=issuer_url, authorization_endpoint=authorization_url, @@ -187,11 +178,9 @@ def build_metadata( authorization_grant_profiles_supported=authorization_grant_profiles_supported, ) - # Add registration endpoint if supported if client_registration_options.enabled: # pragma: no branch metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH) - # Add revocation endpoint if supported if revocation_options.enabled: # pragma: no branch metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH) metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"] @@ -200,19 +189,13 @@ def build_metadata( def build_resource_metadata_url(resource_server_url: AnyHttpUrl) -> AnyHttpUrl: - """Build RFC 9728 compliant protected resource metadata URL. - - Inserts /.well-known/oauth-protected-resource between host and resource path - as specified in RFC 9728 Β§3.1. - - Args: - resource_server_url: The resource server URL (e.g., https://example.com/mcp) + """Build the RFC 9728 Β§3.1 protected resource metadata URL. - Returns: - The metadata URL (e.g., https://example.com/.well-known/oauth-protected-resource/mcp) + Inserts /.well-known/oauth-protected-resource between host and resource path: + https://example.com/mcp -> https://example.com/.well-known/oauth-protected-resource/mcp """ parsed = urlparse(str(resource_server_url)) - # Handle trailing slash: if path is just "/", treat as empty + # A bare "/" path would otherwise leave a trailing slash on the metadata URL resource_path = parsed.path if parsed.path != "/" else "" return AnyHttpUrl(f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-protected-resource{resource_path}") @@ -224,18 +207,7 @@ def create_protected_resource_routes( resource_name: str | None = None, resource_documentation: AnyHttpUrl | None = None, ) -> list[Route]: - """Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728). - - Args: - resource_url: The URL of this resource server - authorization_servers: List of authorization servers that can issue tokens - scopes_supported: Optional list of scopes supported by this resource - resource_name: Optional human-readable name for this resource - resource_documentation: Optional URL to documentation for this resource - - Returns: - List of Starlette routes for protected resource metadata - """ + """Create routes serving OAuth 2.0 Protected Resource Metadata (RFC 9728).""" metadata = ProtectedResourceMetadata( resource=resource_url, authorization_servers=authorization_servers, @@ -247,9 +219,7 @@ def create_protected_resource_routes( handler = ProtectedResourceMetadataHandler(metadata) - # RFC 9728 Β§3.1: Register route at /.well-known/oauth-protected-resource + resource path metadata_url = build_resource_metadata_url(resource_url) - # Extract just the path part for route registration parsed = urlparse(str(metadata_url)) well_known_path = parsed.path diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py index ae2083a38b..704d505e90 100644 --- a/src/mcp/server/auth/settings.py +++ b/src/mcp/server/auth/settings.py @@ -13,10 +13,9 @@ class RevocationOptions(BaseModel): class AuthSettings(BaseModel): - # Preserve empty URL paths so a path-less issuer/resource passed as a string keeps its - # canonical form (no trailing slash). RFC 8414/9207 issuer comparison is exact string - # comparison, so a spurious trailing slash would break it. See PR #2925 for the metadata - # models; this applies the same to the server's own configured URLs. + # Preserve empty URL paths: RFC 8414/9207 issuer comparison is exact-string, so a spurious trailing + # slash on a path-less issuer/resource passed as a string would break it. Same as the metadata + # models (PR #2925). model_config = ConfigDict(url_preserve_empty_path=True) issuer_url: AnyHttpUrl = Field( @@ -34,7 +33,6 @@ class AuthSettings(BaseModel): "IdP flows. The provider must implement `exchange_identity_assertion`.", ) - # Resource Server settings (when operating as RS only) resource_server_url: AnyHttpUrl | None = Field( ..., description="The URL of the MCP server to be used as the resource identifier " diff --git a/src/mcp/server/caching.py b/src/mcp/server/caching.py index a8a2a470c6..72912050ac 100644 --- a/src/mcp/server/caching.py +++ b/src/mcp/server/caching.py @@ -1,10 +1,7 @@ """Server-side caching hints (SEP-2549, protocol revision 2026-07-28). -Results for the cacheable methods carry `ttlMs`/`cacheScope` freshness hints. -A handler sets them by returning a result with explicit `ttl_ms`/`cache_scope` -values; `Server(cache_hints={method: CacheHint(...)})` fills them for handlers -that don't. Fields the handler set win, per field, so a server-wide hint never -overrides a handler's explicit choice. +`Server(cache_hints={method: CacheHint(...)})` fills `ttlMs`/`cacheScope` on +cacheable results per field, never overriding a value the handler set explicitly. """ from __future__ import annotations @@ -25,9 +22,7 @@ "server/discover", "tools/list", ] -"""The methods whose results carry `ttlMs`/`cacheScope`. Closed set: the spec -defines caching hints on exactly these six (tests pin it to which result models -mix in `CacheableResult`).""" +"""Methods whose results carry `ttlMs`/`cacheScope`; a closed set, fixed by the spec.""" CACHEABLE_METHODS: Final[frozenset[str]] = frozenset(get_args(CacheableMethod)) """Runtime mirror of `CacheableMethod`, for callers the type checker can't see.""" @@ -37,10 +32,9 @@ class CacheHint: """Freshness hint for one cacheable method's results. - `ttl_ms` is how long, in milliseconds, a client may consider the result - fresh (`0` means immediately stale). `scope` is whether a cached result may - be shared across authorization contexts (`"public"`) or only reused within - the one that produced it (`"private"`). + `ttl_ms` is how long (in ms) a client may treat the result as fresh, `0` meaning + immediately stale; `scope` is whether a cached result may be shared across + authorization contexts (`"public"`) or only the one that produced it (`"private"`). """ ttl_ms: int = 0 @@ -57,12 +51,10 @@ def __post_init__(self) -> None: def apply_cache_hint(result: CacheableResultT, hint: CacheHint) -> CacheableResultT: - """Fill `ttl_ms`/`cache_scope` on `result` from `hint`. + """Fill unset `ttl_ms`/`cache_scope` fields on `result` from `hint`. - Per-field: a field the handler set explicitly - even to its default value, - tracked via `model_fields_set` - is left alone; only unset fields take the - hint. A handler constructing results with `model_construct` bypasses that - tracking and is treated as having set nothing. + Explicitly set fields win even at their defaults (per `model_fields_set`); + `model_construct` bypasses that tracking and counts as having set nothing. """ update: dict[str, int | str] = {} if "ttl_ms" not in result.model_fields_set: @@ -75,11 +67,9 @@ def apply_cache_hint(result: CacheableResultT, hint: CacheHint) -> CacheableResu def validate_cache_hints(cache_hints: Mapping[Any, Any] | None) -> dict[str, CacheHint]: """Validate a `cache_hints` constructor argument into a plain dict. - The `Server`/`MCPServer` signatures already close the key set and value - type for type-checked callers; this runtime gate is deliberately loose in - its parameter so it covers everyone else (e.g. a map deserialized from - config) - a bad entry fails at construction, not on the first request to - that method. + Deliberately loose parameter type: covers callers the `Server`/`MCPServer` + signatures can't (e.g. maps from config), failing at construction rather + than on the first request. Raises: ValueError: If a key is not a cacheable method. diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 4d9496fef1..250e508183 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -1,20 +1,7 @@ """`Connection` - per-client connection state and the standalone outbound channel. -Always present on `Context` (never `None`), even in stateless deployments. -Holds peer info, per-connection scratch `state` and an `exit_stack` for -teardown, and an `Outbound` for the standalone stream (the SSE GET stream in -streamable HTTP, or the single duplex stream in stdio). - -Construct via the factories: `Connection.from_envelope` for the 2026-era -single-exchange path (born ready, no back-channel) and `Connection.for_loop` -for the handshake-driven loop path. Both populate `protocol_version` so the -kernel reads it as a fact. - -`notify` is best-effort: it never raises. If there's no standalone channel -or the stream has been dropped, the notification is debug-logged and silently -discarded - server-initiated notifications are inherently advisory. -`send_raw_request` raises `NoBackChannelError` when there's no channel; `ping` -is the only spec-sanctioned standalone request. +Always present on `Context` (never `None`), even in stateless deployments. The +standalone stream is the SSE GET in streamable HTTP or the duplex stream in stdio. """ from __future__ import annotations @@ -55,11 +42,7 @@ ResultT = TypeVar("ResultT", bound=BaseModel) -# Result types for the spec's server-to-client request set, used by -# `Connection.send_request` to infer the result type. If the spec's request -# set grows substantially, consider declaring the result mapping on the -# request types themselves (a `__mcp_result__` ClassVar read via a structural -# protocol) so this table and the overload ladder don't need maintaining. +# Spec server-to-client requests -> result types; lets `Connection.send_request` infer the result type. _RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = { CreateMessageRequest: CreateMessageResult, ElicitRequest: ElicitResult, @@ -77,12 +60,11 @@ def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> d class _NoChannelOutbound: - """Connection-scoped `Outbound` for the no-back-channel case. + """No-back-channel `Outbound`: requests raise, notifications drop. - The structural answer to "this connection cannot push to its peer": - `send_raw_request` raises `NoBackChannelError`; `notify` drops with a - debug log. `Connection.from_envelope` installs this so the modern - single-exchange path never needs a mode flag - the channel itself says no. + `send_raw_request` raises `NoBackChannelError`; `notify` drops with a debug + log. Installed by `Connection.from_envelope` so the single-exchange path + never needs a mode flag - the channel itself says no. """ async def send_raw_request( @@ -103,10 +85,8 @@ async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Call class Connection: """Per-client connection state and standalone-stream `Outbound`. - Construct via `from_envelope` (modern single-exchange: born ready, no - back-channel) or `for_loop` (handshake-driven: ready once the client's - `notifications/initialized` arrives). Either way `protocol_version` is - populated at construction. + Construct via `from_envelope` (born ready, no back-channel) or `for_loop` + (ready once the client's `notifications/initialized` arrives). """ outbound: Outbound @@ -115,26 +95,21 @@ class Connection: session_id: str | None client_params: InitializeRequestParams | None - """The full `initialize` request params, or the equivalent built from the - 2026-era envelope. `None` when no client info was supplied.""" + """The `initialize` request params (or the 2026-era envelope equivalent); `None` when none supplied.""" protocol_version: str - """The protocol version this connection speaks. Populated at construction - by the factory and overwritten by `_handle_initialize` once the handshake - commits on the loop path.""" + """The protocol version this connection speaks; seeded at construction and + overwritten once the loop-path handshake commits.""" initialized: anyio.Event - """Set when `notifications/initialized` arrives (matches TS `oninitialized`); - the point from which the spec permits server-initiated requests beyond - ping/logging. Pre-set on connections built via `from_envelope`.""" + """Set when `notifications/initialized` arrives (pre-set by `from_envelope`); + from here the spec permits server-initiated requests beyond ping/logging.""" state: dict[str, Any] """Per-connection scratch state; persists across requests on this connection.""" exit_stack: AsyncExitStack - """Per-connection teardown, unwound LIFO (shielded) when the connection - closes. Push cleanup from handlers or middleware; exceptions are logged - and swallowed.""" + """Per-connection teardown, unwound LIFO (shielded) on close; cleanup exceptions are logged and swallowed.""" def __init__( self, @@ -163,11 +138,8 @@ def from_envelope( ) -> Connection: """A born-ready connection populated from a request's `_meta` envelope. - `initialized` is set and the envelope's client info/capabilities (when - both supplied) are recorded as `client_params` so capability checks - work. `outbound` defaults to the no-channel sentinel for the - single-exchange HTTP path; duplex modern transports (e.g. stdio) pass - the dispatcher so server-initiated messages have a back-channel. + `outbound` defaults to the no-channel sentinel for the single-exchange + HTTP path; duplex transports (e.g. stdio) pass the dispatcher. """ client_params = None if client_info is not None and client_capabilities is not None: @@ -190,10 +162,8 @@ def for_loop( ) -> Connection: """A connection for the handshake-driven loop path. - Not born-ready: `initialized` is set later by the kernel when - `notifications/initialized` arrives. `protocol_version` is seeded from - the transport hint (or `LATEST_HANDSHAKE_VERSION`) so it's never `None`; - the handshake overwrites it once negotiated. + Not born-ready: the kernel sets `initialized` when `notifications/initialized` + arrives; the handshake overwrites the seeded `protocol_version` once negotiated. """ return cls( outbound, @@ -203,16 +173,13 @@ def for_loop( @property def has_standalone_channel(self) -> bool: - """Whether this connection has a real back-channel for server-initiated - messages. Derived from `outbound` - the no-channel sentinel is the only - case that doesn't.""" + """Whether this connection has a real back-channel for server-initiated messages.""" return self.outbound is not _NO_CHANNEL @property def initialize_accepted(self) -> bool: """True once the inbound request gate is open: `initialize` recorded the - peer info, or the handshake completed outright (born-ready, or a bare - `notifications/initialized`). Derived, never stored.""" + peer info, or the handshake completed outright.""" return self.client_params is not None or self.initialized.is_set() async def send_raw_request( @@ -223,10 +190,8 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a raw request on the standalone stream. - Low-level `Outbound` channel. Prefer the typed `send_request` or the - convenience methods below; use this directly only for off-spec - messages. `opts` carries per-call `timeout` / `on_progress` / - resumption hints; see `CallOptions`. + Prefer the typed `send_request` or the convenience methods below; use + this directly only for off-spec messages. Raises: MCPError: The peer responded with an error. @@ -278,8 +243,8 @@ async def send_request( async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: """Send a best-effort notification on the standalone stream. - Never raises. If there's no standalone channel or the stream is broken, - the notification is dropped and debug-logged. + Never raises (server-initiated notifications are advisory): with no + channel or a broken stream the notification is dropped and debug-logged. """ try: await self.outbound.notify(method, params, opts) @@ -287,7 +252,7 @@ async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Call logger.debug("dropped %s: standalone stream closed", method) async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: - """Send a `ping` request on the standalone stream. + """Send a `ping` request - the only spec-sanctioned standalone request. Raises: MCPError: The peer responded with an error. @@ -316,12 +281,8 @@ async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta)) def check_capability(self, capability: ClientCapabilities) -> bool: - """Return whether the connected client declared the given capability. - - Returns `False` when no client info has been recorded. - """ - # TODO(L53): redesign - mirrors v1 ServerSession.check_client_capability - # verbatim for parity. + """Return whether the connected client declared the given capability; `False` when no client info recorded.""" + # TODO(L53): redesign - mirrors v1 ServerSession.check_client_capability verbatim for parity. if self.client_params is None: return False have = self.client_params.capabilities @@ -346,10 +307,8 @@ def check_capability(self, capability: ClientCapabilities) -> bool: if k not in have.experimental or have.experimental[k] != v: return False if capability.extensions is not None: - # SEP-2133: an extension is supported when the client declares its - # identifier. Settings are negotiated per-extension (the client may - # advertise more than the server asks for), so presence - not value - # equality - is the meaningful check. + # SEP-2133: support means the client declares the identifier; settings are + # negotiated per-extension, so presence - not value equality - is the check. if have.extensions is None: return False for identifier in capability.extensions: diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index b5c356075a..135a6483a5 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -24,10 +24,8 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]): """Per-request context handed to lowlevel request and notification handlers. - Built by `ServerRunner._make_context` for each inbound message. Carries the - connection-scoped `ServerSession` (server-to-client requests and - notifications), per-request metadata, and any per-message data the - transport attached (the HTTP request, SSE stream-close callbacks). + Built by `ServerRunner._make_context`; carries the connection-scoped `ServerSession`, + per-request metadata, and per-message transport data (the HTTP request, SSE stream-close callbacks). """ session: ServerSession @@ -49,11 +47,8 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]): class Context(BaseContext[TransportContext], Generic[LifespanT_co]): """Server-side per-request context. - Extends `BaseContext` (transport metadata, the raw back-channel, progress - reporting) with `lifespan`, `connection`, and request-scoped `log`. - - Not currently constructed by `ServerRunner`, which hands handlers a - `ServerRequestContext` instead. + Extends `BaseContext` with `lifespan`, `connection`, and request-scoped `log`. Not currently + constructed by `ServerRunner`, which hands handlers a `ServerRequestContext` instead. """ def __init__( @@ -75,33 +70,24 @@ def lifespan(self) -> LifespanT_co: @property def connection(self) -> Connection: - """The per-client `Connection` for this request's connection.""" + """The per-client `Connection` this request belongs to.""" return self._connection @property def session_id(self) -> str | None: - """The transport's session id for this connection, when one exists. - - Convenience for `ctx.connection.session_id`. `None` on stdio and - stateless HTTP. - """ + """Convenience for `ctx.connection.session_id`; `None` on stdio and stateless HTTP.""" return self._connection.session_id @property def headers(self) -> Mapping[str, str] | None: - """Request headers carried by this message, when the transport has them. - - Convenience for `ctx.transport.headers`. `None` on stdio. - """ + """Convenience for `ctx.transport.headers`; `None` on stdio.""" return self.transport.headers @deprecated("The logging capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: - """Send a request-scoped `notifications/message` log entry. + """Send a `notifications/message` log entry on this request's back-channel. - Uses this request's back-channel (so the entry rides the request's SSE - stream in streamable HTTP), not the standalone stream - use - `ctx.connection.log(...)` for that. + Rides the request's SSE stream in streamable HTTP; `ctx.connection.log(...)` uses the standalone stream. """ params: dict[str, Any] = {"level": level, "data": data} if logger is not None: @@ -112,12 +98,10 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * HandlerResult = BaseModel | dict[str, Any] | None -"""What a request handler (or middleware) may return. `ServerRunner` serializes -all three to a result dict.""" +"""What a request handler (or middleware) may return; `ServerRunner` serializes all three to a result dict.""" CallNext = Callable[["ServerRequestContext[Any, Any]"], Awaitable[HandlerResult]] -"""Invokes the rest of the chain. Pass the `ctx` through; rewrite `method` or -`params` with `dataclasses.replace(ctx, ...)` to alter what the handler sees.""" +"""Invokes the rest of the chain; rewrite `method`/`params` via `dataclasses.replace(ctx, ...)` first.""" _MwLifespanT = TypeVar("_MwLifespanT") @@ -125,46 +109,33 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * class ServerMiddleware(Protocol[_MwLifespanT]): """Context-tier middleware: `(ctx, call_next) -> result`. - Runs at the top of `ServerRunner._on_request` / `_on_notify` after `ctx` - is built but before any validation, lookup, or handshake. Wraps every - inbound request and notification: `initialize`, the pre-init gate, - `METHOD_NOT_FOUND`, params validation, the handler call, and - `notifications/initialized` all run inside `call_next(ctx)`. - `notifications/cancelled` is observed too; the dispatcher applies the - cancellation itself, then forwards the notification. A request-side - failure reaches the middleware as a raised `MCPError` (or - `ValidationError` for malformed params) so observation/logging middleware - can record it. Listed outermost-first on `Server.middleware`. - - The method and the raw inbound params are `ctx.method` and `ctx.params` (no - model validation has happened yet). To rewrite either before the handler - runs, pass an adjusted context: `await call_next(replace(ctx, params=...))`. - `ctx.request_id is None` distinguishes a notification from a request. For - notifications `call_next(ctx)` returns `None` (a dropped or unhandled - notification also returns `None`) and the middleware's own return value is - discarded. + Wraps every inbound request and notification before any validation, lookup, or handshake: + `initialize`, the pre-init gate, `METHOD_NOT_FOUND`, params validation, the handler call, and + `notifications/initialized` all run inside `call_next(ctx)`. `notifications/cancelled` is + observed too; the dispatcher applies the cancellation itself, then forwards it. A request-side + failure reaches the middleware as a raised `MCPError` (or `ValidationError` for malformed + params). Listed outermost-first on `Server.middleware`. + + `ctx.method` and `ctx.params` are the raw inbound values (no model validation yet); to rewrite + either, pass an adjusted context: `await call_next(replace(ctx, params=...))`. + `ctx.request_id is None` distinguishes a notification, for which `call_next(ctx)` returns + `None` and the middleware's own return value is discarded. !!! warning - `initialize` is handled inline - the dispatcher does not read - further inbound messages until the middleware chain returns. Awaiting a - server-to-client request (`ctx.session.send_request`, `send_ping`, ...) - while handling `initialize` therefore deadlocks the connection: the - response can never be dequeued. Send-and-forget notifications are safe. - `initialize` is observed but not rewritable: the post-chain handshake - commit reads the wire params, so to veto the handshake raise *before* + `initialize` is handled inline - the dispatcher reads no further inbound messages until + the chain returns, so awaiting a server-to-client request (`ctx.session.send_request`, + `send_ping`, ...) while handling `initialize` deadlocks the connection. Send-and-forget + notifications are safe. `initialize` is observed but not rewritable: the post-chain + handshake commit reads the wire params, so to veto the handshake raise *before* `call_next()`. - `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific - middleware sees `ctx.lifespan_context: L`. While the context is the - mutable `ServerRequestContext` dataclass it is invariant in `L`, so a - reusable middleware should be typed `ServerMiddleware[Any]` to register on - any `Server[L]`. + `Server[L].middleware` holds `ServerMiddleware[L]`; `ServerRequestContext` is invariant in `L`, + so reusable middleware should be typed `ServerMiddleware[Any]` to register on any `Server[L]`. """ - # TODO(maxisbey): once `_make_context` returns the (covariant) `Context[L]` - # again, restore `_MwLifespanT` to `contravariant=True` and retype `ctx` - # below to `Context[_MwLifespanT]` so reusable middleware can be - # `ServerMiddleware[object]` instead of `ServerMiddleware[Any]`. + # TODO(maxisbey): once `_make_context` returns the covariant `Context[L]` again, restore + # `contravariant=True` on `_MwLifespanT` and retype `ctx` below to `Context[_MwLifespanT]` so + # reusable middleware can be `ServerMiddleware[object]` instead of `ServerMiddleware[Any]`. async def __call__( self, diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 5a4acdd6c3..1e7c0c39da 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -56,9 +56,8 @@ class AcceptedUrlElicitation(BaseModel): class _ElicitationJsonSchema(GenerateJsonSchema): """JSON-Schema generator that flattens `T | None` to `T` and drops `None` defaults. - The spec's `PrimitiveSchemaDefinition` admits no `anyOf` or null type; an - optional field is expressed by leaving it out of `required`, which pydantic - already does for any field with a default. + The spec's `PrimitiveSchemaDefinition` admits no `anyOf` or null type; optionality is + expressed by omission from `required`, which pydantic already does for defaulted fields. """ def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue: @@ -74,8 +73,7 @@ def default_schema(self, schema: core_schema.WithDefaultSchema) -> JsonSchemaVal def _validate_rendered_properties(json_schema: dict[str, Any]) -> None: """Reject any `properties` entry the spec's `PrimitiveSchemaDefinition` won't accept. - Catches whatever the renderer let through that isn't spec-valid: bare - `list[str]` (no enum), multi-primitive unions, nested models. + Catches non-spec-valid renderings: bare `list[str]` (no enum), multi-primitive unions, nested models. """ for field_name, prop in json_schema.get("properties", {}).items(): try: @@ -91,8 +89,7 @@ def render_elicitation_schema(schema: type[BaseModel]) -> dict[str, Any]: """Render a model as the spec-valid `requested_schema` for an elicitation. Raises: - TypeError: If a field renders as something the spec's - `PrimitiveSchemaDefinition` does not accept. + TypeError: If a field renders as something the spec's `PrimitiveSchemaDefinition` does not accept. """ json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema) _validate_rendered_properties(json_schema) @@ -107,17 +104,11 @@ async def elicit_with_validation( ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user with schema validation (form mode). - This method can be used to interactively ask for additional information from the - client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. If the client - is an agent, it might decide how to handle the elicitation -- either by asking - the user or automatically generating a response. - - For sensitive data like credentials or OAuth flows, use elicit_url() instead. + The client may show `message` to the user or, if an agent, generate the response itself. + For sensitive data like credentials or OAuth flows, use `elicit_url` instead. Raises: - ValueError: If the client accepted the elicitation without supplying - content, or with content that does not match the requested schema. + ValueError: If the client accepted with no content, or content not matching the requested schema. """ json_schema = render_elicitation_schema(schema) @@ -151,26 +142,10 @@ async def elicit_url( ) -> UrlElicitationResult: """Elicit information from the user via out-of-band URL navigation (URL mode). - This method directs the user to an external URL where sensitive interactions can - occur without passing data through the MCP client. Use this for: - - Collecting sensitive credentials (API keys, passwords) - - OAuth authorization flows with third-party services - - Payment and subscription flows - - Any interaction where data should not pass through the LLM context - - The response indicates whether the user consented to navigate to the URL. - The actual interaction happens out-of-band. When the elicitation completes, - the server should send an ElicitCompleteNotification to notify the client. - - Args: - session: The server session - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_request_id: Optional ID of the request that triggered this elicitation - - Returns: - UrlElicitationResult indicating accept, decline, or cancel + Directs the user to an external URL where sensitive interactions (credentials, OAuth, + payments) happen without passing data through the MCP client or LLM context. The result + only indicates whether the user consented to navigate; when the out-of-band interaction + completes, the server should send an ElicitCompleteNotification. """ result = await session.elicit_url( message=message, @@ -186,5 +161,4 @@ async def elicit_url( elif result.action == "cancel": return CancelledElicitation() else: # pragma: no cover - # This should never happen, but handle it just in case raise ValueError(f"Unexpected elicitation action: {result.action}") diff --git a/src/mcp/server/extension.py b/src/mcp/server/extension.py index e045e6f29d..d1be063823 100644 --- a/src/mcp/server/extension.py +++ b/src/mcp/server/extension.py @@ -1,20 +1,11 @@ """Pluggable extension interface for MCP servers (SEP-2133). -An extension is a self-contained, opt-in bundle of MCP behaviour, identified by -a reverse-DNS string (e.g. `io.modelcontextprotocol/ui`). It is passed to -`MCPServer(extensions=[...])`, and the server applies a *closed* set of -contribution kinds: tools, resources, new request methods, and one `tools/call` -interceptor. The server never hands itself to an extension; the extension -declares what it adds, and the server consumes it. - -The shape follows the HTTPX `Transport`/`Auth` pattern: a narrow base class whose -methods have sensible defaults, so an extension overrides only what it needs. A -purely additive extension (Apps) overrides `tools`/`resources`; an interceptive -one overrides `methods`/`intercept_tool_call`. - -This module lives at the `mcp.server` tier (not `mcp.server.mcpserver`) so the -base class itself never drags in the composition tier that consumes it; -extensions remain importable without constructing an `MCPServer`. +An extension is an opt-in bundle of MCP behaviour, identified by a reverse-DNS +string (e.g. `io.modelcontextprotocol/ui`) and passed to `MCPServer(extensions=[...])`. +The server applies a closed set of contribution kinds β€” tools, resources, new +request methods, one `tools/call` interceptor β€” and never hands itself to the +extension. Lives at the `mcp.server` tier so extensions stay importable without +the composition tier that consumes them. """ from __future__ import annotations @@ -35,20 +26,14 @@ RequestHandler = Callable[[ServerRequestContext[Any, Any], Any], Awaitable[HandlerResult]] -# Extension identifiers follow the `_meta` key grammar with a mandatory prefix -# (SEP-2133 / basic/index.mdx): dot-separated labels, each starting with a -# letter and ending with a letter or digit (hyphens interior), then `/`, then a -# name that starts and ends alphanumeric (`.`/`_`/`-` interior). +# Extension identifiers follow the `_meta` key grammar with a mandatory vendor prefix (SEP-2133 / basic/index.mdx). _LABEL = r"[A-Za-z](?:[A-Za-z0-9-]*[A-Za-z0-9])?" _NAME = r"[A-Za-z0-9](?:[A-Za-z0-9._-]*[A-Za-z0-9])?" _IDENTIFIER_RE = re.compile(rf"{_LABEL}(?:\.{_LABEL})*/{_NAME}") def validate_extension_identifier(identifier: Any, *, owner: str) -> None: - """Raise `TypeError` unless `identifier` is a `vendor-prefix/name` string. - - SEP-2133 requires extension identifiers to carry a reverse-DNS prefix. - """ + """Raise `TypeError` unless `identifier` is a SEP-2133 `vendor-prefix/name` string.""" if not isinstance(identifier, str) or not _IDENTIFIER_RE.fullmatch(identifier): raise TypeError( f"{owner}.identifier must be a `vendor-prefix/name` string " @@ -76,21 +61,14 @@ class ResourceBinding: class MethodBinding: """A new request method an extension serves, e.g. `tasks/get`. - `params_type` validates incoming params before `handler` runs; it should - subclass `RequestParams` so `_meta` parses uniformly. `protocol_versions`, - when set, restricts the method to those wire versions - a request for the - method at any other version is rejected as `METHOD_NOT_FOUND`, mirroring the - spec's `(method, version)` boundary table. `None` (the default) admits the - method at every version. - - Extension methods are additive: `method` must not name a spec-defined - request method (`tools/list`, `completion/complete`, ...) β€” those handlers - belong to the server, and an extension binding one would silently shadow or - be shadowed by it. Both constraints are enforced at construction. To - re-provide a spec method the 2026 revision removed (e.g. `logging/setLevel` - for legacy clients), use the lowlevel `Server.add_request_handler` API - instead β€” the runner's per-version surface gate would never route such a - method to an extension handler anyway. + `params_type` validates params before `handler` runs; subclass `RequestParams` + so `_meta` parses uniformly. `protocol_versions` restricts the method to those + wire versions (others get `METHOD_NOT_FOUND`); `None` admits every version. + + Binding a spec-defined method raises at construction β€” it would shadow or be + shadowed by the server's handler. To re-provide a spec method the 2026 revision + removed (e.g. `logging/setLevel`), use the lowlevel `Server.add_request_handler` + instead β€” the runner's per-version surface gate would never route it here anyway. """ method: str @@ -114,10 +92,8 @@ def __post_init__(self) -> None: class Extension: """Base class for an opt-in MCP extension. Override only the methods you need. - Subclass and set `identifier`, then override the contribution methods that - apply. Every method has a default, so a minimal extension overrides nothing - but `identifier` and one of `tools`/`resources`/`methods`. `identifier` is - enforced at subclass-definition time. + Subclass and set `identifier` (validated at class-definition time), then + override whichever contribution methods apply β€” every method has a default. """ #: Reverse-DNS extension identifier, advertised under `ServerCapabilities.extensions`. @@ -125,19 +101,14 @@ class Extension: def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) - # Validate a class-level `identifier` at definition time. A subclass may - # instead assign `identifier` in `__init__` (per-instance ids); that case - # is validated when the extension is applied, since no class attribute - # exists to inspect here. + # A subclass may instead assign `identifier` in `__init__` (per-instance + # ids); that case is validated when the extension is applied. identifier = cls.__dict__.get("identifier") if identifier is not None: validate_extension_identifier(identifier, owner=cls.__name__) def settings(self) -> dict[str, Any]: - """Per-extension settings advertised at `capabilities.extensions[identifier]`. - - An empty dict (the default) advertises the extension with no settings. - """ + """Settings advertised at `capabilities.extensions[identifier]`; empty dict (default) means none.""" return {} def tools(self) -> Sequence[ToolBinding]: @@ -160,9 +131,8 @@ async def intercept_tool_call( ) -> HandlerResult: """Wrap `tools/call`. Default: pass through unchanged. - Override to short-circuit (return a result without calling `call_next`) - or to observe the call. `params` is the validated `tools/call` params; - `call_next(ctx)` runs the rest of the chain and the real handler. + Override to observe the call or short-circuit (return without calling + `call_next(ctx)`, which runs the rest of the chain and the real handler). """ return await call_next(ctx) @@ -170,9 +140,8 @@ async def intercept_tool_call( def compose_tool_call_interceptor(extensions: Sequence[Extension]) -> ServerMiddleware[Any]: """Fold every extension's `intercept_tool_call` into one `ServerMiddleware`. - The returned middleware nests the interceptors (first extension outermost) - and is a no-op for any method other than `tools/call`. It validates the - `tools/call` params once and threads them to each interceptor. + Nests the interceptors (first extension outermost), no-ops for methods other + than `tools/call`, and validates the params once for all interceptors. """ async def middleware(ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 97b5557e20..918ce56222 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -1,37 +1,8 @@ -"""MCP Server Module - -This module provides a framework for creating an MCP (Model Context Protocol) server. -It allows you to easily define and handle various types of requests and notifications -using constructor-based handler registration. - -Usage: -1. Define handler functions: - async def my_list_tools(ctx, params): - return types.ListToolsResult(tools=[...]) - - async def my_call_tool(ctx, params): - return types.CallToolResult(content=[...]) - -2. Create a Server instance with on_* handlers: - server = Server( - "your_server_name", - on_list_tools=my_list_tools, - on_call_tool=my_call_tool, - ) - -3. Run the server: - async def main(): - async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) - - asyncio.run(main()) - -The Server class dispatches incoming requests and notifications to registered -handler callables by method string. +"""Low-level MCP server framework. + +The `Server` class dispatches incoming requests and notifications to handler +callables registered by method string (constructor `on_*` kwargs or +`add_request_handler`/`add_notification_handler`). """ from __future__ import annotations @@ -87,12 +58,10 @@ async def main(): class HandlerEntry(Generic[LifespanResultT]): """A registered handler and the params model to validate incoming params against. - Stored in `Server._request_handlers` / `_notification_handlers` and consumed - by `ServerRunner` to validate, build `Context`, and invoke. The handler's - second-argument type is erased to `Any` in storage (each entry has a - different concrete params type and `Callable` parameters are contravariant); - the precise type is recoverable via `params_type`. The correlation is - enforced at registration time by `Server.add_request_handler`. + The handler's second-argument type is erased to `Any` in storage (each entry has + a different concrete params type and `Callable` parameters are contravariant); + `params_type` carries the precise type, correlated at registration time by + `Server.add_request_handler`. """ params_type: type[BaseModel] @@ -108,11 +77,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: - """Default lifespan context manager that does nothing. - - Returns: - An empty context object - """ + """Default no-op lifespan: yields an empty context.""" yield {} @@ -146,7 +111,6 @@ def __init__( [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, - # Request handlers on_list_tools: Callable[ [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult], @@ -229,7 +193,6 @@ def __init__( [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, - # Request handlers on_list_tools: Callable[ [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult], @@ -294,7 +257,6 @@ def __init__( [ServerRequestContext[LifespanResultT], types.RequestParams | None], Awaitable[types.EmptyResult], ] = _ping_handler, - # Notification handlers on_roots_list_changed: Callable[ [ServerRequestContext[LifespanResultT], types.NotificationParams | None], Awaitable[None], @@ -321,7 +283,6 @@ def __init__( [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, - # Request handlers on_list_tools: Callable[ [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult], @@ -386,7 +347,6 @@ def __init__( [ServerRequestContext[LifespanResultT], types.RequestParams | None], Awaitable[types.EmptyResult], ] = _ping_handler, - # Notification handlers on_roots_list_changed: Callable[ [ServerRequestContext[LifespanResultT], types.NotificationParams | None], Awaitable[None], @@ -431,19 +391,14 @@ def __init__( self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} self._session_manager: StreamableHTTPSessionManager | None = None - # Context-tier middleware: wraps every inbound request (including - # `initialize`, lookup, validation, handler) with - # `(ctx, call_next)`. Applied in `ServerRunner._on_request`. - # `OpenTelemetryMiddleware` ships on by default so every server emits a - # SERVER span per message; it is a no-op until an OTel exporter is - # installed. Drop it from this list to opt out. - # TODO(L54): provisional - signature and semantics change with the - # Context/middleware rework (covariant `Context[L]`, outbound seam) before - # v2 final. + # Context-tier middleware: wraps every inbound request (including `initialize`) + # with `(ctx, call_next)`; applied in `ServerRunner._on_request`. OpenTelemetry + # ships on by default (no-op until an exporter is installed); drop it to opt out. + # TODO(L54): provisional - signature and semantics change with the Context/middleware + # rework (covariant `Context[L]`, outbound seam) before v2 final. self.middleware: list[ServerMiddleware[LifespanResultT]] = [OpenTelemetryMiddleware()] - # SEP-2133 extension settings advertised under `ServerCapabilities.extensions` - # (identifier -> settings). Higher layers (e.g. `MCPServer(extensions=...)`) - # populate it; `get_capabilities` reads it when no explicit map is passed. + # SEP-2133 extension settings (identifier -> settings) for `ServerCapabilities.extensions`; + # higher layers populate it, `get_capabilities` reads it when no explicit map is passed. self.extensions: dict[str, dict[str, Any]] = {} logger.debug("Initializing server %r", name) @@ -479,17 +434,14 @@ def add_request_handler( params_type: type[_ParamsT], handler: RequestHandler[LifespanResultT, _ParamsT], ) -> None: - """Register a request handler for `method`. - - `params_type` is the model incoming params are validated against - before the handler is invoked. It should subclass `RequestParams` so - `_meta` parses uniformly. A message with no `params` member validates - `{}` against `params_type`: models with required fields reject it as - INVALID_PARAMS, all-optional models reach the handler with their - defaults - the handler never receives `None`. Replaces any existing - handler for the same method, except `initialize`, which is reserved: - the runner owns the handshake, so registering it raises `ValueError`. - Use `Server.middleware` to observe or wrap initialization. + """Register a request handler for `method`, replacing any existing one. + + `params_type` validates incoming params before the handler is invoked; it + should subclass `RequestParams` so `_meta` parses uniformly. A message with + no `params` member validates `{}`: required fields reject as INVALID_PARAMS, + all-optional models reach the handler with their defaults - never `None`. + `initialize` is reserved (the runner owns the handshake) and raises + `ValueError`; use `Server.middleware` to observe or wrap initialization. """ if method == "initialize": raise ValueError( @@ -504,14 +456,12 @@ def add_notification_handler( params_type: type[_ParamsT], handler: NotificationHandler[LifespanResultT, _ParamsT], ) -> None: - """Register a notification handler for `method`. - - `params_type` should subclass `NotificationParams` so `_meta` - parses uniformly. Absent params follow the same contract as requests: - `{}` is validated, so the handler receives the model with its defaults, - never `None`. Replaces any existing handler. A handler for - `notifications/initialized` runs after the runner has marked the - connection initialized. + """Register a notification handler for `method`, replacing any existing one. + + `params_type` should subclass `NotificationParams` so `_meta` parses + uniformly; absent params validate `{}` as for requests, so the handler never + receives `None`. A `notifications/initialized` handler runs after the runner + has marked the connection initialized. """ self._notification_handlers[method] = HandlerEntry(params_type, handler) @@ -523,11 +473,9 @@ def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] """Return the registered entry for a notification method, or `None`.""" return self._notification_handlers.get(method) - # TODO(L53): Rethink capabilities API. Currently capabilities are derived from registered - # handlers but require NotificationOptions to be passed externally for list_changed - # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities - # entirely from server state (e.g. constructor params for list_changed) instead of - # requiring callers to assemble them at create_initialization_options() time. + # TODO(L53): rethink capabilities API - derive capabilities entirely from server state + # (e.g. constructor params for list_changed) instead of requiring callers to assemble + # NotificationOptions/experimental_capabilities at create_initialization_options() time. def create_initialization_options( self, notification_options: NotificationOptions | None = None, @@ -536,10 +484,8 @@ def create_initialization_options( ) -> InitializationOptions: """Create initialization options from this server instance. - `extensions` advertises SEP-2133 extension support under - `ServerCapabilities.extensions`; keys are extension identifiers (e.g. - `io.modelcontextprotocol/ui`), values are per-extension settings. - Defaults to `self.extensions`, which higher layers populate. + `extensions` advertises SEP-2133 extension support (identifier -> settings) + under `ServerCapabilities.extensions`; defaults to `self.extensions`. """ return InitializationOptions( server_name=self.name, @@ -564,9 +510,8 @@ def get_capabilities( ) -> types.ServerCapabilities: """Convert existing handlers to a ServerCapabilities object. - `extensions` is the SEP-2133 extension map (identifier -> settings) - advertised under `ServerCapabilities.extensions`; it defaults to - `self.extensions`. + `extensions` is the SEP-2133 extension map (identifier -> settings); + defaults to `self.extensions`. """ notification_options = notification_options or NotificationOptions() prompts_capability = None @@ -575,26 +520,21 @@ def get_capabilities( logging_capability = None completions_capability = None - # Set prompt capabilities if handler exists if "prompts/list" in self._request_handlers: prompts_capability = types.PromptsCapability(list_changed=notification_options.prompts_changed) - # Set resource capabilities if handler exists if "resources/list" in self._request_handlers: resources_capability = types.ResourcesCapability( subscribe="resources/subscribe" in self._request_handlers, list_changed=notification_options.resources_changed, ) - # Set tool capabilities if handler exists if "tools/list" in self._request_handlers: tools_capability = types.ToolsCapability(list_changed=notification_options.tools_changed) - # Set logging capabilities if handler exists if "logging/setLevel" in self._request_handlers: logging_capability = types.LoggingCapability() - # Set completions capabilities if handler exists if "completion/complete" in self._request_handlers: completions_capability = types.CompletionsCapability() @@ -613,8 +553,7 @@ def get_capabilities( def server_info(self) -> types.Implementation: """The `serverInfo` block describing this implementation. - Derived from the constructor's identity fields. `version` falls back to - the installed `mcp` package version when not supplied explicitly. + `version` falls back to the installed `mcp` package version when not supplied. """ return types.Implementation( name=self.name, @@ -630,9 +569,7 @@ async def _handle_discover( ) -> types.DiscoverResult: """Default `server/discover` handler. - Auto-derived from server state at call time, so capabilities reflect - whatever has been registered (constructor `on_*` kwargs and later - `add_request_handler` calls). Operators can replace it wholesale via + Capabilities derive from server state at call time; replace wholesale via `add_request_handler("server/discover", ...)`. Reachability for legacy peers is decided at the boundary (`types.methods`), not here. """ @@ -645,10 +582,10 @@ async def _handle_discover( @property def session_manager(self) -> StreamableHTTPSessionManager: - """Get the StreamableHTTP session manager. + """The StreamableHTTP session manager. Raises: - RuntimeError: If called before streamable_http_app() has been called. + RuntimeError: If accessed before `streamable_http_app()` has created it. """ if self._session_manager is None: raise RuntimeError( # pragma: no cover @@ -662,10 +599,8 @@ async def run( read_stream: ReadStream[SessionMessage | Exception], write_stream: WriteStream[SessionMessage], initialization_options: InitializationOptions, - # When False, exceptions are returned as messages to the client. - # When True, exceptions are raised, which will cause the server to shut down - # but also make tracing exceptions much easier during testing and when using - # in-process servers. + # True re-raises handler exceptions (shutting the server down) instead of + # returning error responses - eases tracing in tests and in-process servers. raise_exceptions: bool = False, ) -> None: """Serve a single connection over the given streams until the read side closes. @@ -719,19 +654,15 @@ def streamable_http_app( ) self._session_manager = session_manager - # Create the ASGI handler streamable_http_app = StreamableHTTPASGIApp(session_manager) - # Create routes routes: list[Route | Mount] = [] middleware: list[Middleware] = [] required_scopes: list[str] = [] - # Set up auth if configured if auth: required_scopes = auth.required_scopes or [] - # Add auth middleware if token verifier is available if token_verifier: middleware = [ Middleware( @@ -741,7 +672,6 @@ def streamable_http_app( Middleware(AuthContextMiddleware), ] - # Add auth endpoints if auth server provider is configured if auth_server_provider: routes.extend( create_auth_routes( @@ -754,9 +684,7 @@ def streamable_http_app( ) ) - # Set up routes with or without auth if token_verifier: - # Determine resource metadata URL resource_metadata_url = None if auth and auth.resource_server_url: # pragma: no branch # Build compliant metadata URL for WWW-Authenticate header @@ -769,7 +697,6 @@ def streamable_http_app( ) ) else: - # Auth is disabled, no wrapper needed routes.append( Route( streamable_http_path, @@ -777,7 +704,6 @@ def streamable_http_app( ) ) - # Add protected resource metadata endpoint if configured as RS if auth and auth.resource_server_url: routes.extend( create_protected_resource_routes( diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 82a6fa2b6e..498191d54b 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -23,37 +23,18 @@ class Context(BaseModel, Generic[LifespanContextT, RequestT]): - """Context object providing access to MCP capabilities. + """Request-scoped access to MCP capabilities. - This provides a cleaner interface to MCP's RequestContext functionality. - It gets injected into tool and resource functions that request it via type hints. - - To use context in a tool function, add a parameter with the Context type annotation: + Injected into tool and resource functions that request it via a `Context`-annotated + parameter (any name; optional - functions that don't need it can omit it): ```python @server.tool() async def my_tool(x: int, ctx: Context) -> str: - # Log messages to the client - await ctx.info(f"Processing {x}") - await ctx.debug("Debug info") - await ctx.warning("Warning message") - await ctx.error("Error message") - - # Report progress await ctx.report_progress(50, 100) - - # Access resources data = await ctx.read_resource("resource://data") - - # Get request info - request_id = ctx.request_id - client_id = ctx.client_id - return str(x) ``` - - The context parameter name can be anything as long as it's annotated with Context. - The context is optional - tools that don't need it can omit the parameter. """ _request_context: ServerRequestContext[LifespanContextT, RequestT] | None @@ -90,24 +71,12 @@ def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: return self._request_context async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - """Report progress for the current operation. - - Args: - progress: Current progress value (e.g., 24) - total: Optional total value (e.g., 100) - message: Optional message (e.g., "Starting render...") - """ + """Report progress for the current operation.""" await self.request_context.session.report_progress(progress, total, message) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: """Read a resource by URI. - Args: - uri: Resource URI to read - - Returns: - The resource content as either text or bytes - Raises: ResourceNotFoundError: If no resource or template matches the URI. ResourceError: If template creation or resource reading fails. @@ -120,25 +89,11 @@ async def elicit( message: str, schema: type[ElicitSchemaModelT], ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user. - - This method can be used to interactively ask for additional information from the - client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. If the client - is an agent, it might decide how to handle the elicitation -- either by asking - the user or automatically generating a response. + """Elicit information from the client/user during a tool's execution. - Args: - message: Message to present to the user - schema: A Pydantic model class defining the expected response structure. - According to the specification, only primitive types are allowed. - - Returns: - An ElicitationResult containing the action taken and the data if accepted - - Note: - Check the result.action to determine if the user accepted, declined, or cancelled. - The result.data will only be populated if action is "accept" and validation succeeded. + Per the specification, `schema` may only contain primitive-typed fields. Check + `result.action` for accept/decline/cancel; `result.data` is populated only when + the action is "accept" and validation succeeded. """ return await elicit_with_validation( @@ -156,24 +111,11 @@ async def elicit_url( ) -> UrlElicitationResult: """Request URL mode elicitation from the client. - This directs the user to an external URL for out-of-band interactions - that must not pass through the MCP client. Use this for: - - Collecting sensitive credentials (API keys, passwords) - - OAuth authorization flows with third-party services - - Payment and subscription flows - - Any interaction where data should not pass through the LLM context - - The response indicates whether the user consented to navigate to the URL. - The actual interaction happens out-of-band. When the elicitation completes, - call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - - Returns: - UrlElicitationResult indicating accept, decline, or cancel + Directs the user to an external URL for out-of-band interactions whose data + must not pass through the MCP client or LLM context (credentials, OAuth flows, + payments). The result only indicates whether the user consented to navigate; + when the interaction completes, call + `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. """ return await elicit_url( session=self.request_context.session, @@ -191,15 +133,7 @@ async def log( *, logger_name: str | None = None, ) -> None: - """Send a log message to the client. - - Args: - level: Log level (debug, info, notice, warning, error, critical, - alert, emergency) - data: The data to be logged. Any JSON serializable type is allowed - (string, dict, list, number, bool, etc.) per the MCP specification. - logger_name: Optional logger name - """ + """Send a log message to the client. `data` may be any JSON-serializable value.""" await self.request_context.session.send_log_message( # pyright: ignore[reportDeprecated] level=level, data=data, @@ -210,10 +144,9 @@ async def log( # TODO(maxisbey): see if this is needed otherwise remove @property def client_id(self) -> str | None: - """Get the client ID if available. + """The client ID from the MCP request's `_meta` params, if available. - Note: this reads from the MCP request's `_meta` params, not the OAuth - bearer token. For that, use `get_access_token().client_id`. + Not the OAuth bearer token's client ID - for that, use `get_access_token().client_id`. """ return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover @@ -221,9 +154,8 @@ def client_id(self) -> str | None: def headers(self) -> Mapping[str, str] | None: """Request headers carried by this message, when the transport has them. - Populated by HTTP-based transports; `None` on stdio or when the - transport's request object carries no headers. Headers are - client-supplied input - never treat one as an identity assertion. + Populated by HTTP-based transports; `None` on stdio. Headers are client-supplied + input - never treat one as an identity assertion. """ return cast("Mapping[str, str] | None", getattr(self.request_context.request, "headers", None)) @@ -241,17 +173,13 @@ def protocol_version(self) -> str | None: def input_responses(self) -> InputResponses | None: """Client responses to a prior `InputRequiredResult.input_requests`. - `None` on the initial round, or when the client retried without - responses. + `None` on the initial round, or when the client retried without responses. """ return self._input_params.input_responses if self._input_params else None @property def request_state(self) -> str | None: - """Opaque state echoed from a prior `InputRequiredResult.request_state`. - - `None` on the initial round. - """ + """Opaque state echoed from a prior `InputRequiredResult.request_state`; `None` on the initial round.""" return self._input_params.request_state if self._input_params else None @property @@ -270,38 +198,25 @@ def session(self): return self.request_context.session async def close_sse_stream(self) -> None: - """Close the SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the current request, triggering - client reconnection. Events continue to be stored in the event store and will - be replayed when the client reconnects with Last-Event-ID. + """Close the current request's SSE stream to trigger client reconnection. - Use this to implement polling behavior during long-running operations - - the client will reconnect after the retry interval specified in the priming event. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - The callback is only available when event_store is configured. + Events keep accruing in the event store and are replayed when the client + reconnects with Last-Event-ID, enabling polling behavior during long-running + operations. No-op unless using StreamableHTTP transport with an event_store. """ if self._request_context and self._request_context.close_sse_stream: # pragma: no branch await self._request_context.close_sse_stream() async def close_standalone_sse_stream(self) -> None: - """Close the standalone GET SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the standalone GET stream used - for unsolicited server-to-client notifications. The client SHOULD reconnect - with Last-Event-ID to resume receiving notifications. + """Close the standalone GET SSE stream used for unsolicited server-to-client notifications. - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - Currently, client reconnection for standalone GET streams is NOT - implemented - this is a known gap. + The client SHOULD reconnect with Last-Event-ID to resume. No-op unless using + StreamableHTTP transport with an event_store. Known gap: client reconnection + for standalone GET streams is not implemented. """ if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover await self._request_context.close_standalone_sse_stream() - # Convenience methods for common log levels @deprecated("The logging capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) async def debug(self, data: Any, *, logger_name: str | None = None) -> None: """Send a debug log message.""" diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 338cb1f870..cb23a098d1 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -87,29 +87,23 @@ def from_function( ) -> Prompt: """Create a Prompt from a function. - The function can return: - - A string (converted to a message) - - A Message object - - A dict (converted to a message) - - A sequence of any of the above + The function may return a string, a Message, a dict, or a sequence of these; + each item is converted to a message. """ func_name = name or fn.__name__ if func_name == "": # pragma: no cover raise ValueError("You must provide a name for lambda functions") - # Find context parameter if it exists if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) - # Get schema from func_metadata, excluding context parameter func_arg_metadata = func_metadata( fn, skip_names=[context_kwarg] if context_kwarg is not None else [], ) parameters = func_arg_metadata.arg_model.model_json_schema() - # Convert parameters to PromptArguments arguments: list[PromptArgument] = [] if "properties" in parameters: # pragma: no branch for param_name, param in parameters["properties"].items(): @@ -122,7 +116,7 @@ def from_function( ) ) - # ensure the arguments are properly cast + # validate_call coerces incoming arguments to the declared parameter types fn = validate_call(fn) return cls( @@ -145,7 +139,6 @@ async def render( Raises: ValueError: If required arguments are missing, or if rendering fails. """ - # Validate required arguments if self.arguments: required = {arg.name for arg in self.arguments if arg.required} provided = set(arguments or {}) @@ -154,7 +147,6 @@ async def render( raise ValueError(f"Missing required arguments: {missing}") try: - # Add context to arguments if needed call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg) fn = self.fn @@ -163,11 +155,9 @@ async def render( else: result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args)) - # Validate messages if not isinstance(result, list | tuple): result = [result] - # Convert result to messages messages: list[Message] = [] for msg in result: # type: ignore[reportUnknownVariableType] try: diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 28a7a6e98c..c5ca139d97 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -22,20 +22,16 @@ def __init__(self, warn_on_duplicate_prompts: bool = True): self.warn_on_duplicate_prompts = warn_on_duplicate_prompts def get_prompt(self, name: str) -> Prompt | None: - """Get prompt by name.""" return self._prompts.get(name) def list_prompts(self) -> list[Prompt]: - """List all registered prompts.""" return list(self._prompts.values()) def add_prompt( self, prompt: Prompt, ) -> Prompt: - """Add a prompt to the manager.""" - - # Check for duplicates + """Register a prompt; if the name is already registered, the existing prompt is returned unchanged.""" existing = self._prompts.get(prompt.name) if existing: if self.warn_on_duplicate_prompts: @@ -51,7 +47,6 @@ async def render_prompt( arguments: dict[str, Any] | None, context: Context[LifespanContextT, RequestT], ) -> list[Message]: - """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 323ce5cddb..5024da7868 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -4,24 +4,12 @@ resolver `fn` before the tool body, instead of from the LLM-supplied arguments. Resolvers form a DAG: a resolver may declare its own `Resolve(...)` dependencies, take tool arguments by name, and take the `Context`. A resolver may return -`Elicit[T]` to ask the client; the framework runs the elicitation and injects the -answer. - -The framework picks the elicitation transport from the negotiated protocol. At ->= 2026-07-28 it returns an `InputRequiredResult` carrying the batched questions -and resumes when the client retries with `input_responses`/`request_state` -(independent resolvers are asked in one round; a resolver depending on another's -answer is asked in a later round). At <= 2025-11-25 it issues a synchronous -`elicitation/create` request mid-call. Only *elicited* outcomes are carried in -`request_state` across rounds (so the user is asked each question once); a -resolver that returns a value without eliciting is pure and may re-run each round. - -Whether the consumer receives the unwrapped model or the full -`ElicitationResult` union is decided by the consumer's annotation: - -- `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call. -- `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the - full outcome; the consumer branches on accept/decline/cancel. +`Elicit[T]` to ask the client; each question is asked once per call, but a +resolver that returns without eliciting is pure and may re-run each round. + +Annotating the consumer `Annotated[ElicitationResult[T], Resolve(fn)]` (or a +specific member) injects the full accept/decline/cancel outcome; the bare `T` +form injects the unwrapped model and aborts the call on decline/cancel. """ from __future__ import annotations @@ -64,12 +52,10 @@ T = TypeVar("T", bound=BaseModel) -# The union members the framework injects when a consumer opts into the outcome. _ELICITATION_RESULT_MEMBERS = (AcceptedElicitation, DeclinedElicitation, CancelledElicitation) -# First protocol revision whose `tools/call` carries elicitation inside -# `InputRequiredResult` rather than as a standalone server-to-client request. -# Pinned (not `LATEST_MODERN_VERSION`, which moves when newer revisions are added). +# First revision carrying elicitation inside `InputRequiredResult`; pinned, not +# `LATEST_MODERN_VERSION`, which moves when newer revisions are added. _INPUT_REQUIRED_VERSION = "2026-07-28" _STATE_VERSION = 1 @@ -82,11 +68,7 @@ def __init__(self, fn: Callable[..., Any]) -> None: class Elicit(Generic[T]): - """A resolver's request to ask the client. - - Returned from a resolver to signal that the value must be elicited. The - framework runs `ctx.elicit(message, schema)` and injects the outcome. - """ + """Returned from a resolver to ask the client; the framework runs the elicitation and injects the outcome.""" def __init__(self, message: str, schema: type[T]) -> None: self.message = message @@ -120,22 +102,17 @@ def __init__( self.fn = fn self.params = params self.is_async = is_async - # The `T` from the resolver's `Elicit[T]` return arm, if annotated. Used to - # re-validate an outcome restored from `request_state` into a model. + # `T` from the `Elicit[T]` return arm; re-validates outcomes restored from `request_state`. self.elicit_schema = elicit_schema - # Deterministic, collision-free key for this resolver's elicitation on the - # wire (`input_requests`/`request_state`). Assigned at registration so it is - # stable across rounds even when `module:qualname` collides (closures). + # Wire key for `input_requests`/`request_state`; assigned at registration so it + # stays stable across rounds even when `module:qualname` collides (closures). self.wire_key = wire_key def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: - """Resolve type hints for a function or a callable object. + """Resolve type hints, falling back to `__call__` for callable instances (`get_type_hints` raises on them). - `typing.get_type_hints` raises on a callable *instance*; fall back to its - `__call__`. Returns an empty mapping when hints cannot be resolved, matching - `find_context_parameter`'s tolerance so callables without annotations (or with - unresolvable ones) simply have no resolved parameters. + Unresolvable hints yield an empty mapping: such callables simply have no resolved parameters. """ target = fn if inspect.isroutine(fn) else getattr(type(fn), "__call__", fn) try: @@ -150,9 +127,8 @@ def _resolver_name(fn: Callable[..., Any]) -> str: def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, bool]]: - """Find parameters of `fn` annotated `Annotated[_, Resolve(...)]`. + """Map parameters of `fn` annotated `Annotated[_, Resolve(...)]` to `(Resolve, wants_union)`. - Returns a mapping of parameter name to `(Resolve, wants_union)`, where `wants_union` is True when the annotated type is an `ElicitationResult` member (the consumer wants the full outcome rather than the unwrapped model). """ @@ -161,8 +137,7 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, for name in inspect.signature(fn).parameters: annotation = hints.get(name) if get_origin(annotation) is not Annotated: - # A `Resolve` marker is only honored at the top level; flag (rather than - # silently drop) one buried in a union, e.g. `Annotated[T, Resolve(f)] | None`. + # Flag (rather than silently drop) a `Resolve` buried in a union, e.g. `Annotated[T, Resolve(f)] | None`. if _contains_resolve(annotation): raise InvalidSignature( f"Parameter {name!r} of {_resolver_name(fn)!r} wraps `Resolve(...)` in a " @@ -177,7 +152,6 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, def _contains_resolve(annotation: Any) -> bool: - """True when a `Resolve` marker is nested inside `annotation` (e.g. a union member).""" if get_origin(annotation) is Annotated: return any(isinstance(m, Resolve) for m in get_args(annotation)[1:]) return any(_contains_resolve(arg) for arg in get_args(annotation)) @@ -186,15 +160,11 @@ def _contains_resolve(annotation: Any) -> bool: def _elicit_return_schema(return_annotation: Any, name: str) -> type[BaseModel] | None: """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. - Handles a bare `-> Elicit[T]` and a `-> T | Elicit[T]` union. Lets an elicited - outcome restored from `request_state` (a plain dict) be re-validated into its - model so dependent resolvers and tools receive a typed value. + Used to re-validate an outcome restored from `request_state` (a plain dict) into its model. Raises: - InvalidSignature: If the annotation has more than one `Elicit[...]` arm; - the runtime can honor only one static question schema per resolver. + InvalidSignature: If the annotation has more than one `Elicit[...]` arm. """ - # A bare `Elicit[T]` is itself a candidate; a union contributes its members. candidates = get_args(return_annotation) if _is_union(return_annotation) else (return_annotation,) # Typing dedupes equal union members, so two arms here are genuinely distinct. arms = [c for c in candidates if get_origin(c) is Elicit] @@ -216,12 +186,9 @@ def _is_union(annotation: Any) -> bool: def _wants_union(type_arg: Any) -> bool: """True when `type_arg` is an `ElicitationResult` member (or a union of them). - Handles the subscripted `ElicitationResult[T]` alias (a `TypeAliasType` whose - union is on the origin's `__value__`), the bare `ElicitationResult` alias (the - `__value__` is on `type_arg` itself), an explicit `AcceptedElicitation[T] | ...` - union, and a single member. + The `ElicitationResult` `TypeAliasType` carries its union on `__value__`: on + `type_arg` itself when bare, on the origin when subscripted. """ - # Unwrap the `ElicitationResult` alias whether it is bare or subscripted. value = getattr(type_arg, "__value__", None) or getattr(get_origin(type_arg), "__value__", None) if value is not None: type_arg = value @@ -232,17 +199,13 @@ def _wants_union(type_arg: Any) -> bool: def _resolver_key(fn: Callable[..., Any]) -> Hashable: """Identity key for memoizing a resolver. - A bound method - pure-python (`inspect.ismethod`) or built-in (e.g. `obj.meth` - on a C-extension type) - is recreated on each attribute access, so `id(fn)` - differs every time. Key it by its underlying function (or name) plus its - `__self__` identity so `auth.login` referenced in two places memoizes to one - call. Everything else keys by `id`, so two distinct callables never collide - even if they compare equal. + A bound method (pure-python or built-in) is recreated on each attribute access, + so `id(fn)` differs every time; key it by its underlying function id (or, for + built-ins, `__name__`) plus `__self__` identity. Everything else keys by `id`, + so two distinct callables never collide even if they compare equal. """ bound_self = getattr(fn, "__self__", None) if bound_self is not None: - # `__func__` (pure-python) has a stable identity; built-ins expose only a - # stable `__name__`. Use the function's id or the name's value accordingly. func = getattr(fn, "__func__", None) underlying: Hashable = id(func) if func is not None else getattr(fn, "__name__", id(fn)) return (underlying, id(bound_self)) @@ -256,9 +219,8 @@ def build_resolver_plans( """Statically analyze the resolver DAG rooted at a tool's resolved parameters. Raises: - InvalidSignature: If a resolver has a cyclic dependency, or a resolver - parameter cannot be classified (not a `Context`, a nested `Resolve`, - or a tool argument by name). + InvalidSignature: On a cyclic dependency, or a resolver parameter that is + not a `Context`, a nested `Resolve`, or a tool argument by name. """ plans: dict[Hashable, _ResolverPlan] = {} # Count how many distinct resolvers share each `module:qualname` base so closures @@ -329,12 +291,7 @@ class _Pending(Exception): class _Resolution: - """Per-`tools/call` resolution state, shared across the DAG walk. - - `input_required` selects the transport: at >= 2026-07-28 elicitations are - batched into `pending` and surfaced as an `InputRequiredResult`; at older - revisions each `Elicit` is answered synchronously via `ctx.elicit`. - """ + """Per-`tools/call` resolution state, shared across the DAG walk.""" def __init__( self, @@ -349,25 +306,21 @@ def __init__( self.input_required = input_required self.answers: InputResponses = context.input_responses or {} if input_required else {} self.state = _decode_state(context.request_state) if input_required else {} - # In-call dedup keyed by resolver identity (distinguishes two instances of - # the same bound method); `persist` holds the wire-shaped record of each - # elicited outcome, keyed by its wire key - exactly what the next round's - # `request_state` carries. Entries are the client's own (validated) wire - # data, never re-derived from a model, so encode-restore is the identity. - # Pure resolvers are cheap to re-run each round and are not persisted. + # `cache` dedups within the call by resolver identity. `persist` holds elicited + # outcomes keyed by wire key - the next round's `request_state` - as the client's + # validated wire data, never re-derived from a model, so encode-restore is the + # identity. Pure resolvers are not persisted; they re-run each round. self.cache: dict[Hashable, ElicitationResult[Any]] = {} self.persist: dict[str, _StateEntry] = {} self.pending: InputRequests = {} def _state_key(fn: Callable[..., Any]) -> str: - """Worker-stable base wire key for a resolver, derived only from registration data. + """Worker-stable base wire key for a resolver: `module:qualname`, never `id(...)`. - `input_requests`/`request_state` must round-trip through the client and resume on - any worker (stateless HTTP), so the key carries no `id(...)`: it is the resolver's - `module:qualname` (a callable object uses its type's). Distinct resolvers that - share this base - two instances of one method, two closures from one factory - are - disambiguated deterministically by `build_resolver_plans` (`base`, `base#1`, ...). + `input_requests`/`request_state` round-trip through the client and may resume on + any worker (stateless HTTP). Resolvers sharing a base (bound methods, closures) + are disambiguated deterministically by `build_resolver_plans`. """ qualname = getattr(fn, "__qualname__", None) or type(fn).__qualname__ module = getattr(fn, "__module__", None) or type(fn).__module__ @@ -387,18 +340,13 @@ async def resolve_arguments( negotiated protocol is >= 2026-07-28), returns an `InputRequiredResult` carrying the batched questions instead; the tool body is not run. - An eliciting resolver asks its question once - its answer is carried in - `request_state` across rounds - while a resolver that resolves without - eliciting is pure and may re-run on each round. - Raises: ToolError: If an elicited value is declined or cancelled and the consumer asked for the unwrapped model (rather than the result union). """ - # `ctx.protocol_version` is `None` outside an active request: `MCPServer.call_tool()` - # called directly builds such a `Context`, and a tool whose resolvers never elicit - # must still work there. A missing version means the synchronous (non-input_required) - # transport, which never reaches a server-to-client request anyway. + # `ctx.protocol_version` is None outside an active request (e.g. `MCPServer.call_tool()` + # called directly); that maps to the synchronous transport, which never reaches a + # server-to-client request anyway. res = _Resolution(plans, tool_args, context, _uses_input_required(context.protocol_version)) injected: dict[str, Any] = {} for name, (marker, wants_union) in resolved_params.items(): @@ -414,10 +362,9 @@ async def resolve_arguments( async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResult[Any]: - """Resolve one resolver, deduped within the call by its resolver identity. + """Resolve one resolver, deduped within the call by resolver identity. - Raises `_Pending` when the resolver (or one of its dependencies) needs client - input that has not arrived yet. + Raises `_Pending` when it (or a dependency) needs client input that has not arrived yet. """ cache_key = _resolver_key(fn) if cache_key in res.cache: @@ -426,12 +373,11 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul plan = res.plans[cache_key] wire_key = plan.wire_key if wire_key in res.pending: - # Already asked this round by another consumer; don't run the resolver again. + # Already asked this round by another consumer. raise _Pending - # Restore a prior round's outcome directly only when its model is known from the - # `Elicit[T]` return arm. Without that (a resolver that elicits but isn't annotated - # `-> ... Elicit[T]`), fall through and re-run the resolver so `_elicit` can - # re-validate the stored answer against the live `Elicit.schema`. + # Restore a prior outcome directly only when its model is known from the `Elicit[T]` + # return arm; otherwise re-run the resolver so `_elicit` can re-validate the stored + # answer against the live `Elicit.schema`. if wire_key in res.state and (plan.elicit_schema is not None or res.state[wire_key].action != "accept"): outcome = _restore_outcome(res, wire_key, plan.elicit_schema) if outcome is not None: @@ -448,8 +394,7 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul else: assert param_plan.resolve is not None try: - # Visit every dependency so independent ones that need input are all - # collected into `res.pending` and batched into a single round. + # Keep visiting so all pending dependencies are batched into one round. dep_outcome = await _resolve(param_plan.resolve.fn, res) except _Pending: dep_pending = True @@ -467,9 +412,7 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul if _is_elicit(result): outcome = await _elicit(result, wire_key, res) else: - # A resolver may return any type (not just `BaseModel`), so accept it as the - # outcome without validating against the schema bound. Plain outcomes are not - # persisted in `request_state`; the resolver re-runs next round instead. + # Plain outcomes are not persisted in `request_state`; the resolver re-runs next round. outcome = _accepted(result) res.cache[cache_key] = outcome @@ -481,10 +424,9 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio if not res.input_required: return await res.context.elicit(elicit.message, elicit.schema) - # Answered in a prior round (restored without a known schema, e.g. an unannotated - # resolver): re-validate the stored entry against the live `Elicit.schema`. A - # recorded outcome wins over a re-sent answer; an invalid entry self-deletes and - # falls through to the fresh answer (or to re-asking). + # A prior round's entry is re-validated against the live `Elicit.schema`; a recorded + # outcome wins over a re-sent answer, and an invalid entry self-deletes, falling + # through to the fresh answer (or to re-asking). outcome = _restore_outcome(res, key, elicit.schema) if outcome is not None: return outcome @@ -505,8 +447,7 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio raise ToolError( f"Resolver {key!r} received an accepted elicitation whose content does not match the requested schema" ) from e - # Persist the exact wire content that just passed validation - never the - # model - so restoring next round revalidates the same bytes the client sent. + # Persist the wire content, never the model: next round revalidates the same bytes. res.persist[key] = _StateEntry(action="accept", data=answer.content) return AcceptedElicitation(data=data) if answer.action == "decline": @@ -523,38 +464,31 @@ def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: def _is_elicit(value: Any) -> TypeGuard[Elicit[Any]]: - """Runtime narrow of a resolver's return value to a (parameter-erased) `Elicit`.""" return isinstance(value, Elicit) def _accepted(data: Any) -> AcceptedElicitation[Any]: - """Wrap a resolved value as an accepted outcome without schema validation. + """Wrap a value as an accepted outcome without validation. - A resolver may return any type (the schema bound only constrains `Elicit[T]`), - and a value restored from `request_state` is already validated. + Resolvers may return any type; restored `request_state` values are already validated. """ return AcceptedElicitation[Any].model_construct(data=data) def _uses_input_required(protocol_version: str | None) -> bool: - """True when this request must elicit via `InputRequiredResult` (>= 2026-07-28). - - Older revisions still carry a standalone `elicitation/create` server-to-client - request, so the framework keeps the synchronous `ctx.elicit()` path for them. - """ + """True when elicitation uses `InputRequiredResult` (>= 2026-07-28) rather than `elicitation/create`.""" return protocol_version is not None and is_version_at_least(protocol_version, _INPUT_REQUIRED_VERSION) def _require_form_elicitation(context: Context[Any, Any], key: str) -> None: """Assert the client declared form elicitation before queueing a question for it. - The spec forbids sending an `input_requests` entry the client has not declared a - capability for. A bare `elicitation: {}` declaration (the only shape before modes - existed) counts as form support; an explicit url-only declaration does not. + The spec forbids `input_requests` entries the client lacks a capability for. A bare + `elicitation: {}` (the only shape before modes existed) counts as form support; an + explicit url-only declaration does not. Raises: - MCPError: With code `MISSING_REQUIRED_CLIENT_CAPABILITY` and a - `requiredCapabilities` payload when form elicitation is not declared. + MCPError: `MISSING_REQUIRED_CLIENT_CAPABILITY` with a `requiredCapabilities` payload. """ capabilities = context.client_capabilities elicitation = capabilities.elicitation if capabilities is not None else None @@ -591,10 +525,9 @@ class _State(BaseModel): def _decode_state(request_state: str | None) -> dict[str, _StateEntry]: - """Decode the per-call resolution progress from `request_state`. + """Decode per-call progress from client-trusted `request_state` (integrity sealing is a follow-up). - `request_state` is client-trusted (integrity sealing is a follow-up); validate - it through `_State` and treat anything malformed as "no progress yet". + Anything malformed reads as "no progress yet". """ if not request_state: return {} @@ -606,11 +539,7 @@ def _decode_state(request_state: str | None) -> dict[str, _StateEntry]: def _encode_state(outcomes: Mapping[str, _StateEntry]) -> str: - """Encode recorded elicitation outcomes (keyed by wire key) for the next round. - - Entries already hold the client's wire-shaped data exactly as it was sent (and - validated), so encoding is pure wrapping: encode-restore is the identity. - """ + """Encode recorded elicitation outcomes (keyed by wire key) for the next round.""" return _State(v=_STATE_VERSION, outcomes=dict(outcomes)).model_dump_json() @@ -618,8 +547,7 @@ def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel] | None) -> E """Rebuild an `ElicitationResult` from a decoded `request_state` entry. Raises: - ValidationError: If `schema` is known and the entry's data does not - validate against it. + ValidationError: If `schema` is known and the entry's data does not validate. """ if entry.action == "decline": return DeclinedElicitation() @@ -634,14 +562,10 @@ def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel] | None) -> E def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel] | None) -> ElicitationResult[Any] | None: """Restore `key`'s recorded outcome from a prior round, or `None` when absent. - `request_state` is client-trusted, so an entry whose data fails validation gets - the `_decode_state` treatment - dropped as if no progress was recorded, so the - question is asked again - rather than surfacing a validation error. - - Carries the original decoded entry forward unchanged in `res.persist`: if a - later resolver is still pending, the next round's `request_state` is built from - `res.persist`, so an earlier answer must stay there - byte-identical, never - re-derived - or it would be dropped and re-asked. + Client-trusted state: an entry whose data fails validation is dropped as if no + progress was recorded, so the question is asked again. The original entry is + carried forward unchanged into `res.persist` - the next round's `request_state` + is built from it, so an earlier answer must stay there or it would be re-asked. """ entry = res.state.get(key) if entry is None: diff --git a/src/mcp/server/mcpserver/resources/base.py b/src/mcp/server/mcpserver/resources/base.py index f7bedf6cbe..96bd87df37 100644 --- a/src/mcp/server/mcpserver/resources/base.py +++ b/src/mcp/server/mcpserver/resources/base.py @@ -30,7 +30,6 @@ class Resource(BaseModel, abc.ABC): @field_validator("name", mode="before") @classmethod def set_default_name(cls, name: str | None, info: ValidationInfo) -> str: - """Set default name from URI if not provided.""" if name: return name if uri := info.data.get("uri"): diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index 41d3d7bb37..b235bb660e 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -37,14 +37,7 @@ def __init__(self, warn_on_duplicate_resources: bool = True, *, resources: list[ self.add_resource(resource) def add_resource(self, resource: Resource) -> Resource: - """Add a resource to the manager. - - Args: - resource: A Resource instance to add. - - Returns: - The added resource. If a resource with the same URI already exists, returns the existing resource. - """ + """Add a resource, returning the existing one if a resource with the same URI is already registered.""" logger.debug( "Adding resource", extra={"uri": resource.uri, "type": type(resource).__name__, "resource_name": resource.name}, @@ -94,22 +87,16 @@ async def get_resource(self, uri: AnyUrl | str, context: Context[LifespanContext ResourceError: If a matching template fails to create the resource. Note: - Pydantic's ``AnyUrl`` normalises percent-encoding and - resolves ``..`` segments during validation, so a value - constructed as ``AnyUrl("file:///a/%2E%2E/b")`` arrives - here as ``file:///b``. The JSON-RPC protocol layer passes - raw ``str`` values and is unaffected, but internal callers - wrapping URIs in ``AnyUrl`` should be aware that security - checks see the already-normalised form. + Pydantic's `AnyUrl` normalises percent-encoding and resolves `..` segments during + validation, so internal callers wrapping URIs in `AnyUrl` reach the security checks + with the already-normalised form. The JSON-RPC layer passes raw `str` and is unaffected. """ uri_str = str(uri) logger.debug("Getting resource", extra={"uri": uri_str}) - # First check concrete resources if resource := self._resources.get(uri_str): return resource - # Then check templates for template in self._templates.values(): try: params = template.matches(uri_str) diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index f78b5ec666..83db305be4 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -31,45 +31,33 @@ class ResourceSecurity: """Security policy applied to extracted resource template parameters. - These checks run after :meth:`~mcp.shared.uri_template.UriTemplate.match` - has extracted and decoded parameter values. They catch path-traversal - and absolute-path injection regardless of how the value was encoded in - the URI (literal, ``%2F``, ``%5C``, ``%2E%2E``). + Checks run after `UriTemplate.match` has decoded parameter values, so they catch + traversal and absolute-path injection however the URI encoded them (`%2F`, `%5C`, `%2E%2E`). - Example:: + Example (opt out for a parameter that legitimately contains `..`): - # Opt out for a parameter that legitimately contains .. - @mcp.resource( - "git://diff/{+range}", - security=ResourceSecurity(exempt_params={"range"}), - ) + @mcp.resource("git://diff/{+range}", security=ResourceSecurity(exempt_params={"range"})) def git_diff(range: str) -> str: ... """ reject_path_traversal: bool = True - """Reject values containing ``..`` as a path component.""" + """Reject values containing `..` as a path component.""" reject_absolute_paths: bool = True """Reject values that look like absolute filesystem paths.""" reject_null_bytes: bool = True - """Reject values containing NUL (``\\x00``). Null bytes defeat string - comparisons (``"..\\x00" != ".."``) and can cause truncation in C - extensions or subprocess calls.""" + """Reject values containing NUL; null bytes defeat string comparisons and can + cause truncation in C extensions or subprocess calls.""" exempt_params: Set[str] = field(default_factory=frozenset[str]) """Parameter names to skip all checks for.""" def validate(self, params: Mapping[str, str | list[str]]) -> str | None: - """Check all parameter values against the configured policy. - - Args: - params: Extracted template parameters. List values (from - explode variables) are checked element-wise. + """Check parameter values against the policy; list values are checked element-wise. Returns: - The name of the first parameter that fails, or ``None`` if - all values pass. + The name of the first failing parameter, or `None` if all values pass. """ for name, value in params.items(): if name in self.exempt_params: @@ -90,11 +78,10 @@ def validate(self, params: Mapping[str, str | list[str]]) -> str | None: class ResourceSecurityError(ValueError): - """Raised when an extracted parameter fails :class:`ResourceSecurity` checks. + """Raised when an extracted parameter fails `ResourceSecurity` checks. - Distinct from a simple ``None`` non-match so that template - iteration can stop at the first security rejection rather than - falling through to a later, possibly more permissive, template. + Distinct from a `None` non-match so template iteration stops at the first + security rejection instead of falling through to a more permissive template. """ def __init__(self, template: str, param: str) -> None: @@ -138,8 +125,7 @@ def from_function( """Create a template from a function. Raises: - InvalidUriTemplate: If ``uri_template`` is malformed or uses - unsupported RFC 6570 features. + InvalidUriTemplate: If `uri_template` is malformed or uses unsupported RFC 6570 features. """ func_name = name or fn.__name__ if func_name == "": @@ -147,18 +133,16 @@ def from_function( parsed = UriTemplate.parse(uri_template) - # Find context parameter if it exists if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) - # Get schema from func_metadata, excluding context parameter func_arg_metadata = func_metadata( fn, skip_names=[context_kwarg] if context_kwarg is not None else [], ) parameters = func_arg_metadata.arg_model.model_json_schema() - # ensure the arguments are properly cast + # validate_call coerces arguments to their annotated types fn = validate_call(fn) return cls( @@ -180,20 +164,13 @@ def from_function( def matches(self, uri: str) -> dict[str, str | list[str]] | None: """Check if a URI matches this template and extract parameters. - Delegates to :meth:`UriTemplate.match` for RFC 6570 extraction, - then applies this template's :class:`ResourceSecurity` policy - (path traversal, absolute paths). - Returns: - Extracted parameters on success, or ``None`` if the URI - doesn't match the template. + Extracted parameters, or `None` if the URI doesn't match. Raises: - ResourceSecurityError: If the URI matches but an extracted - parameter fails security validation. Raising (rather - than returning ``None``) prevents the resource manager - from silently falling through to a later, possibly more - permissive, template. + ResourceSecurityError: If a matched parameter fails security validation. Raising + (not returning `None`) prevents the resource manager from silently falling + through to a later, possibly more permissive, template. """ params = self.parsed_template.match(uri) if params is None: @@ -215,7 +192,6 @@ async def create_resource( ResourceError: If creating the resource fails. """ try: - # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) fn = self.fn @@ -233,7 +209,7 @@ async def create_resource( icons=self.icons, annotations=self.annotations, meta=self.meta, - fn=lambda: result, # Capture result in closure + fn=lambda: result, ) except ResourceError: raise diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index e295e21e02..b7ec95f4a3 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -40,16 +40,9 @@ async def read(self) -> bytes: class FunctionResource(Resource): - """A resource that defers data loading by wrapping a function. + """A resource that defers loading by calling `fn` only when read. - The function is only called when the resource is read, allowing for lazy loading - of potentially expensive data. This is particularly useful when listing resources, - as the function won't be called until the resource is actually accessed. - - The function can return: - - str for text content (default) - - bytes for binary content - - other types will be converted to JSON + `fn` may return str (text), bytes (binary), or any other type, which is serialized to JSON. """ fn: Callable[[], Any] = Field(exclude=True) @@ -92,7 +85,6 @@ def from_function( if func_name == "": # pragma: no cover raise ValueError("You must provide a name for lambda functions") - # ensure the arguments are properly cast fn = validate_call(fn) return cls( @@ -109,10 +101,7 @@ def from_function( class FileResource(Resource): - """A resource that reads from a file. - - Set is_binary=True to read the file as binary data instead of text. - """ + """A resource that reads from a file.""" path: Path = Field(description="Path to the file") is_binary: bool = Field( @@ -127,7 +116,6 @@ class FileResource(Resource): @pydantic.field_validator("path") @classmethod def validate_absolute_path(cls, path: Path) -> Path: - """Ensure path is absolute.""" if not path.is_absolute(): raise ValueError("Path must be absolute") return path @@ -135,7 +123,6 @@ def validate_absolute_path(cls, path: Path) -> Path: @pydantic.field_validator("is_binary") @classmethod def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool: - """Set is_binary based on mime_type if not explicitly set.""" if is_binary: return True mime_type = info.data.get("mime_type", "text/plain") @@ -176,7 +163,6 @@ class DirectoryResource(Resource): @pydantic.field_validator("path") @classmethod def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover - """Ensure path is absolute.""" if not path.is_absolute(): raise ValueError("Path must be absolute") return path @@ -195,7 +181,7 @@ def list_files(self) -> list[Path]: # pragma: no cover except Exception as e: raise ValueError(f"Error listing directory {self.path}: {e}") - async def read(self) -> str: # Always returns JSON string # pragma: no cover + async def read(self) -> str: # pragma: no cover """Read the directory listing.""" try: files = await anyio.to_thread.run_sync(self.list_files) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 888eae6541..3b7b0bb265 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -97,11 +97,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): - """MCPServer settings. - - All settings can be configured via environment variables with the prefix MCP_. - For example, MCP_DEBUG=true will set debug=True. - """ + """MCPServer settings, configurable via environment variables with the `MCP_` prefix (e.g. `MCP_DEBUG=true`).""" model_config = SettingsConfigDict( env_prefix="MCP_", @@ -111,17 +107,10 @@ class Settings(BaseSettings, Generic[LifespanResultT]): extra="ignore", ) - # Server settings debug: bool log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - - # resource settings warn_on_duplicate_resources: bool - - # tool settings warn_on_duplicate_tools: bool - - # prompt settings warn_on_duplicate_prompts: bool dependencies: list[str] @@ -206,11 +195,10 @@ def __init__( on_list_resource_templates=self._handle_list_resource_templates, on_list_prompts=self._handle_list_prompts, on_get_prompt=self._handle_get_prompt, - # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an MCPServer and Server. - # We need to create a Lifespan type that is a generic on the server type, like Starlette does. + # TODO(Marcelo): MCPServer/Server lifespan types mismatch; needs a Lifespan generic over the + # server type, like Starlette's. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) - # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: # pragma: no cover raise ValueError("Cannot specify both auth_server_provider and token_verifier") @@ -222,12 +210,11 @@ def __init__( self._auth_server_provider = auth_server_provider self._token_verifier = token_verifier - # Create token verifier from provider if needed (backwards compatibility) + # Backwards compatibility: derive a token verifier from the auth server provider. if auth_server_provider and not token_verifier: self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._custom_starlette_routes: list[Route] = [] - # Configure logging configure_logging(self.settings.log_level) self._extensions: list[Extension] = [] @@ -265,23 +252,18 @@ def version(self) -> str | None: @property def session_manager(self) -> StreamableHTTPSessionManager: - """Get the StreamableHTTP session manager. - - This is exposed to enable advanced use cases like mounting multiple - MCPServer instances in a single FastAPI application. + """The StreamableHTTP session manager, exposed for advanced uses like mounting multiple MCPServer instances. Raises: - RuntimeError: If called before streamable_http_app() has been called. + RuntimeError: If called before `streamable_http_app()`. """ return self._lowlevel_server.session_manager def _apply_extension(self, extension: Extension) -> None: - """Apply one opt-in extension's contributions through the public surface. + """Register an extension's tools/resources/methods and advertise its settings in capabilities. - Registers its tools/resources/methods and advertises its settings under - `ServerCapabilities.extensions[extension.identifier]`. Extensions are fixed - at construction, so this is private; the `tools/call` interceptor is - composed once afterwards by `_install_extension_interceptor`. + Extensions are fixed at construction, so this is private; the `tools/call` + interceptor is composed once afterwards by `_install_extension_interceptor`. """ identifier = getattr(extension, "identifier", None) validate_extension_identifier(identifier, owner=type(extension).__name__) @@ -305,11 +287,7 @@ def _apply_extension(self, extension: Extension) -> None: self._lowlevel_server.extensions[extension.identifier] = extension.settings() def _install_extension_interceptor(self) -> None: - """Compose every extension's `tools/call` interceptor into one middleware. - - Installed only when at least one extension overrides `intercept_tool_call`, - so a server with purely additive extensions adds no middleware. - """ + """Compose extension `tools/call` interceptors into one middleware; no-op when none override it.""" if any(type(e).intercept_tool_call is not Extension.intercept_tool_call for e in self._extensions): self._lowlevel_server.middleware.append(compose_tool_call_interceptor(self._extensions)) @@ -348,12 +326,7 @@ def run( transport: Literal["stdio", "sse", "streamable-http"] = "stdio", **kwargs: Any, ) -> None: - """Run the MCP server. Note this is a synchronous function. - - Args: - transport: Transport protocol to use ("stdio", "sse", or "streamable-http") - **kwargs: Transport-specific options (see overloads for details) - """ + """Run the MCP server (synchronously); transport-specific options are documented on the overloads.""" TRANSPORTS = Literal["stdio", "sse", "streamable-http"] if transport not in TRANSPORTS.__args__: # type: ignore # pragma: no cover raise ValueError(f"Unknown transport: {transport}") @@ -479,6 +452,7 @@ async def list_resources(self) -> list[MCPResource]: ] async def list_resource_templates(self) -> list[MCPResourceTemplate]: + """List all available resource templates.""" templates = self._resource_manager.list_templates() return [ MCPResourceTemplate( @@ -512,7 +486,7 @@ async def read_resource( return [ReadResourceContents(content=content, mime_type=resource.mime_type, meta=resource.meta)] except Exception as exc: logger.exception(f"Error getting resource {uri}") - # If an exception happens when reading the resource, we should not leak the exception to the client. + # Don't leak the underlying exception to the client. raise ResourceError(f"Error reading resource {uri}") from exc def add_tool( @@ -528,21 +502,13 @@ def add_tool( ) -> None: """Add a tool to the server. - The tool function can optionally request a Context object by adding a parameter - with the Context type annotation. See the @tool decorator for examples. + The function may request a `Context` object by annotating a parameter with the + Context type; see the `tool` decorator for examples. Args: - fn: The function to register as a tool - name: Optional name for the tool (defaults to function name) - title: Optional human-readable title for the tool - description: Optional description of what the tool does - annotations: Optional ToolAnnotations providing additional tool information - icons: Optional list of icons for the tool - meta: Optional metadata dictionary for the tool - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool + name: Defaults to the function name. + structured_output: None auto-detects from the return annotation; True forces a + structured tool (annotation permitting); False forces an unstructured tool. """ self._tool_manager.add_tool( fn, @@ -558,11 +524,8 @@ def add_tool( def remove_tool(self, name: str) -> None: """Remove a tool from the server by name. - Args: - name: The name of the tool to remove - Raises: - ToolError: If the tool does not exist + ToolError: If the tool does not exist. """ self._tool_manager.remove_tool(name) @@ -578,44 +541,26 @@ def tool( ) -> Callable[[_CallableT], _CallableT]: """Decorator to register a tool. - Tools can optionally request a Context object by adding a parameter with the - Context type annotation. The context provides access to MCP capabilities like - logging, progress reporting, and resource access. + Tools may request a `Context` object by annotating a parameter with the Context + type; it provides MCP capabilities like logging, progress reporting, and resource access. Args: - name: Optional name for the tool (defaults to function name) - title: Optional human-readable title for the tool - description: Optional description of what the tool does - annotations: Optional ToolAnnotations providing additional tool information - icons: Optional list of icons for the tool - meta: Optional metadata dictionary for the tool - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool + name: Defaults to the function name. + structured_output: None auto-detects from the return annotation; True forces a + structured tool (annotation permitting); False forces an unstructured tool. Example: ```python @server.tool() def my_tool(x: int) -> str: return str(x) - ``` - ```python @server.tool() async def tool_with_context(x: int, ctx: Context) -> str: await ctx.info(f"Processing {x}") return str(x) ``` - - ```python - @server.tool() - async def async_tool(x: int, context: Context) -> str: - await context.report_progress(50, 100) - return str(x) - ``` """ - # Check if user passed function directly instead of calling decorator if callable(name): raise TypeError( "The @tool decorator was used incorrectly. Did you forget to call it? Use @tool() instead of @tool" @@ -639,20 +584,9 @@ def decorator(fn: _CallableT) -> _CallableT: def completion(self): """Decorator to register a completion handler. - The completion handler receives: - - ref: PromptReference or ResourceTemplateReference - - argument: CompletionArgument with name and partial value - - context: Optional CompletionContext with previously resolved arguments - - Example: - ```python - @mcp.completion() - async def handle_completion(ref, argument, context): - if isinstance(ref, ResourceTemplateReference): - # Return completions based on ref, argument, and context - return Completion(values=["option1", "option2"]) - return None - ``` + The handler receives the reference (`PromptReference` or `ResourceTemplateReference`), + the `CompletionArgument` with name and partial value, and an optional + `CompletionContext` with previously resolved arguments; it returns a `Completion` or None. """ def decorator(func: _CallableT) -> _CallableT: @@ -670,11 +604,7 @@ async def handler( return decorator def add_resource(self, resource: Resource) -> None: - """Add a resource to the server. - - Args: - resource: A Resource instance to add - """ + """Add a resource to the server.""" self._resource_manager.add_resource(resource) def resource( @@ -692,28 +622,15 @@ def resource( ) -> Callable[[_CallableT], _CallableT]: """Decorator to register a function as a resource. - The function will be called when the resource is read to generate its content. - The function can return: - - str for text content - - bytes for binary content - - other types will be converted to JSON + The function is called when the resource is read; it may return str (text), + bytes (binary), or any other type, which is converted to JSON. - If the URI contains parameters (e.g. "resource://{param}"), it is - registered as a template resource. Otherwise it is registered as a - static resource; function parameters on a static URI raise an error. + A URI containing parameters (e.g. "resource://{param}") registers a template + resource; otherwise the resource is static and the function must take no parameters. Args: - uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}") - name: Optional name for the resource - title: Optional human-readable title for the resource - description: Optional description of the resource - mime_type: Optional MIME type for the resource - icons: Optional list of icons for the resource - annotations: Optional annotations for the resource - meta: Optional metadata dictionary for the resource - security: Path-safety policy for extracted template parameters. - Defaults to the server's ``resource_security`` setting. - Only applies to template resources. + security: Path-safety policy for extracted template parameters. Defaults to + the server's `resource_security` setting. Template resources only. Example: ```python @@ -721,15 +638,6 @@ def resource( def get_data() -> str: return "Hello, world!" - @server.resource("resource://my-resource") - async def get_data() -> str: - data = await fetch_data() - return f"Hello, world! {data}" - - @server.resource("resource://{city}/weather") - def get_weather(city: str) -> str: - return f"Weather for {city}" - @server.resource("resource://{city}/weather") async def get_weather(city: str) -> str: data = await fetch_weather(city) @@ -737,24 +645,18 @@ async def get_weather(city: str) -> str: ``` Raises: - InvalidUriTemplate: If ``uri`` is not a valid RFC 6570 template. - ValueError: If URI template parameters don't match the - function's parameters, or if a parameter bound to a - ``{?...}``/``{&...}`` query variable has no default - (the client may omit it). - TypeError: If the decorator is applied without being called - (``@resource`` instead of ``@resource("uri")``). + InvalidUriTemplate: If `uri` is not a valid RFC 6570 template. + ValueError: If URI template parameters don't match the function's parameters, + or a parameter bound to a `{?...}`/`{&...}` query variable has no default. + TypeError: If the decorator is applied without being called. """ - # Check if user passed function directly instead of calling decorator if callable(uri): raise TypeError( "The @resource decorator was used incorrectly. " "Did you forget to call it? Use @resource('uri') instead of @resource" ) - # Parse once, early β€” surfaces malformed-template errors at - # decoration time with a clear position, and gives us correct - # variable names for all RFC 6570 operators. + # Parse at decoration time: malformed templates fail early, and variable names cover all RFC 6570 operators. parsed = UriTemplate.parse(uri) uri_params = set(parsed.variable_names) @@ -763,20 +665,15 @@ def decorator(fn: _CallableT) -> _CallableT: context_param = find_context_parameter(fn) func_params = {p for p in sig.parameters.keys() if p != context_param} - # Template/static is decided purely by the URI: variables - # present means template, none means static. if uri_params: if uri_params != func_params: raise ValueError( f"Mismatch between URI parameters {uri_params} and function parameters {func_params}" ) - # A {?...}/{&...} query variable is optional on the wire: - # match() omits it from the extracted parameters when the - # client leaves it out of the URI. The handler parameter - # bound to it must therefore have a Python default; without - # one, the author only finds out on the first request that - # omits it, as an opaque internal error. + # A {?...}/{&...} query variable may be omitted by the client, so the bound handler + # parameter needs a Python default; without one, the author only finds out via an + # opaque internal error on the first request that omits it. missing_defaults = sorted( name for name in parsed.query_variable_names @@ -790,7 +687,6 @@ def decorator(fn: _CallableT) -> _CallableT: f"default." ) - # Register as template self._resource_manager.add_template( fn=fn, uri_template=uri, @@ -818,7 +714,6 @@ def decorator(fn: _CallableT) -> _CallableT: f"Add a template variable to the URI or remove the " f"Context parameter." ) - # Register as regular resource resource = FunctionResource.from_function( fn=fn, uri=uri, @@ -836,11 +731,7 @@ def decorator(fn: _CallableT) -> _CallableT: return decorator def add_prompt(self, prompt: Prompt) -> None: - """Add a prompt to the server. - - Args: - prompt: A Prompt instance to add - """ + """Add a prompt to the server.""" self._prompt_manager.add_prompt(prompt) def prompt( @@ -853,10 +744,7 @@ def prompt( """Decorator to register a prompt. Args: - name: Optional name for the prompt (defaults to function name) - title: Optional human-readable title for the prompt - description: Optional description of what the prompt does - icons: Optional list of icons for the prompt + name: Defaults to the function name. Example: ```python @@ -869,25 +757,8 @@ def analyze_table(table_name: str) -> list[Message]: "content": f"Analyze this schema:\n{schema}" } ] - - @server.prompt() - async def analyze_file(path: str) -> list[Message]: - content = await read_file(path) - return [ - { - "role": "user", - "content": { - "type": "resource", - "resource": { - "uri": f"file://{path}", - "text": content - } - } - } - ] ``` """ - # Check if user passed function directly instead of calling decorator if callable(name): raise TypeError( "The @prompt decorator was used incorrectly. " @@ -908,23 +779,14 @@ def custom_route( name: str | None = None, include_in_schema: bool = True, ): - """Decorator to register a custom HTTP route on the MCP server. - - Allows adding arbitrary HTTP endpoints outside the standard MCP protocol, - which can be useful for OAuth callbacks, health checks, or admin APIs. - The handler function must be an async function that accepts a Starlette - Request and returns a Response. + """Decorator to register a custom HTTP route (an async `Request -> Response` handler). - Routes using this decorator will not require authorization. It is intended - for uses that are either a part of authorization flows or intended to be - public such as health check endpoints. + Useful for endpoints outside the MCP protocol such as OAuth callbacks, health checks, + or admin APIs. These routes do NOT require authorization β€” they are intended for + endpoints that are part of auth flows or deliberately public. Args: - path: URL path for the route (e.g., "/oauth/callback") - methods: List of HTTP methods to support (e.g., ["GET", "POST"]) - name: Optional name for the route (to reference this route with - Starlette's reverse URL lookup feature) - include_in_schema: Whether to include in OpenAPI schema, defaults to True + name: Optional route name for Starlette's reverse URL lookup. Example: ```python @@ -1035,37 +897,30 @@ def sse_app( sse = SseServerTransport(message_path, security_settings=transport_security) async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no cover - # Add client ID from auth context into request context if available - async with sse.connect_sse(scope, receive, send) as streams: await self._lowlevel_server.run( streams[0], streams[1], self._lowlevel_server.create_initialization_options() ) return Response() - # Create routes routes: list[Route | Mount] = [] middleware: list[Middleware] = [] required_scopes: list[str] = [] - # Set up auth if configured if self.settings.auth: # pragma: no cover required_scopes = self.settings.auth.required_scopes or [] - # Add auth middleware if token verifier is available if self._token_verifier: middleware = [ - # extract auth info from request (but do not require it) + # Extracts auth info from the request but does not require it. Middleware( AuthenticationMiddleware, backend=BearerAuthBackend(self._token_verifier), ), - # Add the auth context middleware to store - # authenticated user in a contextvar + # Stores the authenticated user in a contextvar. Middleware(AuthContextMiddleware), ] - # Add auth endpoints if auth server provider is configured if self._auth_server_provider: from mcp.server.auth.routes import create_auth_routes @@ -1080,17 +935,14 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no ) ) - # When auth is configured, require authentication if self._token_verifier: # pragma: no cover - # Determine resource metadata URL resource_metadata_url = None if self.settings.auth and self.settings.auth.resource_server_url: from mcp.server.auth.routes import build_resource_metadata_url - # Build compliant metadata URL for WWW-Authenticate header + # Metadata URL for the WWW-Authenticate header. resource_metadata_url = build_resource_metadata_url(self.settings.auth.resource_server_url) - # Auth is enabled, wrap the endpoints with RequireAuthMiddleware routes.append( Route( sse_path, @@ -1105,10 +957,8 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no ) ) else: - # Auth is disabled, no need for RequireAuthMiddleware - # Since handle_sse is an ASGI app, we need to create a compatible endpoint + # handle_sse is an ASGI app, so wrap it in a request/response endpoint. async def sse_endpoint(request: Request) -> Response: # pragma: no cover - # Convert the Starlette request to ASGI parameters return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] routes.append( @@ -1124,7 +974,6 @@ async def sse_endpoint(request: Request) -> Response: # pragma: no cover app=sse.handle_post_message, ) ) - # Add protected resource metadata endpoint if configured as RS if self.settings.auth and self.settings.auth.resource_server_url: # pragma: no cover from mcp.server.auth.routes import create_protected_resource_routes @@ -1136,10 +985,9 @@ async def sse_endpoint(request: Request) -> Response: # pragma: no cover ) ) - # mount these routes last, so they have the lowest route matching precedence + # Mounted last so custom routes have the lowest matching precedence. routes.extend(self._custom_starlette_routes) - # Create Starlette app with routes and middleware return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware) def streamable_http_app( @@ -1213,7 +1061,7 @@ async def get_prompt( def _version_gated(method: MethodBinding) -> RequestHandler: - """Wrap a method handler so a request at a disallowed protocol version is rejected. + """Wrap a method handler to reject requests at disallowed protocol versions. The low-level `_request_handlers` dict is keyed by method only, so per-version scoping is enforced here rather than at the runner's boundary table. @@ -1230,21 +1078,14 @@ async def gated(ctx: ServerRequestContext[Any, Any], params: Any) -> HandlerResu def require_client_extension(ctx: ServerRequestContext[Any, Any], identifier: str) -> None: - """Assert the connected client declared support for `identifier`. - - Call this from an extension's handler or `intercept_tool_call` before - offering extension-specific behaviour. Raises `MCPError` with the - `-32021` (missing required client capability) code and a - `requiredCapabilities` payload when the client did not declare the - extension, per SEP-2133. + """Assert the connected client declared support for extension `identifier`. - Args: - ctx: The current request context. - identifier: The extension identifier the client must have declared. + Call from an extension's handler or `intercept_tool_call` before offering + extension-specific behaviour (SEP-2133). Raises: - MCPError: With code `MISSING_REQUIRED_CLIENT_CAPABILITY` if the client - did not advertise `identifier`. + MCPError: With code `MISSING_REQUIRED_CLIENT_CAPABILITY` (-32021) and a + `requiredCapabilities` payload if the client did not declare the extension. """ client_params = ctx.session.client_params declared = client_params.capabilities.extensions if client_params else None diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 50d28f574b..e69d7d51df 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -129,16 +129,14 @@ async def run( if self.context_kwarg is not None: pass_directly[self.context_kwarg] = context - # Resolvers see the same validated arguments the tool body receives: - # validate once and reuse it, so a `default_factory`/stateful validator - # can't hand a by-name resolver a different value than the body. + # Validate once and reuse: a `default_factory`/stateful validator must not + # hand a by-name resolver a different value than the tool body sees. pre_validated: dict[str, Any] | None = None if self.resolved_params: pre_validated = self.fn_metadata.validate_arguments(arguments) resolved = await resolve_arguments(self.resolved_params, self.resolver_plans, pre_validated, context) if isinstance(resolved, InputRequiredResult): - # A resolver still needs client input (>= 2026-07-28): surface the - # batched questions instead of running the tool body this round. + # A resolver still needs client input (>= 2026-07-28): surface it instead of running the body. return self.fn_metadata.convert_result(resolved) if convert_result else resolved pass_directly |= resolved @@ -155,11 +153,8 @@ async def run( return result except MCPError: - # `MCPError` (and subclasses such as `UrlElicitationRequiredError`) - # carries a JSON-RPC `ErrorData(code, message, data)` and means - # "respond with a protocol error" - re-raise so the kernel surfaces - # it as a top-level JSON-RPC error rather than wrapping it as a - # `CallToolResult(isError=True)` execution failure. + # `MCPError` means "respond with a protocol error": re-raise so the kernel surfaces it + # as a top-level JSON-RPC error, not a `CallToolResult(isError=True)` execution failure. raise except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index 9e7910ea93..9dfaea34d9 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -29,11 +29,9 @@ def __init__(self, warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | self.warn_on_duplicate_tools = warn_on_duplicate_tools def get_tool(self, name: str) -> Tool | None: - """Get tool by name.""" return self._tools.get(name) def list_tools(self) -> list[Tool]: - """List all registered tools.""" return list(self._tools.values()) def add_tool( @@ -47,7 +45,7 @@ def add_tool( meta: dict[str, Any] | None = None, structured_output: bool | None = None, ) -> Tool: - """Add a tool to the server.""" + """Register a tool built from `fn`; if the name is already registered, return the existing tool unchanged.""" tool = Tool.from_function( fn, name=name, @@ -67,7 +65,6 @@ def add_tool( return tool def remove_tool(self, name: str) -> None: - """Remove a tool by name.""" if name not in self._tools: raise ToolError(f"Unknown tool: {name}") del self._tools[name] @@ -79,7 +76,6 @@ async def call_tool( context: Context[LifespanContextT, RequestT], convert_result: bool = False, ) -> Any: - """Call a tool by name with arguments.""" tool = self.get_tool(name) if not tool: raise ToolError(f"Unknown tool: {name}") diff --git a/src/mcp/server/mcpserver/utilities/context_injection.py b/src/mcp/server/mcpserver/utilities/context_injection.py index ac7ab82d05..c2592feed9 100644 --- a/src/mcp/server/mcpserver/utilities/context_injection.py +++ b/src/mcp/server/mcpserver/utilities/context_injection.py @@ -11,31 +11,18 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None: - """Find the parameter that should receive the Context object. - - Searches through the function's signature to find a parameter - with a Context type annotation. - - Args: - fn: The function to inspect - - Returns: - The name of the context parameter, or None if not found - """ - # Get type hints to properly resolve string annotations + """Find the name of the parameter annotated with a Context type, or None.""" + # get_type_hints (rather than raw annotations) so string annotations resolve try: hints = typing.get_type_hints(fn) except Exception: # pragma: lax no cover - # If we can't resolve type hints, we can't find the context parameter return None - # Check each parameter's type hint for param_name, annotation in hints.items(): - # Handle direct Context type if inspect.isclass(annotation) and issubclass(annotation, Context): return param_name - # Handle generic types like Optional[Context] + # generic annotations like Optional[Context] origin = typing.get_origin(annotation) if origin is not None: args = typing.get_args(annotation) @@ -52,17 +39,7 @@ def inject_context( context: Any | None, context_kwarg: str | None, ) -> dict[str, Any]: - """Inject context into function kwargs if needed. - - Args: - fn: The function that will be called - kwargs: The current keyword arguments - context: The context object to inject (if any) - context_kwarg: The name of the parameter to inject into - - Returns: - Updated kwargs with context injected if applicable - """ + """Return kwargs with `context` added under `context_kwarg` when both are set.""" if context_kwarg is not None and context is not None: return {**kwargs, context_kwarg: context} return kwargs diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index be4afb4e9b..86c38e2650 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -34,13 +34,9 @@ def _is_input_required_type(obj: Any) -> bool: class StrictJsonSchema(GenerateJsonSchema): - """A JSON schema generator that raises exceptions instead of emitting warnings. - - This is used to detect non-serializable types during schema generation. - """ + """JSON schema generator that raises instead of warning, to detect non-serializable types.""" def emit_warning(self, kind: JsonSchemaWarningKind, detail: str) -> None: - # Raise an exception instead of emitting a warning raise ValueError(f"JSON schema warning: {kind} - {detail}") @@ -48,14 +44,10 @@ class ArgModelBase(BaseModel): """A model representing the arguments to a function.""" def model_dump_one_level(self) -> dict[str, Any]: - """Return a dict of the model's fields, one level deep. - - That is, sub-models etc are not dumped - they are kept as Pydantic models. - """ + """Return a dict of the model's fields one level deep; sub-models stay as Pydantic models.""" kwargs: dict[str, Any] = {} for field_name, field_info in self.__class__.model_fields.items(): value = getattr(self, field_name) - # Use the alias if it exists, otherwise use the field name output_name = field_info.alias if field_info.alias else field_name kwargs[output_name] = value return kwargs @@ -70,10 +62,9 @@ class FuncMetadata(BaseModel): wrap_output: bool = False def validate_arguments(self, arguments_to_validate: dict[str, Any]) -> dict[str, Any]: - """Validate raw arguments into a one-level kwargs dict (no function call). + """Validate raw arguments into a one-level kwargs dict, without calling the function. - Used to feed resolver dependency injection the validated tool arguments - before the tool function itself runs. + Feeds resolver dependency injection the validated tool arguments before the tool runs. """ arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) @@ -89,10 +80,8 @@ async def call_fn_with_arg_validation( ) -> Any: """Call the given function with arguments validated and injected. - Arguments are first attempted to be parsed from JSON, then validated against - the argument model, before being passed to the function. Pass `pre_validated` - (the output of `validate_arguments`) to reuse an earlier validation pass - - validating twice can re-run `default_factory`/stateful validators and hand the + Pass `pre_validated` (the output of `validate_arguments`) to reuse an earlier validation + pass - validating twice can re-run `default_factory`/stateful validators and hand the function different values than a caller already observed. """ # Copy so a caller-provided `pre_validated` dict is never mutated in place. @@ -110,16 +99,11 @@ async def call_fn_with_arg_validation( def convert_result(self, result: Any) -> CallToolResult | InputRequiredResult: """Convert a function call result into a `CallToolResult`. - An `InputRequiredResult` is passed through unchanged so the multi-round - flow surfaces on the wire as `resultType: "input_required"` rather than - being JSON-dumped into a text block. - - Note: we build unstructured content here **even though the lowlevel server - tool call handler provides generic backwards compatibility serialization of - structured content**. This is for MCPServer backwards compatibility: we need to - retain MCPServer's ad hoc conversion logic for constructing unstructured output - from function return values, whereas the lowlevel server simply serializes - the structured output. + An `InputRequiredResult` passes through unchanged so the multi-round flow surfaces + on the wire as `resultType: "input_required"` instead of being JSON-dumped into a + text block. Unstructured content is built here rather than left to the lowlevel + server's generic serialization, to retain MCPServer's historical ad hoc conversion + of function return values. """ if isinstance(result, InputRequiredResult): return result @@ -144,22 +128,15 @@ def convert_result(self, result: Any) -> CallToolResult | InputRequiredResult: return CallToolResult(content=unstructured_content, structured_content=structured_content) def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: - """Pre-parse data from JSON. + """Return `data` with string values parsed as JSON where appropriate. - Return a dict with the same keys as input but with values parsed from JSON - if appropriate. - - This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside - a string rather than an actual list. Claude Desktop is prone to this - in fact - it seems incapable of NOT doing this. For sub-models, it tends to pass - dicts (JSON objects) as JSON strings, which can be pre-parsed here. + Handles clients (notably Claude Desktop) that pass lists and sub-model dicts as + JSON inside strings, e.g. `'["a", "b", "c"]'` for a list parameter. """ - new_data = data.copy() # Shallow copy + new_data = data.copy() - # Build a mapping from input keys (including aliases) to field info key_to_field_info: dict[str, FieldInfo] = {} for field_name, field_info in self.arg_model.model_fields.items(): - # Map both the field name and its alias (if any) to the field info key_to_field_info[field_name] = field_info if field_info.alias: key_to_field_info[field_info.alias] = field_info @@ -173,11 +150,9 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: try: pre_parsed = json.loads(data_value) except json.JSONDecodeError: - continue # Not JSON - skip + continue if isinstance(pre_parsed, str | int | float): - # This is likely that the raw value is e.g. `"hello"` which we - # Should really be parsed as '"hello"' in Python - but if we parse - # it as JSON it'll turn into just 'hello'. So we skip it. + # A raw value like `"hello"` would lose its quotes if parsed as JSON, so skip it. continue new_data[data_key] = pre_parsed assert new_data.keys() == data.keys() @@ -193,47 +168,20 @@ def func_metadata( skip_names: Sequence[str] = (), structured_output: bool | None = None, ) -> FuncMetadata: - """Given a function, return metadata including a Pydantic model representing its signature. + """Return metadata for `func`: an argument model, plus an output model/schema when structured. - The use case for this is - ``` - meta = func_metadata(func) - validated_args = meta.arg_model.model_validate(some_raw_data_dict) - return func(**validated_args.model_dump_one_level()) - ``` - - **critically** it also provides a pre-parse helper to attempt to parse things from - JSON. + For structured output, BaseModel return types are used directly; TypedDicts, dataclasses, + and other annotated classes are converted to Pydantic models; primitives and generic types + (list, dict, Union, None, etc.) are wrapped in a model with a `result` field (`wrap_output=True`). Args: - func: The function to convert to a Pydantic model - skip_names: A list of parameter names to skip. These will not be included in - the model. - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool - - If structured, creates a Pydantic model for the function's result based on its annotation. - Supports various return types: - - BaseModel subclasses (used directly) - - Primitive types (str, int, float, bool, bytes, None) - wrapped in a - model with a 'result' field - - TypedDict - converted to a Pydantic model with same fields - - Dataclasses and other annotated classes - converted to Pydantic models - - Generic types (list, dict, Union, etc.) - wrapped in a model with a 'result' field - - Returns: - A FuncMetadata object containing: - - arg_model: A Pydantic model representing the function's arguments - - output_model: A Pydantic model for the return type if the output is structured - - wrap_output: Whether the function result needs to be wrapped in `{"result": ...}` for structured output. + skip_names: Parameter names to exclude from the argument model. + structured_output: None auto-detects from the return annotation, True requires a + serializable return annotation, False unconditionally disables structured output. """ try: sig = inspect.signature(func, eval_str=True) except NameError as e: # pragma: no cover - # This raise could perhaps be skipped, and we (MCPServer) just call - # model_rebuild right before using it 🀷 raise InvalidSignature(f"Unable to evaluate type annotations for callable {func.__name__!r}") from e params = sig.parameters dynamic_pydantic_model_params: dict[str, Any] = {} @@ -250,12 +198,9 @@ def func_metadata( if param.annotation is inspect.Parameter.empty: field_metadata.append(WithJsonSchema({"title": param.name, "type": "string"})) - # Check if the parameter name conflicts with BaseModel attributes - # This is necessary because Pydantic warns about shadowing parent attributes + # Alias params that shadow BaseModel attributes, to avoid Pydantic's shadowing warning. if hasattr(BaseModel, field_name) and callable(getattr(BaseModel, field_name)): - # Use an alias to avoid the shadowing warning field_kwargs["alias"] = field_name - # Use a prefixed field name field_name = f"field_{field_name}" if param.default is not inspect.Parameter.empty: @@ -275,8 +220,6 @@ def func_metadata( if structured_output is False: return FuncMetadata(arg_model=arguments_model) - # set up structured output support based on return type annotation - if sig.return_annotation is inspect.Parameter.empty and structured_output is True: raise InvalidSignature(f"Function {func.__name__}: return annotation required for structured output") @@ -287,32 +230,28 @@ def func_metadata( return_type_expr = inspected_return_ann.type - # `AnnotationSource.FUNCTION` allows no type qualifier to be used, so `return_type_expr` is guaranteed to *not* be - # unknown (i.e. a bare `Final`). + # `AnnotationSource.FUNCTION` forbids type qualifiers, so the type is never UNKNOWN (a bare `Final`). assert return_type_expr is not UNKNOWN if _is_input_required_type(return_type_expr): # A tool annotated to return only InputRequiredResult never produces structured content. return FuncMetadata(arg_model=arguments_model) - # The annotation fed to schema derivation. Starts as the raw return annotation (preserving any - # Annotated[...] wrapper) and is narrowed below if InputRequiredResult arms are stripped. + # The annotation fed to schema derivation; narrowed below if InputRequiredResult arms are stripped. effective_annotation: Any = sig.return_annotation if is_union_origin(get_origin(return_type_expr)): args = get_args(return_type_expr) - # InputRequiredResult is a control-flow signal, not data: strip it so the residual arms - # drive schema derivation. convert_result short-circuits on an InputRequiredResult instance - # before output validation, so the schema only ever sees the data arms at runtime. + # InputRequiredResult is a control-flow signal, not data: strip it so the residual arms drive + # schema derivation. convert_result short-circuits on instances before output validation. residual = tuple(a for a in args if not _is_input_required_type(a)) if not residual: return FuncMetadata(arg_model=arguments_model) if len(residual) != len(args): # PEP 604 has no syntax for "union of a runtime tuple"; Union[...] is the only spelling. effective_annotation = residual[0] if len(residual) == 1 else Union[residual] # noqa: UP007 - # Re-normalize so the residual is processed exactly as if it had been the declared - # return annotation: unwraps a top-level Annotated[...] arm and re-derives metadata, - # so the CallToolResult/BaseModel/TypedDict dispatch below sees the bare type. + # Re-inspect so the residual is processed exactly as if declared: unwraps a top-level + # Annotated[...] arm so the dispatch below sees the bare type. inspected_return_ann = inspect_annotation(effective_annotation, annotation_source=AnnotationSource.FUNCTION) return_type_expr = inspected_return_ann.type if len(residual) > 1 and any( @@ -324,21 +263,18 @@ def func_metadata( ) original_annotation: Any - # if the typehint is CallToolResult, the user either intends to return without validation - # or they provided validation as Annotated metadata + # A CallToolResult hint means return-without-validation, unless validation was provided as + # Annotated metadata. if isinstance(return_type_expr, type) and issubclass(return_type_expr, CallToolResult): if inspected_return_ann.metadata: return_type_expr = inspected_return_ann.metadata[0] if len(inspected_return_ann.metadata) >= 2: - # Reconstruct the original annotation, by preserving the remaining metadata, - # i.e. from `Annotated[CallToolResult, ReturnType, Gt(1)]` to - # `Annotated[ReturnType, Gt(1)]`: + # Preserve remaining metadata: Annotated[CallToolResult, ReturnType, Gt(1)] -> + # Annotated[ReturnType, Gt(1)]. original_annotation = Annotated[ (return_type_expr, *inspected_return_ann.metadata[1:]) ] # pragma: no cover else: - # We only had `Annotated[CallToolResult, ReturnType]`, treat the original annotation - # as being `ReturnType`: original_annotation = return_type_expr else: return FuncMetadata(arg_model=arguments_model) @@ -350,7 +286,6 @@ def func_metadata( ) if output_model is None and structured_output is True: - # Model creation failed or produced warnings - no structured output raise InvalidSignature( f"Function {func.__name__}: return type {return_type_expr} is not serializable for structured output" ) @@ -368,82 +303,61 @@ def _try_create_model_and_schema( type_expr: Any, func_name: str, ) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]: - """Try to create a model and schema for the given annotation without warnings. + """Try to create an output model and schema for the given return annotation. - Args: - original_annotation: The original return annotation (may be wrapped in `Annotated`). - type_expr: The underlying type expression derived from the return annotation - (`Annotated` and type qualifiers were stripped). - func_name: The name of the function. + `type_expr` is `original_annotation` with `Annotated` wrappers and type qualifiers stripped. Returns: - tuple of (model or None, schema or None, wrap_output) - Model and schema are None if warnings occur or creation fails. - wrap_output is True if the result needs to be wrapped in {"result": ...} + (model, schema, wrap_output); model and schema are None if schema generation fails or + warns. wrap_output means the result must be wrapped in `{"result": ...}`. """ model = None wrap_output = False - # First handle special case: None if type_expr is None: model = _create_wrapped_model(func_name, original_annotation) wrap_output = True - # Handle GenericAlias types (list[str], dict[str, int], Union[str, int], etc.) elif isinstance(type_expr, GenericAlias): origin = get_origin(type_expr) - # Special case: dict with string keys can use RootModel if origin is dict: args = get_args(type_expr) if len(args) == 2 and args[0] is str: - # TODO: should we use the original annotation? We are losing any potential `Annotated` - # metadata for Pydantic here: + # TODO: use original_annotation? Any `Annotated` metadata for Pydantic is lost here. model = _create_dict_model(func_name, type_expr) else: - # dict with non-str keys needs wrapping model = _create_wrapped_model(func_name, original_annotation) wrap_output = True else: - # All other generic types need wrapping (list, tuple, Union, Optional, etc.) model = _create_wrapped_model(func_name, original_annotation) wrap_output = True - # Handle regular type objects elif isinstance(type_expr, type): type_annotation = cast(type[Any], type_expr) - # Case 1: BaseModel subclasses (can be used directly) if issubclass(type_annotation, BaseModel): model = type_annotation - # Case 2: TypedDicts: elif is_typeddict(type_annotation): model = _create_model_from_typeddict(type_annotation) - # Case 3: Primitive types that need wrapping elif type_annotation in (str, int, float, bool, bytes, type(None)): model = _create_wrapped_model(func_name, original_annotation) wrap_output = True - # Case 4: Other class types (dataclasses, regular classes with annotations) else: type_hints = get_type_hints(type_annotation) if type_hints: - # Classes with type hints can be converted to Pydantic models model = _create_model_from_class(type_annotation, type_hints) - # Classes without type hints are not serializable - model remains None + # Classes without type hints aren't serializable; model stays None. - # Handle any other types not covered above else: - # This includes typing constructs that aren't GenericAlias in Python 3.10 - # (e.g., Union, Optional in some Python versions) + # Typing constructs that aren't GenericAlias on Python 3.10 (e.g. Union, Optional). model = _create_wrapped_model(func_name, original_annotation) wrap_output = True if model: - # If we successfully created a model, try to get its schema - # Use StrictJsonSchema to raise exceptions instead of warnings try: schema = model.model_json_schema(schema_generator=StrictJsonSchema) except ( @@ -453,12 +367,9 @@ def _try_create_model_and_schema( pydantic_core.SchemaError, pydantic_core.ValidationError, ) as e: - # These are expected errors when a type can't be converted to a Pydantic schema - # PydanticUserError: When Pydantic can't handle the type (e.g. PydanticInvalidForJsonSchema); - # subclasses TypeError on pydantic <2.13 and RuntimeError on pydantic >=2.13 - # ValueError: When there are issues with the type definition (including our custom warnings) - # SchemaError: When Pydantic can't build a schema - # ValidationError: When validation fails + # Expected when a type can't become a Pydantic schema; ValueError includes + # StrictJsonSchema's converted warnings. PydanticUserError subclasses TypeError on + # pydantic <2.13 and RuntimeError on pydantic >=2.13. logger.info(f"Cannot create schema for type {type_expr} in {func_name}: {type(e).__name__}: {e}") return None, None, False @@ -471,14 +382,9 @@ def _try_create_model_and_schema( def _create_model_from_class(cls: type[Any], type_hints: dict[str, Any]) -> type[BaseModel]: - """Create a Pydantic model from an ordinary class. - - The created model will: - - Have the same name as the class - - Have fields with the same names and types as the class's fields - - Include all fields whose type does not include None in the set of required fields + """Create a Pydantic model mirroring an ordinary class's name and type-hinted fields. - Precondition: cls must have type hints (i.e., `type_hints` is non-empty) + Precondition: `type_hints` is non-empty. """ model_fields: dict[str, Any] = {} for field_name, field_type in type_hints.items(): @@ -495,19 +401,14 @@ def _create_model_from_class(cls: type[Any], type_hints: dict[str, Any]) -> type def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]: - """Create a Pydantic model from a TypedDict. - - The created model will have the same name and fields as the TypedDict. - """ + """Create a Pydantic model with the same name and fields as the TypedDict.""" type_hints = get_type_hints(td_type) required_keys = getattr(td_type, "__required_keys__", set(type_hints.keys())) model_fields: dict[str, Any] = {} for field_name, field_type in type_hints.items(): if field_name not in required_keys: - # For optional TypedDict fields, set default=None - # This makes them not required in the Pydantic model - # The model should use exclude_unset=True when dumping to get TypedDict semantics + # Non-required keys default to None; dump with exclude_unset=True for TypedDict semantics. model_fields[field_name] = (field_type, None) else: model_fields[field_name] = field_type @@ -516,10 +417,7 @@ def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]: def _create_wrapped_model(func_name: str, annotation: Any) -> type[BaseModel]: - """Create a model that wraps a type in a 'result' field. - - This is used for primitive types, generic types like list/dict, etc. - """ + """Create a model that wraps a type in a `result` field.""" model_name = f"{func_name}Output" return create_model(model_name, result=annotation) @@ -533,7 +431,6 @@ def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]: class DictModel(RootModel[dict_annotation]): pass - # Give it a meaningful name DictModel.__name__ = f"{func_name}DictOutput" DictModel.__qualname__ = f"{func_name}DictOutput" @@ -541,12 +438,10 @@ class DictModel(RootModel[dict_annotation]): def _convert_to_content(result: Any) -> list[ContentBlock]: - """Convert a result to a sequence of content objects. + """Convert a result to a list of content blocks. - Note: This conversion logic comes from previous versions of MCPServer and is being - retained for purposes of backwards compatibility. It produces different unstructured - output than the lowlevel server tool call handler, which just serializes structured - content verbatim. + Retained from previous MCPServer versions for backwards compatibility; produces different + unstructured output than the lowlevel server, which serializes structured content verbatim. """ if result is None: # pragma: no cover return [] diff --git a/src/mcp/server/mcpserver/utilities/logging.py b/src/mcp/server/mcpserver/utilities/logging.py index 04ca38853b..4db7e27106 100644 --- a/src/mcp/server/mcpserver/utilities/logging.py +++ b/src/mcp/server/mcpserver/utilities/logging.py @@ -5,25 +5,14 @@ def get_logger(name: str) -> logging.Logger: - """Get a logger nested under MCP namespace. - - Args: - name: The name of the logger. - - Returns: - A configured logger instance. - """ + """Get a logger nested under MCP namespace.""" return logging.getLogger(name) def configure_logging( level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", ) -> None: - """Configure logging for MCP. - - Args: - level: The log level to use. - """ + """Configure logging for MCP.""" handlers: list[logging.Handler] = [] try: from rich.console import Console diff --git a/src/mcp/server/mcpserver/utilities/types.py b/src/mcp/server/mcpserver/utilities/types.py index 937a7fa9b6..f8b3b941c6 100644 --- a/src/mcp/server/mcpserver/utilities/types.py +++ b/src/mcp/server/mcpserver/utilities/types.py @@ -26,7 +26,6 @@ def __init__( self._mime_type = self._get_mime_type() def _get_mime_type(self) -> str: - """Get MIME type from format or guess from file extension.""" if self._format: return f"image/{self._format.lower()}" @@ -39,7 +38,7 @@ def _get_mime_type(self) -> str: ".gif": "image/gif", ".webp": "image/webp", }.get(suffix, "application/octet-stream") - return "image/png" # default for raw binary data + return "image/png" def to_image_content(self) -> ImageContent: """Convert to MCP ImageContent.""" @@ -72,7 +71,6 @@ def __init__( self._mime_type = self._get_mime_type() def _get_mime_type(self) -> str: - """Get MIME type from format or guess from file extension.""" if self._format: return f"audio/{self._format.lower()}" @@ -86,7 +84,7 @@ def _get_mime_type(self) -> str: ".aac": "audio/aac", ".m4a": "audio/mp4", }.get(suffix, "application/octet-stream") - return "audio/wav" # default for raw binary data + return "audio/wav" def to_audio_content(self) -> AudioContent: """Convert to MCP AudioContent.""" diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index 6b129165a1..be9aa4dea8 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -1,6 +1,4 @@ -"""This module provides simplified types to use with the server for managing prompts -and tools. -""" +"""Simplified types for use with the server.""" from mcp_types import Icon, ServerCapabilities from pydantic import BaseModel diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 4c25a8a5bc..12fbc28b18 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -1,14 +1,9 @@ """`ServerRunner` - the per-connection handler kernel. -`ServerRunner` bridges the dispatch layer (`on_request` / `on_notify`, untyped -dicts) and the user's handler layer (typed `Context`, typed params). It is a -pure kernel: it holds a pre-populated `Connection` and reads -`connection.protocol_version` / `connection.outbound` as facts. Driving a -dispatcher loop and tearing down the connection live in the free-function -drivers (`serve_connection`, `serve_loop`, `serve_one`); the entry constructs -the `Connection`, the driver tears it down. - -`ServerRunner` holds a `Server` directly - `Server` is the registry. +Bridges the dispatch layer (`on_request`/`on_notify`, untyped dicts) and the +user's typed handler layer. The free-function drivers (`serve_connection`, +`serve_loop`, `serve_one`) drive the dispatcher and tear down the +`Connection`; the entry constructs it. """ from __future__ import annotations @@ -75,13 +70,11 @@ _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) _EXIT_STACK_CLOSE_TIMEOUT: float = 5 -"""Bound for `aclose_shielded`'s exit-stack unwind; a hung cleanup callback -must not wedge shutdown.""" +"""Bound for `aclose_shielded`; a hung cleanup callback must not wedge shutdown.""" def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None: - """Lift `_meta` from raw params; `None` when absent or malformed, so - context construction is independent of params validity.""" + """Lift `_meta` from raw params; `None` when absent or malformed, so context construction never fails.""" if not params or "_meta" not in params: return None try: @@ -94,8 +87,7 @@ def _dump_result(result: Any) -> dict[str, Any]: if result is None: return {} if isinstance(result, ErrorData): - # ErrorData is a JSON-RPC error, not a success result. Handler returns - # already raise in `_inner`; this catches middleware returning one. + # Handler returns already raise in `_inner`; this catches middleware returning an ErrorData. raise MCPError.from_error_data(result) if isinstance(result, BaseModel): return result.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -105,13 +97,10 @@ def _dump_result(result: Any) -> dict[str, Any]: async def aclose_shielded(connection: Connection) -> None: - """Unwind ``connection.exit_stack`` under a shielded, bounded scope. + """Unwind `connection.exit_stack` under a shielded, bounded scope. - Called from a driver's ``finally``: the shield lets per-connection cleanup - callbacks run even when the driver itself is being cancelled, the - `_EXIT_STACK_CLOSE_TIMEOUT` bound stops a hung callback wedging shutdown, - and a raising callback is logged-and-swallowed so it never masks the - driver's own exception. + For driver `finally` blocks: cleanup runs even under cancellation, a hung callback cannot + wedge shutdown, and a raising callback is logged so it never masks the driver's own exception. """ with anyio.move_on_after(_EXIT_STACK_CLOSE_TIMEOUT, shield=True) as scope: try: @@ -128,8 +117,7 @@ async def aclose_shielded(connection: Connection) -> None: def _apply_middleware( middleware: ServerMiddleware[Any], call_next: CallNext, ctx: ServerRequestContext[Any, Any] ) -> Awaitable[HandlerResult]: - """Adapt one middleware to the `CallNext` shape: bind `call_next`, take - `ctx` at call time so a rewritten context flows down the chain.""" + """Bind `call_next`; take `ctx` at call time so a rewritten context flows down the chain.""" return middleware(ctx, call_next) @@ -163,63 +151,49 @@ async def _on_request( ctx = self._make_context(dctx, method, params, meta, version) async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> HandlerResult: - # Read method/params off `ctx` so a middleware that rewrote them via - # `call_next(replace(ctx, ...))` reaches lookup and the handler. + # Read off `ctx` so a middleware rewrite via `call_next(replace(ctx, ...))` takes effect. method, params = ctx.method, ctx.params - # Pinned compat: spec methods are surface-validated before lookup, - # so malformed params are INVALID_PARAMS even with no handler - # registered. Custom methods miss the monolith map and fall through - # to `entry.params_type` exactly as before. + # Pinned compat: spec methods are surface-validated before lookup, so malformed params + # are INVALID_PARAMS even with no handler; custom methods fall through to `entry.params_type`. if method in _methods.SPEC_CLIENT_METHODS: try: _methods.validate_client_request(method, version, params) except KeyError: raise MCPError(code=METHOD_NOT_FOUND, message="Method not found", data=method) from None - # TODO(L29): the 2026-07-28 spec drops the handshake; this branch and - # the gate become a per-version legacy path then. Initialize runs inline - # (read loop parked), so awaiting the peer anywhere on this path deadlocks. + # TODO(L29): the 2026-07-28 spec drops the handshake, making this branch and the gate a + # per-version legacy path. Initialize runs inline (read loop parked); awaiting the peer here deadlocks. if method == "initialize": return self._serialize(method, version, self._handle_initialize(params)) - # Methods without a handler are METHOD_NOT_FOUND regardless of - # initialization state: JSON-RPC 2.0 reserves -32601 for "not - # available on this server", and clients probing a server before - # the handshake key off that code. The init gate below therefore - # only ever applies to methods the server actually serves. + # No handler is METHOD_NOT_FOUND regardless of init state: JSON-RPC 2.0 reserves -32601 and + # pre-handshake probes key off it, so the init gate below only applies to served methods. entry = self.server.get_request_handler(method) if entry is None: raise MCPError(code=METHOD_NOT_FOUND, message="Method not found", data=method) if not self.connection.initialize_accepted and method not in _INIT_EXEMPT: # Pinned compat: the same error shape the union validation produced. raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") - # Absent params validate as {} (required fields still reject), so - # the handler receives the model with its defaults, never None. + # Absent params validate as {} (required fields still reject): the handler gets defaults, never None. typed_params = entry.params_type.model_validate({} if params is None else params, by_name=False) result = await entry.handler(ctx, typed_params) if isinstance(result, ErrorData): # Raise inside the chain so middleware observes the failure. raise MCPError.from_error_data(result) - # Fill cache hints on the typed result, before the serialize sieve - # decides whether the negotiated version carries the fields at all. - # `input_required` interim results are not `CacheableResult` models, - # so the MRTR carve-out (no hints on them) holds by shape. + # Fill cache hints before the serialize sieve drops version-absent fields; `input_required` + # interim results are not `CacheableResult`, so the MRTR carve-out (no hints) holds by shape. if isinstance(result, CacheableResult) and (hint := self.server.cache_hints.get(method)) is not None: result = apply_cache_hint(result, hint) - # Dump and serialize inside the chain so the OpenTelemetry span (the - # outermost middleware) records a failing handler return shape too. return self._serialize(method, version, result) call = self._compose_server_middleware(_inner) - # `_inner` already produced the wire dict; a middleware that short-circuited - # without `call_next` is trusted to return its own well-formed result. + # `_inner` already produced the wire dict; a middleware that short-circuited without + # `call_next` is trusted to return its own well-formed result. result = _dump_result(await call(ctx)) if method == "initialize": - # Commit only on chain success, so a middleware veto leaves no state. - # Race-free: the read loop is parked until this call returns. - # TODO: this re-reads the wire `params`, so a middleware that rewrote - # `ctx.params` (or `ctx.method`, or short-circuited without `call_next`) - # can leave `connection.protocol_version` out of step with the - # `InitializeResult` `_inner` produced. Resolve when `initialize` becomes - # a built-in handler so commit and result derive from one negotiation. + # Commit only on chain success, so a middleware veto leaves no state. Race-free: the + # read loop is parked until this call returns. + # TODO: re-reads the wire `params`, so a middleware that rewrote `ctx.params`/`ctx.method` or + # short-circuited can desync `connection.protocol_version` from the `InitializeResult`; + # resolve when `initialize` becomes a built-in handler. self.connection.client_params, self.connection.protocol_version = self._negotiate_initialize(params) return result @@ -245,9 +219,8 @@ async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> None: logger.warning("dropped %r: malformed params", method) return if method == "notifications/initialized": - # Surface validation above already rejected a malformed body, so - # commit; fall through so a registered handler observes an - # initialized connection. + # Surface validation above already rejected a malformed body, so commit; fall + # through so a registered handler observes an initialized connection. self.connection.initialized.set() elif not self.connection.initialize_accepted: logger.debug("dropped %s: received before initialization", method) @@ -268,17 +241,11 @@ async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> None: try: await call(ctx) except Exception: - # A crashing handler must not cancel the dispatcher's task group; - # middleware saw the raise out of call_next() first. + # A crashing handler must not cancel the dispatcher's task group; middleware saw the raise first. logger.exception("notification handler for %r raised", method) def _compose_server_middleware(self, inner: CallNext) -> CallNext: - """Wrap `inner` in `Server.middleware`, outermost-first. - - Shared by `_on_request` and `_on_notify` so the same middleware chain - observes every inbound message. The composed callable takes the `ctx` - at call time, so a middleware can rewrite it for the rest of the chain. - """ + """Wrap `inner` in `Server.middleware`, outermost-first; one shared chain sees every inbound message.""" call = inner for middleware in reversed(self.server.middleware): call = partial(_apply_middleware, middleware, call) @@ -292,9 +259,8 @@ def _make_context( meta: RequestParamsMeta | None, protocol_version: str, ) -> ServerRequestContext[LifespanT, Any]: - # TODO(L54): remove for Context rework. Reads the SHTTP per-request - # data off the raw `dctx.message_metadata` carrier; replace with the - # per-transport context once that lands. + # TODO(L54): reads SHTTP per-request data off the raw `dctx.message_metadata` carrier; + # replace with the per-transport context once the Context rework lands. md = dctx.message_metadata if isinstance(md, ServerMessageMetadata): request = md.request_context @@ -302,9 +268,9 @@ def _make_context( close_standalone_sse_stream = md.close_standalone_sse_stream else: request = close_sse_stream = close_standalone_sse_stream = None - # Per-request session: `dctx` is the request-scoped channel (auto-threads - # its own request_id on streamable HTTP); the standalone channel is read - # off `connection.outbound`. `related_request_id` on the public API selects. + # Per-request session: `dctx` is the request-scoped channel (auto-threads its request_id on + # streamable HTTP); the standalone channel comes off `connection.outbound`. `related_request_id` + # on the public API selects between them. session = ServerSession(dctx, self.connection) return ServerRequestContext( session=session, @@ -323,9 +289,7 @@ def _make_context( def _serialize(method: str, version: str, result: HandlerResult) -> dict[str, Any]: """Dump a handler result to the wire dict, serializing spec methods. - Runs inside the middleware chain so the OpenTelemetry span observes a - failing return shape (unsupported type, malformed spec result) as an - error rather than closing on a request that the client sees fail. + Runs inside the middleware chain so the OpenTelemetry span observes a failing return shape as an error. """ dumped = _dump_result(result) # TODO(L56): reject resultType values outside {"complete", "input_required"} unless the @@ -336,8 +300,7 @@ def _serialize(method: str, version: str, result: HandlerResult) -> dict[str, An try: return _methods.serialize_server_result(method, version, dumped) except ValidationError: - # Server bug, not client fault. Detail stays in the server log: - # pydantic messages echo the result body. + # Server bug, not client fault; pydantic detail (echoes the result body) stays in the log. logger.exception("handler for %r returned an invalid result", method) raise MCPError(code=INTERNAL_ERROR, message="Handler returned an invalid result") from None @@ -377,12 +340,10 @@ async def serve_connection( init_options: InitializationOptions | None = None, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: - """Drive ``dispatcher`` until the underlying channel closes. + """Drive `dispatcher` until the underlying channel closes. - The loop-mode driver: builds the kernel, hands `on_request`/`on_notify` - to `dispatcher.run()`, and tears down `connection.exit_stack` (shielded) - on the way out. The entry constructs the `Connection`; this only consumes - it. + The loop-mode driver: tears down `connection.exit_stack` (shielded) on the way out. + The entry constructs the `Connection`; this only consumes it. """ runner = ServerRunner(server, connection, lifespan_state, init_options=init_options) try: @@ -401,21 +362,18 @@ async def serve_loop( init_options: InitializationOptions | None = None, raise_exceptions: bool = False, ) -> None: - """Drive ``server`` in loop mode over a stream pair until the channel closes. + """Drive `server` in loop mode over a stream pair until the channel closes. - Builds the loop-mode `JSONRPCDispatcher` + `Connection` and hands them to - `serve_connection`, so loop-mode callers share one dispatcher-construction - recipe (notably the `inline_methods={"initialize"}` rule). Callers that own - a lifespan (the streamable-HTTP manager) pass it in; callers that don't - (`Server.run` for stdio/memory) enter the lifespan and then call this. + Builds the loop-mode `JSONRPCDispatcher` + `Connection` for `serve_connection` so loop-mode + callers share one dispatcher-construction recipe. Callers owning a lifespan (the + streamable-HTTP manager) pass its state in; `Server.run` (stdio/memory) enters the lifespan first. """ dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( read_stream, write_stream, raise_handler_exceptions=raise_exceptions, - # Handle `initialize` inline so a client that pipelines it with the - # next request (spec: SHOULD NOT, not MUST NOT) sees the initialized - # state instead of failing the init-gate. + # Handle `initialize` inline so a client that pipelines it with the next request + # (spec: SHOULD NOT, not MUST NOT) sees initialized state instead of failing the gate. inline_methods=frozenset({"initialize"}), ) connection = Connection.for_loop(dispatcher, session_id=session_id) @@ -433,15 +391,12 @@ async def serve_one( connection: Connection, lifespan_state: LifespanT, ) -> dict[str, Any]: - """Handle a single request ``(method, params)`` and return its result dict. + """Handle a single request `(method, params)` and return its result dict. - The single-exchange driver: builds the kernel, runs `on_request` once under - `dctx`, and tears down `connection.exit_stack` (shielded) on the way out. - The entry constructs the (born-ready) `Connection` and the `dctx`; this - only consumes them. + The single-exchange driver: tears down `connection.exit_stack` (shielded) on the way out. + The entry constructs the (born-ready) `Connection` and `dctx`; this only consumes them. - Raises whatever the handler chain raises (`MCPError` / `ValidationError` / - unmapped); callers own the exception-to-wire mapping. + Raises whatever the handler chain raises; callers own the exception-to-wire mapping. """ runner = ServerRunner(server, connection, lifespan_state) try: @@ -453,11 +408,9 @@ async def serve_one( def modern_on_request(server: Server[LifespanT], lifespan_state: LifespanT) -> OnRequest: """Return an `OnRequest` callback that serves each call via `serve_one` with a fresh per-request `Connection`. - Wire this into the server side of a `DirectDispatcher` peer-pair to drive an - in-process server on the modern per-request-envelope path (each request - carries protocol version, client info, and capabilities in `params._meta`; - no `initialize` handshake). Like `serve_one`, this raises whatever the - handler chain raises - the dispatcher owns the exception-to-error mapping. + For the server side of a `DirectDispatcher` peer-pair on the modern per-request-envelope + path (protocol version, client info, and capabilities ride in `params._meta`; no + `initialize` handshake). Like `serve_one`, raises whatever the handler chain raises. """ async def handle( diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ca62fb9c8e..337d449ee7 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -1,9 +1,6 @@ """`ServerSession`: server-to-client requests and notifications. -A per-request proxy built by the kernel for each inbound request. Exposes the -request-scoped outbound channel and the connection's standalone channel. -Handlers reach it as `ctx.session` and use the typed helpers (`elicit_form`, -`send_log_message`, ...) to call back to the client. +Handlers reach it as `ctx.session` and use the typed helpers to call back to the client. """ from typing import Any, TypeVar, overload @@ -27,13 +24,10 @@ class ServerSession: """Per-request proxy for server-to-client requests and notifications. - Built once per inbound request by the kernel's `_make_context`. Holds two - `Outbound` channels: the request-scoped one (the per-request - `DispatchContext`, which on streamable HTTP routes onto the originating - POST's response stream) and the connection's standalone channel - (`connection.outbound`). `related_request_id` on the public methods is the - selector β€” present means request-scoped, absent means standalone β€” and - never crosses the `Outbound` Protocol. + Built by the kernel per inbound request. Holds two `Outbound` channels: the request-scoped + `DispatchContext` (on streamable HTTP, the originating POST's response stream) and the + connection's standalone channel. `related_request_id` on the public methods selects between + them β€” present means request-scoped, absent standalone β€” and never crosses the `Outbound` Protocol. """ def __init__(self, request_outbound: DispatchContext[Any], connection: Connection) -> None: @@ -47,11 +41,7 @@ def client_params(self) -> types.InitializeRequestParams | None: @property def protocol_version(self) -> str: - """The protocol version this connection speaks. - - Populated at `Connection` construction and overwritten once the - handshake commits on the loop path; never `None`. - """ + """The protocol version this connection speaks (set at construction, updated on handshake; never `None`).""" return self._connection.protocol_version async def send_request( @@ -66,8 +56,7 @@ async def send_request( Raises: MCPError: The peer responded with an error. - NoBackChannelError: The connection has no back-channel for - server-initiated requests (raised by the held `Outbound`). + NoBackChannelError: No back-channel for server-initiated requests. pydantic.ValidationError: The peer's result does not match `result_type`. """ related = metadata.related_request_id if metadata is not None else None @@ -185,31 +174,13 @@ async def create_message( ) -> types.CreateMessageResult | types.CreateMessageResultWithTools: """Send a sampling/create_message request. - Args: - messages: The conversation messages to send. - max_tokens: Maximum number of tokens to generate. - system_prompt: Optional system prompt. - include_context: Optional context inclusion setting. - Should only be set to "thisServer" or "allServers" - if the client has sampling.context capability. - temperature: Optional sampling temperature. - stop_sequences: Optional stop sequences. - metadata: Optional metadata to pass through to the LLM provider. - model_preferences: Optional model selection preferences. - tools: Optional list of tools the LLM can use during sampling. - Requires client to have sampling.tools capability. - tool_choice: Optional control over tool usage behavior. - Requires client to have sampling.tools capability. - related_request_id: Optional ID of a related request. - - Returns: - The sampling result from the client. + `include_context` of "thisServer"/"allServers" requires the client's `sampling.context` + capability; `tools` and `tool_choice` require `sampling.tools`. Raises: - MCPError: If tools are provided but client doesn't support them. - ValueError: If tool_use or tool_result message structure is invalid. - NoBackChannelError: The connection has no back-channel for - server-initiated requests. + MCPError: Tools were provided but the client does not support them. + ValueError: Invalid tool_use or tool_result message structure. + NoBackChannelError: No back-channel for server-initiated requests. """ client_caps = self.client_params.capabilities if self.client_params else None validate_sampling_tools(client_caps, tools, tool_choice) @@ -248,8 +219,7 @@ async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request. Raises: - NoBackChannelError: The connection has no back-channel for - server-initiated requests. + NoBackChannelError: No back-channel for server-initiated requests. """ return await self.send_request( types.ListRootsRequest(), @@ -264,17 +234,7 @@ async def elicit( ) -> types.ElicitResult: """Send a form mode elicitation/create request. - Args: - message: The message to present to the user. - requested_schema: Schema defining the expected response structure. - related_request_id: Optional ID of the request that triggered this elicitation. - - Returns: - The client's response. - - Note: - This method is deprecated in favor of elicit_form(). It remains for - backward compatibility but new code should use elicit_form(). + Deprecated: use `elicit_form()`; kept for backward compatibility. """ return await self.elicit_form(message, requested_schema, related_request_id) @@ -286,17 +246,8 @@ async def elicit_form( ) -> types.ElicitResult: """Send a form mode elicitation/create request. - Args: - message: The message to present to the user. - requested_schema: Schema defining the expected response structure. - related_request_id: Optional ID of the request that triggered this elicitation. - - Returns: - The client's response with form data. - Raises: - NoBackChannelError: The connection has no back-channel for - server-initiated requests. + NoBackChannelError: No back-channel for server-initiated requests. """ return await self.send_request( types.ElicitRequest( @@ -318,21 +269,11 @@ async def elicit_url( ) -> types.ElicitResult: """Send a URL mode elicitation/create request. - This directs the user to an external URL for out-of-band interactions - like OAuth flows, credential collection, or payment processing. - - Args: - message: Human-readable explanation of why the interaction is needed. - url: The URL the user should navigate to. - elicitation_id: Unique identifier for tracking this elicitation. - related_request_id: Optional ID of the request that triggered this elicitation. - - Returns: - The client's response indicating acceptance, decline, or cancellation. + Directs the user to an external URL for out-of-band interactions like OAuth flows, + credential collection, or payment processing. Raises: - NoBackChannelError: The connection has no back-channel for - server-initiated requests. + NoBackChannelError: No back-channel for server-initiated requests. """ return await self.send_request( types.ElicitRequest( @@ -356,10 +297,9 @@ async def send_ping(self) -> types.EmptyResult: async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: """Report progress for the inbound request this session is scoped to. - A no-op when the caller did not request progress. Dispatcher-agnostic: - on JSON-RPC the held `DispatchContext` emits ``notifications/progress`` - against the caller's token; on the in-process direct dispatcher it - invokes the caller's callback directly. + A no-op when the caller did not request progress. On JSON-RPC this emits + `notifications/progress` against the caller's token; the in-process direct + dispatcher invokes the caller's callback directly. """ await self._request_outbound.progress(progress, total, message) @@ -403,13 +343,8 @@ async def send_elicit_complete( ) -> None: """Send an elicitation completion notification. - This should be sent when a URL mode elicitation has been completed - out-of-band to inform the client that it may retry any requests - that were waiting for this elicitation. - - Args: - elicitation_id: The unique identifier of the completed elicitation - related_request_id: Optional ID of the request that triggered this notification + Sent when a URL mode elicitation completes out-of-band, telling the client it + may retry requests that were waiting on it. """ await self.send_notification( types.ElicitCompleteNotification( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 4d02fc4a73..7478a6a647 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -1,39 +1,8 @@ -"""SSE Server Transport Module - -This module implements a Server-Sent Events (SSE) transport layer for MCP servers. - -Example: - ```python - # Create an SSE transport at an endpoint - sse = SseServerTransport("/messages/") - - # Create Starlette routes for SSE and message handling - routes = [ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ] - - # Define handler functions - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) - # Return empty response to avoid NoneType error - return Response() - - # Create and run Starlette app - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port) - ``` +"""Server-Sent Events (SSE) transport for MCP servers; see `SseServerTransport`. -Note: The handle_sse function must return a Response to avoid a -"TypeError: 'NoneType' object is not callable" error when client disconnects. The example above returns -an empty Response() after the SSE connection ends to fix this. - -See SseServerTransport class documentation for more details. +Note: the route handler wrapping `connect_sse` must return a Response (e.g. an empty +`Response()`), otherwise Starlette raises "TypeError: 'NoneType' object is not +callable" when the client disconnects. """ import logging @@ -62,55 +31,41 @@ async def handle_sse(request): class SseServerTransport: - """SSE server transport for MCP. This class provides two ASGI applications, - suitable for use with a framework like Starlette and a server like Hypercorn: - - 1. connect_sse() is an ASGI application which receives incoming GET requests, - and sets up a new SSE stream to send server messages to the client. - 2. handle_post_message() is an ASGI application which receives incoming POST - requests, which should contain client messages that link to a - previously-established SSE session. + """SSE server transport for MCP, exposing two ASGI applications. + + `connect_sse` handles GET requests by opening an SSE stream for server-to-client + messages; `handle_post_message` handles POST requests carrying client messages for + a previously established session. """ _endpoint: str _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] - # Identity of the credential that created each session; requests for a - # session must present the same credential. + # Credential that created each session; requests must present the same credential. _session_owners: dict[UUID, AuthorizationContext] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: - """Creates a new SSE server transport, which will direct the client to POST - messages to the relative path given. + """Create an SSE server transport that directs clients to POST messages to `endpoint`. + + The endpoint must be a relative path (e.g. "/messages/"): relative paths keep + clients on the origin that established the SSE connection and let the server be + mounted at any path. Args: - endpoint: A relative path where messages should be posted - (e.g., "/messages/"). - security_settings: Optional security settings for DNS rebinding protection. - - Note: - We use relative paths instead of full URLs for several reasons: - 1. Security: Prevents cross-origin requests by ensuring clients only connect - to the same origin they established the SSE connection with - 2. Flexibility: The server can be mounted at any path without needing to - know its full URL - 3. Portability: The same endpoint configuration works across different - environments (development, staging, production) + security_settings: Settings for DNS rebinding protection. Raises: - ValueError: If the endpoint is a full URL instead of a relative path + ValueError: If `endpoint` is a full URL instead of a relative path. """ super().__init__() - # Validate that endpoint is a relative path and not a full URL if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint: raise ValueError( f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), " "expecting a relative path (e.g., '/messages/')." ) - # Ensure endpoint starts with a forward slash if not endpoint.startswith("/"): endpoint = "/" + endpoint @@ -126,7 +81,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") - # Validate request headers for DNS rebinding protection + # DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) if error_response: @@ -145,19 +100,10 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") - # Determine the full path for the message endpoint to be sent to the client. - # scope['root_path'] is the prefix where the current Starlette app - # instance is mounted. - # e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix". + # scope["root_path"] is this app's mount prefix; prepending it to self._endpoint + # gives the absolute POST path (e.g. "/api_prefix" + "/messages"). root_path = scope.get("root_path", "") - - # self._endpoint is the path *within* this app, e.g., "/messages". - # Concatenating them gives the full absolute path from the server root. - # e.g., "" + "/messages" -> "/messages" - # e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages" full_message_path_for_client = root_path.rstrip("/") + self._endpoint - - # This is the URI (path + query) the client will use to POST messages. client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) @@ -181,10 +127,7 @@ async def sse_writer(): async with anyio.create_task_group() as tg: async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ + """Close our side of the streams once EventSourceResponse returns (client disconnect).""" await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( scope, receive, send ) @@ -206,7 +149,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) logger.debug("Handling POST message") request = Request(scope, receive) - # Validate request headers for DNS rebinding protection + # DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) if error_response: return await error_response(scope, receive, send) @@ -234,8 +177,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) user = scope.get("user") requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None if requestor != self._session_owners.get(session_id): - # A session can only be used with the credential that created it. - # Respond exactly as if the session did not exist. + # Wrong credential: respond exactly as if the session did not exist. logger.warning("Rejecting message for session %s: credential does not match", session_id) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) @@ -253,7 +195,6 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) await writer.send(err) return - # Pass the ASGI scope for framework-agnostic access to request data metadata = ServerMessageMetadata(request_context=request) session_message = SessionMessage(message, metadata=metadata) logger.debug(f"Sending session message to writer: {session_message}") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 876d256ddb..1a2681eafe 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -1,21 +1,4 @@ -"""Stdio Server Transport Module - -This module provides functionality for creating an stdio-based transport layer -that can be used to communicate with an MCP client through standard input/output -streams. - -Example: - ```python - async def run_server(): - async with stdio_server() as (read_stream, write_stream): - # read_stream contains incoming JSONRPCMessages from stdin - # write_stream allows sending JSONRPCMessages to stdout - server = await create_my_server() - await server.run(read_stream, write_stream, init_options) - - anyio.run(run_server) - ``` -""" +"""Stdio-based server transport for communicating with an MCP client over stdin/stdout.""" import sys from contextlib import asynccontextmanager @@ -31,13 +14,9 @@ async def run_server(): @asynccontextmanager async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): - """Server transport for stdio: this communicates with an MCP client by reading - from the current process' stdin and writing to stdout. - """ - # Purposely not using context managers for these, as we don't want to close - # standard process handles. Encoding of stdin/stdout as text streams on - # python is platform-dependent (Windows is particularly problematic), so we - # re-wrap the underlying binary stream to ensure UTF-8. + """Yield (read_stream, write_stream) for JSON-RPC messages over the current process' stdin/stdout.""" + # Deliberately no context managers here β€” standard process handles must not be closed. Re-wrap the + # binary streams to force UTF-8, since text-mode encoding is platform-dependent (notably Windows). if not stdin: stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) if not stdout: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index d316345c7e..705e244d7e 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -1,10 +1,4 @@ -"""StreamableHTTP Server Transport Module - -This module implements an HTTP transport layer with Streamable HTTP. - -The transport handles bidirectional communication using HTTP requests and -responses, with streaming support for long-running operations. -""" +"""StreamableHTTP server transport: bidirectional HTTP communication with SSE streaming support.""" import logging import re @@ -49,15 +43,12 @@ logger = logging.getLogger(__name__) -# Header names MCP_SESSION_ID_HEADER = "mcp-session-id" LAST_EVENT_ID_HEADER = "last-event-id" -# Content types CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_SSE = "text/event-stream" -# Special key for the standalone GET stream GET_STREAM_KEY = "_GET_stream" # Buffer for the per-request `_request_streams` so the serial `message_router` @@ -65,11 +56,9 @@ # whole session on a lazily-started `sse_writer`. See #1764. REQUEST_STREAM_BUFFER_SIZE: Final = 16 -# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) -# Pattern ensures entire string contains only valid characters by using ^ and $ anchors +# Session IDs must contain only visible ASCII (0x21-0x7E) SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") -# Type aliases StreamId = str EventId = str # An SSE event-dict as accepted by sse-starlette (`event`, `data`, `id`, `retry`). @@ -77,13 +66,7 @@ def check_accept_headers(request: Request) -> tuple[bool, bool]: - """Return (has_json, has_sse) for the request's Accept header, with RFC 7231 wildcard handling. - - Supports wildcard media types per RFC 7231, section 5.3.2: - - */* matches any media type - - application/* matches any application/ subtype - - text/* matches any text/ subtype - """ + """Return (has_json, has_sse) for the request's Accept header, honoring RFC 7231 wildcards.""" accept_header = request.headers.get("accept", "") accept_types = [media_type.strip().split(";")[0].strip().lower() for media_type in accept_header.split(",")] @@ -110,14 +93,9 @@ class EventStore(ABC): @abstractmethod async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: - """Stores an event for later retrieval. + """Store an event and return its generated event ID. - Args: - stream_id: ID of the stream the event belongs to - message: The JSON-RPC message to store, or None for priming events - - Returns: - The generated event ID for the stored event. + `message` is None for priming events. """ pass # pragma: no cover @@ -127,11 +105,7 @@ async def replay_events_after( last_event_id: EventId, send_callback: EventCallback, ) -> StreamId | None: - """Replays events that occurred after the specified event ID. - - Args: - last_event_id: The ID of the last event the client received - send_callback: A callback function to send events to the client + """Replay events that occurred after `last_event_id` via `send_callback`. Returns: The stream ID of the replayed events, or None if no events were found. @@ -146,7 +120,6 @@ class StreamableHTTPServerTransport: Supports optional JSON responses and session management. """ - # Server notification streams for POST requests as well as standalone SSE stream _read_stream_writer: ContextSendStream[SessionMessage | Exception] | None = None _read_stream: ContextReceiveStream[SessionMessage | Exception] | None = None _write_stream: ContextSendStream[SessionMessage] | None = None @@ -164,18 +137,13 @@ def __init__( """Initialize a new StreamableHTTP server transport. Args: - mcp_session_id: Optional session identifier for this connection. - Must contain only visible ASCII characters (0x21-0x7E). - is_json_response_enabled: If True, return JSON responses for requests - instead of SSE streams. Default is False. - event_store: Event store for resumability support. If provided, - resumability will be enabled, allowing clients to - reconnect and resume messages. - security_settings: Optional security settings for DNS rebinding protection. - retry_interval: Retry interval in milliseconds to suggest to clients in SSE - retry field. When set, the server will send a retry field in - SSE priming events to control client reconnection timing for - polling behavior. Only used when event_store is provided. + mcp_session_id: Optional session identifier; visible ASCII (0x21-0x7E) only. + is_json_response_enabled: Return JSON responses instead of SSE streams. + event_store: When provided, enables resumability so clients can reconnect + and resume messages. + security_settings: Settings for DNS rebinding protection. + retry_interval: Retry interval in milliseconds sent in SSE priming events to + control client reconnection timing. Only used when event_store is provided. Raises: ValueError: If the session ID contains invalid characters. @@ -206,28 +174,17 @@ def is_terminated(self) -> bool: return self._terminated def close_sse_stream(self, request_id: RequestId) -> None: - """Close SSE connection for a specific request without terminating the stream. - - This method closes the HTTP connection for the specified request, triggering - client reconnection. Events continue to be stored in the event store and will - be replayed when the client reconnects with Last-Event-ID. - - Use this to implement polling behavior during long-running operations - - the client will reconnect after the retry interval specified in the priming event. + """Close the SSE connection for a request without terminating its stream. - Args: - request_id: The request ID whose SSE stream should be closed. - - Note: - This is a no-op if there is no active stream for the request ID. - Requires event_store to be configured for events to be stored during - the disconnect. + Triggers client reconnection: events continue to be stored and are replayed when + the client reconnects with Last-Event-ID, so this can implement polling during + long-running operations. No-op if there is no active stream for the request ID; + requires event_store for events to survive the disconnect. """ writer = self._sse_stream_writers.pop(request_id, None) if writer: # pragma: no branch writer.close() - # Also close and remove request streams if request_id in self._request_streams: # pragma: no branch send_stream, receive_stream = self._request_streams.pop(request_id) send_stream.close() @@ -236,17 +193,9 @@ def close_sse_stream(self, request_id: RequestId) -> None: def close_standalone_sse_stream(self) -> None: """Close the standalone GET SSE stream, triggering client reconnection. - This method closes the HTTP connection for the standalone GET stream used - for unsolicited server-to-client notifications. The client SHOULD reconnect - with Last-Event-ID to resume receiving notifications. - - Use this to implement polling behavior for the notification stream - - the client will reconnect after the retry interval specified in the priming event. - - Note: - This is a no-op if there is no active standalone SSE stream. - Requires event_store to be configured for events to be stored during - the disconnect. + The client SHOULD reconnect with Last-Event-ID to resume receiving notifications. + No-op if there is no active standalone stream; requires event_store for events to + survive the disconnect. """ self.close_sse_stream(GET_STREAM_KEY) @@ -263,7 +212,6 @@ def _create_session_message( resumability (protocol version >= 2025-11-25). Old clients can't resume if the stream is closed early because they didn't receive a priming event. """ - # Only provide close callbacks when client supports resumability if self._event_store and is_version_at_least(protocol_version, "2025-11-25"): async def close_stream_callback() -> None: @@ -340,7 +288,6 @@ def _create_error_response( if self.mcp_session_id: response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - # Return a properly formatted JSON error response error_response = JSONRPCError( jsonrpc="2.0", id=None, @@ -374,41 +321,35 @@ def _create_json_response( ) def _get_session_id(self, request: Request) -> str | None: - """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) def _create_event_data(self, event_message: EventMessage) -> SSEEvent: - """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", "data": event_message.message.model_dump_json(by_alias=True, exclude_unset=True), } - # If an event ID was provided, include it if event_message.event_id: event_data["id"] = event_message.event_id return event_data async def _clean_up_memory_streams(self, request_id: RequestId) -> None: - """Clean up memory streams for a given request ID.""" if request_id in self._request_streams: # pragma: no branch try: - # Close the request stream await self._request_streams[request_id][0].aclose() await self._request_streams[request_id][1].aclose() except Exception: # pragma: no cover - # During cleanup, we catch all exceptions since streams might be in various states + # Streams might be in various states during cleanup logger.debug("Error closing memory streams - may already be closed") finally: - # Remove the request stream from the mapping self._request_streams.pop(request_id, None) async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests.""" request = Request(scope, receive) - # Validate request headers for DNS rebinding protection + # DNS rebinding protection is_post = request.method == "POST" error_response = await self._security.validate_request(request, is_post=is_post) if error_response: @@ -416,7 +357,6 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No return if self._terminated: - # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", HTTPStatus.NOT_FOUND, @@ -444,7 +384,6 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se """Validate Accept header based on response mode. Returns True if valid.""" has_json, has_sse = check_accept_headers(request) if self.is_json_response_enabled: - # For JSON-only responses, only require application/json if not has_json: response = self._create_error_response( "Not Acceptable: Client must accept application/json", @@ -452,7 +391,6 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se ) await response(scope, request.receive, send) return False - # For SSE responses, require both content types elif not (has_json and has_sse): response = self._create_error_response( "Not Acceptable: Client must accept both application/json and text/event-stream", @@ -468,11 +406,9 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re if writer is None: # pragma: no cover raise ValueError("No read stream writer available. Ensure connect() is called first.") try: - # Validate Accept header if not await self._validate_accept_header(request, scope, send): return - # Validate Content-Type if not self._check_content_type(request): # pragma: no cover response = self._create_error_response( "Unsupported Media Type: Content-Type must be application/json", @@ -481,7 +417,6 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) return - # Parse the body - only read it once body = await request.body() try: @@ -502,16 +437,12 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) return - # Check if this is an initialization request is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" if is_initialization_request: - # Check if the server already has an established session if self.mcp_session_id: - # Check if request has a session ID request_session_id = self._get_session_id(request) - # If request has a session ID but doesn't match, return 404 if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover response = self._create_error_response( "Not Found: Invalid or expired session ID", @@ -522,25 +453,21 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re elif not await self._validate_request_headers(request, send): return - # For notifications and responses only, return 202 Accepted + # Notifications and responses get 202 Accepted before processing if not isinstance(message, JSONRPCRequest): - # Create response object and send it response = self._create_json_response( None, HTTPStatus.ACCEPTED, ) await response(scope, receive, send) - # Process the message after sending the response metadata = ServerMessageMetadata(request_context=request) session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) return - # Extract protocol version for priming event decision. - # For initialize requests, get from request params. - # For other requests, get from header (already validated). + # Initialize requests carry the protocol version in params; later requests in the validated header protocol_version = ( str(message.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION)) if is_initialization_request and message.params @@ -554,32 +481,23 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re REQUEST_STREAM_BUFFER_SIZE ) request_stream_reader = self._request_streams[request_id][1] - # Process the message metadata = ServerMessageMetadata(request_context=request) session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) try: - # Process messages from the request-specific stream - # We need to collect all messages until we get a response response_message = None - # Use similar approach to SSE writer for consistency async for event_message in request_stream_reader: # pragma: no branch - # If it's a response, this is what we're waiting for if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): response_message = event_message.message break - # For notifications and requests, keep waiting else: # pragma: no cover logger.debug(f"received: {event_message.message.method}") - # At this point we should have a response if response_message: - # Create JSON response response = self._create_json_response(response_message) await response(scope, receive, send) else: # pragma: no cover - # This shouldn't happen in normal operation logger.error("No response message received before stream closed") response = self._create_error_response( "Error processing request: No response received", @@ -597,10 +515,9 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re finally: await self._clean_up_memory_streams(request_id) else: - # Mint the priming event before any per-request state exists: - # `EventStore.store_event` is user code and may raise, in which - # case the outer handler returns a 500 with nothing to clean up. - # Still strictly precedes dispatch, so storage order == wire order. + # Mint the priming event before any per-request state exists: store_event is user + # code and may raise, in which case the outer handler returns a 500 with nothing to + # clean up. Minting still precedes dispatch, so storage order == wire order. priming_event = await self._mint_priming_event(request_id, protocol_version) sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0) @@ -624,12 +541,10 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re headers=headers, ) - # Start the SSE response (this will send headers immediately) try: - # First send the response to establish the SSE connection + # Establish the SSE connection (headers sent immediately) before dispatching the message async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) - # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) except Exception: # pragma: lax no cover @@ -651,17 +566,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return async def _handle_get_request(self, request: Request, send: Send) -> None: - """Handle GET request to establish SSE. - - This allows the server to communicate to the client without the client - first sending data via HTTP POST. The server can send JSON-RPC requests - and notifications on this stream. - """ + """Establish the standalone SSE stream for server-initiated requests and notifications.""" writer = self._read_stream_writer if writer is None: # pragma: no cover raise ValueError("No read stream writer available. Ensure connect() is called first.") - # Validate Accept header - must include text/event-stream _, has_sse = check_accept_headers(request) if not has_sse: @@ -675,7 +584,6 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: if not await self._validate_request_headers(request, send): return - # Handle resumability: check for Last-Event-ID header if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) return @@ -689,7 +597,6 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - # Check if we already have an active GET stream if GET_STREAM_KEY in self._request_streams: response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", @@ -698,27 +605,18 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0) async def standalone_sse_writer(): try: - # Create a standalone message stream for server-initiated messages - self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage]( REQUEST_STREAM_BUFFER_SIZE ) standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] async with sse_stream_writer, standalone_stream_reader: - # Process messages from the standalone stream + # Carries server-initiated requests and notifications, never responses async for event_message in standalone_stream_reader: - # For the standalone stream, we handle: - # - JSONRPCNotification (server sends notifications to client) - # - JSONRPCRequest (server sends requests to client) - # We should NOT receive JSONRPCResponse - - # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) except anyio.ClosedResourceError: @@ -730,7 +628,6 @@ async def standalone_sse_writer(): logger.debug("Closing standalone SSE writer") await self._clean_up_memory_streams(GET_STREAM_KEY) - # Create and start EventSourceResponse response = EventSourceResponse( content=sse_stream_reader, data_sender_callable=standalone_sse_writer, @@ -738,7 +635,6 @@ async def standalone_sse_writer(): ) try: - # This will send headers immediately and establish the SSE connection await response(request.scope, request.receive, send) except Exception: # pragma: lax no cover logger.exception("Error in standalone SSE response") @@ -749,9 +645,7 @@ async def standalone_sse_writer(): async def _handle_delete_request(self, request: Request, send: Send) -> None: """Handle DELETE requests for explicit session termination.""" - # Validate session ID if not self.mcp_session_id: # pragma: no cover - # If no session ID set, return Method Not Allowed response = self._create_error_response( "Method Not Allowed: Session termination not supported", HTTPStatus.METHOD_NOT_ALLOWED, @@ -779,14 +673,12 @@ async def terminate(self) -> None: self._terminated = True logger.info(f"Terminating session: {self.mcp_session_id}") - # We need a copy of the keys to avoid modification during iteration + # Copy the keys: cleanup mutates the dict request_stream_keys = list(self._request_streams.keys()) - # Close all request streams asynchronously for key in request_stream_keys: await self._clean_up_memory_streams(key) - # Clear the request streams dictionary immediately self._request_streams.clear() try: if self._read_stream_writer is not None: # pragma: no branch @@ -798,7 +690,7 @@ async def terminate(self) -> None: if self._write_stream is not None: # pragma: no branch await self._write_stream.aclose() except Exception as e: # pragma: no cover - # During cleanup, we catch all exceptions since streams might be in various states + # Streams might be in various states during cleanup logger.debug(f"Error closing streams: {e}") async def _handle_unsupported_request(self, request: Request, send: Send) -> None: @@ -818,21 +710,17 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non await response(request.scope, request.receive, send) async def _validate_request_headers(self, request: Request, send: Send) -> bool: - # Protocol-version validation lives in the manager's era-routing: only - # values in `HANDSHAKE_PROTOCOL_VERSIONS` (or no header at all) reach - # this transport, so the legacy version-gate is gone. + # No protocol-version gate here: the manager's era-routing only sends values in + # `HANDSHAKE_PROTOCOL_VERSIONS` (or no header at all) to this transport return await self._validate_session(request, send) async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" if not self.mcp_session_id: - # If we're not using session IDs, return True return True - # Get the session ID from the request headers request_session_id = self._get_session_id(request) - # If no session ID provided but required, return error if not request_session_id: response = self._create_error_response( "Bad Request: Missing session ID", @@ -841,7 +729,6 @@ async def _validate_session(self, request: Request, send: Send) -> bool: await response(request.scope, request.receive, send) return False - # If session ID doesn't match, return error if request_session_id != self.mcp_session_id: # pragma: no cover response = self._create_error_response( "Not Found: Invalid or expired session ID", @@ -853,10 +740,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: return True async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: - """Replays events that would have been sent after the specified event ID. - - Only used when resumability is enabled. - """ + """Replay events that occurred after `last_event_id`; only used when resumability is enabled.""" event_store = self._event_store if not event_store: return # pragma: no cover @@ -874,40 +758,34 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) # The manager only routes supported (or absent) header values to this transport replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) - # Create SSE stream for replay sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0) async def replay_sender(): try: async with sse_stream_writer: - # Define an async callback for sending events + async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - # Replay past events and get the stream ID stream_id = await event_store.replay_events_after(last_event_id, send_event) - # If stream ID not in mapping, create it if stream_id and stream_id not in self._request_streams: # pragma: no branch try: # Register SSE writer so close_sse_stream() can close it self._sse_stream_writers[stream_id] = sse_stream_writer - # Prime the resumed connection so the client sees the stream - # is re-registered. The replayβ†’live-tail ordering window here - # is pre-existing and tracked separately. + # Prime so the client sees the stream re-registered; the + # replayβ†’live-tail ordering window is pre-existing, tracked separately priming_event = await self._mint_priming_event(stream_id, replay_protocol_version) if priming_event is not None: await sse_stream_writer.send(priming_event) - # Create new request streams for this connection self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage]( REQUEST_STREAM_BUFFER_SIZE ) msg_reader = self._request_streams[stream_id][1] - # Forward messages to SSE async with msg_reader: async for event_message in msg_reader: event_data = self._create_event_data(event_message) @@ -917,12 +795,10 @@ async def send_event(event_message: EventMessage) -> None: self._sse_stream_writers.pop(stream_id, None) await self._clean_up_memory_streams(stream_id) except anyio.ClosedResourceError: # pragma: lax no cover - # Expected when close_sse_stream() is called logger.debug("Replay SSE stream closed by close_sse_stream()") except Exception: # pragma: lax no cover logger.exception("Error in replay sender") - # Create and start EventSourceResponse response = EventSourceResponse( content=sse_stream_reader, data_sender_callable=replay_sender, @@ -956,38 +832,26 @@ async def connect( ], None, ]: - """Context manager that provides read and write streams for a connection. - - Yields: - Tuple of (read_stream, write_stream) for bidirectional communication - """ - - # Create the memory streams for this connection - + """Set up the connection's streams and message router, yielding (read_stream, write_stream).""" read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) write_stream, write_stream_reader = create_context_streams[SessionMessage](0) - # Store the streams self._read_stream_writer = read_stream_writer self._read_stream = read_stream self._write_stream_reader = write_stream_reader self._write_stream = write_stream - # Start a task group for message routing async with anyio.create_task_group() as tg: - # Create a message router that distributes messages to request streams + async def message_router(): try: async for session_message in write_stream_reader: # pragma: no branch - # Determine which request stream(s) should receive this message message = session_message.message target_request_id = None - # Check if this is a response with a known request id. - # Null-id errors (e.g., parse errors) fall through to - # the GET stream since they can't be correlated. + # Null-id errors (e.g. parse errors) can't be correlated, so they fall + # through to the GET stream if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) - # Extract related_request_id from meta if it exists elif ( session_message.metadata is not None and isinstance( @@ -1000,9 +864,7 @@ async def message_router(): request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY - # Store the event if we have an event store, - # regardless of whether a client is connected - # messages will be replayed on the re-connect + # Store even when no client is connected; messages replay on reconnect event_id = None if self._event_store: event_id = await self._event_store.store_event(request_stream_id, message) @@ -1010,10 +872,8 @@ async def message_router(): if request_stream_id in self._request_streams: try: - # Send both the message and the event ID await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover - # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) else: logger.debug( @@ -1029,23 +889,20 @@ async def message_router(): except Exception: # pragma: lax no cover logger.exception("Error in message router") - # Start the message router tg.start_soon(message_router) try: - # Yield the streams for the caller to use yield read_stream, write_stream finally: for stream_id in list(self._request_streams.keys()): await self._clean_up_memory_streams(stream_id) self._request_streams.clear() - # Clean up the read and write streams try: await read_stream_writer.aclose() await read_stream.aclose() await write_stream_reader.aclose() await write_stream.aclose() except Exception as e: # pragma: no cover - # During cleanup, we catch all exceptions since streams might be in various states + # Streams might be in various states during cleanup logger.debug(f"Error closing streams: {e}") diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 60b0989611..2f81213a93 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -38,37 +38,22 @@ class StreamableHTTPSessionManager: - """Manages StreamableHTTP sessions with optional resumability via event store. + """Manages StreamableHTTP sessions, transports, and optional resumability via event store. - This class abstracts away the complexity of session management, event storage, - and request handling for StreamableHTTP transports. It handles: - - 1. Session tracking for clients - 2. Resumability via an optional event store - 3. Connection management and lifecycle - 4. Request handling and transport setup - 5. Idle session cleanup via optional timeout - - Important: Only one StreamableHTTPSessionManager instance should be created - per application. The instance cannot be reused after its run() context has - completed. If you need to restart the manager, create a new instance. + Create only one instance per application. An instance cannot be reused after + its `run()` context exits β€” create a new one to restart. Args: - app: The MCP server instance - event_store: Optional event store for resumability support. If provided, enables resumable connections - where clients can reconnect and receive missed events. If None, sessions are still tracked but not - resumable. - json_response: Whether to use JSON responses instead of SSE streams - stateless: If True, creates a completely fresh transport for each request with no session tracking or - state persistence between requests. + app: The MCP server instance. + event_store: Enables resumable connections (clients can reconnect and receive missed events). + If None, sessions are still tracked but not resumable. + json_response: Use JSON responses instead of SSE streams. + stateless: Create a fresh transport per request with no session tracking or state persistence. security_settings: Optional transport security settings. - retry_interval: Retry interval in milliseconds to suggest to clients in SSE retry field. Used for SSE - polling behavior. - session_idle_timeout: Optional idle timeout in seconds for stateful sessions. If set, sessions that - receive no HTTP requests for this duration will be automatically terminated and removed. When - retry_interval is also configured, ensure the idle timeout comfortably exceeds the retry interval to - avoid reaping sessions during normal SSE polling gaps. Default is None (no timeout). A value of 1800 - (30 minutes) is recommended for most deployments. + retry_interval: Retry interval in milliseconds suggested to clients in the SSE retry field. + session_idle_timeout: Seconds of HTTP inactivity after which a stateful session is terminated + and removed. When retry_interval is set, must comfortably exceed it to avoid reaping sessions + during normal SSE polling gaps. Default None (no timeout); 1800 (30 minutes) suits most deployments. """ def __init__( @@ -94,38 +79,28 @@ def __init__( self.retry_interval = retry_interval self.session_idle_timeout = session_idle_timeout - # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} - # Identity of the credential that created each session; requests for a - # session must present the same credential. + # Credential that created each session; subsequent requests must present the same one. self._session_owners: dict[str, AuthorizationContext] = {} - # The task group and lifespan state are set during run() self._task_group = None self._lifespan_state: Any = None - # Thread-safe tracking of run() calls self._run_lock = anyio.Lock() self._has_started = False @contextlib.asynccontextmanager async def run(self) -> AsyncIterator[None]: - """Run the session manager with proper lifecycle management. - - This creates and manages the task group for all session operations. - - Important: This method can only be called once per instance. The same - StreamableHTTPSessionManager instance cannot be reused after this - context manager exits. Create a new instance if you need to restart. + """Run the task group that owns all session operations. - Use this in the lifespan context manager of your Starlette app: + Can only be called once per instance; create a new instance to restart. + Use it in the lifespan context manager of your Starlette app: @contextlib.asynccontextmanager async def lifespan(app: Starlette) -> AsyncIterator[None]: async with session_manager.run(): yield """ - # Thread-safe check to ensure run() is only called once async with self._run_lock: if self._has_started: raise RuntimeError( @@ -135,38 +110,30 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._has_started = True async with self.app.lifespan(self.app) as lifespan_state, anyio.create_task_group() as tg: - # Store for handle_request: lifespan is entered once for the - # manager's lifetime, not per request (per-connection cleanup - # belongs on `connection.exit_stack`). + # Lifespan is entered once for the manager's lifetime, not per request; + # per-connection cleanup belongs on `connection.exit_stack`. self._lifespan_state = lifespan_state self._task_group = tg logger.info("StreamableHTTP session manager started") try: - yield # Let the application run + yield finally: logger.info("StreamableHTTP session manager shutting down") - # Cancel task group to stop all spawned tasks tg.cancel_scope.cancel() self._task_group = None self._lifespan_state = None - # Clear any remaining server instances self._server_instances.clear() self._session_owners.clear() await resync_tracer() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """Process ASGI request with proper session handling and transport setup. - - Dispatches to the appropriate handler based on stateless mode. - """ + """Process an ASGI request, dispatching by protocol era and stateless mode.""" if self._task_group is None: raise RuntimeError("Task group is not initialized. Make sure to use run().") - # TODO(L49): header-only era-routing for now; body-primary classification - # is a follow-up. The legacy paths below own only the known - # initialize-handshake versions; anything else (including unknown - # values) goes to the modern entry so the classifier can validate it - # and return a structured rejection. 2025 paths below remain unchanged. + # TODO(L49): header-only era routing; body-primary classification is a follow-up. + # Legacy paths own only the known initialize-handshake versions; anything else + # goes to the modern entry so the classifier can validate it and reject it structurally. header = MCP_PROTOCOL_VERSION_HEADER.encode("ascii") pv = next((v.decode("latin-1") for k, v in scope["headers"] if k == header), None) if pv is not None and pv not in HANDSHAKE_PROTOCOL_VERSIONS: @@ -175,7 +142,6 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No ) return - # Dispatch to the appropriate handler if self.stateless: await self._handle_stateless_request(pv, scope, receive, send) else: @@ -184,17 +150,15 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No async def _handle_stateless_request( self, protocol_version_hint: str | None, scope: Scope, receive: Receive, send: Send ) -> None: - """Process request in stateless mode - creating a new transport for each request.""" + """Process request in stateless mode, creating a new transport for each request.""" logger.debug("Stateless mode: Creating new transport for this request") - # No session ID needed in stateless mode http_transport = StreamableHTTPServerTransport( - mcp_session_id=None, # No session tracking in stateless mode + mcp_session_id=None, is_json_response_enabled=self.json_response, - event_store=None, # No event store in stateless mode + event_store=None, security_settings=self.security_settings, ) - # Start server in a new task async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): async with http_transport.connect() as streams: read_stream, write_stream = streams @@ -203,17 +167,13 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA read_stream, write_stream, inline_methods=frozenset({"initialize"}), - # No session ID means a server-to-client request can be - # written to this POST's response stream, but the client's - # reply has nowhere to land β€” `can_send_request=False` - # makes the per-request channel raise `NoBackChannelError` - # for requests while still allowing notifications. + # Without a session ID a server-to-client request could be written to this POST's + # response stream, but the client's reply has nowhere to land β€” `can_send_request=False` + # raises `NoBackChannelError` for requests while still allowing notifications. transport_builder=lambda _md: TransportContext(kind="streamable-http", can_send_request=False), ) - # Born-ready, no standalone channel: the legacy stateless path - # never opens a GET stream and need not see `initialize`. The - # header (or the spec's default-absent value) seeds - # `ctx.protocol_version`. + # Born-ready: the legacy stateless path never opens a GET stream and need not see + # `initialize`. The header (or the spec default when absent) seeds `ctx.protocol_version`. connection = Connection.from_envelope( protocol_version_hint if protocol_version_hint is not None else DEFAULT_NEGOTIATED_VERSION, None, @@ -226,31 +186,25 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") - # Assert task group is not None for type checking assert self._task_group is not None - # Start the server task await self._task_group.start(run_stateless_server) - # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) - # Terminate the transport after the request is handled await http_transport.terminate() async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """Process request in stateful mode - maintaining session state between requests.""" + """Process request in stateful mode, maintaining session state between requests.""" request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) user = scope.get("user") requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None - # Existing session case if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] if requestor != self._session_owners.get(request_mcp_session_id): - # A session can only be used with the credential that created - # it. Respond exactly as if the session did not exist. + # Sessions are bound to the creating credential; respond as if the session did not exist. logger.warning( "Rejecting request for session %s: credential does not match the one that created the session", request_mcp_session_id[:64], @@ -266,21 +220,19 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S await response(scope, receive, send) return logger.debug("Session already exists, handling request directly") - # Push back idle deadline on activity if transport.idle_scope is not None and self.session_idle_timeout is not None: transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout # pragma: no cover await transport.handle_request(scope, receive, send) return if request_mcp_session_id is None: - # New session case logger.debug("Creating new transport") async with self._session_creation_lock: new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, is_json_response_enabled=self.json_response, - event_store=self.event_store, # May be None (no resumability) + event_store=self.event_store, security_settings=self.security_settings, retry_interval=self.retry_interval, ) @@ -291,25 +243,21 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S self._server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") - # Define the server runner async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() try: - # Use a cancel scope for idle timeout β€” when the - # deadline passes the scope cancels the loop and - # execution continues after the ``with`` block. - # Incoming requests push the deadline forward. + # Idle timeout: when the deadline passes the scope cancels the loop and execution + # resumes after the `with` block. Incoming requests push the deadline forward. idle_scope = anyio.CancelScope() if self.session_idle_timeout is not None: idle_scope.deadline = anyio.current_time() + self.session_idle_timeout http_transport.idle_scope = idle_scope with idle_scope: - # Drive via `serve_loop` (not `Server.run()`) so the - # manager's already-entered lifespan is reused - # rather than re-entered per session. + # `serve_loop` (not `Server.run()`) reuses the manager's already-entered + # lifespan rather than re-entering it per session. await serve_loop( self.app, read_stream, @@ -339,17 +287,13 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE del self._server_instances[http_transport.mcp_session_id] self._session_owners.pop(http_transport.mcp_session_id, None) - # Assert task group is not None for type checking assert self._task_group is not None - # Start the server task await self._task_group.start(run_server) - # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) else: - # Unknown or expired session ID - return 404 per MCP spec - # TODO(L62): Align error code once spec clarifies - # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 + # Unknown or expired session ID: 404 per MCP spec. TODO(L62): align error code + # once spec clarifies β€” https://github.com/modelcontextprotocol/python-sdk/issues/1821 logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") body = JSONRPCError( jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index d9e9f965b3..350600b520 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -11,25 +11,16 @@ # TODO(Marcelo): We should flatten these settings. To be fair, I don't think we should even have this middleware. class TransportSecuritySettings(BaseModel): - """Settings for MCP transport security features. - - These settings help protect against DNS rebinding attacks by validating incoming request headers. - """ + """Settings for protecting MCP transports against DNS rebinding via request header validation.""" enable_dns_rebinding_protection: bool = True """Enable DNS rebinding protection (recommended for production).""" allowed_hosts: list[str] = Field(default_factory=list) - """List of allowed Host header values. - - Only applies when `enable_dns_rebinding_protection` is `True`. - """ + """Allowed Host header values; only applies when `enable_dns_rebinding_protection` is `True`.""" allowed_origins: list[str] = Field(default_factory=list) - """List of allowed Origin header values. - - Only applies when `enable_dns_rebinding_protection` is `True`. - """ + """Allowed Origin header values; only applies when `enable_dns_rebinding_protection` is `True`.""" # TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this. @@ -37,25 +28,21 @@ class TransportSecurityMiddleware: """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" def __init__(self, settings: TransportSecuritySettings | None = None): - # If not specified, disable DNS rebinding protection by default for backwards compatibility + # Default to disabled for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) def _validate_host(self, host: str | None) -> bool: - """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") return False - # Check exact match first if host in self.settings.allowed_hosts: return True - # Check wildcard port patterns + # A "host:*" pattern allows any port on that host for allowed in self.settings.allowed_hosts: if allowed.endswith(":*"): - # Extract base host from pattern base_host = allowed[:-2] - # Check if the actual host starts with base host and has a port if host.startswith(base_host + ":"): return True @@ -63,21 +50,17 @@ def _validate_host(self, host: str | None) -> bool: return False def _validate_origin(self, origin: str | None) -> bool: - """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: return True - # Check exact match first if origin in self.settings.allowed_origins: return True - # Check wildcard port patterns + # An "origin:*" pattern allows any port on that origin for allowed in self.settings.allowed_origins: if allowed.endswith(":*"): - # Extract base origin from pattern base_origin = allowed[:-2] - # Check if the actual origin starts with base origin and has a port if origin.startswith(base_origin + ":"): return True @@ -85,7 +68,6 @@ def _validate_origin(self, origin: str | None) -> bool: return False def _validate_content_type(self, content_type: str | None) -> bool: - """Validate the Content-Type header for POST requests.""" return content_type is not None and content_type.lower().startswith("application/json") async def validate_request(self, request: Request, is_post: bool = False) -> Response | None: @@ -93,22 +75,19 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ - # Always validate Content-Type for POST requests + # Content-Type is checked even when DNS rebinding protection is disabled if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) - # Skip remaining validation if DNS rebinding protection is disabled if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header host = request.headers.get("host") if not self._validate_host(host): return Response("Invalid Host header", status_code=421) - # Validate Origin header origin = request.headers.get("origin") if not self._validate_origin(origin): return Response("Invalid Origin header", status_code=403) diff --git a/src/mcp/server/validation.py b/src/mcp/server/validation.py index fd16beb957..88f005a124 100644 --- a/src/mcp/server/validation.py +++ b/src/mcp/server/validation.py @@ -1,7 +1,4 @@ -"""Shared validation functions for server requests. - -This module provides validation logic for sampling and elicitation requests. -""" +"""Shared validation for sampling and elicitation requests.""" from mcp_types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, Tool, ToolChoice @@ -9,14 +6,7 @@ def check_sampling_tools_capability(client_caps: ClientCapabilities | None) -> bool: - """Check if the client supports sampling tools capability. - - Args: - client_caps: The client's declared capabilities - - Returns: - True if client supports sampling.tools, False otherwise - """ + """Return True if the client declares the `sampling.tools` capability.""" if client_caps is None: return False if client_caps.sampling is None: @@ -31,15 +21,10 @@ def validate_sampling_tools( tools: list[Tool] | None, tool_choice: ToolChoice | None, ) -> None: - """Validate that the client supports sampling tools if tools are being used. - - Args: - client_caps: The client's declared capabilities - tools: The tools list, if provided - tool_choice: The tool choice setting, if provided + """Validate that the client supports sampling tools when tools are being used. Raises: - MCPError: If tools/tool_choice are provided but client doesn't support them + MCPError: If `tools` or `tool_choice` is provided but the client lacks `sampling.tools`. """ if tools is not None or tool_choice is not None: if not check_sampling_tools_capability(client_caps): @@ -49,18 +34,12 @@ def validate_sampling_tools( def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None: """Validate tool_use/tool_result message structure per SEP-1577. - This validation ensures: - 1. Messages with tool_result content contain ONLY tool_result content - 2. tool_result messages are preceded by a message with tool_use - 3. tool_result IDs match the tool_use IDs from the previous message - - See: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577 - - Args: - messages: The list of sampling messages to validate + A message with tool_result content must contain only tool_result blocks, follow a + message containing tool_use, and reference exactly that message's tool_use IDs. + See https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577. Raises: - ValueError: If the message structure is invalid + ValueError: If the message structure is invalid. """ if not messages: return @@ -72,8 +51,7 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None: has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content) if has_tool_results: - # Per spec: "SamplingMessage with tool result content blocks - # MUST NOT contain other content types." + # Spec: a SamplingMessage with tool_result blocks MUST NOT contain other content types. if any(c.type != "tool_result" for c in last_content): raise ValueError("The last message must contain only tool_result content if any is present") if previous_content is None: diff --git a/src/mcp/shared/_callable_inspection.py b/src/mcp/shared/_callable_inspection.py index 0e89e446f8..80443c3fc4 100644 --- a/src/mcp/shared/_callable_inspection.py +++ b/src/mcp/shared/_callable_inspection.py @@ -1,6 +1,5 @@ -"""Callable inspection utilities. +"""Callable inspection, adapted from Starlette's `is_async_callable`. -Adapted from Starlette's `is_async_callable` implementation. https://github.com/encode/starlette/blob/main/starlette/_utils.py """ diff --git a/src/mcp/shared/_context_streams.py b/src/mcp/shared/_context_streams.py index 04c33306d9..190ef3d0fa 100644 --- a/src/mcp/shared/_context_streams.py +++ b/src/mcp/shared/_context_streams.py @@ -1,13 +1,8 @@ """Context-aware memory stream wrappers. -anyio memory streams do not propagate ``contextvars.Context`` across task -boundaries. These thin wrappers capture the sender's context at ``send()`` -time and expose it on the receive side via ``last_context``, so consumers -can restore it with ``ctx.run(handler, item)``. - -The iteration interface is unchanged (yields ``T``, not tuples), keeping -these wrappers duck-type compatible with plain ``MemoryObjectSendStream`` -and ``MemoryObjectReceiveStream``. +anyio memory streams don't propagate `contextvars.Context` across tasks; these wrappers snapshot +the sender's context at `send()` and expose it on the receive side via `last_context`, while still +yielding plain `T` so they stay duck-type compatible with the anyio memory streams. """ from __future__ import annotations @@ -21,12 +16,11 @@ T = TypeVar("T") -# Internal payload carried through the underlying raw stream. _Envelope = tuple[contextvars.Context, T] class ContextSendStream(Generic[T]): - """Send-side wrapper that snapshots ``contextvars.copy_context()`` on every ``send()``.""" + """Send-side wrapper that snapshots `contextvars.copy_context()` on every `send()`.""" __slots__ = ("_inner",) @@ -59,7 +53,7 @@ async def __aexit__( class ContextReceiveStream(Generic[T]): - """Receive-side wrapper that yields ``T`` and stores the sender's context in ``last_context``.""" + """Receive-side wrapper that yields `T` and stores the sender's context in `last_context`.""" __slots__ = ("_inner", "last_context") @@ -108,8 +102,8 @@ class create_context_streams( ): """Create context-aware memory object streams. - Supports ``create_context_streams[T](n)`` bracket syntax, - matching anyio's ``create_memory_object_stream`` API style. + A class so `create_context_streams[T](n)` bracket syntax works, matching anyio's + `create_memory_object_stream`. """ def __new__(cls, max_buffer_size: float = 0) -> tuple[ContextSendStream[T], ContextReceiveStream[T]]: # type: ignore[type-var] diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 6a121aff6d..5c46d7d94e 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -6,9 +6,8 @@ __all__ = ["create_mcp_http_client", "MCP_DEFAULT_TIMEOUT", "MCP_DEFAULT_SSE_READ_TIMEOUT"] -# Default MCP timeout configuration -MCP_DEFAULT_TIMEOUT = 30.0 # General operations (seconds) -MCP_DEFAULT_SSE_READ_TIMEOUT = 300.0 # SSE streams - 5 minutes (seconds) +MCP_DEFAULT_TIMEOUT = 30.0 # seconds, general operations +MCP_DEFAULT_SSE_READ_TIMEOUT = 300.0 # seconds, long-lived SSE streams class McpHttpClientFactory(Protocol): # pragma: no branch @@ -25,70 +24,22 @@ def create_mcp_http_client( timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None, ) -> httpx.AsyncClient: - """Create a standardized httpx AsyncClient with MCP defaults. + """Create an httpx AsyncClient with MCP defaults. - Always enables follow_redirects and applies an SSE-friendly default timeout. - - Args: - headers: Optional headers to include with all requests. - timeout: Request timeout as httpx.Timeout object. Defaults to 30s for - connect/write/pool and 300s for read (for long-lived SSE streams). - auth: Optional authentication handler. - - Returns: - Configured httpx.AsyncClient instance with MCP defaults. - - Note: - The returned AsyncClient must be used as a context manager to ensure - proper cleanup of connections. - - Example: - Basic usage with MCP defaults: - - ```python - async with create_mcp_http_client() as client: - response = await client.get("https://api.example.com") - ``` - - With custom headers: - - ```python - headers = {"Authorization": "Bearer token"} - async with create_mcp_http_client(headers) as client: - response = await client.get("/endpoint") - ``` - - With both custom headers and timeout: - - ```python - timeout = httpx.Timeout(60.0, read=300.0) - async with create_mcp_http_client(headers, timeout) as client: - response = await client.get("/long-request") - ``` - - With authentication: - - ```python - from httpx import BasicAuth - auth = BasicAuth(username="user", password="pass") - async with create_mcp_http_client(headers, timeout, auth) as client: - response = await client.get("/protected-endpoint") - ``` + Enables follow_redirects and, when `timeout` is omitted, defaults to 30s for + connect/write/pool and 300s for read so long-lived SSE streams stay open. + Use the returned client as a context manager to clean up connections. """ - # Set MCP defaults kwargs: dict[str, Any] = {"follow_redirects": True} - # Handle timeout if timeout is None: kwargs["timeout"] = httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT) else: kwargs["timeout"] = timeout - # Handle headers if headers is not None: kwargs["headers"] = headers - # Handle authentication if auth is not None: # pragma: no cover kwargs["auth"] = auth diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py index b7b05b11ab..d633726d00 100644 --- a/src/mcp/shared/_otel.py +++ b/src/mcp/shared/_otel.py @@ -24,7 +24,6 @@ def otel_span( record_exception: bool = True, set_status_on_exception: bool = True, ) -> Generator[Span]: - """Create an OTel span.""" with _tracer.start_as_current_span( name, kind=kind, @@ -44,10 +43,8 @@ def inject_trace_context(meta: dict[str, Any]) -> None: def extract_trace_context(meta: Mapping[str, Any] | None) -> Context | None: """Extract W3C trace context from a `_meta` dict. - Returns `None` when the carrier is absent, malformed, or carries no - valid `traceparent`, so callers fall through to ambient parenting; an - explicit empty `Context` would orphan the span instead of nesting under - the current one. + Returns `None` when the carrier is absent, malformed, or lacks a valid `traceparent`, + so callers fall back to ambient parenting; an empty `Context` would orphan the span. """ if not meta: return None diff --git a/src/mcp/shared/_stream_protocols.py b/src/mcp/shared/_stream_protocols.py index b799751329..6fdc22406d 100644 --- a/src/mcp/shared/_stream_protocols.py +++ b/src/mcp/shared/_stream_protocols.py @@ -1,8 +1,4 @@ -"""Stream protocols for MCP transports. - -These are general-purpose protocols satisfied by both ``MemoryObjectSendStream``/ -``MemoryObjectReceiveStream`` and the context-aware wrappers in ``_context_streams``. -""" +"""Stream protocols satisfied by both anyio memory object streams and the `_context_streams` wrappers.""" from __future__ import annotations @@ -18,8 +14,7 @@ class ReadStream(Protocol[T_co]): """Protocol for reading items from a stream. - Consumers that need the sender's context should use - ``getattr(stream, 'last_context', None)``. + `getattr(stream, 'last_context', None)` gives the sender's context. """ async def receive(self) -> T_co: ... diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 2bbf7a715a..8b8de93165 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -19,8 +19,7 @@ class OAuthToken(BaseModel): @classmethod def normalize_token_type(cls, v: str | None) -> str | None: if isinstance(v, str): - # Bearer is title-cased in the spec, so we normalize it - # https://datatracker.ietf.org/doc/html/rfc6750#section-4 + # The spec title-cases "Bearer"; normalize (https://datatracker.ietf.org/doc/html/rfc6750#section-4) return v.title() return v # pragma: no cover @@ -28,8 +27,7 @@ def normalize_token_type(cls, v: str | None) -> str | None: class AuthorizationCodeResult(BaseModel): """Authorization-code-grant redirect parameters returned by a callback handler. - `iss` carries the RFC 9207 authorization-response issuer when the authorization server - includes it in the redirect; the client validates it against the expected issuer. + `iss` is the RFC 9207 authorization-response issuer; the client validates it against the expected issuer. """ code: str @@ -48,34 +46,27 @@ def __init__(self, message: str): class OAuthClientMetadata(BaseModel): - """RFC 7591 OAuth 2.0 Dynamic Client Registration Metadata. - See https://datatracker.ietf.org/doc/html/rfc7591#section-2 - """ + """RFC 7591 Dynamic Client Registration metadata. See https://datatracker.ietf.org/doc/html/rfc7591#section-2""" model_config = ConfigDict(url_preserve_empty_path=True) redirect_uris: list[AnyUrl] | None = Field(..., min_length=1) - # supported auth methods for the token endpoint token_endpoint_auth_method: ( Literal["none", "client_secret_post", "client_secret_basic", "private_key_jwt"] | None ) = None - # supported grant_types of this implementation grant_types: list[ Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"] | str ] = [ "authorization_code", "refresh_token", ] - # The MCP spec requires the "code" response type, but OAuth - # servers may also return additional types they support + # MCP requires the "code" response type; servers may return additional types they support response_types: list[str] = ["code"] scope: str | None = None - # SEP-837: OIDC application_type. Defaults to "native" since MCP clients typically use - # loopback redirect URIs; set "web" for remote browser-based clients on a non-local host. + # SEP-837 OIDC application_type: "native" for loopback redirects, "web" for remote browser-based clients. application_type: Literal["web", "native"] = "native" - # these fields are currently unused, but we support & store them for potential - # future use + # Currently unused; accepted and stored for potential future use client_name: str | None = None client_uri: AnyHttpUrl | None = None logo_uri: AnyHttpUrl | None = None @@ -97,10 +88,8 @@ class OAuthClientMetadata(BaseModel): ) @classmethod def _empty_string_optional_url_to_none(cls, v: object) -> object: - # RFC 7591 Β§2 marks these URL fields OPTIONAL. Some authorization servers - # echo omitted metadata back as "" instead of dropping the keys, which - # AnyHttpUrl would otherwise reject β€” throwing away an otherwise valid - # registration response. Treat "" as absent. + # RFC 7591 Β§2 marks these URL fields OPTIONAL, but some servers echo omitted metadata back + # as "" β€” which AnyHttpUrl would reject. Treat "" as absent. if v == "": return None return v @@ -117,7 +106,6 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: - # Validate redirect_uri against client's registered redirect URIs if self.redirect_uris is None or redirect_uri not in self.redirect_uris: raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client") return redirect_uri @@ -128,23 +116,18 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: class OAuthClientInformationFull(OAuthClientMetadata): - """RFC 7591 OAuth 2.0 Dynamic Client Registration full response - (client information plus metadata). - """ + """RFC 7591 Dynamic Client Registration full response (client information plus metadata).""" client_id: str | None = None client_secret: str | None = None client_id_issued_at: int | None = None client_secret_expires_at: int | None = None - # SEP-2352: the issuer these credentials were registered with, recorded by the SDK (not an - # RFC 7591 field) to detect authorization-server migration and avoid cross-AS credential reuse. + # SEP-2352: issuer the credentials were registered with; SDK-recorded (not RFC 7591) to detect AS migration. issuer: str | None = None class OAuthMetadata(BaseModel): - """RFC 8414 OAuth 2.0 Authorization Server Metadata. - See https://datatracker.ietf.org/doc/html/rfc8414#section-2 - """ + """RFC 8414 Authorization Server Metadata. See https://datatracker.ietf.org/doc/html/rfc8414#section-2""" model_config = ConfigDict(url_preserve_empty_path=True) @@ -171,15 +154,12 @@ class OAuthMetadata(BaseModel): code_challenge_methods_supported: list[str] | None = None client_id_metadata_document_supported: bool | None = None authorization_response_iss_parameter_supported: bool | None = None - # SEP-990 / draft-ietf-oauth-identity-assertion-authz-grant Β§7.2: profiles whose grants the - # authorization server supports, e.g. `urn:ietf:params:oauth:grant-profile:id-jag`. + # SEP-990 / draft-ietf-oauth-identity-assertion-authz-grant Β§7.2, e.g. `urn:ietf:params:oauth:grant-profile:id-jag` authorization_grant_profiles_supported: list[str] | None = None class ProtectedResourceMetadata(BaseModel): - """RFC 9728 OAuth 2.0 Protected Resource Metadata. - See https://datatracker.ietf.org/doc/html/rfc9728#section-2 - """ + """RFC 9728 Protected Resource Metadata. See https://datatracker.ietf.org/doc/html/rfc9728#section-2""" model_config = ConfigDict(url_preserve_empty_path=True) diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 3ba880f40d..033040337f 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -9,19 +9,11 @@ def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: """Convert server URL to canonical resource URL per RFC 8707. - RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". - Returns absolute URI with lowercase scheme/host for canonical form. - - Args: - url: Server URL to convert - - Returns: - Canonical resource URL string + Lowercases scheme/host and strips the fragment (RFC 8707 section 2: resource URIs + "MUST NOT include a fragment component"). """ - # Convert to string if needed url_str = str(url) - # Parse the URL and remove fragment, create canonical form parsed = urlsplit(url_str) canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment="")) @@ -31,28 +23,16 @@ def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: def check_resource_allowed(requested_resource: str, configured_resource: str) -> bool: """Check if a requested resource URL matches a configured resource URL. - A requested resource matches if it has the same scheme, domain, port, - and its path starts with the configured resource's path. This allows - hierarchical matching where a token for a parent resource can be used - for child resources. - - Args: - requested_resource: The resource URL being requested - configured_resource: The resource URL that has been configured - - Returns: - True if the requested resource matches the configured resource + Matches when the origin is identical and the requested path starts with the + configured path, so a token for a parent resource covers child resources. """ - # Parse both URLs requested = urlparse(requested_resource) configured = urlparse(configured_resource) - # Compare scheme, host, and port (origin) if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): return False - # Normalize trailing slashes before comparison so that - # "/foo" and "/foo/" are treated as equivalent. + # Normalize trailing slashes so "/foo" == "/foo/" and "/api123/" can't prefix-match "/api/". requested_path = requested.path configured_path = configured.path if not requested_path.endswith("/"): @@ -60,21 +40,14 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> if not configured_path.endswith("/"): configured_path += "/" - # Check hierarchical match: requested must start with configured path. - # The trailing-slash normalization ensures "/api123/" won't match "/api/". return requested_path.startswith(configured_path) def calculate_token_expiry(expires_in: int | str | None) -> float | None: - """Calculate token expiry timestamp from expires_in seconds. - - Args: - expires_in: Seconds until token expiration (may be string from some servers) + """Calculate the Unix expiry timestamp from `expires_in` seconds, or None if not specified. - Returns: - Unix timestamp when token expires, or None if no expiry specified + Accepts strings because some servers return `expires_in` as a string. """ if expires_in is None: return None # pragma: no cover - # Defensive: handle servers that return expires_in as string return time.time() + int(expires_in) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 13c145be5b..1498f4088c 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,12 +1,7 @@ -"""`BaseContext` - the user-facing per-request context. +"""`BaseContext` - the user-facing per-request context, composing a `DispatchContext`. -Composition over a `DispatchContext`: forwards the transport metadata, the -back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel -event. Adds `meta` (the inbound request's `_meta` field). - -Satisfies `Outbound`, so `ClientPeer` can wrap it. Shared between client and -server: the server's `Context` extends this with `lifespan`/`connection`; -`ClientContext` is just an alias. +Satisfies `Outbound`, so `ClientPeer` can wrap it. Shared between client and server: +the server's `Context` adds `lifespan`/`connection`; `ClientContext` is just an alias. """ from collections.abc import Mapping @@ -25,11 +20,7 @@ class BaseContext(Generic[TransportT]): - """Per-request context wrapping a `DispatchContext`. - - `ServerRunner` constructs one per inbound request and passes it to the - user's handler. - """ + """Per-request context wrapping a `DispatchContext`; constructed by `ServerRunner` per inbound request.""" def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None: self._dctx = dctx @@ -49,8 +40,7 @@ def cancel_requested(self) -> anyio.Event: def can_send_request(self) -> bool: """Whether the back-channel can currently deliver server-initiated requests. - `False` when the transport has no back-channel, or when the underlying - dispatch context has been closed because the inbound request finished. + `False` when the transport has no back-channel or the inbound request has finished. """ return self._dctx.can_send_request @@ -78,8 +68,5 @@ async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Call await self._dctx.notify(method, params, opts) async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - """Report progress for this request, if the peer supplied a progress token. - - A no-op when no token was supplied. - """ + """Report progress for this request; a no-op when the peer supplied no progress token.""" await self._dctx.progress(progress, total, message) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index fd3e69d493..8673dc1c4a 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -1,18 +1,8 @@ """In-memory `Dispatcher` that wires two peers together with no transport. -`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a -request on one side directly invokes the other side's `on_request`. There is no -serialization, no JSON-RPC framing, and no streams. It exists to: - -* prove the `Dispatcher` Protocol is implementable without JSON-RPC -* provide a fast substrate for testing the layers above the dispatcher - (`ServerRunner`, `Context`, `Connection`) without wire-level moving parts -* embed a server in-process when the JSON-RPC overhead is unnecessary - -Like `JSONRPCDispatcher`, this is an exception-to-error boundary: a handler -exception surfaces to the caller as `MCPError`. The `raise_handler_exceptions` -knob controls whether unmapped exceptions are sanitized (matching the wire -path) or chained as ``__cause__`` for in-process debugging. +A request on one side directly invokes the other side's `on_request` β€” no +serialization, no JSON-RPC framing, no streams. A fast substrate for testing +the layers above the dispatcher and for embedding a server in-process. """ from __future__ import annotations @@ -46,11 +36,7 @@ @dataclass class _DirectDispatchContext: - """`DispatchContext` for an inbound request on a `DirectDispatcher`. - - The back-channel callables target the *originating* side, so a handler's - `send_raw_request` reaches the peer that made the inbound request. - """ + """`DispatchContext` for an inbound request; back-channel callables target the originating peer.""" transport: TransportContext _back_request: _Request @@ -87,16 +73,11 @@ async def progress(self, progress: float, total: float | None = None, message: s class DirectDispatcher: """A `Dispatcher` that calls a peer's handlers directly, in-process. - Two instances are wired together with `create_direct_dispatcher_pair`; each - holds a reference to the other. `send_raw_request` on one awaits the peer's - `on_request`. `run` parks until `close` is called. - + Two instances are wired together with `create_direct_dispatcher_pair`. Lifecycle mirrors `JSONRPCDispatcher`: `send_raw_request` requires `run()` - to have started, and once a side has closed - via `close()` or `run()` - ending - `send_raw_request` raises `MCPError` (`CONNECTION_CLOSED`) and - inbound requests fail the peer's call the same way instead of invoking the - handler. Notifications are fire-and-forget in both directions: after close - they are silently dropped. + to have started and raises `MCPError` (`CONNECTION_CLOSED`) once either + side has closed; notifications are fire-and-forget and silently dropped + after close. """ def __init__(self, transport_ctx: TransportContext, *, raise_handler_exceptions: bool = True): @@ -123,14 +104,11 @@ async def send_raw_request( """Send a request by invoking the peer's `on_request` directly. Raises: - MCPError: The peer's handler raised; `REQUEST_TIMEOUT` if - `opts["timeout"]` elapsed; `CONNECTION_CLOSED` if either - side has closed. + MCPError: The handler raised; `REQUEST_TIMEOUT` on timeout; `CONNECTION_CLOSED` after close. RuntimeError: Called before `run()`. """ if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") - # Post-close sends get the same CONNECTION_CLOSED contract as JSONRPCDispatcher. if self._closed: raise MCPError(code=CONNECTION_CLOSED, message="Connection closed") if not self._running: @@ -140,10 +118,8 @@ async def send_raw_request( async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: """Send a notification by invoking the peer's `on_notify` directly. - Fire-and-forget: usable before `run()` (delivery waits for the peer to - start), and after close it is silently dropped, matching - `JSONRPCDispatcher.notify`. `opts` is accepted for `Dispatcher` - conformance; there is no HTTP layer here so `headers` is ignored. + Fire-and-forget: delivery waits for the peer's `run()`, and after close + it is silently dropped. `opts` is accepted for `Dispatcher` conformance only. """ if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") @@ -159,11 +135,7 @@ async def run( *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: - """Mark this side ready and park until `close()` is called. - - Single-shot, like `JSONRPCDispatcher.run`: once it returns the - dispatcher stays closed and cannot be restarted. - """ + """Mark this side ready and park until `close()`; single-shot like `JSONRPCDispatcher.run`.""" try: self._on_request = on_request self._on_notify = on_notify @@ -174,9 +146,7 @@ async def run( finally: self._running = False self._closed = True - # run() may end via cancellation without close() ever being - # called; setting the event wakes `_wait_ready` waiters so they - # observe the closed state instead of parking forever. + # Cancellation can end run() without close(); set the event so `_wait_ready` waiters see closed. self._close_event.set() def close(self) -> None: @@ -197,11 +167,7 @@ def _make_context( ) async def _wait_ready(self) -> None: - """Park until `run()` has started, waking early if this side closes. - - Raises: - MCPError: `CONNECTION_CLOSED` if this side has closed. - """ + """Park until `run()` has started; raises `MCPError` (`CONNECTION_CLOSED`) if this side closes.""" if not self._ready.is_set() and not self._close_event.is_set(): async with anyio.create_task_group() as tg: @@ -223,8 +189,7 @@ async def _dispatch_request( opts = opts or {} try: with anyio.fail_after(opts.get("timeout")): - # Inside the timeout scope, so a configured timeout also bounds - # waiting on a peer whose run() has not started yet. + # Inside the timeout scope, so the timeout also bounds waiting for a peer whose run() hasn't started. await self._wait_ready() assert self._on_request is not None # Synthesize an id: the DispatchContext contract reserves None for notifications. @@ -235,14 +200,11 @@ async def _dispatch_request( except MCPError: raise except ValidationError as e: - # Same shape JSONRPCDispatcher writes, so runner-over-direct - # tests see what runner-over-JSONRPC would. + # Same shape JSONRPCDispatcher writes, so runner-over-direct tests match runner-over-JSONRPC. raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") from e except Exception as e: - # Single owner of the in-proc exception-to-error policy (mirrors - # JSONRPCDispatcher / `_streamable_http_modern._to_jsonrpc_response` - # for the wire paths). True chains the original for in-process - # debugging; False sanitizes to match the wire path's leak guard. + # True chains the original for in-process debugging; False sanitizes + # to match the wire path's leak guard (JSONRPCDispatcher). if self._raise_handler_exceptions: raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e logger.exception("request handler raised") @@ -259,8 +221,7 @@ async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) try: await self._wait_ready() except MCPError: - # Notifications are fire-and-forget: a notify to a closed peer is - # dropped, not raised back into the sender's call. + # Fire-and-forget: a notify to a closed peer is dropped, not raised back to the sender. logger.debug("dropped notification %r to closed DirectDispatcher", method) return assert self._on_notify is not None @@ -277,18 +238,13 @@ def create_direct_dispatcher_pair( """Create two `DirectDispatcher` instances wired to each other. Args: - can_send_request: Sets `TransportContext.can_send_request` on both - sides. Pass `False` to simulate a transport with no back-channel. - headers: Sets `TransportContext.headers` on both sides. - raise_handler_exceptions: When `True` (the default - this is an - in-process debugging substrate), an unmapped handler exception - reaches the caller as `MCPError` with the original chained as - ``__cause__``. When `False` it is sanitized to an opaque - `INTERNAL_ERROR` so the in-process path matches the wire. + can_send_request: Pass `False` to simulate a transport with no back-channel. + raise_handler_exceptions: When `True` (default), an unmapped handler exception + reaches the caller as `MCPError` with the original chained as `__cause__`; + when `False` it is sanitized to an opaque `INTERNAL_ERROR`, matching the wire path. Returns: - A `(client, server)` pair. The wiring is symmetric, so the roles - are conventional only. + A `(client, server)` pair; the wiring is symmetric, so the roles are conventional only. """ ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers) client = DirectDispatcher(ctx, raise_handler_exceptions=raise_handler_exceptions) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index de83189f13..d56dade230 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -1,19 +1,10 @@ """Dispatcher Protocol - the call/return boundary between transports and handlers. -A Dispatcher turns a duplex message channel into two things: - -* an outbound API: `send_raw_request(method, params)` and `notify(method, params)` -* an inbound pump: `run(on_request, on_notify)` that drives the receive loop - and invokes the supplied handlers for each incoming request/notification - -It is deliberately *not* MCP-aware. Method names are strings, params and -results are `dict[str, Any]`. The MCP type layer (request/result models, -capability negotiation, `Context`) sits above this; the wire encoding -(JSON-RPC, gRPC, in-process direct calls) sits below it. - -See `JSONRPCDispatcher` for the production implementation and -`DirectDispatcher` for an in-memory implementation used in tests and for -embedding a server in-process. +A Dispatcher turns a duplex message channel into an outbound API +(`send_raw_request`, `notify`) and an inbound pump (`run`). It is deliberately +not MCP-aware: methods are strings, params and results are dicts; the MCP type +layer sits above, the wire encoding (JSON-RPC, in-process) below. See +`JSONRPCDispatcher` (production) and `DirectDispatcher` (in-memory). """ from collections.abc import Awaitable, Callable, Mapping @@ -48,7 +39,7 @@ async def __call__(self, progress: float, total: float | None, message: str | No class CallOptions(TypedDict, total=False): """Per-call options for `Outbound.send_raw_request`. - All keys are optional. Dispatchers ignore keys they do not understand. + Dispatchers ignore keys they do not understand. """ timeout: float @@ -67,21 +58,16 @@ class CallOptions(TypedDict, total=False): resumption_token: str """Opaque token to resume a previously interrupted request. - Client-side, streamable-HTTP only. Ignored by server dispatchers and other - transports, and also ignored (with a debug log) for requests sent from a - `DispatchContext`, where routing onto the inbound request's stream takes - precedence. Supports protocol version 2025-11-25 and earlier; SSE-stream - resumption is removed in the next protocol revision. + Client-side, streamable-HTTP only. Ignored (with a debug log) for requests + sent from a `DispatchContext`, where routing onto the inbound request's + stream takes precedence. Protocol version 2025-11-25 and earlier; + SSE-stream resumption is removed in the next protocol revision. """ on_resumption_token: Callable[[str], Awaitable[None]] """Receive a resumption token when the transport issues one for this request. - Client-side, streamable-HTTP only. Ignored by server dispatchers and other - transports, and also ignored (with a debug log) for requests sent from a - `DispatchContext`, where routing onto the inbound request's stream takes - precedence. Supports protocol version 2025-11-25 and earlier; SSE-stream - resumption is removed in the next protocol revision. + Same scope and caveats as `resumption_token`. """ headers: dict[str, str] @@ -90,13 +76,7 @@ class CallOptions(TypedDict, total=False): @runtime_checkable class Outbound(Protocol): - """Anything that can send requests and notifications to the peer. - - Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel - during an inbound request) extend this. The MCP type layer (`ClientPeer`, - `Connection`) builds typed `send_request` / convenience methods on top of - this raw channel. - """ + """Anything that can send requests and notifications to the peer.""" async def send_raw_request( self, @@ -107,9 +87,8 @@ async def send_raw_request( """Send a request and await its raw result dict. Raises: - MCPError: If the peer responded with an error, or the handler - raised. Implementations normalize all handler exceptions to - `MCPError` so callers see a single exception type. + MCPError: If the peer responded with an error or the handler + raised; implementations normalize all handler exceptions to `MCPError`. """ ... @@ -119,13 +98,7 @@ async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Call class DispatchContext(Outbound, Protocol[TransportT_co]): - """Per-request context handed to `on_request` / `on_notify`. - - Carries the transport metadata for the inbound message and provides the - back-channel for sending requests/notifications to the peer while handling - it. `send_raw_request` raises `NoBackChannelError` if `can_send_request` - is `False`. - """ + """Per-request context handed to `on_request` / `on_notify`: transport metadata plus the back-channel.""" @property def transport(self) -> TransportT_co: @@ -136,8 +109,8 @@ def transport(self) -> TransportT_co: def can_send_request(self) -> bool: """Whether the back-channel can currently deliver server-initiated requests. - `False` when the transport has no back-channel, or when this context has - been closed (the inbound request finished). `send_raw_request` raises + `False` when the transport has no back-channel or this context has closed + (the inbound request finished); `send_raw_request` raises `NoBackChannelError` exactly when this is `False`. """ ... @@ -146,9 +119,8 @@ def can_send_request(self) -> bool: def request_id(self) -> RequestId | None: """The id of the inbound request, or `None` for a notification. - For JSON-RPC this is the wire `id` field. Handlers thread it through - as `related_request_id` on outbound notifications so HTTP transports - can route them onto the originating request's response stream. + Threaded through as `related_request_id` on outbound notifications so + HTTP transports can route them onto the originating request's stream. """ ... @@ -156,11 +128,9 @@ def request_id(self) -> RequestId | None: def message_metadata(self) -> MessageMetadata: """The metadata the transport attached to this inbound message, if any. - This is `SessionMessage.metadata` passed through verbatim: HTTP - transports attach `ServerMessageMetadata` (the HTTP request, SSE - stream-close callbacks); stdio and in-memory dispatch attach nothing. - Tied to the `SessionMessage` wire format - goes away when transports - stop delivering messages that way. + `SessionMessage.metadata` passed through verbatim: HTTP transports + attach `ServerMessageMetadata`, stdio and in-memory dispatch attach + nothing. Goes away when transports stop delivering `SessionMessage`s. """ # TODO(maxisbey): remove for context rework ... @@ -171,10 +141,7 @@ def cancel_requested(self) -> anyio.Event: ... async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - """Report progress for the inbound request, if the peer supplied a progress token. - - A no-op when no token was supplied. - """ + """Report progress for the inbound request; a no-op when the peer supplied no progress token.""" ... @@ -188,10 +155,9 @@ async def progress(self, progress: float, total: float | None = None, message: s class Dispatcher(Outbound, Protocol[TransportT_co]): """A duplex request/notification channel with call-return semantics. - Implementations own correlation of outbound requests to inbound results, the - receive loop, per-request concurrency, and cancellation/progress wiring. - - The lifecycle surface is provisional; `run()` may change before v2 stable. + Implementations own request/result correlation, the receive loop, + per-request concurrency, and cancellation/progress wiring. The lifecycle + surface is provisional; `run()` may change before v2 stable. """ async def run( @@ -203,12 +169,9 @@ async def run( ) -> None: """Drive the receive loop until the underlying channel closes. - Each inbound request is dispatched to `on_request` in its own task; - the returned dict (or raised `MCPError`) is sent back as the response. - Inbound notifications go to `on_notify`. - - `task_status.started()` is called once the dispatcher is ready to - accept `send_request`/`notify` calls, so callers can use - `await tg.start(dispatcher.run, on_request, on_notify)`. + Each inbound request is dispatched to `on_request` in its own task; the + returned dict (or raised `MCPError`) is sent back as the response. + `task_status.started()` fires once the dispatcher accepts outbound + calls, so callers can use `await tg.start(dispatcher.run, ...)`. """ ... diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 2f8a539dab..a751767841 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -6,12 +6,9 @@ class MCPDeprecationWarning(UserWarning): - """A custom deprecation warning for the MCP SDK. - - Unlike the built-in `DeprecationWarning`, this inherits from `UserWarning` so - it is shown by default, helping users discover deprecated features without - enabling warnings explicitly. + """Deprecation warning for the MCP SDK. + Inherits from `UserWarning` rather than `DeprecationWarning` so it is shown by default. Reference: https://sethmlarson.dev/deprecations-via-warnings-dont-work-for-python-libraries """ @@ -55,10 +52,8 @@ def __str__(self) -> str: class NoBackChannelError(MCPError): """Raised when sending a server-initiated request over a transport that cannot deliver it. - Stateless HTTP and JSON-response-mode HTTP have no channel for the server to - push requests (sampling, elicitation, roots/list) to the client. This is - raised by `DispatchContext.send_raw_request` when `can_send_request` is - `False`, and serializes to an `INVALID_REQUEST` error response. + Stateless and JSON-response-mode HTTP cannot push server requests (sampling, + elicitation, roots/list) to the client; serializes to an `INVALID_REQUEST` error. """ def __init__(self, method: str): @@ -72,25 +67,12 @@ def __init__(self, method: str): class UrlElicitationRequiredError(MCPError): - """Specialized error for when a tool requires URL mode elicitation(s) before proceeding. - - Servers can raise this error from tool handlers to indicate that the client - must complete one or more URL elicitations before the request can be processed. - - Example: - ```python - raise UrlElicitationRequiredError([ - ElicitRequestURLParams( - message="Authorization required for your files", - url="https://example.com/oauth/authorize", - elicitation_id="auth-001" - ) - ]) - ``` + """Raised by tool handlers when the client must complete URL elicitation(s) before proceeding. + + Serializes to a `URL_ELICITATION_REQUIRED` error with the elicitations in `data`. """ def __init__(self, elicitations: list[ElicitRequestURLParams], message: str | None = None): - """Initialize UrlElicitationRequiredError.""" if message is None: message = f"URL elicitation{'s' if len(elicitations) > 1 else ''} required" diff --git a/src/mcp/shared/inbound.py b/src/mcp/shared/inbound.py index 3eb16495ee..b8f175b987 100644 --- a/src/mcp/shared/inbound.py +++ b/src/mcp/shared/inbound.py @@ -1,14 +1,10 @@ """Inbound request classification for the modern per-request-envelope path. -Pure module: no I/O, no transport, no `mcp.server` imports. Runs the -validation ladder against a decoded JSON-RPC body and returns either an -:class:`InboundModernRoute` (every rung passed) or an -:class:`InboundLadderRejection` (the first rung that failed). Callers map a -rejection's `code` through :data:`ERROR_CODE_HTTP_STATUS` to pick the HTTP -status. - -Also hosts the shared header-value codec and the `x-mcp-header` schema -validator so client emit and server validate read the same source of truth. +Pure module (no I/O, no transport, no `mcp.server` imports): the validation ladder +returns `InboundModernRoute` (every rung passed) or `InboundLadderRejection` (the +first failed rung), whose `code` maps through `ERROR_CODE_HTTP_STATUS` to an HTTP +status. Also hosts the header-value codec and `x-mcp-header` schema validator +shared by client emit and server validate. """ import base64 @@ -73,28 +69,21 @@ "resources/read": "uri", } ) -"""Method β†’ params key whose value is mirrored as the `Mcp-Name` HTTP header. - -Shared by client emit (which header to send) and server validate (which body -field to compare against), so both ends agree on the field by construction. -""" +"""Method β†’ params key mirrored as the `Mcp-Name` HTTP header; shared by client emit and server validate.""" _B64_SENTINEL = re.compile(r"^=\?base64\?(?P.*)\?=$") # RFC 7230 token chars minus DEL; visible ASCII 0x20-0x7E is the practical bound for a header value. _HEADER_SAFE = re.compile(r"^[\x20-\x7E]*$") # RFC 9110 Β§5.6.2 token: the only characters permitted in an HTTP field name. _RFC9110_TOKEN = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$") -# JSON-Schema types the spec permits to carry `x-mcp-header` (transports.mdx -# Β§Custom Headers). `number` is explicitly forbidden β€” floatβ†’str is not -# portable across implementations. +# Types the spec permits to carry `x-mcp-header` (transports.mdx Β§Custom Headers). +# `number` is explicitly forbidden β€” floatβ†’str is not portable across implementations. _X_MCP_HEADER_PRIMITIVE_TYPES: Final = frozenset({"string", "integer", "boolean"}) -# JSON Schema 2020-12 applicator keywords whose values are themselves schema -# positions, grouped by value shape. `properties` is handled separately as the -# only keyword that preserves the statically-reachable chain; every keyword -# here drops the chain to None. Instance-data keywords (`default`, `examples`, -# `const`, `enum`) and `$ref`/`$dynamicRef` are deliberately absent so the -# walk never mistakes data for an annotation and never dereferences. +# JSON Schema 2020-12 applicator keywords grouped by value shape; `properties` alone preserves +# the statically-reachable chain. Instance-data keywords (`default`, `examples`, `const`, `enum`) +# and `$ref`/`$dynamicRef` are deliberately absent so the walk never mistakes data for an +# annotation and never dereferences. _SUBSCHEMA_SINGLE: Final = frozenset( { "items", @@ -117,12 +106,10 @@ def _walk_schema_positions(root: Any) -> Iterator[tuple[tuple[str, ...] | None, dict[str, Any]]]: """Yield `(properties_path, schema)` for every schema position in `root`. - `properties_path` is the chain of `properties` keys from the root to the - position, or `None` once any other applicator keyword has been crossed. - The root itself yields `()`. Only the JSON Schema 2020-12 applicators - listed above are entered; instance-data keywords are not, and `$ref` is - not dereferenced, so the walk terminates on any finite JSON value. An - explicit stack keeps the function total even on pathologically deep input. + `properties_path` is the chain of `properties` keys from the root (itself `()`), + or `None` once any other applicator keyword is crossed. `$ref` is never + dereferenced and the stack is explicit, so the walk terminates on any finite + JSON value, however deep. """ stack: list[tuple[tuple[str, ...] | None, Any]] = [((), root)] while stack: @@ -146,10 +133,8 @@ def _walk_schema_positions(root: Any) -> Iterator[tuple[tuple[str, ...] | None, def encode_header_value(value: str) -> str: """Wrap `value` in the `=?base64?...?=` sentinel when it would not survive an HTTP field round-trip. - Plain printable ASCII without leading/trailing whitespace passes verbatim; - anything else (control chars, non-ASCII, edge whitespace, or a value that - already looks like the sentinel) is base64-wrapped so the receiver can - recover the exact bytes. + Printable ASCII without edge whitespace passes verbatim; control chars, non-ASCII, + edge whitespace, or a value already shaped like the sentinel is base64-wrapped. """ if _HEADER_SAFE.fullmatch(value) and value == value.strip() and not _B64_SENTINEL.fullmatch(value): return value @@ -157,13 +142,10 @@ def encode_header_value(value: str) -> str: def decode_header_value(value: str | None) -> str | None: - """Inverse of :func:`encode_header_value`. + """Inverse of `encode_header_value`; `None` in β†’ `None` out. - Returns the value verbatim unless it carries the `=?base64?...?=` sentinel, - in which case the payload is decoded as UTF-8. A malformed sentinel (bad - base64 or bad UTF-8) yields `None` so a corrupt header never matches a body - value by accident. `None` in β†’ `None` out so callers can pass - `headers.get(...)` directly. + A malformed sentinel (bad base64 or bad UTF-8) yields `None` so a corrupt + header never matches a body value by accident. """ if value is None: return None @@ -179,12 +161,9 @@ def decode_header_value(value: str | None) -> str | None: def find_invalid_x_mcp_header(input_schema: Any) -> str | None: """Return a reason string if any `x-mcp-header` annotation in `input_schema` is invalid; else `None`. - Walks every JSON Schema 2020-12 schema position. An annotation is valid - only when it sits on a property statically reachable from the root via a - chain of pure `properties` keys, names a non-empty RFC 9110 token, is on - an integer/string/boolean property, and is case-insensitively unique - across the whole schema. A `None` / non-mapping schema has no schema - positions and returns `None`. + Valid annotations sit on a property reachable from the root via pure `properties` + keys, name an RFC 9110 token, annotate an integer/string/boolean property, and + are case-insensitively unique across the whole schema. """ seen: dict[str, str] = {} for path, schema in _walk_schema_positions(input_schema): @@ -194,9 +173,8 @@ def find_invalid_x_mcp_header(input_schema: Any) -> str | None: return f"{X_MCP_HEADER_KEY} found at a schema position not reachable via a pure `properties` chain" where = ".".join(path) header = schema[X_MCP_HEADER_KEY] - # Wrong type and malformed value are distinct failures with distinct messages: the - # non-str arm returns before any interpolation, because `repr` of an arbitrary - # schema value is not total (a large `int` exceeds `sys.get_int_max_str_digits`). + # The non-str arm returns before any interpolation: `repr` of an arbitrary schema + # value is not total (a large `int` exceeds `sys.get_int_max_str_digits`). if not isinstance(header, str): return f"property {where!r}: {X_MCP_HEADER_KEY} must be a string, not {type(header).__name__}" if not _RFC9110_TOKEN.fullmatch(header): @@ -224,13 +202,10 @@ def find_invalid_x_mcp_header(input_schema: Any) -> str | None: def x_mcp_header_map(input_schema: Any) -> dict[tuple[str, ...], str]: - """Map each property carrying a valid `x-mcp-header` to its annotation token, keyed by property path. + """Map each property carrying a valid `x-mcp-header` to its token, keyed by `properties`-key path. - The key is the chain of `properties` keys from the schema root to the - annotated property; a top-level property has a one-element path, a nested - one a longer path. Call only on a schema that - :func:`find_invalid_x_mcp_header` accepts; an invalid schema yields an - undefined subset. + Call only on a schema `find_invalid_x_mcp_header` accepts; an invalid schema + yields an undefined subset. """ mapping: dict[tuple[str, ...], str] = {} for path, schema in _walk_schema_positions(input_schema): @@ -242,12 +217,9 @@ def x_mcp_header_map(input_schema: Any) -> dict[tuple[str, ...], str]: def mcp_param_headers(header_map: Mapping[tuple[str, ...], str], arguments: Mapping[str, Any]) -> dict[str, str]: """Build the `Mcp-Param-*` headers a `tools/call` mirrors from its arguments. - For each `(path, token)` in `header_map`, read the value at that property - path in `arguments` and, when it is present and not `None`, emit - `Mcp-Param-` carrying it: `bool` as `true`/`false`, other scalars via - `str`, each passed through :func:`encode_header_value` so a non-token value - is base64-wrapped. A path that hits a missing key or a non-mapping node is - skipped, matching the spec's "omit the header when no value is present". + `bool` renders as `true`/`false`, other scalars via `str`, each passed through + `encode_header_value`. A missing or `None` value skips its header, matching the + spec's "omit the header when no value is present". """ headers: dict[str, str] = {} for path, token in header_map.items(): @@ -270,7 +242,7 @@ def _value_at_path(arguments: Mapping[str, Any], path: tuple[str, ...]) -> Any: # INTERNAL_ERROR is deliberately unmapped (β†’ HTTP 200): the spec assigns no status to -# -32603, and whether handler-origin errors get 5xx is an open S4 question β€” see TODO(L66). +# -32603, and whether handler-origin errors should get 5xx is still an open question. ERROR_CODE_HTTP_STATUS: Final[Mapping[int, int]] = MappingProxyType( { PARSE_ERROR: 400, @@ -282,11 +254,9 @@ def _value_at_path(arguments: Mapping[str, Any], path: tuple[str, ...]) -> Any: METHOD_NOT_FOUND: 404, } ) -"""HTTP status to send for a JSON-RPC `error.code`. +"""HTTP status to send for a JSON-RPC `error.code`, classifier- and handler-origin alike. -Consulted for classifier-origin *and* handler-origin errors, so one table -decides the wire status regardless of where the error was produced. Unmapped -codes fall back to the caller's default (typically 200). +Unmapped codes fall back to the caller's default (typically 200). """ @@ -294,9 +264,8 @@ def _value_at_path(arguments: Mapping[str, Any], path: tuple[str, ...]) -> Any: class InboundModernRoute: """A modern-protocol request whose envelope passed every ladder rung. - `client_info` and `client_capabilities` are the raw envelope values; - the classifier checks presence only, not shape. Method existence is not a - ladder rung β€” kernel dispatch is the single source of truth for that. + `client_info` / `client_capabilities` are the raw envelope values β€” the + classifier checks presence only, not shape. """ protocol_version: str @@ -323,29 +292,21 @@ def classify_inbound_request( Rungs, in order β€” first failure wins: - 1. `params._meta` is a mapping carrying every reserved envelope key - (protocol version, client info, client capabilities) β†’ else - :data:`~mcp_types.jsonrpc.INVALID_PARAMS`. - 2. When `headers` is given, `MCP-Protocol-Version` equals the envelope's - protocol version, `Mcp-Method` equals `body.method`, and β€” for the - methods in :data:`NAME_BEARING_METHODS` β€” `Mcp-Name` equals the named - body param β†’ else :data:`~mcp_types.jsonrpc.HEADER_MISMATCH`. Runs - before the supported-version rung so a client that disagrees with itself - is told so, rather than told the body's version is unsupported. - 3. The envelope's protocol version is in `supported_modern_versions` β†’ - else :data:`~mcp_types.jsonrpc.UNSUPPORTED_PROTOCOL_VERSION` with - `data = {"supported": [...], "requested": }`. + 1. `params._meta` carries every reserved envelope key β†’ else `INVALID_PARAMS`. + 2. When `headers` is given, `MCP-Protocol-Version`, `Mcp-Method`, and (for + `NAME_BEARING_METHODS`) `Mcp-Name` match the body β†’ else `HEADER_MISMATCH`. + Runs before rung 3 so a client that disagrees with itself is told so, + rather than told the body's version is unsupported. + 3. The envelope's protocol version is in `supported_modern_versions` β†’ else + `UNSUPPORTED_PROTOCOL_VERSION` with supported/requested in `data`. Method existence is *not* a rung: kernel dispatch owns that decision so custom-registered methods route and the answer lives in one place. Args: - body: The decoded JSON-RPC request mapping. Envelope shape - (`jsonrpc` / `id`) is not checked here. - headers: Transport headers keyed by lowercase name, or `None` to - skip the header rung (non-HTTP callers). - supported_modern_versions: Modern protocol revisions this server - accepts on the per-request-envelope path. + body: Decoded JSON-RPC request mapping; `jsonrpc`/`id` shape is not checked. + headers: Transport headers keyed by lowercase name, or `None` to skip + rung 2 (non-HTTP callers). """ try: meta = body["params"]["_meta"] diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 64fcd3298d..9dc7a766b6 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -1,8 +1,7 @@ """JSON-RPC `Dispatcher` over the `SessionMessage` stream contract all transports speak. -Owns request-id correlation, the receive loop, per-request task isolation, -cancellation/progress wiring, and the single exception-to-wire boundary; -methods and params are otherwise opaque strings and dicts. +Owns request-id correlation, the receive loop, per-request task isolation, cancellation/progress +wiring, and the single exception-to-wire boundary; methods and params stay opaque. """ from __future__ import annotations @@ -70,12 +69,10 @@ def handler_exception_to_error_data(exc: BaseException) -> ErrorData | None: """Map a handler-raised exception to its wire `ErrorData`. - The two rungs every dispatcher shares: an `MCPError` carries its own - `ErrorData`; a pydantic `ValidationError` is the spec's INVALID_PARAMS - with empty ``data`` (no pydantic text on the wire). Returns ``None`` for - any other exception so each caller applies its own catch-all - - `JSONRPCDispatcher` currently pins ``code=0`` for v1 compat, - the modern HTTP entry uses `INTERNAL_ERROR`. + An `MCPError` carries its own `ErrorData`; a pydantic `ValidationError` is the + spec's INVALID_PARAMS with empty `data` (no pydantic text on the wire). Returns + None for anything else so each caller applies its own catch-all + (`JSONRPCDispatcher` pins `code=0` for v1 compat; the HTTP entry uses INTERNAL_ERROR). """ if isinstance(exc, MCPError): return exc.error @@ -210,9 +207,9 @@ class _OutboundPlan: def _plan_outbound(related_request_id: RequestId | None, opts: CallOptions | None) -> _OutboundPlan: """Choose the outbound `SessionMessage.metadata` and the abandon-cancellation policy. - `related_request_id` wins over resumption hints (they are dropped). Only - hints that actually reach the transport suppress the courtesy cancel - a - request that is neither resumable nor cancelled would leak the peer's work. + `related_request_id` wins over resumption hints (dropped). Only hints that reach + the transport suppress the courtesy cancel - a request that is neither resumable + nor cancelled would leak the peer's work. """ opts = opts or {} cancel_on_abandon = opts.get("cancel_on_abandon", True) @@ -255,16 +252,15 @@ def __init__( """Wire a dispatcher over a transport's `SessionMessage` stream pair. Args: - transport_builder: Builds each message's `TransportContext` from - its `SessionMessage.metadata`. - raise_handler_exceptions: Re-raise handler exceptions out of - `run()` after the error response is written. - inline_methods: Methods awaited in the read loop before the next - message is dequeued (e.g. `initialize`); an inline handler - that awaits the peer deadlocks the parked loop. - on_stream_exception: Observer for `Exception` items on the read - stream; without it they are debug-logged and dropped. Awaited - inline in the read loop, so a slow observer stalls dispatch. + transport_builder: Called per inbound message to build its `TransportContext`. + raise_handler_exceptions: Re-raise handler exceptions out of `run()` + after the error response is written. + inline_methods: Methods awaited in the read loop before the next message + is dequeued (e.g. `initialize`); an inline handler that awaits the + peer deadlocks the parked loop. + on_stream_exception: Observer for `Exception` items on the read stream; + without it they are debug-logged and dropped. Awaited inline, so a + slow observer stalls dispatch. """ self._read_stream = read_stream self._write_stream = write_stream @@ -278,9 +274,11 @@ def __init__( self._raise_handler_exceptions = raise_handler_exceptions self._inline_methods = inline_methods self.on_stream_exception = on_stream_exception - """Observer for ``Exception`` items on the read stream. Mutable so a session can - bind it after the dispatcher is built (e.g. ``ClientSession`` routing into - ``message_handler``); only consulted inside ``run()`` so pre-enter assignment is safe.""" + """Mutable so a session can bind it after construction. + + E.g. `ClientSession` routes into `message_handler`; only consulted inside + `run()`, so pre-enter assignment is safe. + """ self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} @@ -299,13 +297,12 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a JSON-RPC request and await its response. - `_related_request_id` is set only by `_JSONRPCDispatchContext` so that - mid-handler requests route onto the inbound request's SSE stream. + `_related_request_id` is set only by `_JSONRPCDispatchContext` so mid-handler + requests route onto the inbound request's SSE stream. Raises: - MCPError: Peer error response; `REQUEST_TIMEOUT` if - `opts["timeout"]` elapsed; `CONNECTION_CLOSED` if the - transport closed or the dispatcher shut down. + MCPError: Peer error response; `REQUEST_TIMEOUT` if `opts["timeout"]` + elapsed; `CONNECTION_CLOSED` if the transport closed or the dispatcher shut down. RuntimeError: Called before `run()`. """ # Post-close sends get the same CONNECTION_CLOSED contract as in-flight waiters. @@ -330,11 +327,10 @@ async def send_raw_request( self._pending[request_id] = pending plan = _plan_outbound(_related_request_id, opts) - # Spec MUST: only previously-issued requests may be cancelled. A write - # interrupted by cancellation may still have delivered (a memory-stream - # send can hand its item to the receiver and still raise), so a started - # write counts as issued: the peer ignores a cancel for an id it never - # saw, while skipping it would leak a delivered request's handler. + # Spec MUST: only previously-issued requests may be cancelled. A cancelled write + # may still have delivered (a memory-stream send can hand over its item and still + # raise), so a started write counts as issued: the peer ignores a cancel for an + # unseen id, while skipping it would leak a delivered request's handler. request_write_started = False timeout_armed = False @@ -365,10 +361,9 @@ async def send_raw_request( outcome = await receive.receive() except TimeoutError: if not timeout_armed: - # `fail_after` arms only after the write, so this TimeoutError is the - # transport's own bounded send() failing - a transport error, not - # `opts["timeout"]` elapsing. Propagate it raw (v1 kept the write - # outside the timeout-catching try and did the same). + # `fail_after` arms only after the write, so this is the transport's own + # bounded send() failing, not `opts["timeout"]` elapsing - propagate raw + # (v1 did the same by keeping the write outside this try). raise # Courtesy cancel (spec-recommended, new vs v1) so the peer stops work; # unshielded so an outer caller cancellation can still interrupt the write. @@ -386,8 +381,7 @@ async def send_raw_request( ) raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None except anyio.get_cancelled_exc_class(): - # Caller cancelled: bare awaits re-raise here, so the shielded helper - # lets the courtesy cancel go out before we propagate. + # Bare awaits re-raise here, so the shielded helper lets the courtesy cancel out first. if plan.cancel_on_abandon and request_write_started: await self._final_write( partial(self._cancel_outbound, request_id, "caller cancelled", _related_request_id), @@ -416,9 +410,8 @@ async def notify( ) -> None: """Send a fire-and-forget notification. - Fire-and-forget all the way: a post-close send or a write onto a - torn-down transport drops the notification with a debug log instead - of raising (same policy as the response writes and `ctx.notify`). + Post-close sends and writes onto a torn-down transport are dropped with a debug + log instead of raising (same policy as the response writes and `ctx.notify`). """ if self._closed: logger.debug("dropped %s: dispatcher closed", method) @@ -432,7 +425,6 @@ async def notify( try: await self._write(msg, _plan_outbound(_related_request_id, opts).metadata) except (anyio.BrokenResourceError, anyio.ClosedResourceError): - # Transport tore down before run() noticed EOF. logger.debug("dropped %s: write stream closed", method) async def run( @@ -458,8 +450,7 @@ async def run( async with self._read_stream: try: async for item in self._read_stream: - # Duck-typed: only `ContextReceiveStream` carries the - # sender's per-message contextvars snapshot. + # Duck-typed: only `ContextReceiveStream` carries the sender's contextvars. sender_ctx: contextvars.Context | None = getattr( self._read_stream, "last_context", None ) @@ -472,8 +463,7 @@ async def run( self._closed = True self._fan_out_closed() finally: - # Cancel in-flight handlers; otherwise the task-group join - # waits on handlers whose callers are already gone. + # Cancel in-flight handlers; otherwise the join waits on handlers whose callers are gone. tg.cancel_scope.cancel() finally: # Covers cancel/crash paths that skip the inline fan-out; idempotent. @@ -574,9 +564,8 @@ def _dispatch_notification( ) -> None: """Route one inbound notification. - `notifications/cancelled` and `notifications/progress` are intercepted - here (they correlate against the `_in_flight`/`_pending` tables this - layer owns) and still teed to `on_notify` afterwards. + Cancelled/progress notifications are intercepted here (they correlate against + the `_in_flight`/`_pending` tables this layer owns), then still teed to `on_notify`. """ if msg.method == "notifications/cancelled": match msg.params: @@ -635,10 +624,10 @@ def _spawn( *args: object, sender_ctx: contextvars.Context | None, ) -> None: - """Schedule `fn(*args)` in the run() task group, propagating the sender's contextvars. + """Schedule `fn(*args)` in the run() task group under `sender_ctx`. - ASGI middleware (auth, OTel) sets contextvars on the task that wrote the - message; `Context.run` makes the spawned handler inherit that context. + Handlers inherit contextvars that ASGI middleware (auth, OTel) set on the + writing task. """ assert self._tg is not None if sender_ctx is not None: @@ -647,9 +636,9 @@ def _spawn( self._tg.start_soon(fn, *args) def _fan_out_closed(self) -> None: - """Wake every pending `send_raw_request` waiter with `CONNECTION_CLOSED`. + """Wake every pending `send_raw_request` waiter with CONNECTION_CLOSED. - Synchronous: callers may be inside a cancelled scope. Idempotent. + Idempotent, and synchronous because callers may be inside a cancelled scope. """ closed = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") for pending in self._pending.values(): @@ -666,10 +655,7 @@ async def _handle_request( scope: anyio.CancelScope, on_request: OnRequest, ) -> None: - """Run `on_request` for one inbound request and write its response. - - The single exception-to-wire boundary: handler exceptions become `JSONRPCError` here. - """ + """Run `on_request` and write the response - the single exception-to-wire boundary.""" answer_write_started = False try: with scope: @@ -683,26 +669,23 @@ async def _handle_request( key = _coerce_id(req.id) if (entry := self._in_flight.get(key)) is not None and entry.dctx is dctx: del self._in_flight[key] - # A write interrupted by cancellation may still have delivered - # (a memory-stream send can hand its item to the receiver and - # still raise), so a started answer write counts as sent below: - # peers drop late responses, while a second answer for one id - # would break JSON-RPC. + # A cancelled write may still have delivered (a memory-stream send can + # deliver and still raise), so a started answer write counts as sent: + # peers drop late responses, but a second answer for one id breaks JSON-RPC. answer_write_started = True await self._write_result(req.id, result) if scope.cancelled_caught: - # anyio absorbs the scope's own cancel at __exit__, and - # `cancelled_caught` (unlike `cancel_called`) guarantees the - # result write above did not happen - no double response. - # TODO(L38): spec says SHOULD NOT respond after cancel; - # the existing server always has, so match that for now. + # anyio absorbed the scope's own cancel at __exit__; `cancelled_caught` + # (unlike `cancel_called`) guarantees the result write above did not + # happen - no double response. + # TODO(L38): spec says SHOULD NOT respond after cancel; the existing + # server always has, so match that for now. answer_write_started = True await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) except anyio.get_cancelled_exc_class(): - # Shutdown: answer the request so the peer isn't left waiting - unless - # an answer write already started (it may have reached the transport; - # prefer possibly-zero answers over possibly-two). The shielded helper - # is needed because bare awaits re-raise here. + # Shutdown: answer so the peer isn't left waiting - unless an answer write + # already started (prefer possibly-zero answers over possibly-two). Shielded + # because bare awaits re-raise here. if not answer_write_started: await self._final_write( partial(self._write_error, req.id, ErrorData(code=CONNECTION_CLOSED, message="Connection closed")), @@ -753,9 +736,8 @@ async def _final_write( ) -> None: """Attempt one last write under the shared abandon/teardown policy. - `shield=True` is for arms already inside a cancelled scope (a bare - `await` would re-raise); the bound keeps a wedged transport write - from becoming an uncancellable hang. + `shield=True` is for arms already inside a cancelled scope (a bare `await` would + re-raise); the bound keeps a wedged transport write from hanging uncancellably. """ with anyio.move_on_after(timeout, shield=shield) as scope: await write() @@ -763,9 +745,8 @@ async def _final_write( logger.warning("%s gave up: transport write blocked", describe) async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None: - # Thread `related_request_id` so streamable HTTP routes the cancel onto - # the request's own SSE stream instead of a possibly-absent GET stream. - # `notify` swallows connection-state errors itself, so no guard here. + # Thread `related_request_id` so streamable HTTP routes the cancel onto the request's + # own SSE stream, not a possibly-absent GET stream; `notify` swallows write errors itself. await self.notify( "notifications/cancelled", {"requestId": request_id, "reason": reason}, diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 01cab77c85..b36a2614a8 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -14,13 +14,7 @@ @asynccontextmanager async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]: - """Creates a pair of bidirectional memory streams for client-server communication. - - Yields: - A tuple of (client_streams, server_streams) where each is a tuple of - (read_stream, write_stream) - """ - # Create streams for both directions + """Yield in-memory streams as ((client_read, client_write), (server_read, server_write)).""" server_to_client_send, server_to_client_receive = create_context_streams[SessionMessage | Exception](1) client_to_server_send, client_to_server_receive = create_context_streams[SessionMessage | Exception](1) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 236569fac2..239c863a78 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -1,8 +1,4 @@ -"""Message wrapper with metadata support. - -This module defines a wrapper type that combines JSONRPCMessage with metadata -to support transport-specific features like resumability. -""" +"""Wrapper pairing JSONRPCMessage with metadata for transport-specific features like resumability.""" from collections.abc import Awaitable, Callable from dataclasses import dataclass @@ -14,7 +10,6 @@ ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] -# Callback type for closing SSE streams without terminating CloseSSEStreamCallback = Callable[[], Awaitable[None]] @@ -33,13 +28,12 @@ class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None - # Transport-specific request context (e.g. starlette Request for HTTP - # transports, None for stdio). Typed as Any because the server layer is - # transport-agnostic. + # Transport-specific request context (e.g. starlette Request for HTTP, None for stdio). + # Typed as Any because the server layer is transport-agnostic. request_context: Any = None - # Callback to close SSE stream for the current request without terminating + # Closes the current request's SSE connection without terminating its stream (client resumes via Last-Event-ID). close_sse_stream: CloseSSEStreamCallback | None = None - # Callback to close the standalone GET SSE stream (for unsolicited notifications) + # Closes the standalone GET SSE stream (unsolicited notifications). close_standalone_sse_stream: CloseSSEStreamCallback | None = None @@ -48,7 +42,7 @@ class ServerMessageMetadata: @dataclass class SessionMessage: - """A message with specific metadata for transport-specific features.""" + """A JSON-RPC message paired with transport metadata.""" message: JSONRPCMessage metadata: MessageMetadata = None diff --git a/src/mcp/shared/metadata_utils.py b/src/mcp/shared/metadata_utils.py index b646133477..0f1ff48d44 100644 --- a/src/mcp/shared/metadata_utils.py +++ b/src/mcp/shared/metadata_utils.py @@ -1,46 +1,21 @@ -"""Utility functions for working with metadata in MCP types. - -These utilities are primarily intended for client-side usage to properly display -human-readable names in user interfaces in a spec-compliant way. -""" +"""Client-side utilities for displaying human-readable names in a spec-compliant way.""" from mcp_types import Implementation, Prompt, Resource, ResourceTemplate, Tool def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implementation) -> str: - """Get the display name for an MCP object with proper precedence. - - This is a client-side utility function designed to help MCP clients display - human-readable names in their user interfaces. When servers provide a 'title' - field, it should be preferred over the programmatic 'name' field for display. - - For tools: title > annotations.title > name - For other objects: title > name - - Example: - ```python - # In a client displaying available tools - tools = await session.list_tools() - for tool in tools.tools: - display_name = get_display_name(tool) - print(f"Available tool: {display_name}") - ``` - - Args: - obj: An MCP object with name and optional title fields + """Get the display name for an MCP object for UI presentation. - Returns: - The display name to use for UI presentation + Precedence for tools: `title` > `annotations.title` > `name`. For all other + objects: `title` > `name`. """ if isinstance(obj, Tool): - # Tools have special precedence: title > annotations.title > name if hasattr(obj, "title") and obj.title is not None: return obj.title if obj.annotations and hasattr(obj.annotations, "title") and obj.annotations.title is not None: return obj.annotations.title return obj.name else: - # All other objects: title > name if hasattr(obj, "title") and obj.title is not None: return obj.title return obj.name diff --git a/src/mcp/shared/path_security.py b/src/mcp/shared/path_security.py index 0d338eacc2..7429c91cc6 100644 --- a/src/mcp/shared/path_security.py +++ b/src/mcp/shared/path_security.py @@ -1,18 +1,7 @@ """Filesystem path safety primitives for resource handlers. -These functions help MCP servers reject paths that would resolve -outside the served root when extracted URI template parameters are -used in filesystem operations. They are standalone utilities usable from both the -high-level :class:`~mcp.server.mcpserver.MCPServer` and lowlevel server -implementations. - -The canonical safe pattern:: - - from mcp.shared.path_security import safe_join - - @mcp.resource("file://docs/{+path}") - def read_doc(path: str) -> str: - return safe_join("/data/docs", path).read_text() +Helpers to reject paths that would resolve outside the served root when +extracted URI template parameters are used in filesystem operations. """ import string @@ -22,48 +11,18 @@ def read_doc(path: str) -> str: class PathEscapeError(ValueError): - """Raised by :func:`safe_join` when the resolved path escapes the base.""" + """Raised by `safe_join` when the resolved path escapes the base.""" def contains_path_traversal(value: str) -> bool: r"""Check whether a value, treated as a relative path, escapes its origin. - This is a **base-free** check: it does not know the sandbox root, so - it detects only whether ``..`` components would move above the - starting point. Use :func:`safe_join` when you know the root β€” it - additionally catches symlink escapes and absolute-path injection. - - Note: - This is a string-level check on the value as supplied. It does - not model platform-specific filesystem normalisation (e.g. Win32 - stripping of trailing dots and spaces from the final path - component). For filesystem access, use :func:`safe_join`, which - resolves through the OS and verifies containment. - - The check is component-based: ``..`` is dangerous only as a - standalone path segment, not as a substring. Both ``/`` and ``\`` - are treated as separators. - - Example:: - - >>> contains_path_traversal("a/b/c") - False - >>> contains_path_traversal("../etc") - True - >>> contains_path_traversal("a/../../b") - True - >>> contains_path_traversal("a/../b") - False - >>> contains_path_traversal("1.0..2.0") - False - >>> contains_path_traversal("..") - True - - Args: - value: A string that may be used as a filesystem path. - - Returns: - ``True`` if the path would escape its starting directory. + Base-free, string-level check: it only detects `..` segments climbing above + the starting point. `..` counts as a whole segment, not a substring + (`a/../b` and `1.0..2.0` are safe); both `/` and `\` are separators. It does + not model platform normalisation (e.g. Win32 stripping trailing dots and + spaces). When the root is known, use `safe_join`, which resolves through the + OS and additionally catches symlink escapes and absolute-path injection. """ depth = 0 for part in value.replace("\\", "/").split("/"): @@ -77,42 +36,20 @@ def contains_path_traversal(value: str) -> bool: def is_absolute_path(value: str) -> bool: - r"""Check whether a value is an absolute filesystem path. + r"""Check whether a value is an absolute path on any common platform. Absolute paths are dangerous when joined onto a base: in Python, - ``Path("/data") / "/etc/passwd"`` yields ``/etc/passwd`` β€” the - absolute right-hand side silently discards the base. - - Detects POSIX absolute (``/foo``), Windows drive-absolute - (``C:\foo``) and drive-relative (``C:foo``), and Windows - UNC/root-relative (``\\server\share``, ``\foo``). - - Example:: - - >>> is_absolute_path("relative/path") - False - >>> is_absolute_path("/etc/passwd") - True - >>> is_absolute_path("C:\\Windows") - True - >>> is_absolute_path("") - False - - Args: - value: A string that may be used as a filesystem path. - - Returns: - ``True`` if the path is absolute on any common platform. + `Path("/data") / "/etc/passwd"` yields `/etc/passwd`. Detects POSIX absolute + (`/foo`), Windows drive-absolute (`C:\foo`) and drive-relative (`C:foo`), + and Windows UNC/root-relative (`\\server\share`, `\foo`). """ if not value: return False if value[0] in ("/", "\\"): return True - # Windows drive form: C:, C:\, C:foo (drive-relative). A drive- - # relative right-hand side discards the join base when drives - # differ, so flag it even though PureWindowsPath.is_absolute() - # is False. This means single-letter-prefixed identifiers like - # "x:y" also match β€” opt out via ResourceSecurity(exempt_params=). + # Drive-relative C:foo discards the join base when drives differ, so flag it even + # though PureWindowsPath.is_absolute() is False. Single-letter-prefixed identifiers + # like "x:y" also match β€” opt out via ResourceSecurity(exempt_params=). if len(value) >= 2 and value[1] == ":" and value[0] in string.ascii_letters: return True return False @@ -121,50 +58,28 @@ def is_absolute_path(value: str) -> bool: def safe_join(base: str | Path, *parts: str) -> Path: """Join path components onto a base, rejecting escapes. - Resolves the joined path and verifies it remains within ``base``. - This is the **gold-standard** check: it catches ``..`` traversal, - absolute-path injection, and symlink escapes that the base-free - checks cannot. - - The symlink check is point-in-time: a directory swapped for a - symlink between this call and the caller's subsequent open would not - be re-checked. Handlers serving a tree that may be modified - concurrently should additionally open with ``O_NOFOLLOW`` or use - platform path-confinement primitives. - - Example:: - - >>> safe_join("/data/docs", "readme.txt") - PosixPath('/data/docs/readme.txt') - >>> safe_join("/data/docs", "../../../etc/passwd") - Traceback (most recent call last): - ... - PathEscapeError: ... - - Args: - base: The sandbox root. May be relative; it will be resolved. - parts: Path components to join. Each is checked for null bytes - and absolute form before joining. + Resolves the joined path and verifies it stays within `base` (which may be + relative; it is resolved too), catching `..` traversal, absolute-path + injection, and symlink escapes that the base-free checks cannot. The symlink + check is point-in-time: a directory swapped for a symlink between this call + and the caller's open is not re-checked β€” handlers serving a concurrently + modified tree should also open with `O_NOFOLLOW` or use platform + path-confinement primitives. Returns: - The resolved path, verified to be within ``base`` at resolution - time. + The resolved path, verified to be within `base` at resolution time. Raises: - PathEscapeError: If any part contains a null byte, any part is - absolute, or the resolved path is not contained within the - resolved base. + PathEscapeError: If any part contains a null byte or is absolute, or the + resolved path is not contained within the resolved base. """ base_resolved = Path(base).resolve() for part in parts: - # Null bytes pass through Path construction but fail at the - # syscall boundary with a cryptic error. Reject here so callers - # get a clear PathEscapeError instead. + # Null bytes pass Path construction but fail at the syscall boundary with + # a cryptic error; reject here with a clear PathEscapeError instead. if "\0" in part: raise PathEscapeError(f"Path component contains a null byte; refusing to join onto {base_resolved}") - # Absolute parts would silently discard everything to the left - # in Path's / operator. if is_absolute_path(part): raise PathEscapeError(f"Path component {part!r} is absolute; refusing to join onto {base_resolved}") diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index ca59b56af6..813485ae20 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -1,12 +1,8 @@ -"""Typed MCP request sugar over an `Outbound`. +"""Typed server-to-client MCP request sugar over an `Outbound`. -`ClientPeer` wraps any `Outbound` (anything with `send_raw_request` and -`notify`) and exposes the server-to-client request methods (sampling, -elicitation, roots, ping) as typed methods. - -`ClientPeer` does no capability gating: it builds the params, calls -`send_raw_request(method, params)`, and parses the result into the typed -model. Gating (and `NoBackChannelError`) is the wrapped `Outbound`'s job. +`ClientPeer` gives a bare dispatcher (or any `Outbound`) typed `sample`, `elicit_form`, +`elicit_url`, `list_roots`, and `ping` methods. It does no capability gating β€” gating +(and `NoBackChannelError`) is the wrapped `Outbound`'s job. """ from collections.abc import Mapping @@ -42,17 +38,12 @@ def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, Any] | None: - """Serialize a params model to a wire dict, merging `meta` into `_meta`. - - Shared by `ClientPeer` and `Connection` so every typed convenience method - gets the same `_meta` handling. `meta` keys take precedence over any - `_meta` already present on the model. + """Serialize a params model to a wire dict, merging `meta` into `_meta` (`meta` keys win). - `meta` is serialized through `RequestParams` so Python field names emit - their wire aliases: an inbound `ctx.meta` carries `progress_token` (the - key `_extract_meta` validation produces), and forwarding it outbound via - `meta=ctx.meta` must put `progressToken` back on the wire. Keys not - declared on `RequestParamsMeta` pass through unchanged. + `meta` round-trips through `RequestParams` so Python field names emit their wire + aliases: an inbound `ctx.meta` carries `progress_token`, and forwarding it outbound + must put `progressToken` back on the wire. Keys not declared on `RequestParamsMeta` + pass through unchanged. """ out = model.model_dump(by_alias=True, mode="json", exclude_none=True) if model is not None else None if meta: @@ -63,12 +54,7 @@ def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, class ClientPeer: - """Typed server-to-client request methods over a wrapped `Outbound`. - - Use this when you have a bare dispatcher (or any `Outbound`) and want the - typed methods (`sample`, `elicit_form`, `elicit_url`, `list_roots`, - `ping`) without writing your own host class. - """ + """Typed server-to-client request methods over a wrapped `Outbound`.""" def __init__(self, outbound: Outbound) -> None: self._outbound = outbound @@ -142,7 +128,7 @@ async def sample( Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. - pydantic.ValidationError: The peer's result does not match the expected result type. + pydantic.ValidationError: The peer's result does not match the expected type. """ params = CreateMessageRequestParams( messages=messages, @@ -174,7 +160,7 @@ async def elicit_form( Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. - pydantic.ValidationError: The peer's result does not match the expected result type. + pydantic.ValidationError: The peer's result does not match the expected type. """ params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) @@ -194,7 +180,7 @@ async def elicit_url( Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. - pydantic.ValidationError: The peer's result does not match the expected result type. + pydantic.ValidationError: The peer's result does not match the expected type. """ params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) @@ -207,7 +193,7 @@ async def list_roots(self, *, meta: Meta | None = None, opts: CallOptions | None Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. - pydantic.ValidationError: The peer's result does not match the expected result type. + pydantic.ValidationError: The peer's result does not match the expected type. """ result = await self.send_raw_request("roots/list", dump_params(None, meta), opts) return ListRootsResult.model_validate(result, by_name=False) diff --git a/src/mcp/shared/tool_name_validation.py b/src/mcp/shared/tool_name_validation.py index f35efa5a61..c4319ba821 100644 --- a/src/mcp/shared/tool_name_validation.py +++ b/src/mcp/shared/tool_name_validation.py @@ -1,10 +1,4 @@ -"""Tool name validation utilities according to SEP-986. - -Tool names SHOULD be between 1 and 128 characters in length (inclusive). -Tool names are case-sensitive. -Allowed characters: uppercase and lowercase ASCII letters (A-Z, a-z), -digits (0-9), underscore (_), dash (-), and dot (.). -Tool names SHOULD NOT contain spaces, commas, or other special characters. +"""Tool name validation per SEP-986: 1-128 characters from A-Z, a-z, 0-9, `_`, `-`, `.`. See: https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names """ @@ -17,68 +11,49 @@ logger = logging.getLogger(__name__) -# Regular expression for valid tool names according to SEP-986 specification TOOL_NAME_REGEX = re.compile(r"^[A-Za-z0-9._-]{1,128}$") -# SEP reference URL for warning messages SEP_986_URL = "https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names" @dataclass class ToolNameValidationResult: - """Result of tool name validation. - - Attributes: - is_valid: Whether the tool name conforms to SEP-986 requirements. - warnings: List of warning messages for non-conforming aspects. - """ + """Result of tool name validation.""" is_valid: bool warnings: list[str] = field(default_factory=lambda: []) def validate_tool_name(name: str) -> ToolNameValidationResult: - """Validate a tool name according to the SEP-986 specification. - - Args: - name: The tool name to validate. - - Returns: - ToolNameValidationResult containing validation status and any warnings. - """ + """Validate a tool name against SEP-986; a valid name may still carry warnings.""" warnings: list[str] = [] - # Check for empty name if not name: return ToolNameValidationResult( is_valid=False, warnings=["Tool name cannot be empty"], ) - # Check length if len(name) > 128: return ToolNameValidationResult( is_valid=False, warnings=[f"Tool name exceeds maximum length of 128 characters (current: {len(name)})"], ) - # Check for problematic patterns (warnings, not validation failures) if " " in name: warnings.append("Tool name contains spaces, which may cause parsing issues") if "," in name: warnings.append("Tool name contains commas, which may cause parsing issues") - # Check for potentially confusing leading/trailing characters if name.startswith("-") or name.endswith("-"): warnings.append("Tool name starts or ends with a dash, which may cause parsing issues in some contexts") if name.startswith(".") or name.endswith("."): warnings.append("Tool name starts or ends with a dot, which may cause parsing issues in some contexts") - # Check for invalid characters if not TOOL_NAME_REGEX.match(name): - # Find all invalid characters (unique, preserving order) + # Collect invalid characters, unique and in order of first appearance invalid_chars: list[str] = [] seen: set[str] = set() for char in name: @@ -95,12 +70,7 @@ def validate_tool_name(name: str) -> ToolNameValidationResult: def issue_tool_name_warning(name: str, warnings: list[str]) -> None: - """Log warnings for non-conforming tool names. - - Args: - name: The tool name that triggered the warnings. - warnings: List of warning messages to log. - """ + """Log warnings for a non-conforming tool name.""" if not warnings: return @@ -113,17 +83,7 @@ def issue_tool_name_warning(name: str, warnings: list[str]) -> None: def validate_and_warn_tool_name(name: str) -> bool: - """Validate a tool name and issue warnings for non-conforming names. - - This is the primary entry point for tool name validation. It validates - the name and logs any warnings via the logging module. - - Args: - name: The tool name to validate. - - Returns: - True if the name is valid, False otherwise. - """ + """Validate a tool name, log any warnings, and return whether it is valid.""" result = validate_tool_name(name) issue_tool_name_warning(name, result.warnings) return result.is_valid diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py index 55e5f6bc5f..bda6a63eb5 100644 --- a/src/mcp/shared/transport_context.py +++ b/src/mcp/shared/transport_context.py @@ -1,9 +1,7 @@ """Transport-specific metadata attached to each inbound message. -`TransportContext` is the base; each transport defines its own subclass with -whatever fields make sense (HTTP request id, ASGI scope, stdio process handle, -etc.). The dispatcher passes it through opaquely; only the layers above the -dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. +Each transport subclasses `TransportContext`; the dispatcher passes it through +opaquely β€” only `ServerRunner`, `Context`, and user handlers read concrete fields. """ from collections.abc import Mapping @@ -14,10 +12,7 @@ @dataclass(kw_only=True, frozen=True) class TransportContext: - """Base transport metadata for an inbound message. - - Subclass per transport and add fields as needed. Instances are immutable. - """ + """Base transport metadata for an inbound message.""" kind: str """Short identifier for the transport (e.g. `"stdio"`, `"streamable-http"`).""" @@ -25,14 +20,9 @@ class TransportContext: can_send_request: bool """Whether the transport can deliver server-initiated requests to the peer. - `False` for stateless HTTP and HTTP with JSON response mode; `True` for - stdio, SSE, and stateful streamable HTTP. When `False`, + `False` for stateless HTTP and HTTP with JSON response mode, where `DispatchContext.send_raw_request` raises `NoBackChannelError`. """ headers: Mapping[str, str] | None = None - """Request headers carried by this message, when the transport has them. - - Populated by HTTP-based transports; `None` on stdio. Handlers should - None-check before use. - """ + """Request headers carried by this message; populated by HTTP-based transports, `None` on stdio.""" diff --git a/src/mcp/shared/uri_template.py b/src/mcp/shared/uri_template.py index dc57bfa757..becc67fb8f 100644 --- a/src/mcp/shared/uri_template.py +++ b/src/mcp/shared/uri_template.py @@ -1,46 +1,26 @@ """RFC 6570 URI Templates with bidirectional support. -Provides both expansion (template + variables β†’ URI) and matching -(URI β†’ variables). RFC 6570 only specifies expansion; matching is the -inverse operation needed by MCP servers to route ``resources/read`` -requests to handlers. - -Supports Levels 1-3 fully, plus Level 4 explode modifier for path-like -operators (``{/var*}``, ``{.var*}``, ``{;var*}``). The Level 4 prefix -modifier (``{var:N}``) and query-explode (``{?var*}``) are not supported. - -Matching semantics ------------------- - -Matching is not specified by RFC 6570 (Β§1.4 explicitly defers to regex -languages). This implementation uses a two-ended scan that never -backtracks: match time is O(nΒ·v) where n is URI length and v is the -number of template variables. Realistic templates have v < 10, making -this effectively linear; there is no input that produces -superpolynomial time. - -A template may contain **at most one multi-segment variable** β€” -``{+var}``, ``{#var}``, or an explode-modified variable (``{/var*}``, -``{.var*}``, ``{;var*}``). This variable greedily consumes whatever the -surrounding bounded variables and literals do not. Two such variables -in one template are inherently ambiguous (which one gets the extra -segment?) and are rejected at parse time. So are any two variables -adjacent with no literal between them β€” including a variable adjacent -to the multi-segment variable: the scan has nothing to anchor the -boundary on. Operators that emit their own lead character supply that -literal themselves, so ``{+path}{.ext}`` and ``{a}{.b}`` are fine -while ``{+path}{ext}`` and ``{a}{b}`` are not. - -Bounded variables before the multi-segment variable match **lazily** -(first occurrence of the following literal); those after match -**greedily** (last occurrence of the preceding literal). Templates -without a multi-segment variable match greedily throughout, identical -to regex semantics. - -Reserved expansion ``{+var}`` leaves ``?`` and ``#`` unencoded, but -the scan stops at those characters so ``{+path}{?q}`` can separate path -from query. A value containing a literal ``?`` or ``#`` expands fine -but will not round-trip through ``match()``. +RFC 6570 only specifies expansion (template + variables β†’ URI); matching +(URI β†’ variables) is the inverse operation MCP servers need to route +`resources/read` requests. Levels 1-3 are supported fully, plus the Level 4 +explode modifier for path-like operators (`{/var*}`, `{.var*}`, `{;var*}`); +the prefix modifier (`{var:N}`) and query-explode (`{?var*}`) are not. + +Matching (which RFC 6570 Β§1.4 leaves to regex languages) uses a two-ended +scan that never backtracks: O(nΒ·v) in URI length and variable count, so no +input produces superpolynomial time. A template may contain at most one +multi-segment variable β€” `{+var}`, `{#var}`, or an exploded variable β€” which +greedily consumes whatever the surrounding bounded variables and literals do +not; bounded variables before it match lazily, those after it greedily +(templates without one are greedy throughout, like regex). Variables +adjacent with no literal between them are rejected at parse time; operators +that emit a lead character supply that literal, so `{+path}{.ext}` is fine +while `{+path}{ext}` is not. + +Reserved expansion `{+var}` leaves `?` and `#` unencoded, but the scan stops +at those characters so `{+path}{?q}` can separate path from query; a value +containing a literal `?` or `#` expands fine but will not round-trip through +`match()`. """ from __future__ import annotations @@ -65,9 +45,8 @@ _OPERATORS: frozenset[str] = frozenset({"+", "#", ".", "/", ";", "?", "&"}) -# RFC 6570 Β§2.3: varname = varchar *(["."] varchar), varchar = ALPHA / DIGIT / "_" -# Dots appear only between varchar groups β€” not consecutive, not trailing. -# (Percent-encoded varchars are technically allowed but unseen in practice.) +# RFC 6570 Β§2.3: varname = varchar *(["."] varchar), varchar = ALPHA / DIGIT / "_". +# Percent-encoded varchars are technically allowed but unseen in practice. _VARNAME_RE = re.compile(r"^[A-Za-z0-9_]+(?:\.[A-Za-z0-9_]+)*$") DEFAULT_MAX_TEMPLATE_LENGTH = 8_192 @@ -87,11 +66,11 @@ class _OperatorSpec: separator: str """Character between variables (and between exploded list items).""" named: bool - """Emit ``name=value`` pairs (query/path-param style) rather than bare values.""" + """Emit `name=value` pairs (query/path-param style) rather than bare values.""" allow_reserved: bool """Keep reserved characters unencoded ({+var}, {#var}).""" ifemp: str - """Suffix after a named variable whose expanded value is empty (RFC Β§A): '' for ;, '=' for ?/&.""" + """Suffix after a named variable with empty value (RFC Β§A): '' for ;, '=' for ?/&.""" _OPERATOR_SPECS: dict[Operator, _OperatorSpec] = { @@ -105,10 +84,8 @@ class _OperatorSpec: "&": _OperatorSpec(prefix="&", separator="&", named=True, allow_reserved=False, ifemp="="), } -# Per-operator stop characters for the linear scan. A bounded variable's -# value ends at the first occurrence of any character in its stop set, -# mirroring the character-class boundaries a regex would use but without -# the backtracking. +# Per-operator stop set: a bounded variable's value ends at the first stop +# character β€” the character-class boundary a regex would use, minus backtracking. _STOP_CHARS: dict[Operator, str] = { "": "/?#&,", # simple: everything structural is pct-encoded "+": "?#", # reserved: / allowed, stop at query/fragment @@ -126,8 +103,7 @@ class InvalidUriTemplate(ValueError): Attributes: template: The template string that failed to parse. - position: Character offset where the error was detected, or None - if the error is not tied to a specific position. + position: Character offset of the error, or None if not positional. """ def __init__(self, message: str, *, template: str, position: int | None = None) -> None: @@ -147,7 +123,7 @@ class Variable: @dataclass class _Expression: - """A parsed ``{...}`` expression: one operator, one or more variables.""" + """A parsed `{...}` expression: one operator, one or more variables.""" operator: Operator variables: list[Variable] @@ -167,9 +143,8 @@ class _Lit: class _Cap: """A single-variable capture in the flattened match-atom sequence. - ``ifemp`` marks the ``;`` operator's optional-equals quirk: ``{;id}`` - expands to ``;id=value`` or bare ``;id`` when the value is empty, so - the scan must accept both forms. + `ifemp` marks the `;` operator's quirk: `{;id}` expands to `;id=value` + or bare `;id` when the value is empty, so the scan must accept both. """ var: Variable @@ -180,12 +155,7 @@ class _Cap: def _is_greedy(var: Variable) -> bool: - """Return True if this variable can span multiple path segments. - - Reserved/fragment expansion and explode variables are the only - constructs whose match range is not bounded by a single structural - delimiter. A template may contain at most one such variable. - """ + """True if the variable's match range is unbounded by a single delimiter (at most one per template).""" return var.explode or var.operator in ("+", "#") @@ -203,18 +173,14 @@ def _is_str_sequence(value: object) -> bool: def _encode(value: str, *, allow_reserved: bool) -> str: """Percent-encode a value per RFC 6570 Β§3.2.1. - Simple expansion encodes everything except unreserved characters. - Reserved expansion (``{+var}``, ``{#var}``) additionally keeps - RFC 3986 reserved characters intact and passes through existing - ``%XX`` pct-triplets unchanged (RFC 6570 Β§3.2.3). A bare ``%`` not - followed by two hex digits is still encoded to ``%25``. + Reserved expansion ({+var}, {#var}) keeps RFC 3986 reserved characters + and existing `%XX` triplets intact (Β§3.2.3); a bare `%` not followed by + two hex digits is still encoded to `%25`. """ if not allow_reserved: return quote(value, safe="") - # Reserved expansion: walk the string, pass through triplets as-is, - # quote the gaps between them. A bare % with no triplet lands in a - # gap and gets encoded normally. + # Pass triplets through as-is, quote the gaps; a bare % lands in a gap. out: list[str] = [] last = 0 for m in _PCT_TRIPLET_RE.finditer(value): @@ -226,25 +192,22 @@ def _encode(value: str, *, allow_reserved: bool) -> str: def _expand_expression(expr: _Expression, variables: Mapping[str, str | Sequence[str]]) -> str: - """Expand a single ``{...}`` expression into its URI fragment. + """Expand a single `{...}` expression into its URI fragment. - Walks the expression's variables, encoding and joining defined ones - according to the operator's spec. Undefined variables are skipped - (RFC 6570 Β§2.3); if all are undefined, the expression contributes - nothing (no prefix is emitted). + Undefined variables are skipped (RFC 6570 Β§2.3); if all are undefined, + the expression contributes nothing (no prefix is emitted). """ spec = _OPERATOR_SPECS[expr.operator] rendered: list[str] = [] for var in expr.variables: if var.name not in variables: - # Undefined: skip entirely, no placeholder. continue value = variables[var.name] - # Explicit type guard: reject non-str scalars with a clear message - # rather than a confusing "not iterable" from the sequence branch. + # Reject non-str scalars here for a clear message rather than a + # confusing "not iterable" from the sequence branch. if not isinstance(value, str) and not _is_str_sequence(value): raise TypeError(f"Variable {var.name!r} must be str or a sequence of str, got {type(value).__name__}") @@ -255,12 +218,10 @@ def _expand_expression(expr: _Expression, variables: Mapping[str, str | Sequence else: rendered.append(encoded) else: - # Sequence value. items = [_encode(v, allow_reserved=spec.allow_reserved) for v in value] if not items: continue if var.explode: - # Each item gets the operator's separator; named ops repeat the key. if spec.named: rendered.append( spec.separator.join(f"{var.name}{spec.ifemp}" if v == "" else f"{var.name}={v}" for v in items) @@ -268,9 +229,7 @@ def _expand_expression(expr: _Expression, variables: Mapping[str, str | Sequence else: rendered.append(spec.separator.join(items)) else: - # Non-explode: comma-join into a single value, then apply - # ifemp to the joined result (RFC Β§3.2.1: behaves as if the - # value were the joined string). + # Non-explode: comma-join, then ifemp applies to the joined value (RFC 6570 Β§3.2.1). joined = ",".join(items) if spec.named: rendered.append(f"{var.name}{spec.ifemp}" if joined == "" else f"{var.name}={joined}") @@ -286,8 +245,8 @@ def _expand_expression(expr: _Expression, variables: Mapping[str, str | Sequence class UriTemplate: """A parsed RFC 6570 URI template. - Construct via :meth:`parse`. Instances are immutable and hashable; - equality is based on the template string alone. + Construct via :meth:`parse`. Immutable and hashable; equality is based + on the template string alone. """ template: str @@ -300,22 +259,10 @@ class UriTemplate: @staticmethod def is_template(value: str) -> bool: - """Check whether a string contains URI template expressions. + """Check whether a string contains at least one `{...}` pair. - A cheap heuristic for distinguishing concrete URIs from templates - without the cost of full parsing. Returns ``True`` if the string - contains at least one ``{...}`` pair. - - Example:: - - >>> UriTemplate.is_template("file://docs/{name}") - True - >>> UriTemplate.is_template("file://docs/readme.txt") - False - - Note: - This does not validate the template. A ``True`` result does - not guarantee :meth:`parse` will succeed. + A cheap heuristic for distinguishing concrete URIs from templates; + `True` does not guarantee :meth:`parse` will succeed. """ open_i = value.find("{") return open_i != -1 and value.find("}", open_i) != -1 @@ -331,14 +278,10 @@ def parse( """Parse a URI template string. Args: - template: An RFC 6570 URI template. - max_length: Maximum permitted length of the template string. - Guards against resource exhaustion. - max_variables: Maximum number of variables permitted across - all expressions. Counting variables rather than - ``{...}`` expressions closes the gap where a single - ``{v0,v1,...,vN}`` expression packs arbitrarily many - variables under one expression count. + max_length: Maximum template length; guards resource exhaustion. + max_variables: Maximum variables across all expressions β€” + counted per variable, not per expression, so `{v0,...,vN}` + cannot pack arbitrarily many under one expression. Raises: InvalidUriTemplate: If the template is malformed, exceeds the @@ -352,9 +295,7 @@ def parse( parts, variables = _parse(template, max_variables=max_variables) - # Trailing {?...}/{&...} expressions are split off and matched as - # a query string (order-agnostic, partial, extras ignored) rather - # than via the linear scan. + # Trailing {?...}/{&...} runs are matched as a lenient query string, not via the linear scan. path_parts, query_vars = _split_query_tail(parts) atoms = _flatten(path_parts) prefix, greedy, suffix = _partition_greedy(atoms, template) @@ -383,28 +324,22 @@ def variable_names(self) -> list[str]: def query_variable_names(self) -> frozenset[str]: """Names of variables that :meth:`match` treats as optional query parameters. - These are the variables in a trailing run of ``{?...}``/``{&...}`` - expressions, which are matched leniently: a URI that omits some - (or all) of them still matches, and the omitted names are simply - absent from the result. Any value bound to such a name therefore - needs a fallback for the omitted case. - - Every other variable is bound on every successful :meth:`match` - (possibly to an empty string) and is *not* in this set. That - includes a ``{&...}`` expression with no preceding ``{?...}``: it - never emits the ``?`` the lenient query split keys on, so it is - matched strictly. + Variables in a trailing run of `{?...}`/`{&...}` expressions are + matched leniently: a URI may omit any of them, and omitted names are + absent from the result. Every other variable is bound on every + successful match (possibly to an empty string) β€” including a + `{&...}` with no preceding `{?...}`, which never emits the `?` the + lenient split keys on and is therefore matched strictly. """ return frozenset(v.name for v in self._query_variables) def expand(self, variables: Mapping[str, str | Sequence[str]]) -> str: """Expand the template by substituting variable values. - String values are percent-encoded according to their operator: - simple ``{var}`` encodes reserved characters; ``{+var}`` and - ``{#var}`` leave them intact. Sequence values are joined with - commas for non-explode variables, or with the operator's - separator for explode variables. + String values are percent-encoded per their operator: simple `{var}` + encodes reserved characters; `{+var}` and `{#var}` leave them + intact. Sequence values are comma-joined, or joined with the + operator's separator for explode variables. Example:: @@ -412,41 +347,20 @@ def expand(self, variables: Mapping[str, str | Sequence[str]]) -> str: >>> t.expand({"name": "hello world.txt"}) 'file://docs/hello%20world.txt' - >>> t = UriTemplate.parse("file://docs/{+path}") - >>> t.expand({"path": "src/main.py"}) - 'file://docs/src/main.py' - >>> t = UriTemplate.parse("/search{?q,lang}") >>> t.expand({"q": "mcp", "lang": "en"}) '/search?q=mcp&lang=en' - >>> t = UriTemplate.parse("/files{/path*}") - >>> t.expand({"path": ["a", "b", "c"]}) - '/files/a/b/c' - - Args: - variables: Values for each template variable. Keys must be - strings; values must be ``str`` or a sequence of ``str``. - - Returns: - The expanded URI string. - Note: - Per RFC 6570, variables absent from the mapping are - **silently omitted**. This is the correct behavior for - optional query parameters (``{?page}`` with no page yields - no ``?page=``), but for required path segments it produces - a structurally incomplete URI. If you need all variables - present, validate before calling:: - - missing = set(t.variable_names) - variables.keys() - if missing: - raise ValueError(f"Missing: {missing}") + Per RFC 6570, variables absent from the mapping are silently + omitted β€” correct for optional query parameters, but a missing + path variable yields a structurally incomplete URI. Check + `set(t.variable_names) - variables.keys()` first if you need + all variables present. Raises: - TypeError: If a value is neither ``str`` nor an iterable of - ``str``. Non-string scalars (``int``, ``None``) are not - coerced. + TypeError: If a value is neither `str` nor an iterable of + `str`; non-string scalars (`int`, `None`) are not coerced. """ out: list[str] = [] for part in self._parts: @@ -459,71 +373,47 @@ def expand(self, variables: Mapping[str, str | Sequence[str]]) -> str: def match(self, uri: str, *, max_uri_length: int = DEFAULT_MAX_URI_LENGTH) -> dict[str, str | list[str]] | None: """Match a concrete URI against this template and extract variables. - This is the inverse of :meth:`expand`. The URI is matched via a - linear scan of the template and captured values are - percent-decoded. The round-trip ``match(expand({k: v})) == {k: v}`` - holds when ``v`` does not contain its operator's separator - unencoded: ``{.ext}`` with ``ext="tar.gz"`` expands to - ``.tar.gz`` but does not match β€” the scan stops ``ext`` at the - first ``.`` and the trailing ``.gz`` has nothing to consume it. - RFC 6570 Β§1.4 notes this is an inherent reversal limitation. - - Matching is structural at the URI level only: a simple ``{name}`` - will not match across a literal ``/`` in the URI (the scan stops - there), but a percent-encoded ``%2F`` that decodes to ``/`` is - accepted as part of the value. Path-safety validation belongs at - a higher layer; see :mod:`mcp.shared.path_security`. + The inverse of :meth:`expand`; captured values are percent-decoded. + The round-trip `match(expand({k: v})) == {k: v}` holds when `v` does + not contain its operator's separator unencoded: `{.ext}` with + `ext="tar.gz"` expands but does not match back β€” an inherent + reversal limitation noted by RFC 6570 Β§1.4. + + Matching is structural at the URI level only: a simple `{name}` + will not match across a literal `/`, but a percent-encoded `%2F` + that decodes to `/` is accepted as part of the value. Path-safety + validation belongs at a higher layer; see + :mod:`mcp.shared.path_security`. + + Trailing query expressions (`{?q,lang}`) match leniently: + order-agnostic, partial, unrecognized params ignored, and absent + params omitted from the result so downstream defaults can apply. Example:: >>> t = UriTemplate.parse("file://docs/{name}") - >>> t.match("file://docs/readme.txt") - {'name': 'readme.txt'} >>> t.match("file://docs/hello%20world.txt") {'name': 'hello world.txt'} - >>> t = UriTemplate.parse("file://docs/{+path}") - >>> t.match("file://docs/src/main.py") - {'path': 'src/main.py'} - - >>> t = UriTemplate.parse("/files{/path*}") - >>> t.match("/files/a/b/c") - {'path': ['a', 'b', 'c']} - - **Query parameters** (``{?q,lang}`` at the end of a template) - are matched leniently: order-agnostic, partial, and unrecognized - params are ignored. Absent params are omitted from the result so - downstream function defaults can apply:: - >>> t = UriTemplate.parse("logs://{service}{?since,level}") - >>> t.match("logs://api") - {'service': 'api'} >>> t.match("logs://api?level=error") {'service': 'api', 'level': 'error'} - >>> t.match("logs://api?level=error&since=5m&utm=x") - {'service': 'api', 'since': '5m', 'level': 'error'} Args: - uri: A concrete URI string. - max_uri_length: Maximum permitted length of the input URI. - Oversized inputs return ``None`` without scanning, + max_uri_length: Oversized inputs return None without scanning, guarding against resource exhaustion. Returns: - A mapping from variable names to decoded values (``str`` for - scalar variables, ``list[str]`` for explode variables), or - ``None`` if the URI does not match the template or exceeds - ``max_uri_length``. + Variable names mapped to decoded values (`str`, or `list[str]` + for explode variables), or None if the URI does not match or + exceeds `max_uri_length`. """ if len(uri) > max_uri_length: return None if self._query_variables: - # Two-phase: scan matches the path, the query is split and - # decoded manually. Query params may be partial, reordered, - # or include extras; absent params stay absent so downstream - # defaults can apply. Fragment is stripped first since the - # template's {?...} tail never describes a fragment. + # Scan the path, then decode the query separately. Fragment is + # stripped first: the template's {?...} tail never describes one. before_fragment, _, _ = uri.partition("#") path, _, query = before_fragment.partition("?") result = self._scan(path) @@ -551,11 +441,9 @@ def _scan(self, uri: str) -> dict[str, str | list[str]] | None: suffix_result, suffix_start = suffix return suffix_result if suffix_start == 0 else None - # Greedy var present. The parser rejects a capture adjacent to - # the greedy slot, so a non-empty suffix begins with a _Lit whose - # rfind-derived anchor does not depend on how far the prefix - # scans. Scan the suffix first, then give the prefix that exact - # position as its ceiling so it cannot consume past the anchor. + # The parser rejects a capture adjacent to the greedy slot, so a + # non-empty suffix begins with a _Lit whose rfind anchor is independent + # of the prefix scan: scan the suffix first, then cap the prefix at it. suffix = _scan_suffix(self._suffix, uri, n, anchored=False) if suffix is None: return None @@ -565,11 +453,9 @@ def _scan(self, uri: str) -> dict[str, str | list[str]] | None: return None prefix_result, prefix_end = prefix - # Prefix consumed [0, prefix_end); suffix consumed [suffix_start, n); - # the greedy var takes the gap. The prefix scan is bounded by - # suffix_start, so this holds by construction; guard explicitly - # rather than asserting so a future regression surfaces as a - # non-match, not an exception. + # The greedy var takes [prefix_end, suffix_start). The prefix scan is + # bounded by suffix_start, so this holds by construction; guard rather + # than assert so a future regression surfaces as a non-match. if suffix_start < prefix_end: return None # pragma: no cover - unreachable while bounds hold middle = uri[prefix_end:suffix_start] @@ -586,19 +472,12 @@ def __str__(self) -> str: def _parse_query(query: str) -> dict[str, str]: """Parse a query string into a nameβ†’value mapping. - Unlike ``urllib.parse.parse_qs``, this follows RFC 3986 semantics: - ``+`` is a literal sub-delim, not a space. Form-urlencoding treats - ``+`` as space for HTML form submissions, but RFC 6570 and MCP - resource URIs follow RFC 3986 where only ``%20`` encodes a space. - - Parameter names are **not** percent-decoded. RFC 6570 expansion - never encodes variable names, so a legitimate match will always - have the name in literal form. Decoding names would let - ``%74oken=evil&token=real`` shadow the real ``token`` parameter - via first-wins. - - Duplicate keys keep the first value. Pairs without ``=`` are - treated as empty-valued. + Unlike `urllib.parse.parse_qs`, this follows RFC 3986 rather than + form-urlencoding: `+` is a literal sub-delim, only `%20` is a space. + Names are not percent-decoded β€” RFC 6570 expansion never encodes them, + and decoding would let `%74oken=evil&token=real` shadow the real `token` + via first-wins. Duplicate keys keep the first value; pairs without `=` + are empty-valued. """ result: dict[str, str] = {} for pair in query.split("&"): @@ -611,10 +490,9 @@ def _parse_query(query: str) -> dict[str, str]: def _extract_greedy(var: Variable, raw: str) -> str | list[str] | None: """Decode the greedy variable's isolated middle span. - For scalar greedy (``{+var}``, ``{#var}``) this is a stop-char - validation and a single ``unquote``. For explode variables the span - is a run of separator-delimited segments (``/a/b/c`` or - ``;keys=a;keys=b``) that is split, validated, and decoded per item. + Scalar greedy ({+var}, {#var}): stop-char validation plus one unquote. + Explode: split the separator-delimited run (`/a/b/c`, `;keys=a;keys=b`), + validate, and decode per item. """ spec = _OPERATOR_SPECS[var.operator] stops = _STOP_CHARS[var.operator] @@ -627,27 +505,22 @@ def _extract_greedy(var: Variable, raw: str) -> str | list[str] | None: sep = spec.separator if not raw: return [] - # A non-empty explode span must begin with the separator: {/a*} - # expands to "/x/y", never "x/y". The scan does not consume the - # separator itself, so it must be the first character here. + # A non-empty explode span must begin with the separator: {/a*} expands + # to "/x/y", never "x/y", and the scan leaves the separator in the span. if raw[0] != sep: return None - # Segments must not contain the operator's non-separator stop - # characters (e.g. {/path*} segments may contain neither ? nor #). + # Segments must not contain the operator's other stop chars (e.g. ?/# under {/path*}). body_stops = set(stops) - {sep} if any(c in body_stops for c in raw): return None segments: list[str] = [] prefix = f"{var.name}=" - # split()[0] is always "" because raw starts with the separator; - # subsequent empties are legitimate values ({/path*} with - # ["a","","c"] expands to /a//c). + # split()[0] is always "" (raw starts with the separator); later + # empties are legitimate values (/a//c). for seg in raw.split(sep)[1:]: if spec.named: - # Named explode emits name=value per item (or bare name - # under ; with empty value). Validate the name and strip - # the prefix before decoding. + # Named explode emits name=value per item (bare name under ; when empty). if seg.startswith(prefix): seg = seg[len(prefix) :] elif seg == var.name: @@ -659,19 +532,11 @@ def _extract_greedy(var: Variable, raw: str) -> str | list[str] | None: def _split_query_tail(parts: list[_Part]) -> tuple[list[_Part], list[Variable]]: - """Separate trailing ``?``/``&`` expressions from the path portion. - - Lenient query matching (order-agnostic, partial, ignores extras) - applies when a template ends with one or more consecutive ``?``/``&`` - expressions and the preceding path portion contains no literal - ``?``. If the path has a literal ``?`` (e.g., ``?fixed=1{&page}``), - the URI's ``?`` split won't align with the template's expression - boundary, so the strict scan is used instead. - - Returns: - A pair ``(path_parts, query_vars)``. If lenient matching does - not apply, ``query_vars`` is empty and ``path_parts`` is the - full input. + """Separate trailing `?`/`&` expressions from the path portion. + + Lenient query matching applies when the template ends with consecutive + `?`/`&` expressions; when it does not apply, `query_vars` is empty and + `path_parts` is the full input. """ split = len(parts) for i in range(len(parts) - 1, -1, -1): @@ -684,18 +549,16 @@ def _split_query_tail(parts: list[_Part]) -> tuple[list[_Part], list[Variable]]: if split == len(parts): return parts, [] - # The tail must start with a {?...} expression so that expand() - # emits a ? the URI can split on. A standalone {&page} expands - # with an & prefix, which partition("?") won't find. + # The tail must start with {?...} so expand() emits the ? the URI is + # split on; a leading {&page} expands with & and partition("?") misses it. first = parts[split] assert isinstance(first, _Expression) if first.operator != "?": return parts, [] - # If the path portion contains a literal ?/# or a {?...}/{#...} - # expression, lenient matching's partition("#") then partition("?") - # would strip content the path scan expects to see. Fall back to - # the strict scan. + # A literal ?/# or a {?...}/{#...} expression in the path would be + # stripped by the lenient partitions before the path scan sees it; + # fall back to the strict scan. for part in parts[:split]: if isinstance(part, str): if "?" in part or "#" in part: @@ -714,14 +577,9 @@ def _split_query_tail(parts: list[_Part]) -> tuple[list[_Part], list[Variable]]: def _parse(template: str, *, max_variables: int) -> tuple[list[_Part], list[Variable]]: """Split a template into an ordered sequence of literals and expressions. - Walks the string, alternating between collecting literal runs and - parsing ``{...}`` expressions. The resulting ``parts`` sequence - preserves positional interleaving so ``match()`` and ``expand()`` can - walk it in order. - Raises: - InvalidUriTemplate: On unclosed braces, too many expressions, or - any error surfaced by :func:`_parse_expression`. + InvalidUriTemplate: On unclosed braces, too many variables, or any + error surfaced by :func:`_parse_expression`. """ parts: list[_Part] = [] variables: list[Variable] = [] @@ -729,16 +587,13 @@ def _parse(template: str, *, max_variables: int) -> tuple[list[_Part], list[Vari n = len(template) while i < n: - # Find the next expression opener from the current cursor. brace = template.find("{", i) if brace == -1: - # No more expressions; everything left is a trailing literal. parts.append(template[i:]) break if brace > i: - # Literal text between cursor and the brace. parts.append(template[i:brace]) end = template.find("}", brace) @@ -749,7 +604,6 @@ def _parse(template: str, *, max_variables: int) -> tuple[list[_Part], list[Vari position=brace, ) - # Delegate body (between braces, exclusive) to the expression parser. expr = _parse_expression(template, template[brace + 1 : end], brace) parts.append(expr) variables.extend(expr.variables) @@ -760,7 +614,6 @@ def _parse(template: str, *, max_variables: int) -> tuple[list[_Part], list[Vari template=template, ) - # Advance past the closing brace. i = end + 1 _check_duplicate_variables(template, variables) @@ -769,17 +622,11 @@ def _parse(template: str, *, max_variables: int) -> tuple[list[_Part], list[Vari def _parse_expression(template: str, body: str, pos: int) -> _Expression: - """Parse the body of a single ``{...}`` expression. + """Parse the body (between braces) of a single `{...}` expression. - The body is everything between the braces. It consists of an optional - leading operator character followed by one or more comma-separated - variable specifiers. Each specifier is a name with an optional - trailing ``*`` (explode modifier). - - Args: - template: The full template string, for error reporting. - body: The expression body, braces excluded. - pos: Character offset of the opening brace, for error reporting. + The body is an optional leading operator followed by comma-separated + `name[*]` variable specifiers. `template` and `pos` (offset of the + opening brace) are for error reporting only. Raises: InvalidUriTemplate: On empty body, invalid variable names, or @@ -800,7 +647,6 @@ def _parse_expression(template: str, body: str, pos: int) -> _Expression: position=pos, ) - # Remaining body is comma-separated variable specs: name[*] variables: list[Variable] = [] for spec in body.split(","): if ":" in spec: @@ -820,9 +666,8 @@ def _parse_expression(template: str, body: str, pos: int) -> _Expression: position=pos, ) - # Explode only makes sense for operators that repeat a separator. - # Simple/reserved/fragment have no per-item separator; query-explode - # needs order-agnostic dict matching which we don't support yet. + # Simple/reserved/fragment have no per-item separator to explode on; + # query-explode needs order-agnostic dict matching, unsupported so far. if explode and operator in ("", "+", "#", "?", "&"): raise InvalidUriTemplate( f"Explode modifier on {{{operator}{name}*}} is not supported for matching", @@ -838,10 +683,9 @@ def _parse_expression(template: str, body: str, pos: int) -> _Expression: def _check_duplicate_variables(template: str, variables: list[Variable]) -> None: """Reject templates that use the same variable name more than once. - RFC 6570 requires repeated variables to expand to the same value, - which would require backreference matching with potentially - exponential cost. Rather than silently returning only the last - captured value, we reject at parse time. + RFC 6570 requires repeated variables to expand to the same value, which + matching would need potentially-exponential backreference support for; + reject at parse time rather than silently keeping the last capture. Raises: InvalidUriTemplate: If any variable name appears more than once. @@ -857,11 +701,10 @@ def _check_duplicate_variables(template: str, variables: list[Variable]) -> None def _check_single_query_expression(template: str, parts: list[_Part]) -> None: - """Reject templates with more than one ``{?...}`` expression. + """Reject templates with more than one `{?...}` expression. - The ``?`` operator emits a leading ``?``, so two such expressions - expand to a URI with two ``?`` characters β€” malformed per RFC 3986 - Β§3.4. Use ``{?a,b}`` or ``{?a}{&b}`` for multiple query parameters. + Two would expand to a URI with two `?` characters β€” malformed per + RFC 3986 Β§3.4. Use `{?a,b}` or `{?a}{&b}` instead. """ seen = False for part in parts: @@ -878,14 +721,10 @@ def _check_single_query_expression(template: str, parts: list[_Part]) -> None: def _flatten(parts: list[_Part]) -> list[_Atom]: """Lower expressions into a flat sequence of literals and single-variable captures. - Operator prefixes and separators become explicit ``_Lit`` atoms so - the scan only ever sees two atom kinds. Adjacent literals are - coalesced so that anchor-finding (``find``/``rfind``) operates on - the longest possible literal, reducing false matches. - - Explode variables emit no lead literal: the explode capture - includes its own separator-prefixed repetitions (``{/a*}`` β†’ - ``/x/y/z``, not ``/`` then ``x/y/z``). + Operator prefixes and separators become explicit `_Lit` atoms; adjacent + literals are coalesced so find/rfind anchors on the longest run. Explode + variables emit no lead literal β€” the capture includes its own + separator-prefixed repetitions (`{/a*}` β†’ `/x/y/z`). """ atoms: list[_Atom] = [] @@ -907,8 +746,7 @@ def push_lit(text: str) -> None: if var.explode: atoms.append(_Cap(var)) elif spec.named: - # ; uses ifemp (bare name when empty); ? and & always - # emit name= so the equals is part of the literal. + # ; uses ifemp (bare name when empty); ?/& always emit name= as literal. if part.operator == ";": push_lit(f"{lead}{var.name}") atoms.append(_Cap(var, ifemp=True)) @@ -924,19 +762,14 @@ def push_lit(text: str) -> None: def _partition_greedy(atoms: list[_Atom], template: str) -> tuple[list[_Atom], Variable | None, list[_Atom]]: """Split atoms at the single greedy variable, if any. - Returns ``(prefix, greedy_var, suffix)``. If there is no greedy - variable the entire atom list is returned as the suffix so that - the right-to-left scan (which matches regex-greedy semantics) - handles it. + With no greedy variable the entire atom list is returned as the suffix + so the right-to-left scan (regex-greedy semantics) handles it. Raises: - InvalidUriTemplate: If two variables are adjacent with no - literal between them β€” whether or not one is the - multi-segment variable, the scan has nothing to anchor the - boundary on β€” or if more than one multi-segment variable - is present (two are inherently ambiguous: there is no - principled way to decide which one absorbs an extra - segment). + InvalidUriTemplate: If two variables are adjacent with no literal + between them (the scan has nothing to anchor the boundary on), + or if more than one multi-segment variable is present (which + one absorbs an extra segment is inherently ambiguous). """ greedy_idx: int | None = None prev: _Atom | None = None @@ -969,16 +802,13 @@ def _partition_greedy(atoms: list[_Atom], template: str) -> tuple[list[_Atom], V def _scan_suffix( atoms: Sequence[_Atom], uri: str, end: int, *, anchored: bool ) -> tuple[dict[str, str | list[str]], int] | None: - """Scan atoms right-to-left from ``end``, returning captures and start position. - - Each bounded variable takes the minimum span that lets its - preceding literal match (found via ``rfind``), which makes the - *first* variable in template order greedy β€” identical to Python - regex semantics for a sequence of greedy groups. + """Scan atoms right-to-left from `end`, returning captures and start position. - When ``anchored`` is true the atom sequence is the entire template - (no greedy variable), so ``atoms[0]`` must match at URI position 0 - rather than at its rightmost occurrence. + Each bounded variable takes the minimum span that lets its preceding + literal match (rfind), making the first variable in template order + greedy β€” identical to Python regex semantics for greedy groups. When + `anchored`, the atoms are the entire template (no greedy variable) and + `atoms[0]` must match at position 0, not at its rightmost occurrence. """ result: dict[str, str | list[str]] = {} pos = end @@ -998,9 +828,8 @@ def _scan_suffix( prev = atoms[i - 1] if i > 0 else None if atom.ifemp: - # ;name or ;name=value. The preceding _Lit is ";name". - # Try empty first: if the lit ends at pos the value is - # absent (RFC ifemp). Otherwise require =value. + # ;name or ;name=value, preceding _Lit is ";name": try the + # empty (bare-name) form first, else require =value. assert isinstance(prev, _Lit) if uri.endswith(prev.text, 0, pos): result[var.name] = "" @@ -1017,8 +846,7 @@ def _scan_suffix( i -= 1 continue - # Earliest valid start: the var cannot extend left past any - # stop-char, so scan backward to find that boundary. + # Earliest valid start: the var cannot extend left past a stop-char. earliest = pos while earliest > 0 and uri[earliest - 1] not in stops: earliest -= 1 @@ -1026,20 +854,17 @@ def _scan_suffix( if prev is None: start = earliest else: - # prev is a _Lit: the parser rejects two adjacent captures, - # so the only possible neighbour kind is a literal. + # The parser rejects adjacent captures, so prev can only be a _Lit. assert isinstance(prev, _Lit) if anchored and i - 1 == 0: - # First atom of the whole template: positionally fixed at - # 0, not rightmost occurrence. rfind would land inside the - # value when the literal repeats there (e.g. "prefix-{id}" - # against "prefix-prefix-123"). + # First atom of the template is positionally fixed at 0; + # rfind would land inside the value when the literal repeats + # ("prefix-{id}" against "prefix-prefix-123"). start = len(prev.text) if start < earliest or start > pos: return None else: - # Rightmost occurrence of the preceding literal whose end - # falls within the var's valid range. + # Rightmost occurrence of the preceding literal ending within the var's range. idx = uri.rfind(prev.text, 0, pos) if idx == -1 or idx + len(prev.text) < earliest: return None @@ -1054,11 +879,10 @@ def _scan_suffix( def _scan_prefix( atoms: Sequence[_Atom], uri: str, start: int, limit: int ) -> tuple[dict[str, str | list[str]], int] | None: - """Scan atoms left-to-right from ``start``, not exceeding ``limit``. + """Scan atoms left-to-right from `start`, not exceeding `limit`. - Each bounded variable takes the minimum span that lets its - following literal match (found via ``find``), leaving the - greedy variable as much of the URI as possible. + Each bounded variable takes the minimum span that lets its following + literal match (find), leaving the greedy variable as much as possible. """ result: dict[str, str | list[str]] = {} pos = start @@ -1072,41 +896,33 @@ def _scan_prefix( var = atom.var stops = _STOP_CHARS[var.operator] - # Every capture here is followed by a literal: the parser rejects - # two adjacent captures, and a capture at the END of the prefix - # would be adjacent to the greedy variable. + # A literal always follows: the parser rejects adjacent captures, and a + # capture ending the prefix would be adjacent to the greedy variable. nxt = atoms[i + 1] assert isinstance(nxt, _Lit) if atom.ifemp: - # RFC Β§3.2.7 ifemp: ;name=val for non-empty, bare ;name for - # empty. Decide which form is present without falling through - # to the stop-char scan when the value is empty. + # RFC 6570 Β§3.2.7 ifemp: bare ;name when empty, ;name=val otherwise. if uri.startswith(nxt.text, pos): - # Following literal begins immediately: value is empty. - # Checked before '=' so a literal that itself starts - # with '=' is not mistaken for the ifemp separator. + # Empty value. Checked before '=' so a literal that itself + # starts with '=' is not mistaken for the ifemp separator. result[var.name] = "" continue if pos < limit and uri[pos] == "=": pos += 1 # value follows; fall through to the scan else: - # The following literal does not start here and there is - # no '=': the URI's name continued past the template's - # (e.g. ;keys vs ;key) β€” no parse. + # No following literal and no '=': the URI's name continued + # past the template's (e.g. ;keys vs ;key) β€” no parse. return None - # Latest valid end: the var stops at the first stop-char or - # the scan limit, whichever comes first. + # Latest valid end: first stop-char or the scan limit. latest = pos while latest < limit and uri[latest] not in stops: latest += 1 - # First occurrence of the following literal: the capture takes - # the minimum span, leaving the greedy variable as much of the - # URI as possible. The search window's upper bound already - # forces any hit to start at or before ``latest``, so the var - # never extends past a stop-char. + # First occurrence of the following literal = minimum span. The search + # window's upper bound forces any hit to start at or before `latest`, + # so the var never extends past a stop-char. end = uri.find(nxt.text, pos, latest + len(nxt.text)) if end == -1: return None diff --git a/tests/cli/test_claude.py b/tests/cli/test_claude.py index d0a74e0d00..ca7c2f157d 100644 --- a/tests/cli/test_claude.py +++ b/tests/cli/test_claude.py @@ -1,5 +1,3 @@ -"""Tests for mcp.cli.claude β€” Claude Desktop config file generation.""" - import importlib.metadata import json from pathlib import Path @@ -21,40 +19,35 @@ def fake_version(distribution_name: str) -> str: @pytest.fixture def config_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: - """Temp Claude config dir with the config path, uv path, and SDK version mocked.""" claude_dir = tmp_path / "Claude" claude_dir.mkdir() monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: claude_dir) monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") - # The ambient version is a dev build in the repo venv but varies by - # environment; pin it so the generated --with requirement is stable. + # Pin the SDK version (a dev build in the repo venv) so the generated --with requirement is stable. _set_mcp_version(monkeypatch, "1.2.3") return claude_dir def test_mcp_requirement_pins_release_versions(monkeypatch: pytest.MonkeyPatch): - """Release versions produce an exact pin so spawned environments run the installed SDK version.""" _set_mcp_version(monkeypatch, "2.0.0a1") assert mcp_requirement() == "mcp==2.0.0a1" assert mcp_requirement("mcp[cli]") == "mcp[cli]==2.0.0a1" def test_mcp_requirement_leaves_dev_versions_unpinned(monkeypatch: pytest.MonkeyPatch): - """Dev versions are not published to PyPI, so the requirement falls back to the unpinned package.""" + """Dev versions are not on PyPI, so no pin is emitted.""" _set_mcp_version(monkeypatch, "2.0.0a2.dev3") assert mcp_requirement() == "mcp" assert mcp_requirement("mcp[cli]") == "mcp[cli]" def test_mcp_requirement_leaves_local_versions_unpinned(monkeypatch: pytest.MonkeyPatch): - """Local version segments (source builds) are not published to PyPI, so no pin is emitted.""" + """Local version segments (source builds) are not on PyPI, so no pin is emitted.""" _set_mcp_version(monkeypatch, "1.2.3+g0123abc") assert mcp_requirement() == "mcp" def test_mcp_requirement_falls_back_when_mcp_is_not_installed(monkeypatch: pytest.MonkeyPatch): - """Without distribution metadata there is no version to pin, so the requirement stays unpinned.""" - def raise_not_found(distribution_name: str) -> str: raise importlib.metadata.PackageNotFoundError(distribution_name) @@ -69,7 +62,6 @@ def _read_server(config_dir: Path, name: str) -> dict[str, Any]: def test_generates_uv_run_command(config_dir: Path): - """Should write a uv run command that invokes mcp run on the resolved file spec.""" assert update_claude_config(file_spec="server.py:app", server_name="my_server") resolved = Path("server.py").resolve() @@ -80,14 +72,12 @@ def test_generates_uv_run_command(config_dir: Path): def test_file_spec_without_object_suffix(config_dir: Path): - """File specs without :object should still resolve to an absolute path.""" assert update_claude_config(file_spec="server.py", server_name="s") assert _read_server(config_dir, "s")["args"][-1] == str(Path("server.py").resolve()) def test_with_packages_sorted_and_deduplicated(config_dir: Path): - """Extra packages should appear as sorted --with flags with duplicates removed.""" assert update_claude_config(file_spec="s.py:app", server_name="s", with_packages=["zebra", "aardvark", "zebra"]) args = _read_server(config_dir, "s")["args"] @@ -95,7 +85,7 @@ def test_with_packages_sorted_and_deduplicated(config_dir: Path): def test_explicit_mcp_cli_kept_alongside_pinned_requirement(config_dir: Path): - """A user-supplied mcp[cli] no longer collapses into the pinned requirement; uv resolves both to the pin.""" + """Both requirements are emitted; uv resolves them to the pinned version.""" assert update_claude_config(file_spec="s.py:app", server_name="s", with_packages=["mcp[cli]"]) args = _read_server(config_dir, "s")["args"] @@ -103,7 +93,6 @@ def test_explicit_mcp_cli_kept_alongside_pinned_requirement(config_dir: Path): def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path): - """with_editable should add --with-editable after the --with flags.""" editable = tmp_path / "project" assert update_claude_config(file_spec="s.py:app", server_name="s", with_editable=editable) @@ -112,14 +101,12 @@ def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path): def test_env_vars_written(config_dir: Path): - """env_vars should be written under the server's env key.""" assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "val"}) assert _read_server(config_dir, "s")["env"] == {"KEY": "val"} def test_existing_env_vars_merged_new_wins(config_dir: Path): - """Re-installing should merge env vars, with new values overriding existing ones.""" (config_dir / "claude_desktop_config.json").write_text( json.dumps({"mcpServers": {"s": {"env": {"OLD": "keep", "KEY": "old"}}}}) ) @@ -130,7 +117,6 @@ def test_existing_env_vars_merged_new_wins(config_dir: Path): def test_existing_env_vars_preserved_without_new(config_dir: Path): - """Re-installing without env_vars should keep the existing env block intact.""" (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"s": {"env": {"KEEP": "me"}}}})) assert update_claude_config(file_spec="s.py:app", server_name="s") @@ -139,7 +125,6 @@ def test_existing_env_vars_preserved_without_new(config_dir: Path): def test_other_servers_preserved(config_dir: Path): - """Installing a new server should not clobber existing mcpServers entries.""" (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"other": {"command": "x"}}})) assert update_claude_config(file_spec="s.py:app", server_name="s") @@ -150,7 +135,6 @@ def test_other_servers_preserved(config_dir: Path): def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch): - """Should raise RuntimeError when Claude Desktop config dir can't be found.""" monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: None) monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") @@ -160,8 +144,6 @@ def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("which_result, expected", [("/usr/local/bin/uv", "/usr/local/bin/uv"), (None, "uv")]) def test_get_uv_path(monkeypatch: pytest.MonkeyPatch, which_result: str | None, expected: str): - """Should return shutil.which's result, or fall back to bare 'uv' when not on PATH.""" - def fake_which(cmd: str) -> str | None: return which_result @@ -179,11 +161,7 @@ def fake_which(cmd: str) -> str | None: def test_windows_drive_letter_not_split( config_dir: Path, monkeypatch: pytest.MonkeyPatch, file_spec: str, expected_last_arg: str ): - """Drive-letter paths like 'C:\\server.py' must not be split on the drive colon. - - Before the fix, a bare 'C:\\path\\server.py' would hit rsplit(":", 1) and yield - ("C", "\\path\\server.py"), calling resolve() on Path("C") instead of the full path. - """ + """Regression: 'C:\\Users\\server.py' once hit rsplit(":", 1), resolving Path("C") instead of the full path.""" seen: list[str] = [] def fake_resolve(self: Path) -> Path: diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index d217d82fc7..dd7d224c54 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -26,7 +26,6 @@ def fake_version(distribution_name: str) -> str: ], ) def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected_obj: str | None): - """Should accept valid file specs.""" file = tmp_path / spec.split(":")[0] file.write_text("x = 1") path, obj = _parse_file_path(f"{file}:{expected_obj}" if ":" in spec else str(file)) @@ -35,13 +34,11 @@ def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected def test_parse_file_path_missing(tmp_path: Path): - """Should system exit if a file is missing.""" with pytest.raises(SystemExit): _parse_file_path(str(tmp_path / "missing.py")) def test_parse_file_exit_on_dir(tmp_path: Path): - """Should system exit if a directory is passed""" dir_path = tmp_path / "dir" dir_path.mkdir() with pytest.raises(SystemExit): @@ -49,7 +46,6 @@ def test_parse_file_exit_on_dir(tmp_path: Path): def test_build_uv_command_pins_the_running_mcp_version(monkeypatch: pytest.MonkeyPatch): - """The spawned environment installs the same SDK version that is running, not the latest stable.""" _set_mcp_version(monkeypatch, "1.2.3") cmd = _build_uv_command("foo.py") assert cmd == ["uv", "run", "--with", "mcp==1.2.3", "mcp", "run", "foo.py"] @@ -63,7 +59,6 @@ def test_build_uv_command_leaves_source_builds_unpinned(monkeypatch: pytest.Monk def test_build_uv_command_adds_editable_and_packages(monkeypatch: pytest.MonkeyPatch): - """Should include --with-editable and every --with pkg in correct order.""" _set_mcp_version(monkeypatch, "1.2.3") test_path = Path("/pkg") cmd = _build_uv_command( @@ -77,7 +72,7 @@ def test_build_uv_command_adds_editable_and_packages(monkeypatch: pytest.MonkeyP "--with", "mcp==1.2.3", "--with-editable", - str(test_path), # Use str() to match what the function does + str(test_path), "--with", "package1", "--with", @@ -89,13 +84,11 @@ def test_build_uv_command_adds_editable_and_packages(monkeypatch: pytest.MonkeyP def test_get_npx_unix_like(monkeypatch: pytest.MonkeyPatch): - """Should return "npx" on unix-like systems.""" monkeypatch.setattr(sys, "platform", "linux") assert _get_npx_command() == "npx" def test_get_npx_windows(monkeypatch: pytest.MonkeyPatch): - """Should return one of the npx candidates on Windows.""" candidates = ["npx.cmd", "npx.exe", "npx"] def fake_run(cmd: list[str], **kw: Any) -> subprocess.CompletedProcess[bytes]: @@ -110,7 +103,6 @@ def fake_run(cmd: list[str], **kw: Any) -> subprocess.CompletedProcess[bytes]: def test_get_npx_returns_none_when_npx_missing(monkeypatch: pytest.MonkeyPatch): - """Should give None if every candidate fails.""" monkeypatch.setattr(sys, "platform", "win32", raising=False) def always_fail(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess[bytes]: diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index 3ad649d1f2..869e818efa 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -24,8 +24,6 @@ class MockTokenStorage: - """Mock token storage for testing.""" - def __init__(self): self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None @@ -61,11 +59,9 @@ def client_metadata(): @pytest.fixture def rfc7523_oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: # pragma: no cover - """Mock redirect handler.""" pass async def callback_handler() -> AuthorizationCodeResult: # pragma: no cover - """Mock callback handler.""" return AuthorizationCodeResult(code="test_auth_code", state="test_state") with warnings.catch_warnings(): @@ -80,12 +76,8 @@ async def callback_handler() -> AuthorizationCodeResult: # pragma: no cover class TestOAuthFlowClientCredentials: - """Test OAuth flow behavior for client credentials flows.""" - @pytest.mark.anyio async def test_token_exchange_request_jwt_predefined(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider): - """Test token exchange request building with a predefined JWT assertion.""" - # Set up required context rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], token_endpoint_auth_method="private_key_jwt", @@ -111,7 +103,6 @@ async def test_token_exchange_request_jwt_predefined(self, rfc7523_oauth_provide assert str(request.url) == "https://api.example.com/token" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" - # Check form data content = urllib.parse.unquote_plus(request.content.decode()) assert "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer" in content assert "scope=read write" in content @@ -123,8 +114,6 @@ async def test_token_exchange_request_jwt_predefined(self, rfc7523_oauth_provide @pytest.mark.anyio async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider): - """Test token exchange request building wiith a generated JWT assertion.""" - # Set up required context rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], token_endpoint_auth_method="private_key_jwt", @@ -158,13 +147,11 @@ async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523O assert str(request.url) == "https://api.example.com/token" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" - # Check form data content = urllib.parse.unquote_plus(request.content.decode()).split("&") assert "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer" in content assert "scope=read write" in content assert "resource=https://api.example.com/v1/mcp" in content - # Check assertion assertion = next(param for param in content if param.startswith("assertion="))[len("assertion=") :] claims = jwt.decode( assertion, @@ -181,11 +168,8 @@ async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523O class TestClientCredentialsOAuthProvider: - """Test ClientCredentialsOAuthProvider.""" - @pytest.mark.anyio async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): - """Test that _initialize sets client_info.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", storage=mock_storage, @@ -193,7 +177,6 @@ async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): client_secret="test-client-secret", ) - # client_info is set during _initialize await provider._initialize() assert provider.context.client_info is not None @@ -204,7 +187,6 @@ async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): @pytest.mark.anyio async def test_init_with_scopes(self, mock_storage: MockTokenStorage): - """Test that constructor accepts scopes.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", storage=mock_storage, @@ -219,7 +201,6 @@ async def test_init_with_scopes(self, mock_storage: MockTokenStorage): @pytest.mark.anyio async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage): - """Test that constructor accepts client_secret_post auth method.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", storage=mock_storage, @@ -234,7 +215,6 @@ async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage @pytest.mark.anyio async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): - """Test token exchange request building.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", storage=mock_storage, @@ -284,12 +264,10 @@ async def test_exchange_token_client_secret_post_includes_client_id(self, mock_s assert "grant_type=client_credentials" in content assert "client_id=test-client-id" in content assert "client_secret=test-client-secret" in content - # Should NOT have Basic auth header assert "Authorization" not in request.headers @pytest.mark.anyio async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage): - """Test client_secret_post skips body credentials when client_id is None.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", storage=mock_storage, @@ -305,7 +283,7 @@ async def test_exchange_token_client_secret_post_without_client_id(self, mock_st token_endpoint=AnyHttpUrl("https://api.example.com/token"), ) provider.context.protocol_version = "2025-06-18" - # Override client_info to have client_id=None (edge case) + # Replace the client_info set by _initialize to hit the client_id=None edge case provider.context.client_info = OAuthClientInformationFull( redirect_uris=None, client_id=None, @@ -319,15 +297,13 @@ async def test_exchange_token_client_secret_post_without_client_id(self, mock_st content = urllib.parse.unquote_plus(request.content.decode()) assert "grant_type=client_credentials" in content - # Neither client_id nor client_secret should be in body since client_id is None - # (RFC 6749 Β§2.3.1 requires both for client_secret_post) + # RFC 6749 Β§2.3.1 requires both credentials for client_secret_post, so neither is sent assert "client_id=" not in content assert "client_secret=" not in content assert "Authorization" not in request.headers @pytest.mark.anyio async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): - """Test token exchange without scopes.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com/v1/mcp", storage=mock_storage, @@ -350,12 +326,8 @@ async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorag class TestPrivateKeyJWTOAuthProvider: - """Test PrivateKeyJWTOAuthProvider.""" - @pytest.mark.anyio async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): - """Test that _initialize sets client_info.""" - async def mock_assertion_provider(audience: str) -> str: # pragma: no cover return "mock-jwt" @@ -366,7 +338,6 @@ async def mock_assertion_provider(audience: str) -> str: # pragma: no cover assertion_provider=mock_assertion_provider, ) - # client_info is set during _initialize await provider._initialize() assert provider.context.client_info is not None @@ -376,8 +347,6 @@ async def mock_assertion_provider(audience: str) -> str: # pragma: no cover @pytest.mark.anyio async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): - """Test token exchange request building with assertion provider.""" - async def mock_assertion_provider(audience: str) -> str: return f"jwt-for-{audience}" @@ -408,8 +377,6 @@ async def mock_assertion_provider(audience: str) -> str: @pytest.mark.anyio async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): - """Test token exchange without scopes.""" - async def mock_assertion_provider(audience: str) -> str: return f"jwt-for-{audience}" @@ -435,11 +402,8 @@ async def mock_assertion_provider(audience: str) -> str: class TestSignedJWTParameters: - """Test SignedJWTParameters.""" - @pytest.mark.anyio async def test_create_assertion_provider(self): - """Test that create_assertion_provider creates valid JWTs.""" params = SignedJWTParameters( issuer="test-issuer", subject="test-subject", @@ -466,7 +430,6 @@ async def test_create_assertion_provider(self): @pytest.mark.anyio async def test_create_assertion_provider_with_additional_claims(self): - """Test that additional_claims are included in the JWT.""" params = SignedJWTParameters( issuer="test-issuer", subject="test-subject", @@ -488,11 +451,8 @@ async def test_create_assertion_provider_with_additional_claims(self): class TestStaticAssertionProvider: - """Test static_assertion_provider helper.""" - @pytest.mark.anyio async def test_returns_static_token(self): - """Test that static_assertion_provider returns the same token regardless of audience.""" token = "my-static-jwt-token" provider = static_assertion_provider(token) diff --git a/tests/client/auth/extensions/test_identity_assertion.py b/tests/client/auth/extensions/test_identity_assertion.py index 1bc63a1173..8d3efee051 100644 --- a/tests/client/auth/extensions/test_identity_assertion.py +++ b/tests/client/auth/extensions/test_identity_assertion.py @@ -1,8 +1,7 @@ -"""Unit tests for the standalone SEP-990 jwt-bearer `httpx.Auth`. +"""Tests for the standalone SEP-990 jwt-bearer `httpx.Auth`. -The provider's authorization server is configuration; these tests assert that authorization-server -metadata is fetched only from the configured issuer, that the resource server is never consulted for -AS selection, and that the ID-JAG and client secret reach only the issuer's token endpoint. +The AS is configuration: metadata is fetched only from the configured issuer, the resource server is +never consulted for AS selection, and the ID-JAG and client secret reach only the issuer's token endpoint. """ import base64 @@ -88,11 +87,9 @@ def mock_transport( rs_first_status: int = 401, rs_first_headers: dict[str, str] | None = None, ) -> httpx.MockTransport: - """Build a `MockTransport` that records every request and serves the configured ASM and token. + """Record every request; `asm`/`token` are a body (served as 200 JSON) or an int status (no body). - `asm` / `token` are either a body (served as 200 JSON) or an int status (served with no body). - The MCP resource server's first response is `rs_first_status` (default 401) with optional - headers; subsequent RS requests return 200. + The RS's first response is `rs_first_status` (default 401) with optional headers; later RS requests get 200. """ rs_hits = 0 @@ -124,7 +121,6 @@ def form(request: httpx.Request) -> dict[str, str]: @pytest.mark.anyio async def test_on_401_exchanges_assertion_at_configured_issuer_and_retries() -> None: - """A 401 fetches ASM from the configured issuer, posts the jwt-bearer grant, and retries.""" requests: list[httpx.Request] = [] record: list[tuple[str, str]] = [] storage = InMemoryStorage() @@ -160,11 +156,6 @@ async def test_on_401_exchanges_assertion_at_configured_issuer_and_retries() -> @pytest.mark.anyio async def test_resource_server_metadata_is_never_consulted() -> None: - """No PRM well-known and no RS-origin ASM well-known is ever fetched. - - This is the by-construction property: the AS is configuration, so the resource server has no - input into where the ID-JAG or client secret go. Any GET to the RS host fails the test. - """ requests: list[httpx.Request] = [] auth = make_provider() @@ -182,7 +173,6 @@ async def test_resource_server_metadata_is_never_consulted() -> None: @pytest.mark.anyio async def test_asm_404_at_configured_issuer_raises_before_minting_assertion() -> None: - """If the issuer's well-knowns 404, the flow fails closed and the assertion is never minted.""" requests: list[httpx.Request] = [] record: list[tuple[str, str]] = [] auth = make_provider(record=record) @@ -199,7 +189,6 @@ async def test_asm_404_at_configured_issuer_raises_before_minting_assertion() -> @pytest.mark.anyio async def test_asm_5xx_stops_discovery_and_raises() -> None: - """A 5xx at the issuer's well-known stops discovery without trying further URLs.""" requests: list[httpx.Request] = [] auth = make_provider() @@ -229,7 +218,6 @@ async def test_asm_with_wrong_issuer_is_rejected_before_minting_assertion() -> N @pytest.mark.anyio async def test_asm_with_off_origin_token_endpoint_is_rejected_before_minting_assertion() -> None: - """A `token_endpoint` off the configured issuer's origin is refused before any credential is sent.""" requests: list[httpx.Request] = [] record: list[tuple[str, str]] = [] auth = make_provider(record=record) @@ -246,7 +234,6 @@ async def test_asm_with_off_origin_token_endpoint_is_rejected_before_minting_ass @pytest.mark.anyio async def test_403_insufficient_scope_unions_challenged_scope_with_configured() -> None: - """A 403 `insufficient_scope` re-exchanges with the union of configured and challenged scopes.""" requests: list[httpx.Request] = [] auth = make_provider(scope="mcp") @@ -267,7 +254,6 @@ async def test_403_insufficient_scope_unions_challenged_scope_with_configured() @pytest.mark.anyio async def test_403_without_insufficient_scope_does_not_reauthorize() -> None: - """A plain 403 (not `insufficient_scope`) is returned to the caller without re-exchanging.""" requests: list[httpx.Request] = [] record: list[tuple[str, str]] = [] auth = make_provider(record=record) @@ -309,7 +295,6 @@ async def test_client_secret_basic_sends_basic_header_not_body_secret() -> None: @pytest.mark.anyio async def test_stored_token_is_reused_without_reauthorizing() -> None: - """A valid stored token is sent on the first request; on success no ASM or /token is fetched.""" requests: list[httpx.Request] = [] storage = InMemoryStorage(tokens=OAuthToken(access_token="cached", token_type="Bearer", expires_in=3600)) auth = make_provider(storage) @@ -325,7 +310,6 @@ async def test_stored_token_is_reused_without_reauthorizing() -> None: @pytest.mark.anyio async def test_second_401_re_exchanges_without_refetching_asm() -> None: - """ASM is discovered once; a later 401 mints a fresh assertion against the cached token endpoint.""" requests: list[httpx.Request] = [] record: list[tuple[str, str]] = [] auth = make_provider(record=record) @@ -337,7 +321,6 @@ def handle(request: httpx.Request) -> httpx.Response: host, path = request.url.host, request.url.path if host == "mcp.example.com": rs_hits += 1 - # First and third RS hits draw a 401; second and fourth succeed. return httpx.Response(401 if rs_hits in (1, 3) else 200) if host == "auth.example.com" and path == ASM_PATH: return httpx.Response(200, content=asm_body(), headers={"content-type": "application/json"}) @@ -404,7 +387,6 @@ async def assertion_provider(audience: str, resource: str) -> str: def test_origin_normalizes_default_ports() -> None: - """`_origin` treats an explicit scheme-default port as equal to the port-less form.""" assert _origin("https://host") == _origin("https://host:443") assert _origin("http://host") == _origin("http://host:80") assert _origin("https://host") != _origin("https://host:8443") diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 1fbee13c18..5bae910648 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -36,12 +36,10 @@ def __init__(self, client_spy: SpyMemoryObjectSendStream, server_spy: SpyMemoryO self.server = server_spy def clear(self) -> None: - """Clear all captured messages.""" self.client.sent_messages.clear() self.server.sent_messages.clear() def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]: - """Get client-sent requests, optionally filtered by method.""" return [ req.message for req in self.client.sent_messages @@ -49,7 +47,6 @@ def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest] ] def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: # pragma: no cover - """Get server-sent requests, optionally filtered by method.""" return [ # pragma: no cover req.message for req in self.server.sent_messages @@ -57,7 +54,6 @@ def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest] ] def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: # pragma: no cover - """Get client-sent notifications, optionally filtered by method.""" return [ notif.message for notif in self.client.sent_messages @@ -65,7 +61,6 @@ def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNot ] def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: # pragma: no cover - """Get server-sent notifications, optionally filtered by method.""" return [ notif.message for notif in self.server.sent_messages @@ -75,36 +70,19 @@ def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNot @pytest.fixture def stream_spy() -> Generator[Callable[[], StreamSpyCollection], None, None]: - """Fixture that provides spies for both client and server write streams. + """Patch memory stream creation so tests can inspect client- and server-sent messages. - Example: - ```python - async def test_something(stream_spy): - # ... set up server and client ... - - spies = stream_spy() - - # Run some operation that sends messages - await client.some_operation() - - # Check the messages - requests = spies.get_client_requests(method="some/method") - assert len(requests) == 1 - - # Clear for the next operation - spies.clear() - ``` + Call the yielded factory after the streams exist (i.e. once client/server are set up) + to get a `StreamSpyCollection`. """ client_spy = None server_spy = None - # Store references to our spy objects def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendStream): nonlocal client_spy, server_spy client_spy = c_spy server_spy = s_spy - # Create patched version of stream creation original_create_streams = mcp.shared.memory.create_client_server_memory_streams @asynccontextmanager @@ -113,20 +91,17 @@ async def patched_create_streams(): client_read, client_write = client_streams server_read, server_write = server_streams - # Create spy wrappers spy_client_write = SpyMemoryObjectSendStream(client_write) spy_server_write = SpyMemoryObjectSendStream(server_write) - # Capture references for the test to use capture_spies(spy_client_write, spy_server_write) yield (client_read, spy_client_write), (server_read, spy_server_write) - # Apply the patch for the duration of the test # Patch both locations since InMemoryTransport imports it directly with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams): with patch("mcp.client._memory.create_client_server_memory_streams", patched_create_streams): - # Return a collection with helper methods + def get_spy_collection() -> StreamSpyCollection: assert client_spy is not None, "client_spy was not initialized" assert server_spy is not None, "server_spy was not initialized" diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 1ec38ccf6f..42823fd398 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,4 +1,4 @@ -"""Tests for refactored OAuth client authentication implementation.""" +"""Tests for OAuth client authentication.""" import base64 import json @@ -44,8 +44,6 @@ class MockTokenStorage: - """Mock token storage for testing.""" - def __init__(self): self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None @@ -92,11 +90,9 @@ def valid_tokens(): @pytest.fixture def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: - """Mock redirect handler.""" pass # pragma: no cover async def callback_handler() -> AuthorizationCodeResult: - """Mock callback handler.""" return AuthorizationCodeResult(code="test_auth_code", state="test_state") # pragma: no cover return OAuthClientProvider( @@ -110,7 +106,6 @@ async def callback_handler() -> AuthorizationCodeResult: @pytest.fixture def prm_metadata_response(): - """PRM metadata response with scopes.""" return httpx.Response( 200, content=( @@ -123,7 +118,6 @@ def prm_metadata_response(): @pytest.fixture def prm_metadata_without_scopes_response(): - """PRM metadata response without scopes.""" return httpx.Response( 200, content=( @@ -136,7 +130,6 @@ def prm_metadata_without_scopes_response(): @pytest.fixture def init_response_with_www_auth_scope(): - """Initial 401 response with WWW-Authenticate header containing scope.""" return httpx.Response( 401, headers={"WWW-Authenticate": 'Bearer scope="special:scope from:www-authenticate"'}, @@ -146,7 +139,6 @@ def init_response_with_www_auth_scope(): @pytest.fixture def init_response_without_www_auth_scope(): - """Initial 401 response without WWW-Authenticate scope.""" return httpx.Response( 401, headers={}, @@ -155,25 +147,19 @@ def init_response_without_www_auth_scope(): class TestPKCEParameters: - """Test PKCE parameter generation.""" - def test_pkce_generation(self): - """Test PKCE parameter generation creates valid values.""" pkce = PKCEParameters.generate() - # Verify lengths assert len(pkce.code_verifier) == 128 assert 43 <= len(pkce.code_challenge) <= 128 - # Verify characters used in verifier allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") assert all(c in allowed_chars for c in pkce.code_verifier) - # Verify base64url encoding in challenge (no padding) + # base64url challenge must be unpadded assert "=" not in pkce.code_challenge def test_pkce_uniqueness(self): - """Test PKCE generates unique values each time.""" pkce1 = PKCEParameters.generate() pkce2 = PKCEParameters.generate() @@ -182,13 +168,10 @@ def test_pkce_uniqueness(self): class TestOAuthContext: - """Test OAuth context functionality.""" - @pytest.mark.anyio async def test_oauth_provider_initialization( self, oauth_provider: OAuthClientProvider, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test OAuthClientProvider basic setup.""" assert oauth_provider.context.server_url == "https://api.example.com/v1/mcp" assert oauth_provider.context.client_metadata == client_metadata assert oauth_provider.context.storage == mock_storage @@ -199,19 +182,12 @@ def test_context_url_parsing(self, oauth_provider: OAuthClientProvider): """Test get_authorization_base_url() extracts base URLs correctly.""" context = oauth_provider.context - # Test with path assert context.get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com" - - # Test with no path assert context.get_authorization_base_url("https://api.example.com") == "https://api.example.com" - - # Test with port assert ( context.get_authorization_base_url("https://api.example.com:8080/path/to/mcp") == "https://api.example.com:8080" ) - - # Test with query params assert ( context.get_authorization_base_url("https://api.example.com/path?param=value") == "https://api.example.com" ) @@ -221,60 +197,47 @@ async def test_token_validity_checking(self, oauth_provider: OAuthClientProvider """Test is_token_valid() and can_refresh_token() logic.""" context = oauth_provider.context - # No tokens - should be invalid assert not context.is_token_valid() assert not context.can_refresh_token() - # Set valid tokens and client info context.current_tokens = valid_tokens - context.token_expiry_time = time.time() + 1800 # 30 minutes from now + context.token_expiry_time = time.time() + 1800 context.client_info = OAuthClientInformationFull( client_id="test_client_id", client_secret="test_client_secret", redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - # Should be valid assert context.is_token_valid() - assert context.can_refresh_token() # Has refresh token and client info + assert context.can_refresh_token() - # Expire the token - context.token_expiry_time = time.time() - 100 # Expired 100 seconds ago + context.token_expiry_time = time.time() - 100 assert not context.is_token_valid() - assert context.can_refresh_token() # Can still refresh + assert context.can_refresh_token() # Expired tokens can still be refreshed - # Remove refresh token context.current_tokens.refresh_token = None assert not context.can_refresh_token() - # Remove client info context.current_tokens.refresh_token = "test_refresh_token" context.client_info = None assert not context.can_refresh_token() def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): - """Test clear_tokens() removes token data.""" context = oauth_provider.context context.current_tokens = valid_tokens context.token_expiry_time = time.time() + 1800 - # Clear tokens context.clear_tokens() - # Verify cleared assert context.current_tokens is None assert context.token_expiry_time is None class TestOAuthFlow: - """Test OAuth flow methods.""" - @pytest.mark.anyio async def test_build_protected_resource_discovery_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test protected resource metadata discovery URL building with fallback.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -289,7 +252,6 @@ async def callback_handler() -> AuthorizationCodeResult: callback_handler=callback_handler, ) - # Test without WWW-Authenticate (fallback) init_response = httpx.Response( status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com") ) @@ -300,7 +262,6 @@ async def callback_handler() -> AuthorizationCodeResult: assert len(urls) == 1 assert urls[0] == "https://api.example.com/.well-known/oauth-protected-resource" - # Test with WWW-Authenticate header init_response.headers["WWW-Authenticate"] = ( 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' ) @@ -314,10 +275,8 @@ async def callback_handler() -> AuthorizationCodeResult: @pytest.mark.anyio def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider): - """Test OAuth metadata discovery request building.""" request = create_oauth_metadata_request("https://example.com") - # Ensure correct method and headers, and that the URL is unmodified assert request.method == "GET" assert str(request.url) == "https://example.com" assert "mcp-protocol-version" in request.headers @@ -328,23 +287,19 @@ class TestOAuthFallback: @pytest.mark.anyio async def test_oauth_discovery_legacy_fallback_when_no_prm(self): - """Test that when PRM discovery fails, only root OAuth URL is tried (March 2025 spec).""" - # When auth_server_url is None (PRM failed), we use server_url and only try root + """When PRM discovery fails, only the root OAuth URL is tried (March 2025 spec).""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://mcp.linear.app/sse") - # Should only try the root URL (legacy behavior) assert discovery_urls == [ "https://mcp.linear.app/.well-known/oauth-authorization-server", ] @pytest.mark.anyio async def test_oauth_discovery_path_aware_when_auth_server_has_path(self): - """Test that when auth server URL has a path, only path-based URLs are tried.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com/tenant1", "https://api.example.com/mcp" ) - # Should try path-based URLs only (no root URLs) assert discovery_urls == [ "https://auth.example.com/.well-known/oauth-authorization-server/tenant1", "https://auth.example.com/.well-known/openid-configuration/tenant1", @@ -353,12 +308,10 @@ async def test_oauth_discovery_path_aware_when_auth_server_has_path(self): @pytest.mark.anyio async def test_oauth_discovery_root_when_auth_server_has_no_path(self): - """Test that when auth server URL has no path, only root URLs are tried.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com", "https://api.example.com/mcp" ) - # Should try root URLs only assert discovery_urls == [ "https://auth.example.com/.well-known/oauth-authorization-server", "https://auth.example.com/.well-known/openid-configuration", @@ -366,12 +319,10 @@ async def test_oauth_discovery_root_when_auth_server_has_no_path(self): @pytest.mark.anyio async def test_oauth_discovery_root_when_auth_server_has_only_slash(self): - """Test that when auth server URL has only trailing slash, treated as root.""" discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( "https://auth.example.com/", "https://api.example.com/mcp" ) - # Should try root URLs only assert discovery_urls == [ "https://auth.example.com/.well-known/oauth-authorization-server", "https://auth.example.com/.well-known/openid-configuration", @@ -379,7 +330,6 @@ async def test_oauth_discovery_root_when_auth_server_has_only_slash(self): @pytest.mark.anyio async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider): - """Test fallback URL construction order when auth server URL has a path.""" # Simulate PRM discovery returning an auth server URL with a path oauth_provider.context.auth_server_url = oauth_provider.context.server_url @@ -395,8 +345,6 @@ async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientP @pytest.mark.anyio async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthClientProvider): - """Test the conditions during which an AS metadata discovery fallback will be attempted.""" - # Ensure no tokens are stored oauth_provider.context.current_tokens = None oauth_provider.context.token_expiry_time = None oauth_provider._initialized = True @@ -407,17 +355,13 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - # Create a test request test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") - # Mock the auth flow auth_flow = oauth_provider.async_auth_flow(test_request) - # First request should be the original request without auth header request = await auth_flow.__anext__() assert "Authorization" not in request.headers - # Send a 401 response to trigger the OAuth flow response = httpx.Response( 401, headers={ @@ -426,20 +370,17 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl request=test_request, ) - # Next request should be to discover protected resource metadata discovery_request = await auth_flow.asend(response) assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" assert discovery_request.method == "GET" - # Send a successful discovery response with minimal protected resource metadata - # Note: auth server URL has a path (/v1/mcp), so only path-based URLs will be tried + # The auth server URL has a path (/v1/mcp), so only path-based discovery URLs are tried discovery_response = httpx.Response( 200, content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/v1/mcp"]}', request=discovery_request, ) - # Next request should be to discover OAuth metadata at path-aware OAuth URL oauth_metadata_request_1 = await auth_flow.asend(discovery_response) assert ( str(oauth_metadata_request_1.url) @@ -447,49 +388,41 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl ) assert oauth_metadata_request_1.method == "GET" - # Send a 404 response oauth_metadata_response_1 = httpx.Response( 404, content=b"Not Found", request=oauth_metadata_request_1, ) - # Next request should be path-aware OIDC URL (not root URL since auth server has path) oauth_metadata_request_2 = await auth_flow.asend(oauth_metadata_response_1) assert str(oauth_metadata_request_2.url) == "https://auth.example.com/.well-known/openid-configuration/v1/mcp" assert oauth_metadata_request_2.method == "GET" - # Send a 400 response oauth_metadata_response_2 = httpx.Response( 400, content=b"Bad Request", request=oauth_metadata_request_2, ) - # Next request should be OIDC path-appended URL oauth_metadata_request_3 = await auth_flow.asend(oauth_metadata_response_2) assert str(oauth_metadata_request_3.url) == "https://auth.example.com/v1/mcp/.well-known/openid-configuration" assert oauth_metadata_request_3.method == "GET" - # Send a 500 response oauth_metadata_response_3 = httpx.Response( 500, content=b"Internal Server Error", request=oauth_metadata_request_3, ) - # Mock the authorization process to minimize unnecessary state in this test oauth_provider._perform_authorization_code_grant = mock.AsyncMock( return_value=("test_auth_code", "test_code_verifier") ) - # All path-based URLs failed, flow continues with default endpoints - # Next request should be token exchange using MCP server base URL (fallback when OAuth metadata not found) + # All discovery URLs failed: the token endpoint falls back to the MCP server base URL token_request = await auth_flow.asend(oauth_metadata_response_3) assert str(token_request.url) == "https://api.example.com/token" assert token_request.method == "POST" - # Send a successful token response token_response = httpx.Response( 200, content=( @@ -499,23 +432,19 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl request=token_request, ) - # After OAuth flow completes, the original request is retried with auth header final_request = await auth_flow.asend(token_response) assert final_request.headers["Authorization"] == "Bearer new_access_token" assert final_request.method == "GET" assert str(final_request.url) == "https://api.example.com/v1/mcp" - # Send final success response to properly close the generator final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) except StopAsyncIteration: - pass # Expected - generator should complete + pass @pytest.mark.anyio async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider): - """Test successful metadata response handling.""" - # Create minimal valid OAuth metadata content = b"""{ "issuer": "https://auth.example.com", "authorization_endpoint": "https://auth.example.com/authorize", @@ -523,7 +452,7 @@ async def test_handle_metadata_response_success(self, oauth_provider: OAuthClien }""" response = httpx.Response(200, content=content) - # Should set metadata; the empty path is preserved (no trailing slash added) + # The issuer's empty path is preserved (no trailing slash added) await oauth_provider._handle_oauth_metadata_response(response) assert oauth_provider.context.oauth_metadata is not None assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com" @@ -535,17 +464,13 @@ async def test_prioritize_www_auth_scope_over_prm( prm_metadata_response: httpx.Response, init_response_with_www_auth_scope: httpx.Response, ): - """Test that WWW-Authenticate scope is prioritized over PRM scopes.""" - # First, process PRM metadata to set protected_resource_metadata with scopes await oauth_provider._handle_protected_resource_response(prm_metadata_response) - # Process the scope selection with WWW-Authenticate header scopes = get_client_metadata_scopes( extract_scope_from_www_auth(init_response_with_www_auth_scope), oauth_provider.context.protected_resource_metadata, ) - # Verify that WWW-Authenticate scope is used (not PRM scopes) assert scopes == "special:scope from:www-authenticate" @pytest.mark.anyio @@ -555,17 +480,13 @@ async def test_prioritize_prm_scopes_when_no_www_auth_scope( prm_metadata_response: httpx.Response, init_response_without_www_auth_scope: httpx.Response, ): - """Test that PRM scopes are prioritized when WWW-Authenticate header has no scopes.""" - # Process the PRM metadata to set protected_resource_metadata with scopes await oauth_provider._handle_protected_resource_response(prm_metadata_response) - # Process the scope selection without WWW-Authenticate scope scopes = get_client_metadata_scopes( extract_scope_from_www_auth(init_response_without_www_auth_scope), oauth_provider.context.protected_resource_metadata, ) - # Verify that PRM scopes are used assert scopes == "resource:read resource:write" @pytest.mark.anyio @@ -575,22 +496,16 @@ async def test_omit_scope_when_no_prm_scopes_or_www_auth( prm_metadata_without_scopes_response: httpx.Response, init_response_without_www_auth_scope: httpx.Response, ): - """Test that scope is omitted when PRM has no scopes and WWW-Authenticate doesn't specify scope.""" - # Process the PRM metadata without scopes await oauth_provider._handle_protected_resource_response(prm_metadata_without_scopes_response) - # Process the scope selection without WWW-Authenticate scope scopes = get_client_metadata_scopes( extract_scope_from_www_auth(init_response_without_www_auth_scope), oauth_provider.context.protected_resource_metadata, ) - # Verify that scope is omitted assert scopes is None @pytest.mark.anyio async def test_token_exchange_request_authorization_code(self, oauth_provider: OAuthClientProvider): - """Test token exchange request building.""" - # Set up required context oauth_provider.context.client_info = OAuthClientInformationFull( client_id="test_client", client_secret="test_secret", @@ -604,7 +519,6 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O assert str(request.url) == "https://api.example.com/token" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" - # Check form data content = request.content.decode() assert "grant_type=authorization_code" in content assert "code=test_auth_code" in content @@ -614,8 +528,6 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O @pytest.mark.anyio async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): - """Test refresh token request building.""" - # Set up required context oauth_provider.context.current_tokens = valid_tokens oauth_provider.context.client_info = OAuthClientInformationFull( client_id="test_client", @@ -630,7 +542,6 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, assert str(request.url) == "https://api.example.com/token" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" - # Check form data content = request.content.decode() assert "grant_type=refresh_token" in content assert "refresh_token=test_refresh_token" in content @@ -639,8 +550,6 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, @pytest.mark.anyio async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider): - """Test token exchange with client_secret_basic authentication.""" - # Set up OAuth metadata to support basic auth oauth_provider.context.oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), @@ -660,20 +569,16 @@ async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvid request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier") - # Should use basic auth (registered method) assert "Authorization" in request.headers assert request.headers["Authorization"].startswith("Basic ") - # Decode and verify credentials are properly URL-encoded encoded_creds = request.headers["Authorization"][6:] # Remove "Basic " prefix decoded = base64.b64decode(encoded_creds).decode() client_id, client_secret = decoded.split(":", 1) - # Check URL encoding was applied - assert client_id == "test%40client" # @ should be encoded as %40 - assert client_secret == "test%3Asecret" # : should be encoded as %3A + assert client_id == "test%40client" + assert client_secret == "test%3Asecret" - # Verify decoded values match original assert unquote(client_id) == client_id_raw assert unquote(client_secret) == client_secret_raw @@ -684,10 +589,8 @@ async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvid @pytest.mark.anyio async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): - """Test token refresh with client_secret_basic authentication.""" oauth_provider.context.current_tokens = valid_tokens - # Set up OAuth metadata to only support basic auth oauth_provider.context.oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), @@ -730,29 +633,24 @@ async def test_none_auth_method(self, oauth_provider: OAuthClientProvider): client_id = "public_client" oauth_provider.context.client_info = OAuthClientInformationFull( client_id=client_id, - client_secret=None, # No secret for public client + client_secret=None, redirect_uris=[AnyUrl("http://localhost:3030/callback")], token_endpoint_auth_method="none", ) request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier") - # Should NOT have Authorization header assert "Authorization" not in request.headers - # Should NOT have client_secret in body content = request.content.decode() assert "client_secret=" not in content assert "client_id=public_client" in content class TestProtectedResourceMetadata: - """Test protected resource handling.""" - @pytest.mark.anyio async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider): """Test resource parameter is included for protocol version >= 2025-06-18.""" - # Set protocol version to 2025-06-18 oauth_provider.context.protocol_version = "2025-06-18" oauth_provider.context.client_info = OAuthClientInformationFull( client_id="test_client", @@ -760,15 +658,12 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - # Test in token exchange request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" in content - # Check URL-encoded resource parameter expected_resource = quote(oauth_provider.context.get_resource_url(), safe="") assert f"resource={expected_resource}" in content - # Test in refresh token oauth_provider.context.current_tokens = OAuthToken( access_token="test_access", token_type="Bearer", @@ -781,7 +676,6 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ @pytest.mark.anyio async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider): """Test resource parameter is excluded for protocol version < 2025-06-18.""" - # Set protocol version to older version oauth_provider.context.protocol_version = "2025-03-26" oauth_provider.context.client_info = OAuthClientInformationFull( client_id="test_client", @@ -789,12 +683,10 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - # Test in token exchange request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" not in content - # Test in refresh token oauth_provider.context.current_tokens = OAuthToken( access_token="test_access", token_type="Bearer", @@ -806,8 +698,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro @pytest.mark.anyio async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider): - """Test resource parameter is always included when protected resource metadata exists.""" - # Set old protocol version but with protected resource metadata + """PRM presence forces the resource param even on protocol versions < 2025-06-18.""" oauth_provider.context.protocol_version = "2025-03-26" oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata( resource=AnyHttpUrl("https://api.example.com/v1/mcp"), @@ -819,7 +710,6 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - # Test in token exchange request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" in content @@ -847,7 +737,6 @@ def test_should_include_resource_param_by_protocol_version( async def test_validate_resource_rejects_mismatched_resource( client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ) -> None: - """Client must reject PRM resource that doesn't match server URL.""" provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", client_metadata=client_metadata, @@ -867,7 +756,6 @@ async def test_validate_resource_rejects_mismatched_resource( async def test_validate_resource_accepts_matching_resource( client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ) -> None: - """Client must accept PRM resource that matches server URL.""" provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", client_metadata=client_metadata, @@ -887,7 +775,6 @@ async def test_validate_resource_accepts_matching_resource( async def test_validate_resource_custom_callback( client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ) -> None: - """Custom callback overrides default validation.""" callback_called_with: list[tuple[str, str | None]] = [] async def custom_validate(server_url: str, prm_resource: str | None) -> None: @@ -901,8 +788,7 @@ async def custom_validate(server_url: str, prm_resource: str | None) -> None: ) provider._initialized = True - # This would normally fail default validation (different origin), - # but custom callback accepts it + # Would fail default validation (different origin); the custom callback accepts it prm = ProtectedResourceMetadata( resource=AnyHttpUrl("https://evil.example.com/mcp"), authorization_servers=[AnyHttpUrl("https://auth.example.com")], @@ -915,7 +801,6 @@ async def custom_validate(server_url: str, prm_resource: str | None) -> None: async def test_validate_resource_accepts_root_url_with_trailing_slash( client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ) -> None: - """Root URLs with trailing slash normalization should match.""" provider = OAuthClientProvider( server_url="https://api.example.com", client_metadata=client_metadata, @@ -935,7 +820,6 @@ async def test_validate_resource_accepts_root_url_with_trailing_slash( async def test_validate_resource_accepts_server_url_with_trailing_slash( client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ) -> None: - """Server URL with trailing slash should match PRM resource.""" provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp/", client_metadata=client_metadata, @@ -955,7 +839,6 @@ async def test_validate_resource_accepts_server_url_with_trailing_slash( async def test_get_resource_url_uses_canonical_when_prm_mismatches( client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ) -> None: - """get_resource_url falls back to canonical URL when PRM resource doesn't match.""" provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", client_metadata=client_metadata, @@ -969,18 +852,12 @@ async def test_get_resource_url_uses_canonical_when_prm_mismatches( authorization_servers=[AnyHttpUrl("https://auth.example.com")], ) - # get_resource_url should return the canonical server URL, not the PRM resource assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp") class TestRegistrationResponse: - """Test client registration response handling.""" - @pytest.mark.anyio async def test_handle_registration_response_reads_before_accessing_text(self): - """Test that response.aread() is called before accessing response.text.""" - - # Track if aread() was called class MockResponse(httpx.Response): def __init__(self): self.status_code = 400 @@ -999,21 +876,15 @@ def text(self): mock_response = MockResponse() - # This should call aread() before accessing text with pytest.raises(Exception) as exc_info: await handle_registration_response(mock_response) - # Verify aread() was called assert mock_response._aread_called - # Verify the error message includes the response text assert "Registration failed: 400" in str(exc_info.value) class TestCreateClientRegistrationRequest: - """Test client registration request creation.""" - def test_uses_registration_endpoint_from_metadata(self): - """Test that registration URL comes from metadata when available.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), @@ -1028,7 +899,6 @@ def test_uses_registration_endpoint_from_metadata(self): assert request.method == "POST" def test_falls_back_to_default_register_endpoint_when_no_metadata(self): - """Test that registration uses fallback URL when auth_server_metadata is None.""" client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) request = create_client_registration_request(None, client_metadata, "https://auth.example.com") @@ -1037,7 +907,6 @@ def test_falls_back_to_default_register_endpoint_when_no_metadata(self): assert request.method == "POST" def test_falls_back_when_metadata_has_no_registration_endpoint(self): - """Test fallback when metadata exists but lacks registration_endpoint.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), @@ -1068,55 +937,41 @@ def test_registration_request_sends_application_type(): class TestAuthFlow: - """Test the auth flow in httpx.""" - @pytest.mark.anyio async def test_auth_flow_with_valid_tokens( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken ): - """Test auth flow when tokens are already valid.""" - # Pre-store valid tokens await mock_storage.set_tokens(valid_tokens) oauth_provider.context.current_tokens = valid_tokens oauth_provider.context.token_expiry_time = time.time() + 1800 oauth_provider._initialized = True - # Create a test request test_request = httpx.Request("GET", "https://api.example.com/test") - # Mock the auth flow auth_flow = oauth_provider.async_auth_flow(test_request) - # Should get the request with auth header added request = await auth_flow.__anext__() assert request.headers["Authorization"] == "Bearer test_access_token" - # Send a successful response response = httpx.Response(200) try: await auth_flow.asend(response) except StopAsyncIteration: - pass # Expected + pass @pytest.mark.anyio async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage): - """Test auth flow when no tokens are available, triggering the full OAuth flow.""" - # Ensure no tokens are stored oauth_provider.context.current_tokens = None oauth_provider.context.token_expiry_time = None oauth_provider._initialized = True - # Create a test request test_request = httpx.Request("GET", "https://api.example.com/mcp") - # Mock the auth flow auth_flow = oauth_provider.async_auth_flow(test_request) - # First request should be the original request without auth header request = await auth_flow.__anext__() assert "Authorization" not in request.headers - # Send a 401 response to trigger the OAuth flow response = httpx.Response( 401, headers={ @@ -1125,25 +980,21 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide request=test_request, ) - # Next request should be to discover protected resource metadata discovery_request = await auth_flow.asend(response) assert discovery_request.method == "GET" assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" - # Send a successful discovery response with minimal protected resource metadata discovery_response = httpx.Response( 200, content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', request=discovery_request, ) - # Next request should be to discover OAuth metadata oauth_metadata_request = await auth_flow.asend(discovery_response) assert oauth_metadata_request.method == "GET" assert str(oauth_metadata_request.url).startswith("https://auth.example.com/") assert "mcp-protocol-version" in oauth_metadata_request.headers - # Send a successful OAuth metadata response oauth_metadata_response = httpx.Response( 200, content=( @@ -1155,30 +1006,25 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide request=oauth_metadata_request, ) - # Next request should be to register client registration_request = await auth_flow.asend(oauth_metadata_response) assert registration_request.method == "POST" assert str(registration_request.url) == "https://auth.example.com/register" - # Send a successful registration response registration_response = httpx.Response( 201, content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}', request=registration_request, ) - # Mock the authorization process oauth_provider._perform_authorization_code_grant = mock.AsyncMock( return_value=("test_auth_code", "test_code_verifier") ) - # Next request should be to exchange token token_request = await auth_flow.asend(registration_response) assert token_request.method == "POST" assert str(token_request.url) == "https://auth.example.com/token" assert "code=test_auth_code" in token_request.content.decode() - # Send a successful token response token_response = httpx.Response( 200, content=( @@ -1188,20 +1034,17 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide request=token_request, ) - # Final request should be the original request with auth header final_request = await auth_flow.asend(token_response) assert final_request.headers["Authorization"] == "Bearer new_access_token" assert final_request.method == "GET" assert str(final_request.url) == "https://api.example.com/mcp" - # Send final success response to properly close the generator final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) except StopAsyncIteration: - pass # Expected - generator should complete + pass - # Verify tokens were stored assert oauth_provider.context.current_tokens is not None assert oauth_provider.context.current_tokens.access_token == "new_access_token" assert oauth_provider.context.token_expiry_time is not None @@ -1210,8 +1053,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide async def test_auth_flow_no_unnecessary_retry_after_oauth( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken ): - """Test that requests are not retried unnecessarily - the core bug that caused 2x performance degradation.""" - # Pre-store valid tokens so no OAuth flow is needed + """Successful responses end the flow without a retry (regression: 2x performance degradation).""" await mock_storage.set_tokens(valid_tokens) oauth_provider.context.current_tokens = valid_tokens oauth_provider.context.token_expiry_time = time.time() + 1800 @@ -1220,56 +1062,43 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( test_request = httpx.Request("GET", "https://api.example.com/mcp") auth_flow = oauth_provider.async_auth_flow(test_request) - # Count how many times the request is yielded request_yields = 0 - # First request - should have auth header already request = await auth_flow.__anext__() request_yields += 1 assert request.headers["Authorization"] == "Bearer test_access_token" - # Send a successful 200 response response = httpx.Response(200, request=request) - # In the buggy version, this would yield the request AGAIN unconditionally - # In the fixed version, this should end the generator + # The buggy version yielded the request again here instead of ending the generator try: - await auth_flow.asend(response) # extra request + await auth_flow.asend(response) request_yields += 1 # pragma: no cover - # If we reach here, the bug is present pytest.fail( f"Unnecessary retry detected! Request was yielded {request_yields} times. " f"This indicates the retry logic bug that caused 2x performance degradation. " f"The request should only be yielded once for successful responses." ) # pragma: no cover except StopAsyncIteration: - # This is the expected behavior - no unnecessary retry pass - # Verify exactly one request was yielded (no double-sending) assert request_yields == 1, f"Expected 1 request yield, got {request_yields}" @pytest.mark.anyio async def test_token_exchange_accepts_201_status( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage ): - """Test that token exchange accepts both 200 and 201 status codes.""" - # Ensure no tokens are stored oauth_provider.context.current_tokens = None oauth_provider.context.token_expiry_time = None oauth_provider._initialized = True - # Create a test request test_request = httpx.Request("GET", "https://api.example.com/mcp") - # Mock the auth flow auth_flow = oauth_provider.async_auth_flow(test_request) - # First request should be the original request without auth header request = await auth_flow.__anext__() assert "Authorization" not in request.headers - # Send a 401 response to trigger the OAuth flow response = httpx.Response( 401, headers={ @@ -1278,25 +1107,21 @@ async def test_token_exchange_accepts_201_status( request=test_request, ) - # Next request should be to discover protected resource metadata discovery_request = await auth_flow.asend(response) assert discovery_request.method == "GET" assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" - # Send a successful discovery response with minimal protected resource metadata discovery_response = httpx.Response( 200, content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', request=discovery_request, ) - # Next request should be to discover OAuth metadata oauth_metadata_request = await auth_flow.asend(discovery_response) assert oauth_metadata_request.method == "GET" assert str(oauth_metadata_request.url).startswith("https://auth.example.com/") assert "mcp-protocol-version" in oauth_metadata_request.headers - # Send a successful OAuth metadata response oauth_metadata_response = httpx.Response( 200, content=( @@ -1308,30 +1133,25 @@ async def test_token_exchange_accepts_201_status( request=oauth_metadata_request, ) - # Next request should be to register client registration_request = await auth_flow.asend(oauth_metadata_response) assert registration_request.method == "POST" assert str(registration_request.url) == "https://auth.example.com/register" - # Send a successful registration response with 201 status registration_response = httpx.Response( 201, content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}', request=registration_request, ) - # Mock the authorization process oauth_provider._perform_authorization_code_grant = mock.AsyncMock( return_value=("test_auth_code", "test_code_verifier") ) - # Next request should be to exchange token token_request = await auth_flow.asend(registration_response) assert token_request.method == "POST" assert str(token_request.url) == "https://auth.example.com/token" assert "code=test_auth_code" in token_request.content.decode() - # Send a successful token response with 201 status code (test both 200 and 201 are accepted) token_response = httpx.Response( 201, content=( @@ -1341,20 +1161,17 @@ async def test_token_exchange_accepts_201_status( request=token_request, ) - # Final request should be the original request with auth header final_request = await auth_flow.asend(token_response) assert final_request.headers["Authorization"] == "Bearer new_access_token" assert final_request.method == "GET" assert str(final_request.url) == "https://api.example.com/mcp" - # Send final success response to properly close the generator final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) except StopAsyncIteration: - pass # Expected - generator should complete + pass - # Verify tokens were stored assert oauth_provider.context.current_tokens is not None assert oauth_provider.context.current_tokens.access_token == "new_access_token" assert oauth_provider.context.token_expiry_time is not None @@ -1366,8 +1183,6 @@ async def test_403_insufficient_scope_updates_scope_from_header( mock_storage: MockTokenStorage, valid_tokens: OAuthToken, ): - """Test that 403 response correctly updates scope from WWW-Authenticate header.""" - # Pre-store valid tokens and client info client_info = OAuthClientInformationFull( client_id="test_client_id", client_secret="test_client_secret", @@ -1380,7 +1195,6 @@ async def test_403_insufficient_scope_updates_scope_from_header( oauth_provider.context.client_info = client_info oauth_provider._initialized = True - # Original scope assert oauth_provider.context.client_metadata.scope == "read write" redirect_captured = False @@ -1392,14 +1206,12 @@ async def capture_redirect(url: str) -> None: # SEP-2350: the authorization URL carries the union of the prior and challenged scopes scope = parse_qs(urlparse(url).query)["scope"][0] assert scope == "read write admin:write admin:delete" - # Extract state from redirect URL parsed = urlparse(url) params = parse_qs(parsed.query) captured_state = params.get("state", [None])[0] oauth_provider.context.redirect_handler = capture_redirect - # Mock callback async def mock_callback() -> AuthorizationCodeResult: return AuthorizationCodeResult(code="auth_code", state=captured_state) @@ -1408,24 +1220,19 @@ async def mock_callback() -> AuthorizationCodeResult: test_request = httpx.Request("GET", "https://api.example.com/mcp") auth_flow = oauth_provider.async_auth_flow(test_request) - # First request request = await auth_flow.__anext__() - # Send 403 with new scope requirement response_403 = httpx.Response( 403, headers={"WWW-Authenticate": 'Bearer error="insufficient_scope", scope="admin:write admin:delete"'}, request=request, ) - # Trigger step-up - should get token exchange request token_exchange_request = await auth_flow.asend(response_403) - # Verify scope was updated to the union of prior and challenged scopes (SEP-2350) assert oauth_provider.context.client_metadata.scope == "read write admin:write admin:delete" assert redirect_captured - # Complete the flow with successful token response token_response = httpx.Response( 200, json={ @@ -1437,16 +1244,14 @@ async def mock_callback() -> AuthorizationCodeResult: request=token_exchange_request, ) - # Should get final retry request final_request = await auth_flow.asend(token_response) - # Send success response - flow should complete success_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(success_response) pytest.fail("Should have stopped after successful response") # pragma: no cover except StopAsyncIteration: - pass # Expected + pass @pytest.mark.anyio @@ -1519,8 +1324,7 @@ async def mock_callback() -> AuthorizationCodeResult: "revocation_endpoint", ), ( - # Pydantic's AnyUrl incorrectly adds trailing slash to base URLs - # This is being fixed in https://github.com/pydantic/pydantic-core/pull/1719 (Pydantic 2.12+) + # Pydantic AnyUrl adds a trailing slash to base URLs; fix: https://github.com/pydantic/pydantic-core/pull/1719 pytest.param( "https://auth.example.com", "https://auth.example.com/docs", @@ -1592,7 +1396,7 @@ class TestLegacyServerFallback: async def test_legacy_server_no_prm_falls_back_to_root_oauth_discovery( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test that when PRM discovery fails completely, we fall back to root OAuth discovery (March 2025 spec).""" + """When all PRM URLs fail, fall back to root OAuth discovery (March 2025 spec).""" async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -1622,33 +1426,25 @@ async def callback_handler() -> AuthorizationCodeResult: test_request = httpx.Request("GET", "https://mcp.linear.app/sse") auth_flow = provider.async_auth_flow(test_request) - # First request request = await auth_flow.__anext__() assert "Authorization" not in request.headers - # Send 401 without WWW-Authenticate header (typical legacy server) response = httpx.Response(401, headers={}, request=test_request) - # Should try path-based PRM first prm_request_1 = await auth_flow.asend(response) assert str(prm_request_1.url) == "https://mcp.linear.app/.well-known/oauth-protected-resource/sse" - # PRM returns 404 prm_response_1 = httpx.Response(404, request=prm_request_1) - # Should try root-based PRM prm_request_2 = await auth_flow.asend(prm_response_1) assert str(prm_request_2.url) == "https://mcp.linear.app/.well-known/oauth-protected-resource" - # PRM returns 404 again - all PRM URLs failed prm_response_2 = httpx.Response(404, request=prm_request_2) - # Should fall back to root OAuth discovery (March 2025 spec behavior) oauth_metadata_request = await auth_flow.asend(prm_response_2) assert str(oauth_metadata_request.url) == "https://mcp.linear.app/.well-known/oauth-authorization-server" assert oauth_metadata_request.method == "GET" - # Send successful OAuth metadata response oauth_metadata_response = httpx.Response( 200, content=( @@ -1659,28 +1455,23 @@ async def callback_handler() -> AuthorizationCodeResult: request=oauth_metadata_request, ) - # Mock authorization provider._perform_authorization_code_grant = mock.AsyncMock( return_value=("test_auth_code", "test_code_verifier") ) - # Next should be token exchange token_request = await auth_flow.asend(oauth_metadata_response) assert str(token_request.url) == "https://mcp.linear.app/token" - # Send successful token response token_response = httpx.Response( 200, content=b'{"access_token": "linear_token", "token_type": "Bearer", "expires_in": 3600}', request=token_request, ) - # Final request with auth header final_request = await auth_flow.asend(token_response) assert final_request.headers["Authorization"] == "Bearer linear_token" assert str(final_request.url) == "https://mcp.linear.app/sse" - # Complete flow final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) @@ -1691,8 +1482,6 @@ async def callback_handler() -> AuthorizationCodeResult: async def test_legacy_server_with_different_prm_and_root_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test PRM fallback with different WWW-Authenticate and root URLs.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -1721,7 +1510,6 @@ async def callback_handler() -> AuthorizationCodeResult: await auth_flow.__anext__() - # 401 with custom WWW-Authenticate PRM URL response = httpx.Response( 401, headers={ @@ -1730,32 +1518,24 @@ async def callback_handler() -> AuthorizationCodeResult: request=test_request, ) - # Try custom PRM URL first prm_request_1 = await auth_flow.asend(response) assert str(prm_request_1.url) == "https://custom.prm.com/.well-known/oauth-protected-resource" - # Returns 500 prm_response_1 = httpx.Response(500, request=prm_request_1) - # Try path-based fallback prm_request_2 = await auth_flow.asend(prm_response_1) assert str(prm_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" - # Returns 404 prm_response_2 = httpx.Response(404, request=prm_request_2) - # Try root fallback prm_request_3 = await auth_flow.asend(prm_response_2) assert str(prm_request_3.url) == "https://api.example.com/.well-known/oauth-protected-resource" - # Also returns 404 - all PRM URLs failed prm_response_3 = httpx.Response(404, request=prm_request_3) - # Should fall back to root OAuth discovery oauth_metadata_request = await auth_flow.asend(prm_response_3) assert str(oauth_metadata_request.url) == "https://api.example.com/.well-known/oauth-authorization-server" - # Complete the flow oauth_metadata_response = httpx.Response( 200, content=( @@ -1796,8 +1576,6 @@ class TestSEP985Discovery: async def test_path_based_fallback_when_no_www_authenticate( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test that client falls back to path-based well-known URI when WWW-Authenticate is absent.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -1812,17 +1590,14 @@ async def callback_handler() -> AuthorizationCodeResult: callback_handler=callback_handler, ) - # Test with 401 response without WWW-Authenticate header init_response = httpx.Response( status_code=401, headers={}, request=httpx.Request("GET", "https://api.example.com/v1/mcp") ) - # Build discovery URLs discovery_urls = build_protected_resource_metadata_discovery_urls( extract_resource_metadata_from_www_auth(init_response), provider.context.server_url ) - # Should have path-based URL first, then root-based URL assert len(discovery_urls) == 2 assert discovery_urls[0] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" @@ -1831,8 +1606,6 @@ async def callback_handler() -> AuthorizationCodeResult: async def test_root_based_fallback_after_path_based_404( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test that client falls back to root-based URI when path-based returns 404.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -1847,7 +1620,6 @@ async def callback_handler() -> AuthorizationCodeResult: callback_handler=callback_handler, ) - # Ensure no tokens are stored provider.context.current_tokens = None provider.context.token_expiry_time = None provider._initialized = True @@ -1858,33 +1630,25 @@ async def callback_handler() -> AuthorizationCodeResult: redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - # Create a test request test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") - # Mock the auth flow auth_flow = provider.async_auth_flow(test_request) - # First request should be the original request without auth header request = await auth_flow.__anext__() assert "Authorization" not in request.headers - # Send a 401 response without WWW-Authenticate header response = httpx.Response(401, headers={}, request=test_request) - # Next request should be to discover protected resource metadata (path-based) discovery_request_1 = await auth_flow.asend(response) assert str(discovery_request_1.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" assert discovery_request_1.method == "GET" - # Send 404 response for path-based discovery discovery_response_1 = httpx.Response(404, request=discovery_request_1) - # Next request should be to root-based well-known URI discovery_request_2 = await auth_flow.asend(discovery_response_1) assert str(discovery_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource" assert discovery_request_2.method == "GET" - # Send successful discovery response discovery_response_2 = httpx.Response( 200, content=( @@ -1893,14 +1657,11 @@ async def callback_handler() -> AuthorizationCodeResult: request=discovery_request_2, ) - # Mock the rest of the OAuth flow provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) - # Next should be OAuth metadata discovery oauth_metadata_request = await auth_flow.asend(discovery_response_2) assert oauth_metadata_request.method == "GET" - # Complete the flow oauth_metadata_response = httpx.Response( 200, content=( @@ -1932,8 +1693,6 @@ async def callback_handler() -> AuthorizationCodeResult: async def test_www_authenticate_takes_priority_over_well_known( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -1948,7 +1707,6 @@ async def callback_handler() -> AuthorizationCodeResult: callback_handler=callback_handler, ) - # Test with 401 response with WWW-Authenticate header init_response = httpx.Response( status_code=401, headers={ @@ -1957,12 +1715,10 @@ async def callback_handler() -> AuthorizationCodeResult: request=httpx.Request("GET", "https://api.example.com/v1/mcp"), ) - # Build discovery URLs discovery_urls = build_protected_resource_metadata_discovery_urls( extract_resource_metadata_from_www_auth(init_response), provider.context.server_url ) - # Should have WWW-Authenticate URL first, then fallback URLs assert len(discovery_urls) == 3 assert discovery_urls[0] == "https://custom.example.com/.well-known/oauth-protected-resource" assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" @@ -1970,8 +1726,6 @@ async def callback_handler() -> AuthorizationCodeResult: class TestWWWAuthenticate: - """Test WWW-Authenticate header parsing functionality.""" - @pytest.mark.parametrize( "www_auth_header,field_name,expected_value", [ @@ -2026,8 +1780,6 @@ def test_extract_field_from_www_auth_valid_cases( field_name: str, expected_value: str, ): - """Test extraction of various fields from valid WWW-Authenticate headers.""" - init_response = httpx.Response( status_code=401, headers={"WWW-Authenticate": www_auth_header}, @@ -2040,14 +1792,10 @@ def test_extract_field_from_www_auth_valid_cases( @pytest.mark.parametrize( "www_auth_header,field_name,description", [ - # No header (None, "scope", "no WWW-Authenticate header"), - # Empty header ("", "scope", "empty WWW-Authenticate header"), - # Header without requested field ('Bearer realm="api", error="insufficient_scope"', "scope", "no scope parameter"), ('Bearer realm="api", scope="read write"', "resource_metadata", "no resource_metadata parameter"), - # Malformed field (empty value) ("Bearer scope=", "scope", "malformed scope parameter"), ("Bearer resource_metadata=", "resource_metadata", "malformed resource_metadata parameter"), ], @@ -2060,8 +1808,6 @@ def test_extract_field_from_www_auth_invalid_cases( field_name: str, description: str, ): - """Test extraction returns None for invalid cases.""" - headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} init_response = httpx.Response( status_code=401, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") @@ -2095,11 +1841,9 @@ class TestCIMD: ], ) def test_is_valid_client_metadata_url(self, url: str | None, expected: bool): - """Test CIMD URL validation.""" assert is_valid_client_metadata_url(url) == expected def test_should_use_client_metadata_url_when_server_supports(self): - """Test that CIMD is used when server supports it and URL is provided.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), @@ -2109,7 +1853,6 @@ def test_should_use_client_metadata_url_when_server_supports(self): assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is True def test_should_not_use_client_metadata_url_when_server_does_not_support(self): - """Test that CIMD is not used when server doesn't support it.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), @@ -2119,7 +1862,6 @@ def test_should_not_use_client_metadata_url_when_server_does_not_support(self): assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is False def test_should_not_use_client_metadata_url_when_not_provided(self): - """Test that CIMD is not used when no URL is provided.""" oauth_metadata = OAuthMetadata( issuer=AnyHttpUrl("https://auth.example.com"), authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), @@ -2129,11 +1871,9 @@ def test_should_not_use_client_metadata_url_when_not_provided(self): assert should_use_client_metadata_url(oauth_metadata, None) is False def test_should_not_use_client_metadata_url_when_no_metadata(self): - """Test that CIMD is not used when OAuth metadata is None.""" assert should_use_client_metadata_url(None, "https://example.com/client") is False def test_create_client_info_from_metadata_url(self): - """Test creating client info from CIMD URL.""" client_info = create_client_info_from_metadata_url( "https://example.com/client", redirect_uris=[AnyUrl("http://localhost:3030/callback")], @@ -2146,8 +1886,6 @@ def test_create_client_info_from_metadata_url(self): def test_oauth_provider_with_valid_client_metadata_url( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test OAuthClientProvider initialization with valid client_metadata_url.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -2167,8 +1905,6 @@ async def callback_handler() -> AuthorizationCodeResult: def test_oauth_provider_with_invalid_client_metadata_url_raises_error( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test OAuthClientProvider raises error for invalid client_metadata_url.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -2190,8 +1926,6 @@ async def callback_handler() -> AuthorizationCodeResult: async def test_auth_flow_uses_cimd_when_server_supports( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test that auth flow uses CIMD URL as client_id when server supports it.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -2214,14 +1948,11 @@ async def callback_handler() -> AuthorizationCodeResult: test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") auth_flow = provider.async_auth_flow(test_request) - # First request request = await auth_flow.__anext__() assert "Authorization" not in request.headers - # Send 401 response response = httpx.Response(401, headers={}, request=test_request) - # PRM discovery prm_request = await auth_flow.asend(response) prm_response = httpx.Response( 200, @@ -2229,7 +1960,6 @@ async def callback_handler() -> AuthorizationCodeResult: request=prm_request, ) - # OAuth metadata discovery oauth_request = await auth_flow.asend(prm_response) oauth_response = httpx.Response( 200, @@ -2242,7 +1972,6 @@ async def callback_handler() -> AuthorizationCodeResult: request=oauth_request, ) - # Mock authorization provider._perform_authorization_code_grant = mock.AsyncMock( return_value=("test_auth_code", "test_code_verifier") ) @@ -2252,16 +1981,13 @@ async def callback_handler() -> AuthorizationCodeResult: assert token_request.method == "POST" assert str(token_request.url) == "https://auth.example.com/token" - # Verify client_id is the CIMD URL content = token_request.content.decode() assert "client_id=https%3A%2F%2Fexample.com%2Fclient" in content - # Verify client info was set correctly assert provider.context.client_info is not None assert provider.context.client_info.client_id == "https://example.com/client" assert provider.context.client_info.token_endpoint_auth_method == "none" - # Complete the flow token_response = httpx.Response( 200, content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', @@ -2281,8 +2007,6 @@ async def callback_handler() -> AuthorizationCodeResult: async def test_auth_flow_falls_back_to_dcr_when_no_cimd_support( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test that auth flow falls back to DCR when server doesn't support CIMD.""" - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -2305,13 +2029,10 @@ async def callback_handler() -> AuthorizationCodeResult: test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") auth_flow = provider.async_auth_flow(test_request) - # First request await auth_flow.__anext__() - # Send 401 response response = httpx.Response(401, headers={}, request=test_request) - # PRM discovery prm_request = await auth_flow.asend(response) prm_response = httpx.Response( 200, @@ -2332,7 +2053,6 @@ async def callback_handler() -> AuthorizationCodeResult: request=oauth_request, ) - # Should proceed to DCR instead of skipping it registration_request = await auth_flow.asend(oauth_response) assert registration_request.method == "POST" assert str(registration_request.url) == "https://auth.example.com/register" @@ -2344,7 +2064,6 @@ async def callback_handler() -> AuthorizationCodeResult: request=registration_request, ) - # Mock authorization provider._perform_authorization_code_grant = mock.AsyncMock( return_value=("test_auth_code", "test_code_verifier") ) @@ -2383,7 +2102,6 @@ def _make_prm(self, scopes_supported: list[str] | None = None) -> ProtectedResou ) def test_offline_access_added_when_as_supports_and_client_has_refresh_token(self): - """offline_access is appended when AS advertises it and client supports refresh_token grant.""" prm = self._make_prm(scopes_supported=["read", "write"]) asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) @@ -2396,7 +2114,6 @@ def test_offline_access_added_when_as_supports_and_client_has_refresh_token(self assert scopes == "read write offline_access" def test_offline_access_added_with_www_authenticate_scope(self): - """offline_access is appended even when scopes come from WWW-Authenticate header.""" asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) scopes = get_client_metadata_scopes( @@ -2408,7 +2125,6 @@ def test_offline_access_added_with_www_authenticate_scope(self): assert scopes == "read write offline_access" def test_offline_access_not_added_when_as_does_not_support(self): - """offline_access is not added when AS does not advertise it in scopes_supported.""" prm = self._make_prm(scopes_supported=["read", "write"]) asm = self._make_as_metadata(scopes_supported=["read", "write"]) @@ -2421,7 +2137,6 @@ def test_offline_access_not_added_when_as_does_not_support(self): assert scopes == "read write" def test_offline_access_not_added_when_client_has_no_refresh_token_grant(self): - """offline_access is not added when client does not support refresh_token grant.""" prm = self._make_prm(scopes_supported=["read", "write"]) asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) @@ -2434,7 +2149,6 @@ def test_offline_access_not_added_when_client_has_no_refresh_token_grant(self): assert scopes == "read write" def test_offline_access_not_duplicated_when_already_present(self): - """offline_access is not added again if it already appears in the selected scopes.""" prm = self._make_prm(scopes_supported=["read", "offline_access", "write"]) asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) @@ -2447,7 +2161,6 @@ def test_offline_access_not_duplicated_when_already_present(self): assert scopes == "read offline_access write" def test_offline_access_not_added_when_no_scopes_selected(self): - """offline_access is not added when no base scopes are available (None).""" asm = self._make_as_metadata(scopes_supported=["offline_access"]) scopes = get_client_metadata_scopes( @@ -2456,12 +2169,10 @@ def test_offline_access_not_added_when_no_scopes_selected(self): authorization_server_metadata=asm, client_grant_types=["authorization_code", "refresh_token"], ) - # When AS scopes are the only source and include offline_access, - # the base scope is "offline_access" and no duplication happens + # AS scopes are the only base source here, so offline_access is already present β€” no duplication assert scopes == "offline_access" def test_offline_access_not_added_when_as_scopes_supported_is_none(self): - """offline_access is not added when AS scopes_supported is None.""" prm = self._make_prm(scopes_supported=["read", "write"]) asm = self._make_as_metadata(scopes_supported=None) @@ -2474,7 +2185,6 @@ def test_offline_access_not_added_when_as_scopes_supported_is_none(self): assert scopes == "read write" def test_offline_access_not_added_when_no_as_metadata(self): - """offline_access is not added when AS metadata is not available.""" prm = self._make_prm(scopes_supported=["read", "write"]) scopes = get_client_metadata_scopes( @@ -2486,7 +2196,6 @@ def test_offline_access_not_added_when_no_as_metadata(self): assert scopes == "read write" def test_offline_access_not_added_when_no_grant_types_provided(self): - """offline_access is not added when client_grant_types is None.""" prm = self._make_prm(scopes_supported=["read", "write"]) asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) @@ -2499,7 +2208,6 @@ def test_offline_access_not_added_when_no_grant_types_provided(self): assert scopes == "read write" def test_default_client_metadata_includes_refresh_token_grant(self): - """Default OAuthClientMetadata includes refresh_token in grant_types (SEP-2207 guidance).""" metadata = OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]) assert "refresh_token" in metadata.grant_types @@ -2507,8 +2215,6 @@ def test_default_client_metadata_includes_refresh_token_grant(self): async def test_auth_flow_adds_offline_access_when_as_advertises( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """E2E: auth flow includes offline_access in authorization request when AS advertises it.""" - captured_auth_url: str | None = None captured_state: str | None = None @@ -2544,14 +2250,11 @@ async def callback_handler() -> AuthorizationCodeResult: test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") auth_flow = provider.async_auth_flow(test_request) - # First request request = await auth_flow.__anext__() assert "Authorization" not in request.headers - # Send 401 response = httpx.Response(401, headers={}, request=test_request) - # PRM discovery prm_request = await auth_flow.asend(response) prm_response = httpx.Response( 200, @@ -2563,7 +2266,6 @@ async def callback_handler() -> AuthorizationCodeResult: request=prm_request, ) - # OAuth metadata discovery - AS advertises offline_access oauth_request = await auth_flow.asend(prm_response) oauth_response = httpx.Response( 200, @@ -2579,7 +2281,6 @@ async def callback_handler() -> AuthorizationCodeResult: # This triggers authorization, which calls redirect_handler token_request = await auth_flow.asend(oauth_response) - # Verify the authorization URL included offline_access in the scope assert captured_auth_url is not None parsed = urlparse(captured_auth_url) params = parse_qs(parsed.query) @@ -2592,7 +2293,6 @@ async def callback_handler() -> AuthorizationCodeResult: # OIDC requires prompt=consent when offline_access is requested assert params["prompt"][0] == "consent" - # Complete the token exchange token_response = httpx.Response( 200, content=( @@ -2605,7 +2305,6 @@ async def callback_handler() -> AuthorizationCodeResult: final_request = await auth_flow.asend(token_response) assert final_request.headers["Authorization"] == "Bearer new_access_token" - # Close the generator final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) @@ -2616,8 +2315,6 @@ async def callback_handler() -> AuthorizationCodeResult: async def test_auth_flow_no_offline_access_when_as_does_not_advertise( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """E2E: auth flow does NOT include offline_access when AS doesn't advertise it.""" - captured_auth_url: str | None = None captured_state: str | None = None @@ -2653,13 +2350,10 @@ async def callback_handler() -> AuthorizationCodeResult: test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") auth_flow = provider.async_auth_flow(test_request) - # First request await auth_flow.__anext__() - # Send 401 response = httpx.Response(401, headers={}, request=test_request) - # PRM discovery prm_request = await auth_flow.asend(response) prm_response = httpx.Response( 200, @@ -2687,7 +2381,6 @@ async def callback_handler() -> AuthorizationCodeResult: # This triggers authorization, which calls redirect_handler token_request = await auth_flow.asend(oauth_response) - # Verify the authorization URL does NOT include offline_access assert captured_auth_url is not None parsed = urlparse(captured_auth_url) params = parse_qs(parsed.query) @@ -2700,7 +2393,6 @@ async def callback_handler() -> AuthorizationCodeResult: # prompt=consent should NOT be present without offline_access assert "prompt" not in params - # Complete the token exchange token_response = httpx.Response( 200, content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}', @@ -2710,7 +2402,6 @@ async def callback_handler() -> AuthorizationCodeResult: final_request = await auth_flow.asend(token_response) assert final_request.headers["Authorization"] == "Bearer new_access_token" - # Close the generator final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) @@ -2842,7 +2533,7 @@ async def test_handle_token_response_backfills_omitted_scope_from_request( """RFC 6749 Β§5.1: an omitted token-response scope means granted == requested. The token is stored with the requested scope filled in so it remains self-describing - after a restart, when the SEP-2350 step-up union reads it but ``client_metadata.scope`` + after a restart, when the SEP-2350 step-up union reads it but `client_metadata.scope` has reverted to its constructor value. """ oauth_provider.context.client_metadata.scope = "read admin" @@ -2906,7 +2597,7 @@ async def test_handle_refresh_response_carries_prior_scope_and_refresh_token_whe async def test_handle_refresh_response_adopts_rotated_refresh_token_when_returned( oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage ): - """A refresh response that includes ``refresh_token`` replaces the prior one (rotation).""" + """A refresh response that includes `refresh_token` replaces the prior one (rotation).""" oauth_provider.context.current_tokens = OAuthToken( access_token="old", scope="read write", refresh_token="prior-refresh" ) @@ -2929,7 +2620,7 @@ async def test_issuer_binding_re_evaluated_after_asm_when_prm_discovery_failed( ): """SEP-2352: on the legacy no-PRM path the binding check uses the ASM-discovered issuer. - PRM discovery fails (404) so ``auth_server_url`` stays ``None`` and the post-PRM check is + PRM discovery fails (404) so `auth_server_url` stays `None` and the post-PRM check is skipped; when ASM discovery then succeeds via the root well-known fallback, the discovered metadata's issuer is compared against the stored credentials' bound issuer and a mismatch triggers re-registration. @@ -2947,7 +2638,6 @@ async def test_issuer_binding_re_evaluated_after_asm_when_prm_discovery_failed( request = await auth_flow.__anext__() response_401 = httpx.Response(401, request=request) - # PRM discovery: path-based then root, both 404. prm_req = await auth_flow.asend(response_401) assert str(prm_req.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" prm_req = await auth_flow.asend(httpx.Response(404, request=prm_req)) @@ -3005,9 +2695,9 @@ async def test_issuer_is_not_stamped_when_registration_falls_back_to_the_resourc """SEP-2352: a fallback registration is not recorded as bound to the PRM-advertised AS. PRM advertises a new authorization server, so the stored credentials (bound to the old - issuer) are discarded. DCR then falls back to the resource-server origin's ``/register`` + issuer) are discarded. DCR then falls back to the resource-server origin's `/register` because the new AS's metadata either could not be discovered or omits - ``registration_endpoint``. That registration was not derived from the new AS's metadata, + `registration_endpoint`. That registration was not derived from the new AS's metadata, so persisting it as bound to the new AS would wedge the binding check on later flows; instead the issuer is left unset. """ @@ -3065,7 +2755,7 @@ async def echo_callback() -> AuthorizationCodeResult: asm_response.request = next_req next_req = await auth_flow.asend(asm_response) - # Step 4 falls back to the resource-server origin's /register. + # DCR falls back to the resource-server origin's /register. dcr_req = next_req assert dcr_req.method == "POST" assert str(dcr_req.url) == "https://api.example.com/register" @@ -3100,8 +2790,8 @@ async def test_issuer_is_stamped_when_same_origin_fallback_register_is_on_the_di """SEP-2352: a fallback registration on the discovered issuer's own host is still bound. Legacy same-origin embedded AS: PRM is absent, root ASM discovery succeeds with - ``issuer`` equal to the resource origin and no ``registration_endpoint``. DCR falls - back to ``/register`` β€” the issuer's own host β€” so the binding was + `issuer` equal to the resource origin and no `registration_endpoint`. DCR falls + back to `/register` β€” the issuer's own host β€” so the binding was established and is recorded, preserving auto-recovery on a later AS migration. """ oauth_provider.context.current_tokens = None @@ -3124,7 +2814,6 @@ async def echo_callback() -> AuthorizationCodeResult: auth_flow = oauth_provider.async_auth_flow(httpx.Request("GET", "https://api.example.com/v1/mcp")) request = await auth_flow.__anext__() - # PRM discovery 404s on both well-known URLs. prm_req = await auth_flow.asend(httpx.Response(401, request=request)) assert str(prm_req.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" prm_req = await auth_flow.asend(httpx.Response(404, request=prm_req)) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index a6a9ac6ea8..c0aacb91f4 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -52,8 +52,6 @@ @pytest.fixture def simple_server() -> Server: - """Create a simple MCP server for testing.""" - async def handle_list_resources( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> ListResourcesResult: @@ -87,7 +85,6 @@ async def handle_completion(ctx: ServerRequestContext, params: types.CompleteReq @pytest.fixture def app() -> MCPServer: - """Create an MCPServer server for testing.""" server = MCPServer("test") @server.tool() @@ -109,7 +106,6 @@ def greeting_prompt(name: str) -> str: async def test_client_is_initialized(app: MCPServer): - """Test that the client is initialized after entering context.""" async with Client(app, mode="legacy") as client: assert client.server_capabilities == snapshot( ServerCapabilities( @@ -123,13 +119,11 @@ async def test_client_is_initialized(app: MCPServer): async def test_client_exposes_negotiated_protocol_version(app: MCPServer): - """The negotiated protocol version is readable after initialization.""" async with Client(app, mode="legacy") as client: assert client.protocol_version == LATEST_HANDSHAKE_VERSION async def test_client_with_simple_server(simple_server: Server): - """Test that from_server works with a basic Server instance.""" async with Client(simple_server) as client: resources = await client.list_resources() assert resources == snapshot( @@ -184,7 +178,6 @@ async def test_client_call_tool(app: MCPServer): async def test_read_resource(app: MCPServer): - """Test reading a resource.""" async with Client(app) as client: result = await client.read_resource("test://resource") assert result == snapshot( @@ -195,8 +188,6 @@ async def test_read_resource(app: MCPServer): async def test_read_resource_error_propagates(): - """MCPError raised by a server handler propagates to the client with its code intact.""" - async def handle_read_resource( ctx: ServerRequestContext, params: types.ReadResourceRequestParams ) -> ReadResourceResult: @@ -210,10 +201,6 @@ async def handle_read_resource( async def test_raise_exceptions_propagates_handler_error_on_modern_inproc_path(): - """`raise_exceptions=True` on the modern in-process path: an unmapped handler - exception reaches the client with its original type chained, instead of being - sanitized to an opaque `INTERNAL_ERROR`.""" - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: raise ValueError("boom") @@ -221,15 +208,12 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ async with Client(server, mode="2026-07-28", raise_exceptions=True) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("explode", {}) - # The original exception is chained β€” not swallowed into a generic "Internal server error". assert isinstance(exc_info.value.__cause__, ValueError) assert str(exc_info.value.__cause__) == "boom" async def test_raise_exceptions_false_sanitizes_handler_error_on_modern_inproc_path(): - """`raise_exceptions=False` (the default) on the modern in-process path: an - unmapped handler exception is sanitized to an opaque `INTERNAL_ERROR` so the - in-process path matches the wire path's leak guard.""" + """Sanitized to opaque `INTERNAL_ERROR` by default so the in-process path matches the wire path's leak guard.""" async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: raise ValueError("boom") @@ -244,7 +228,6 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ async def test_get_prompt(app: MCPServer): - """Test getting a prompt.""" async with Client(app) as client: result = await client.get_prompt("greeting_prompt", {"name": "Alice"}) assert result == snapshot( @@ -256,21 +239,18 @@ async def test_get_prompt(app: MCPServer): def test_client_session_property_before_enter(app: MCPServer): - """Test that accessing session before context manager raises RuntimeError.""" client = Client(app) with pytest.raises(RuntimeError, match="Client must be used within an async context manager"): client.session async def test_client_reentry_raises_runtime_error(app: MCPServer): - """Test that reentering a client raises RuntimeError.""" async with Client(app) as client: with pytest.raises(RuntimeError, match="Client is already entered"): await client.__aenter__() async def test_client_send_progress_notification(): - """Test sending progress notification.""" received_from_client = None event = anyio.Event() @@ -301,14 +281,12 @@ async def test_client_unsubscribe_resource(simple_server: Server): async def test_client_set_logging_level(simple_server: Server): - """Test setting logging level.""" async with Client(simple_server, mode="legacy") as client: result = await client.set_logging_level("debug") # pyright: ignore[reportDeprecated] assert result == snapshot(EmptyResult()) async def test_client_list_resources_with_params(app: MCPServer): - """Test listing resources with params parameter.""" async with Client(app) as client: result = await client.list_resources() assert result == snapshot( @@ -326,14 +304,12 @@ async def test_client_list_resources_with_params(app: MCPServer): async def test_client_list_resource_templates(app: MCPServer): - """Test listing resource templates with params parameter.""" async with Client(app) as client: result = await client.list_resource_templates() assert result == snapshot(ListResourceTemplatesResult(resource_templates=[])) async def test_list_prompts(app: MCPServer): - """Test listing prompts with params parameter.""" async with Client(app) as client: result = await client.list_prompts() assert result == snapshot( @@ -350,7 +326,6 @@ async def test_list_prompts(app: MCPServer): async def test_complete_with_prompt_reference(simple_server: Server): - """Test getting completions for a prompt argument.""" async with Client(simple_server) as client: ref = types.PromptReference(type="ref/prompt", name="test_prompt") result = await client.complete(ref=ref, argument={"name": "arg", "value": "test"}) @@ -388,7 +363,6 @@ def _set_test_contextvar(value: str) -> Iterator[None]: async def test_context_propagation(): - """Sender's contextvars.Context is propagated to the server handler.""" server = MCPServer("test") @server.tool() @@ -406,10 +380,7 @@ async def check_context() -> str: async def test_client_auto_mode_probes_discover_then_adopts(simple_server: Server) -> None: - """`mode='auto'` over an in-process HTTP transport: the `server/discover` probe - reaches the modern entry and the negotiated protocol version is adopted without - an `initialize` handshake. Runs over HTTP because the in-memory runner gates - `server/discover` behind the init handshake.""" + """Runs over HTTP because the in-memory runner gates `server/discover` behind the init handshake.""" with anyio.fail_after(5): async with ( mounted_app(simple_server) as (http, _), @@ -421,12 +392,8 @@ async def test_client_auto_mode_probes_discover_then_adopts(simple_server: Serve @pytest.mark.parametrize("code", [types.METHOD_NOT_FOUND, types.REQUEST_TIMEOUT, types.INTERNAL_ERROR]) async def test_client_auto_mode_falls_back_to_initialize_on_legacy_signal(code: int) -> None: - """`mode='auto'`: any JSON-RPC error from `server/discover` makes - `Client.__aenter__` run the legacy `initialize()` handshake and land at a - handshake-era protocol version. The denylist policy treats every server-sent - rpc-error as "not modern" β€” including INTERNAL_ERROR, since a legacy server - may crash on the unknown method before reaching its router. A real `Server` - always implements `server/discover`, so the server side is hand-played.""" + """Any rpc-error from `server/discover` reads as "not modern" β€” even INTERNAL_ERROR, since a legacy server + may crash on the unknown method. A real `Server` always implements `server/discover`, so it's hand-played.""" methods_seen: list[str] = [] async def scripted_server(streams: MessageStream) -> None: @@ -476,10 +443,7 @@ async def scripted_transport() -> AsyncIterator[TransportStreams]: @pytest.mark.anyio async def test_modern_list_tools_drops_tools_with_invalid_x_mcp_header_but_legacy_does_not() -> None: - """At 2026-07-28 the spec requires clients to exclude tools whose `x-mcp-header` - annotation is malformed; handshake-era sessions surface them unchanged. Two - tools are advertised β€” one valid, one with a non-RFC-9110-token header name β€” - and the modern client sees only the valid one.""" + """The 2026-07-28 spec requires excluding tools with a malformed `x-mcp-header`; handshake-era sessions don't.""" valid = types.Tool( name="ok", input_schema={"type": "object", "properties": {"a": {"type": "string", "x-mcp-header": "Region"}}}, @@ -507,9 +471,6 @@ async def on_list_tools( def test_client_rejects_handshake_era_mode_at_construction() -> None: - """A handshake-era protocol-version string passed as `mode=` is rejected by - `__post_init__` with a hint to use `mode='legacy'` β€” the version-pin path is - modern-only.""" server = MCPServer("test") with pytest.raises(ValueError, match=r"handshake-era version; use mode='legacy'"): Client(server, mode="2025-06-18") @@ -517,9 +478,6 @@ def test_client_rejects_handshake_era_mode_at_construction() -> None: Client(server, mode="not-a-version") -# ── SEP-2322 multi-round-trip auto-loop ──────────────────────────────────────── - - _NAME_SCHEMA = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} @@ -528,9 +486,7 @@ def _name_elicitation(message: str = "What is your name?") -> types.ElicitReques async def test_call_tool_auto_loop_dispatches_elicitation_then_returns_final_result() -> None: - """When the server returns `InputRequiredResult` carrying an elicitation, - `Client.call_tool` routes it to `elicitation_callback` and retries - automatically β€” the caller sees only the terminal `CallToolResult`.""" + """SEP-2322 auto-loop: `call_tool` routes the `InputRequiredResult` to `elicitation_callback` and retries.""" server = MCPServer("test") @server.tool() @@ -566,8 +522,6 @@ async def elicitation_callback( async def test_call_tool_auto_loop_dispatches_sampling_then_returns_final_result() -> None: - """`InputRequiredResult` with an embedded `CreateMessageRequest` is routed - to `sampling_callback` and the call retried with the model's reply.""" server = MCPServer("test") @server.tool() @@ -611,8 +565,6 @@ async def sampling_callback( async def test_call_tool_auto_loop_dispatches_list_roots_then_returns_final_result() -> None: - """`InputRequiredResult` with an embedded `ListRootsRequest` is routed to - `list_roots_callback` and the call retried with the returned roots.""" server = MCPServer("test") @server.tool() @@ -645,14 +597,11 @@ async def list_roots_callback(context: ClientRequestContext) -> types.ListRootsR async def test_call_tool_auto_loop_round_trips_evolving_request_state_across_three_rounds() -> None: - """A three-round flow where each `InputRequiredResult.request_state` - encodes the round number: the driver echoes it back byte-exact, the server - advances per round, and the elicitation callback runs once per round.""" + """The driver must echo each round's `request_state` back to the server byte-exact.""" server = MCPServer("test") @server.tool() async def multi(ctx: Context) -> str | types.InputRequiredResult: - # Round number is the integer the server stashed in `request_state` last leg. round_num = int(ctx.request_state) if ctx.request_state else 0 if round_num == 3: return "done after 3 rounds" @@ -680,9 +629,7 @@ async def elicitation_callback( async def test_call_tool_auto_loop_raises_mcp_error_when_no_callback_registered() -> None: - """SDK-defined: with no `elicitation_callback`, the default returns - `ErrorData(INVALID_REQUEST, ...)` and the driver raises it as `MCPError` - rather than retrying.""" + """SDK-defined: the default callback returns `ErrorData(INVALID_REQUEST)`, raised as `MCPError`, no retry.""" server = MCPServer("test") @server.tool() @@ -698,9 +645,6 @@ async def needs_input(ctx: Context) -> str | types.InputRequiredResult: async def test_get_prompt_auto_loop_resolves_input_required_via_callbacks() -> None: - """`Client.get_prompt` runs the same driver as `call_tool`: an - `InputRequiredResult` from `prompts/get` is fulfilled and retried.""" - async def handler( ctx: ServerRequestContext, params: types.GetPromptRequestParams ) -> types.GetPromptResult | types.InputRequiredResult: @@ -724,9 +668,6 @@ async def elicitation_callback( async def test_read_resource_auto_loop_resolves_input_required_via_callbacks() -> None: - """`Client.read_resource` runs the same driver as `call_tool`: an - `InputRequiredResult` from `resources/read` is fulfilled and retried.""" - async def handler( ctx: ServerRequestContext, params: types.ReadResourceRequestParams ) -> types.ReadResourceResult | types.InputRequiredResult: diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 387fa4b48e..c08d7f4608 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -1,8 +1,4 @@ -"""Tests for Unicode handling in streamable HTTP transport. - -Verifies that Unicode text is correctly transmitted and received in both directions -(serverβ†’client and clientβ†’server) using the streamable HTTP transport. -""" +"""Unicode round-trip tests (serverβ†’client and clientβ†’server) for the streamable HTTP transport.""" from collections.abc import AsyncIterator from contextlib import asynccontextmanager @@ -23,7 +19,6 @@ # The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. BASE_URL = "http://127.0.0.1:8000" -# Test constants with various Unicode characters UNICODE_TEST_STRINGS = { "cyrillic": "Π‘Π»ΠΎΠΉ Ρ…Ρ€Π°Π½ΠΈΠ»ΠΈΡ‰Π°, Π³Π΄Π΅ Ρ€Π°ΡΠΏΠΎΠ»Π°Π³Π°ΡŽΡ‚ΡΡ", "cyrillic_short": "ΠŸΡ€ΠΈΠ²Π΅Ρ‚ ΠΌΠΈΡ€", @@ -97,8 +92,7 @@ async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRe @asynccontextmanager async def unicode_session() -> AsyncIterator[ClientSession]: - """Yield an initialized ClientSession speaking streamable HTTP (SSE responses) to the - Unicode test server, entirely in process.""" + """Yield an initialized in-process ClientSession speaking streamable HTTP to the Unicode test server.""" server = Server( name="unicode_test_server", on_list_tools=handle_list_tools, @@ -112,8 +106,7 @@ async def unicode_session() -> AsyncIterator[ClientSession]: async with ( session_manager.run(), - # follow_redirects matches the SDK's own client factory; Starlette's Mount 307-redirects - # the bare /mcp path to /mcp/. + # follow_redirects matches the SDK's own client factory; Starlette's Mount 307-redirects /mcp to /mcp/. httpx.AsyncClient( transport=StreamingASGITransport(app), base_url=BASE_URL, follow_redirects=True ) as http_client, @@ -126,24 +119,19 @@ async def unicode_session() -> AsyncIterator[ClientSession]: @pytest.mark.anyio async def test_streamable_http_client_unicode_tool_call() -> None: - """Test that Unicode text is correctly handled in tool calls via streamable HTTP.""" async with unicode_session() as session: - # Test 1: List tools (serverβ†’client Unicode in descriptions) tools = await session.list_tools() assert len(tools.tools) == 1 - # Check Unicode in tool descriptions echo_tool = tools.tools[0] assert echo_tool.name == "echo_unicode" assert echo_tool.description is not None assert "πŸ”€" in echo_tool.description assert "πŸ‘‹" in echo_tool.description - # Test 2: Send Unicode text in tool call (clientβ†’serverβ†’client) for test_name, test_string in UNICODE_TEST_STRINGS.items(): result = await session.call_tool("echo_unicode", arguments={"text": test_string}) - # Verify server correctly received and echoed back Unicode assert len(result.content) == 1 content = result.content[0] assert content.type == "text" @@ -152,9 +140,7 @@ async def test_streamable_http_client_unicode_tool_call() -> None: @pytest.mark.anyio async def test_streamable_http_client_unicode_prompts() -> None: - """Test that Unicode text is correctly handled in prompts via streamable HTTP.""" async with unicode_session() as session: - # Test 1: List prompts (serverβ†’client Unicode in descriptions) prompts = await session.list_prompts() assert len(prompts.prompts) == 1 @@ -163,7 +149,6 @@ async def test_streamable_http_client_unicode_prompts() -> None: assert prompt.description is not None assert "Π‘Π»ΠΎΠΉ Ρ…Ρ€Π°Π½ΠΈΠ»ΠΈΡ‰Π°, Π³Π΄Π΅ Ρ€Π°ΡΠΏΠΎΠ»Π°Π³Π°ΡŽΡ‚ΡΡ" in prompt.description - # Test 2: Get prompt with Unicode content (serverβ†’client) result = await session.get_prompt("unicode_prompt", arguments={}) assert len(result.messages) == 1 diff --git a/tests/client/test_input_required.py b/tests/client/test_input_required.py index cc58cf8dbe..0c82d53ecf 100644 --- a/tests/client/test_input_required.py +++ b/tests/client/test_input_required.py @@ -1,10 +1,7 @@ """Unit tests for the SEP-2322 client-side multi-round-trip driver. -`run_input_required_driver` is pure: it takes the first `InputRequiredResult` -plus `dispatch` / `retry` closures and loops until a terminal result. These -tests build those closures by hand (scripted lists, recording lists) so the -driver is exercised without a `ClientSession`. Integration against a real -server lives in `test_client.py`. +`run_input_required_driver` is pure, so these tests hand-build its `dispatch`/`retry` +closures and never touch a `ClientSession`; integration lives in `test_client.py`. """ import anyio @@ -48,8 +45,6 @@ async def _never_dispatch(key: str, req: InputRequest) -> InputResponse | ErrorD async def test_single_round_dispatches_then_retries_to_terminal_result() -> None: - """One `InputRequiredResult` with one elicit request: dispatch runs once, - retry runs once with the collected response, and the terminal result is returned.""" first = InputRequiredResult(input_requests={"ask": _elicit()}) terminal = CallToolResult(content=[TextContent(text="done")]) dispatched: list[tuple[str, InputRequest]] = [] @@ -73,8 +68,6 @@ async def retry(responses: InputResponses | None, state: str | None) -> CallTool async def test_multi_round_loops_until_retry_returns_non_input_required() -> None: - """Two consecutive `InputRequiredResult` legs followed by a terminal result: - the driver dispatches and retries each leg in order.""" terminal = CallToolResult(content=[TextContent(text="done")]) script: list[CallToolResult | InputRequiredResult] = [ InputRequiredResult(input_requests={"b": _elicit("second?")}), @@ -106,8 +99,6 @@ async def retry(responses: InputResponses | None, state: str | None) -> CallTool async def test_exceeding_max_rounds_raises_with_the_configured_cap() -> None: - """When every retry returns another `InputRequiredResult`, the driver gives - up after `max_rounds` retries with `InputRequiredRoundsExceededError`.""" rounds: list[int] = [] async def dispatch(key: str, req: InputRequest) -> InputResponse | ErrorData: @@ -128,9 +119,6 @@ async def retry(responses: InputResponses | None, state: str | None) -> CallTool async def test_dispatch_returning_error_data_aborts_the_loop_as_mcp_error() -> None: - """SDK-defined: a callback that refuses an embedded request returns - `ErrorData`; the driver surfaces it as `MCPError` rather than retrying.""" - async def dispatch(key: str, req: InputRequest) -> InputResponse | ErrorData: return ErrorData(code=INVALID_REQUEST, message="not supported") @@ -145,8 +133,6 @@ async def retry(responses: InputResponses | None, state: str | None) -> CallTool async def test_request_state_passes_through_byte_identical() -> None: - """`request_state` is opaque to the driver: each leg's value reaches `retry` - as the same object the server sent, never parsed or rebuilt.""" states = ['{"round": 1, "tag": "hΓ©llo"}', '{"round": 2, "tag": "wΓΆrld"}'] received_states: list[str | None] = [] @@ -167,16 +153,12 @@ async def retry(responses: InputResponses | None, state: str | None) -> CallTool assert received_states[1] is states[1] -# Runs on trio's autojumping virtual clock so the backoff sleeps add zero -# wall-clock and the recorded deltas are exact: `anyio.sleep` advances the -# MockClock by precisely the requested duration once every task is idle. +# Trio's autojumping MockClock makes the backoff sleeps instant and the recorded deltas exact. @pytest.mark.parametrize( "anyio_backend", [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], ) async def test_state_only_legs_back_off_exponentially_to_the_cap() -> None: - """SDK-defined pacing: state-only legs sleep 50ms, 100ms, 200ms, then cap at - 250ms. Six state-only rounds β†’ deltas `[0.05, 0.1, 0.2, 0.25, 0.25, 0.25]`.""" retry_times: list[float] = [] async def retry(responses: InputResponses | None, state: str | None) -> CallToolResult | InputRequiredResult: @@ -203,9 +185,6 @@ async def retry(responses: InputResponses | None, state: str | None) -> CallTool [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], ) async def test_backoff_counter_resets_after_a_leg_with_input_requests() -> None: - """A leg carrying `input_requests` resets `consecutive_state_only`: the - next state-only leg sleeps the initial 50ms again, not the prior position.""" - # state-only, state-only, dispatch leg (no sleep), state-only, terminal. script: list[CallToolResult | InputRequiredResult] = [ InputRequiredResult(request_state="s"), InputRequiredResult(input_requests={"k": _elicit()}), @@ -233,9 +212,6 @@ async def retry(responses: InputResponses | None, state: str | None) -> CallTool async def test_input_requests_are_dispatched_concurrently() -> None: - """All `input_requests` in a round are dispatched together: each dispatch - blocks on a shared gate that only opens once every key has started, so a - sequential implementation would deadlock under the `fail_after`.""" keys = ["a", "b", "c"] started: set[str] = set() all_started = anyio.Event() @@ -244,7 +220,7 @@ async def dispatch(key: str, req: InputRequest) -> InputResponse | ErrorData: started.add(key) if started == set(keys): all_started.set() - await all_started.wait() # blocks until every sibling is in-flight + await all_started.wait() # gate opens only when every key has started; sequential dispatch deadlocks here return ElicitResult(action="accept", content={"name": key}) received: list[InputResponses | None] = [] diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index 043b630d5c..efa211c61b 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -15,12 +15,9 @@ @pytest.fixture async def full_featured_server(): - """Create a server with tools, resources, prompts, and templates.""" server = MCPServer("test") - # pragma: no cover on handlers below - these exist only to register items with the - # server so list_* methods return results. The handlers themselves are never called - # because these tests only verify pagination/cursor behavior, not tool/resource invocation. + # The no-cover handlers exist only so list_* methods return results; these tests never invoke them. @server.tool() def greet(name: str) -> str: # pragma: no cover """Greet someone by name.""" @@ -59,16 +56,10 @@ async def test_list_methods_params_parameter( method_name: str, request_method: str, ): - """Test that the params parameter is accepted and correctly passed to the server. - - Covers: list_tools, list_resources, list_prompts, list_resource_templates - - See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format - """ + """See https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format""" async with Client(full_featured_server, mode="legacy") as client: spies = stream_spy() - # Test without params (omitted) method = getattr(client, method_name) _ = await method() requests = spies.get_client_requests(method=request_method) @@ -77,7 +68,6 @@ async def test_list_methods_params_parameter( spies.clear() - # Test with params containing cursor _ = await method(cursor="from_params") requests = spies.get_client_requests(method=request_method) assert len(requests) == 1 @@ -86,18 +76,16 @@ async def test_list_methods_params_parameter( spies.clear() - # Test with empty params + # A plain call after a cursor call must again omit the cursor _ = await method() requests = spies.get_client_requests(method=request_method) assert len(requests) == 1 - # Empty params means no cursor assert requests[0].params is None or "cursor" not in requests[0].params async def test_list_tools_with_strict_server_validation( full_featured_server: MCPServer, ): - """Test pagination with a server that validates request format strictly.""" async with Client(full_featured_server) as client: result = await client.list_tools() assert isinstance(result, ListToolsResult) @@ -105,12 +93,10 @@ async def test_list_tools_with_strict_server_validation( async def test_list_tools_with_lowlevel_server(): - """Test that list_tools works with a lowlevel Server using params.""" - async def handle_list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> ListToolsResult: - # Echo back what cursor we received in the tool description + # Echo the received cursor through the tool description so the client side can assert on it cursor = params.cursor if params else None return ListToolsResult( tools=[types.Tool(name="test_tool", description=f"cursor={cursor}", input_schema={"type": "object"})] diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 14ca1577d1..3d369a4ec5 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -30,17 +30,14 @@ async def test_list_roots(context: Context, message: str): assert roots == callback_return return True - # Test with list_roots callback async with Client(server, list_roots_callback=list_roots_callback, mode="legacy") as client: - # Make a request to trigger sampling callback result = await client.call_tool("test_list_roots", {"message": "test message"}) assert result.is_error is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" - # Without a list_roots callback the client responds with an MCPError, which the - # tool body doesn't catch β€” the wrapper re-raises it as a top-level JSON-RPC - # error rather than wrapping it as an isError result. + # Without a callback the client responds with an MCPError; the tool body doesn't catch it, + # so the wrapper re-raises it as a top-level JSON-RPC error rather than an isError result. async with Client(server, mode="legacy") as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("test_list_roots", {"message": "test message"}) diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index d62b7e19b3..56e30def6d 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -25,13 +25,10 @@ async def test_logging_callback(): server = MCPServer("test") logging_collector = LoggingCollector() - # Create a simple test tool @server.tool("test_tool") async def test_tool() -> bool: - # The actual tool is very simple and just returns True return True - # Create a function that can send a log notification @server.tool("test_tool_with_log") async def test_tool_with_log( message: str, level: Literal["debug", "info", "warning", "error"], logger: str, ctx: Context @@ -54,7 +51,6 @@ async def test_tool_with_log_dict( ) return True - # Create a message handler to catch exceptions async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: @@ -67,13 +63,11 @@ async def message_handler( message_handler=message_handler, mode="legacy", ) as client: - # First verify our test tool works result = await client.call_tool("test_tool", {}) assert result.is_error is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" - # Now send a log message via our tool log_result = await client.call_tool( "test_tool_with_log", { @@ -92,7 +86,6 @@ async def message_handler( assert log_result.is_error is False assert log_result_with_dict.is_error is False assert len(logging_collector.log_messages) == 2 - # Create meta object with related_request_id added dynamically log = logging_collector.log_messages[0] assert log.level == "info" assert log.logger == "test_logger" diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 418a6bc54b..b41c421fb1 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -1,8 +1,4 @@ -"""Tests for StreamableHTTP client transport with non-SDK servers. - -These tests verify client behavior when interacting with servers -that don't follow SDK conventions. -""" +"""StreamableHTTP client behavior against servers that don't follow SDK conventions.""" import json @@ -33,8 +29,6 @@ def _init_json_response(data: dict[str, object]) -> JSONResponse: def _create_non_sdk_server_app() -> Starlette: - """Create a minimal server that doesn't follow SDK conventions.""" - async def handle_mcp_request(request: Request) -> Response: body = await request.body() data = json.loads(body) @@ -42,7 +36,7 @@ async def handle_mcp_request(request: Request) -> Response: if data.get("method") == "initialize": return _init_json_response(data) - # For notifications, return 204 No Content (non-SDK behavior) + # Notifications get 204 instead of the spec's 202 if "id" not in data: return Response(status_code=204, headers={"Content-Type": "application/json"}) @@ -54,8 +48,6 @@ async def handle_mcp_request(request: Request) -> Response: def _create_unexpected_content_type_app() -> Starlette: - """Create a server that returns an unexpected content type for requests.""" - async def handle_mcp_request(request: Request) -> Response: body = await request.body() data = json.loads(body) @@ -66,19 +58,15 @@ async def handle_mcp_request(request: Request) -> Response: if "id" not in data: return Response(status_code=202) - # Return text/plain for all other requests β€” an unexpected content type. return Response(content="this is plain text, not json or sse", status_code=200, media_type="text/plain") return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) async def test_non_compliant_notification_response() -> None: - """Verify the client ignores unexpected responses to notifications. + """Non-202 2xx notification responses (e.g. 204) are ignored, matching the TS SDK. - The spec states notifications should get either 202 + no response body, or 4xx + optional error body - (https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server), - but some servers wrongly return other 2xx codes (e.g. 204). For now we simply ignore unexpected responses - (aligning behaviour w/ the TS SDK). + Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server """ returned_exception = None @@ -94,7 +82,6 @@ async def message_handler( # pragma: no cover async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: await session.initialize() - # The test server returns a 204 instead of the expected 202 await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) if returned_exception: # pragma: no cover @@ -102,12 +89,7 @@ async def message_handler( # pragma: no cover async def test_unexpected_content_type_sends_jsonrpc_error() -> None: - """Verify unexpected content types unblock the pending request with an MCPError. - - When a server returns a content type that is neither application/json nor text/event-stream, - the client should send a JSONRPCError so the pending request resolves immediately - instead of hanging until timeout. - """ + """The synthesized JSONRPCError resolves the pending request immediately instead of hanging until timeout.""" async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_unexpected_content_type_app())) as client: async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch @@ -118,8 +100,6 @@ async def test_unexpected_content_type_sends_jsonrpc_error() -> None: def _create_http_error_app(error_status: int, *, error_on_notifications: bool = False) -> Starlette: - """Create a server that returns an HTTP error for non-init requests.""" - async def handle_mcp_request(request: Request) -> Response: body = await request.body() data = json.loads(body) @@ -138,12 +118,7 @@ async def handle_mcp_request(request: Request) -> Response: async def test_http_error_status_sends_jsonrpc_error() -> None: - """Verify HTTP 5xx errors unblock the pending request with an MCPError. - - When a server returns a non-2xx status code (e.g. 500), the client should - send a JSONRPCError so the pending request resolves immediately instead of - raising an unhandled httpx.HTTPStatusError that causes the caller to hang. - """ + """The HTTP error becomes a JSONRPCError rather than an unhandled httpx.HTTPStatusError that hangs the caller.""" async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_http_error_app(500))) as client: async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch @@ -154,24 +129,17 @@ async def test_http_error_status_sends_jsonrpc_error() -> None: async def test_http_error_on_notification_does_not_hang() -> None: - """Verify HTTP errors on notifications are silently ignored. - - When a notification gets an HTTP error, there is no pending request to - unblock, so the client should just return without sending a JSONRPCError. - """ + """With no pending request to unblock, the client silently ignores the error.""" app = _create_http_error_app(500, error_on_notifications=True) async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() - # Should not raise or hang β€” the error is silently ignored for notifications await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) def _create_invalid_json_response_app() -> Starlette: - """Create a server that returns invalid JSON for requests.""" - async def handle_mcp_request(request: Request) -> Response: body = await request.body() data = json.loads(body) @@ -182,19 +150,12 @@ async def handle_mcp_request(request: Request) -> Response: if "id" not in data: return Response(status_code=202) - # Return application/json content type but with invalid JSON body. return Response(content="not valid json{{{", status_code=200, media_type="application/json") return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) async def test_invalid_json_response_sends_jsonrpc_error() -> None: - """Verify invalid JSON responses unblock the pending request with an MCPError. - - When a server returns application/json with an unparseable body, the client - should send a JSONRPCError so the pending request resolves immediately - instead of hanging until timeout. - """ async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_invalid_json_response_app())) as client: async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch @@ -205,11 +166,8 @@ async def test_invalid_json_response_sends_jsonrpc_error() -> None: def _create_non_2xx_json_body_app(status: int, body: bytes) -> Starlette: - """Server that returns a fixed non-2xx status + ``application/json`` body for non-init requests. - - The initialize response carries an ``mcp-session-id`` so the client treats subsequent - requests as part of an established session (needed for the 404 → session-terminated mapping). - """ + """Server returning a fixed non-2xx status + JSON body; init sets `mcp-session-id` so later + requests count as in-session (needed for the 404 → session-terminated mapping).""" async def handle_mcp_request(request: Request) -> Response: data = json.loads(await request.body()) @@ -226,9 +184,8 @@ async def handle_mcp_request(request: Request) -> Response: async def test_client_surfaces_jsonrpc_error_from_non_2xx_body_with_correlated_id() -> None: - """SDK-defined: a JSON-RPC error in a non-2xx body is surfaced verbatim even when the - server set ``id: null`` — the client rewraps it under the pending request's id, so - the awaiting call resolves with the server's error code instead of the generic fallback.""" + """SDK-defined: a JSON-RPC error in a non-2xx body with `id: null` is rewrapped under the + pending request's id, so the caller sees the server's error code, not the generic fallback.""" body = json.dumps( {"jsonrpc": "2.0", "id": None, "error": {"code": types.METHOD_NOT_FOUND, "message": "nope"}} ).encode() @@ -243,9 +200,8 @@ async def test_client_surfaces_jsonrpc_error_from_non_2xx_body_with_correlated_i async def test_client_falls_back_to_generic_error_when_non_2xx_body_is_a_jsonrpc_result() -> None: - """SDK-defined: a non-2xx response whose JSON body parses as a JSON-RPC *result* (not an - error) falls through to the generic ``INTERNAL_ERROR`` fallback rather than being - treated as the request's reply.""" + """SDK-defined: a non-2xx body that parses as a JSON-RPC result (not an error) falls through + to the generic INTERNAL_ERROR fallback rather than being treated as the request's reply.""" app = _create_non_2xx_json_body_app(400, b'{"jsonrpc":"2.0","id":1,"result":{}}') async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): @@ -257,9 +213,8 @@ async def test_client_falls_back_to_generic_error_when_non_2xx_body_is_a_jsonrpc async def test_client_falls_back_to_session_terminated_when_404_body_is_malformed_json() -> None: - """SDK-defined: an unparseable ``application/json`` body on a 404 response is swallowed - and the status-derived ``INVALID_REQUEST`` (session-terminated) fallback resolves the - pending request — the parse failure never propagates.""" + """SDK-defined: a malformed 404 body is swallowed; the status-derived session-terminated + fallback resolves the pending request rather than the parse failure propagating.""" app = _create_non_2xx_json_body_app(404, b"not valid json{{{") async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index 60f6fadc0f..994ab4dcff 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -19,7 +19,7 @@ def _make_server( tools: list[Tool], structured_content: dict[str, Any], ) -> Server: - """Create a low-level server that returns the given structured_content for any tool call.""" + """Server that returns `structured_content` for every tool call.""" async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: return ListToolsResult(tools=tools) @@ -35,7 +35,6 @@ async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_basemodel(): - """Test that client validates structured content against schema for BaseModel outputs""" output_schema = { "type": "object", "properties": {"name": {"type": "string", "title": "Name"}, "age": {"type": "integer", "title": "Age"}}, @@ -63,7 +62,6 @@ async def test_tool_structured_output_client_side_validation_basemodel(): @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_primitive(): - """Test that client validates structured content for primitive outputs""" output_schema = { "type": "object", "properties": {"result": {"type": "integer", "title": "Result"}}, @@ -91,7 +89,6 @@ async def test_tool_structured_output_client_side_validation_primitive(): @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_dict_typed(): - """Test that client validates dict[str, T] structured content""" output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} server = _make_server( @@ -114,7 +111,6 @@ async def test_tool_structured_output_client_side_validation_dict_typed(): @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_missing_required(): - """Test that client validates missing required fields""" output_schema = { "type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "email": {"type": "string"}}, @@ -142,8 +138,6 @@ async def test_tool_structured_output_client_side_validation_missing_required(): @pytest.mark.anyio async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture): - """Test that client logs warning when tool is not in list_tools but has output_schema""" - async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: return ListToolsResult(tools=[]) diff --git a/tests/client/test_probe.py b/tests/client/test_probe.py index 34a347fa7d..3677af7d5f 100644 --- a/tests/client/test_probe.py +++ b/tests/client/test_probe.py @@ -1,14 +1,9 @@ -"""Unit tests for the connect-time auto-negotiation policy (`mcp.client._probe.negotiate_auto`). +"""Tests for the connect-time auto-negotiation policy (`mcp.client._probe.negotiate_auto`). -`negotiate_auto` is a small policy function that drives a `ClientSession` through the -``server/discover`` probe and decides between ``adopt()`` (modern), ``initialize()`` (legacy -fallback), or letting the probe's exception propagate. The policy is a *denylist*: every -``MCPError`` falls back to ``initialize()``, the sole exception being -32022 with a disjoint -modern-only ``supported`` list. Any non-``MCPError`` exception (network errors, anyio -resource errors) propagates untouched — an outage is never an era verdict. - -These tests pin the classifier in isolation with a stub session; the end-to-end wire shape is -covered by ``tests/interaction/lowlevel/test_client_connect.py``. +The policy is a denylist: every `MCPError` from the `server/discover` probe falls back to +`initialize()`, except -32022 with a disjoint modern-only `supported` list, which re-raises. +Non-`MCPError` exceptions propagate untouched — an outage is never an era verdict. +Wire-level coverage lives in `tests/interaction/lowlevel/test_client_connect.py`. """ from __future__ import annotations @@ -42,11 +37,7 @@ class _StubSession: - """Minimal stand-in for `ClientSession` exposing only what `negotiate_auto` touches. - - `send_discover` plays back a script (raise an exception, or return a dict); - `initialize` and `adopt` just record that they were called. - """ + """Stand-in for `ClientSession`: `send_discover` plays back a script; `initialize`/`adopt` record calls.""" def __init__(self, *script: dict[str, Any] | Exception) -> None: self._script: list[dict[str, Any] | Exception] = list(script) @@ -69,7 +60,7 @@ def adopt(self, result: types.DiscoverResult) -> None: async def _negotiate(session: _StubSession) -> None: - """Drive `negotiate_auto` against the stub; cast at one seam so the tests stay suppression-free.""" + """Cast at one seam so the tests stay suppression-free.""" await negotiate_auto(cast("ClientSession", session)) @@ -89,11 +80,7 @@ def _err_32022(supported: Any) -> MCPError: ) -# --- happy path: modern server --- - - async def test_a_valid_discover_result_is_adopted_without_initializing() -> None: - """A parseable `DiscoverResult` from the probe is adopted; `initialize()` is never called.""" session = _StubSession(_discover_dict()) await _negotiate(session) assert session.adopted is not None @@ -103,17 +90,12 @@ async def test_a_valid_discover_result_is_adopted_without_initializing() -> None async def test_an_unparseable_discover_result_falls_back_to_initialize() -> None: - """A probe response that does not validate as `DiscoverResult` is not modern evidence, - so the policy falls back to the legacy handshake instead of adopting garbage.""" session = _StubSession({"not": "a discover result"}) await _negotiate(session) assert session.initialized assert session.adopted is None -# --- the denylist: every JSON-RPC error code falls back --- - - @pytest.mark.parametrize( "code", [ @@ -124,10 +106,7 @@ async def test_an_unparseable_discover_result_falls_back_to_initialize() -> None ], ) async def test_any_jsonrpc_error_from_the_probe_falls_back_to_initialize(code: int) -> None: - """The denylist: every server-sent JSON-RPC error code is treated as "not modern" and - triggers the legacy `initialize()` handshake. Legacy servers reject the unknown - ``server/discover`` method with various codes (-32601, -32600, -32603, -32700) depending - on where in their pipeline the request bounces.""" + """Legacy servers reject the unknown `server/discover` method with varying codes; every one means "not modern".""" session = _StubSession(MCPError(code=code, message="nope")) await _negotiate(session) assert session.initialized @@ -135,12 +114,7 @@ async def test_any_jsonrpc_error_from_the_probe_falls_back_to_initialize(code: i assert session.probed_at == [LATEST_MODERN_VERSION] -# --- -32022 corrective retry --- - - async def test_unsupported_version_with_a_mutual_modern_version_retries_once_then_adopts() -> None: - """-32022 with a `supported` list naming a modern version we speak: re-probe once at - the highest mutual version, then adopt the second response.""" session = _StubSession(_err_32022(list(MODERN_PROTOCOL_VERSIONS)), _discover_dict()) await _negotiate(session) assert session.probed_at == [LATEST_MODERN_VERSION, MODERN_PROTOCOL_VERSIONS[-1]] @@ -149,8 +123,6 @@ async def test_unsupported_version_with_a_mutual_modern_version_retries_once_the async def test_unsupported_version_naming_only_handshake_versions_falls_back_to_initialize() -> None: - """-32022 with `supported` naming only handshake-era versions: the server is reachable - via the legacy handshake, so fall back rather than raise.""" session = _StubSession(_err_32022(list(HANDSHAKE_PROTOCOL_VERSIONS))) await _negotiate(session) assert session.initialized @@ -159,9 +131,7 @@ async def test_unsupported_version_naming_only_handshake_versions_falls_back_to_ async def test_unsupported_version_with_disjoint_modern_only_supported_reraises() -> None: - """-32022 with `supported` naming only modern versions we *don't* speak: this is the - one denylist exception — the server is modern-only and there is no mutual version, so - falling back to `initialize()` would also fail. The original `MCPError` re-raises.""" + """The sole denylist exception: no mutual version exists, so `initialize()` would also fail.""" session = _StubSession(_err_32022(["2099-01-01"])) with pytest.raises(MCPError) as exc_info: await _negotiate(session) @@ -179,8 +149,6 @@ async def test_unsupported_version_with_disjoint_modern_only_supported_reraises( ], ) async def test_unsupported_version_with_unparseable_data_falls_back_to_initialize(data: Any) -> None: - """-32022 with no/malformed `error.data`: nothing actionable, so fall through to the - denylist's `initialize()` fallback rather than guess or raise.""" session = _StubSession(MCPError(code=UNSUPPORTED_PROTOCOL_VERSION, message="bad version", data=data)) await _negotiate(session) assert session.initialized @@ -189,9 +157,7 @@ async def test_unsupported_version_with_unparseable_data_falls_back_to_initializ async def test_a_second_unsupported_version_after_the_corrective_retry_does_not_loop() -> None: - """The corrective -32022 retry happens at most once; a second -32022 naming a - modern-only `supported` list re-raises rather than re-probing forever (the loop - guard makes this the disjoint-modern case on attempt two).""" + """The retry happens at most once; the loop guard makes attempt two the disjoint-modern re-raise case.""" session = _StubSession(_err_32022(list(MODERN_PROTOCOL_VERSIONS)), _err_32022(list(MODERN_PROTOCOL_VERSIONS))) with pytest.raises(MCPError) as exc_info: await _negotiate(session) @@ -201,9 +167,6 @@ async def test_a_second_unsupported_version_after_the_corrective_retry_does_not_ assert session.adopted is None -# --- non-MCP errors propagate --- - - @pytest.mark.parametrize( "exc", [ @@ -212,8 +175,6 @@ async def test_a_second_unsupported_version_after_the_corrective_retry_does_not_ ], ) async def test_a_network_or_resource_error_from_the_probe_propagates_unchanged(exc: Exception) -> None: - """Anything that is not an `MCPError` propagates as-is; an outage or in-process bug - is never an era verdict, and `initialize()` is not called.""" session = _StubSession(exc) with pytest.raises(type(exc)): await _negotiate(session) @@ -221,9 +182,6 @@ async def test_a_network_or_resource_error_from_the_probe_propagates_unchanged(e assert session.adopted is None -# --- helper --- - - @pytest.mark.parametrize( ("data", "expected"), [ @@ -237,6 +195,4 @@ async def test_a_network_or_resource_error_from_the_probe_propagates_unchanged(e def test_parse_supported_returns_none_for_anything_not_shaped_like_the_spec_error_data( data: Any, expected: list[str] | None ) -> None: - """`_parse_supported` returns the `supported` list when `error.data` validates as - `UnsupportedProtocolVersionErrorData`, and `None` otherwise — never raises.""" assert _parse_supported(data) == expected diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index d9bfddaf27..eec3b0d392 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -41,17 +41,14 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool: assert value == callback_return return True - # Test with sampling callback async with Client(server, sampling_callback=sampling_callback, mode="legacy") as client: - # Make a request to trigger sampling callback result = await client.call_tool("test_sampling", {"message": "Test message for sampling"}) assert result.is_error is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" - # Without a sampling callback the client responds with an MCPError, which the - # tool body doesn't catch — the wrapper re-raises it as a top-level JSON-RPC - # error rather than wrapping it as an isError result. + # Without a sampling callback the client responds with an MCPError the tool body + # doesn't catch, so it surfaces as a top-level JSON-RPC error, not an isError result. async with Client(server, mode="legacy") as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("test_sampling", {"message": "Test message for sampling"}) @@ -60,10 +57,8 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool: @pytest.mark.anyio async def test_create_message_backwards_compat_single_content(): - """Test backwards compatibility: create_message without tools returns single content.""" server = MCPServer("test") - # Callback returns single content (text) callback_return = CreateMessageResult( role="assistant", content=TextContent(type="text", text="Hello from LLM"), @@ -79,17 +74,14 @@ async def sampling_callback( @server.tool("test_backwards_compat") async def test_tool(message: str, ctx: Context) -> bool: - # Call create_message WITHOUT tools result = await ctx.session.create_message( # pyright: ignore[reportDeprecated] messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) - # Backwards compat: result should be CreateMessageResult assert isinstance(result, CreateMessageResult) - # Content should be single (not a list) - this is the key backwards compat check assert isinstance(result.content, TextContent) assert result.content.text == "Hello from LLM" - # CreateMessageResult should NOT have content_as_list (that's on WithTools) + # content_as_list exists only on CreateMessageResultWithTools assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None)) return True @@ -102,8 +94,7 @@ async def test_tool(message: str, ctx: Context) -> bool: @pytest.mark.anyio async def test_create_message_result_with_tools_type(): - """Test that CreateMessageResultWithTools supports content_as_list.""" - # Test the type itself, not the overload (overload requires client capability setup) + # Tests the type directly, not the create_message overload (which requires client capability setup) result = CreateMessageResultWithTools( role="assistant", content=ToolUseContent(type="tool_use", id="call_123", name="get_weather", input={"city": "SF"}), @@ -111,12 +102,10 @@ async def test_create_message_result_with_tools_type(): stop_reason="toolUse", ) - # CreateMessageResultWithTools should have content_as_list content_list = result.content_as_list assert len(content_list) == 1 assert content_list[0].type == "tool_use" - # It should also work with array content result_array = CreateMessageResultWithTools( role="assistant", content=[ diff --git a/tests/client/test_scope_bug_1630.py b/tests/client/test_scope_bug_1630.py index 338755dc68..653829561f 100644 --- a/tests/client/test_scope_bug_1630.py +++ b/tests/client/test_scope_bug_1630.py @@ -1,8 +1,4 @@ -"""Regression test for issue #1630: OAuth2 scope incorrectly set to resource_metadata URL. - -This test verifies that when a 401 response contains both resource_metadata and scope -in the WWW-Authenticate header, the actual scope is used (not the resource_metadata URL). -""" +"""Regression test for #1630: OAuth2 scope was incorrectly set to the resource_metadata URL from WWW-Authenticate.""" from unittest import mock @@ -20,8 +16,6 @@ class MockTokenStorage: - """Mock token storage for testing.""" - def __init__(self) -> None: self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None @@ -41,15 +35,6 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None @pytest.mark.anyio async def test_401_uses_www_auth_scope_not_resource_metadata_url(): - """Regression test for #1630: Ensure scope is extracted from WWW-Authenticate header, - not the resource_metadata URL. - - When a 401 response contains: - WWW-Authenticate: Bearer resource_metadata="https://...", scope="read write" - - The client should use "read write" as the scope, NOT the resource_metadata URL. - """ - async def redirect_handler(url: str) -> None: pass # pragma: no cover @@ -82,11 +67,9 @@ async def callback_handler() -> AuthorizationCodeResult: test_request = httpx.Request("GET", "https://api.example.com/mcp") auth_flow = provider.async_auth_flow(test_request) - # First request (no auth header yet) await auth_flow.__anext__() - # 401 response with BOTH resource_metadata URL and scope in WWW-Authenticate - # This is the key: the bug would use the URL as scope instead of "read write" + # WWW-Authenticate carries both resource_metadata and scope; the bug used the URL as the scope resource_metadata_url = "https://api.example.com/.well-known/oauth-protected-resource" expected_scope = "read write" @@ -96,11 +79,10 @@ async def callback_handler() -> AuthorizationCodeResult: request=test_request, ) - # Send 401, expect PRM discovery request prm_request = await auth_flow.asend(response_401) assert ".well-known/oauth-protected-resource" in str(prm_request.url) - # PRM response with scopes_supported (these should be overridden by WWW-Auth scope) + # scopes_supported must lose to the WWW-Authenticate scope prm_response = httpx.Response( 200, content=( @@ -111,11 +93,9 @@ async def callback_handler() -> AuthorizationCodeResult: request=prm_request, ) - # Send PRM response, expect OAuth metadata discovery oauth_metadata_request = await auth_flow.asend(prm_response) assert ".well-known/oauth-authorization-server" in str(oauth_metadata_request.url) - # OAuth metadata response oauth_metadata_response = httpx.Response( 200, content=( @@ -129,23 +109,17 @@ async def callback_handler() -> AuthorizationCodeResult: # Mock authorization to skip interactive flow provider._perform_authorization_code_grant = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) - # Send OAuth metadata response, expect token request token_request = await auth_flow.asend(oauth_metadata_response) assert "token" in str(token_request.url) - # NOW CHECK: The scope should be the WWW-Authenticate scope, NOT the URL - # This is where the bug manifested - scope was set to resource_metadata_url actual_scope = provider.context.client_metadata.scope - # This assertion would FAIL on main (scope would be the URL) - # but PASS on the fix branch (scope is "read write") assert actual_scope == expected_scope, ( f"Expected scope to be '{expected_scope}' from WWW-Authenticate header, " f"but got '{actual_scope}'. " f"If scope is '{resource_metadata_url}', the bug from #1630 is present." ) - # Verify it's definitely not the URL (explicit check for the bug) assert actual_scope != resource_metadata_url, ( f"BUG #1630: Scope was incorrectly set to resource_metadata URL '{resource_metadata_url}' " f"instead of the actual scope '{expected_scope}'" @@ -161,7 +135,6 @@ async def callback_handler() -> AuthorizationCodeResult: final_request = await auth_flow.asend(token_response) assert final_request.headers["Authorization"] == "Bearer test_token" - # Finish the flow final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 83893e36f9..6c03b60f89 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -54,10 +54,9 @@ async def raw_client_session( **kwargs: Any, ) -> AsyncIterator[tuple[ClientSession, _SendToClient, _RecvFromClient]]: - """Yield `(session, send_to_client, recv_from_client)` with the receive loop running. + """Yield `(session, send_to_client, recv_from_client)` with the receive loop running, no handshake. - `send_to_client` accepts `SessionMessage | Exception` so tests can inject - transport-level exceptions. No initialize handshake is performed. + `send_to_client` accepts `SessionMessage | Exception` so tests can inject transport-level exceptions. """ s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage](32) @@ -119,7 +118,6 @@ async def mock_server(): jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) ) - # Create a message handler to catch exceptions async def message_handler( # pragma: no cover message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: @@ -141,14 +139,12 @@ async def message_handler( # pragma: no cover tg.start_soon(mock_server) result = await session.initialize() - # Assert the result assert isinstance(result, InitializeResult) assert result.protocol_version == LATEST_HANDSHAKE_VERSION assert isinstance(result.capabilities, ServerCapabilities) assert result.server_info == Implementation(name="mock-server", version="0.1.0") assert result.instructions == "The server instructions." - # Check that the client sent the initialized notification assert initialized_notification assert isinstance(initialized_notification, InitializedNotification) @@ -207,7 +203,6 @@ async def mock_server(): tg.start_soon(mock_server) await session.initialize() - # Assert that the custom client info was sent assert received_client_info == custom_client_info @@ -260,13 +255,11 @@ async def mock_server(): tg.start_soon(mock_server) await session.initialize() - # Assert that the default client info was sent assert received_client_info == DEFAULT_CLIENT_INFO @pytest.mark.anyio async def test_client_session_version_negotiation_success(): - """Test successful version negotiation with supported version""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) result = None @@ -280,10 +273,8 @@ async def mock_server(): ) assert isinstance(request, InitializeRequest) - # Verify client offers the newest handshake protocol version assert request.params.protocol_version == LATEST_HANDSHAKE_VERSION - # Server responds with a supported older version result = InitializeResult( protocol_version="2024-11-05", capabilities=ServerCapabilities(), @@ -314,7 +305,6 @@ async def mock_server(): tg.start_soon(mock_server) result = await session.initialize() - # Assert the result with negotiated version assert isinstance(result, InitializeResult) assert result.protocol_version == "2024-11-05" assert result.protocol_version in HANDSHAKE_PROTOCOL_VERSIONS @@ -322,7 +312,6 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_version_negotiation_failure(): - """Test version negotiation failure with unsupported version""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -335,7 +324,6 @@ async def mock_server(): ) assert isinstance(request, InitializeRequest) - # Server responds with an unsupported version result = InitializeResult( protocol_version="2020-01-01", # Unsupported old version capabilities=ServerCapabilities(), @@ -363,14 +351,12 @@ async def mock_server(): ): tg.start_soon(mock_server) - # Should raise RuntimeError for unsupported version with pytest.raises(RuntimeError, match="Unsupported protocol version"): await session.initialize() @pytest.mark.anyio async def test_client_capabilities_default(): - """Test that client capabilities are properly set with default callbacks""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -418,7 +404,6 @@ async def mock_server(): tg.start_soon(mock_server) await session.initialize() - # Assert that capabilities are properly set with defaults assert received_capabilities is not None assert received_capabilities.sampling is None # No custom sampling callback assert received_capabilities.roots is None # No custom list_roots callback @@ -426,7 +411,6 @@ async def mock_server(): @pytest.mark.anyio async def test_client_capabilities_with_custom_callbacks(): - """Test that client capabilities are properly set with custom callbacks""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -494,23 +478,17 @@ async def mock_server(): tg.start_soon(mock_server) await session.initialize() - # Assert that capabilities are properly set with custom callbacks assert received_capabilities is not None - # Custom sampling callback provided assert received_capabilities.sampling is not None assert isinstance(received_capabilities.sampling, types.SamplingCapability) - # Default sampling capabilities (no tools) assert received_capabilities.sampling.tools is None - # Custom list_roots callback provided assert received_capabilities.roots is not None assert isinstance(received_capabilities.roots, types.RootsCapability) - # Should be True for custom callback assert received_capabilities.roots.list_changed is True @pytest.mark.anyio async def test_client_capabilities_with_sampling_tools(): - """Test that sampling capabilities with tools are properly advertised""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -573,11 +551,9 @@ async def mock_server(): tg.start_soon(mock_server) await session.initialize() - # Assert that sampling capabilities with tools are properly advertised assert received_capabilities is not None assert received_capabilities.sampling is not None assert isinstance(received_capabilities.sampling, types.SamplingCapability) - # Tools capability should be present assert received_capabilities.sampling.tools is not None assert isinstance(received_capabilities.sampling.tools, types.SamplingToolsCapability) @@ -657,14 +633,12 @@ async def mock_server(): @pytest.mark.anyio @pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None): - """Test that client tool call requests can include metadata""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) mocked_tool = types.Tool(name="sample_tool", input_schema={"type": "object"}) async def mock_server(): - # Receive initialization request from client session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -679,7 +653,6 @@ async def mock_server(): server_info=Implementation(name="mock-server", version="0.1.0"), ) - # Answer initialization request await server_to_client_send.send( SessionMessage( JSONRPCResponse( @@ -693,7 +666,6 @@ async def mock_server(): # Receive initialized notification await client_to_server_receive.receive() - # Wait for the client to send a 'tools/call' request session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -707,7 +679,6 @@ async def mock_server(): result = CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False) - # Send the tools/call result await server_to_client_send.send( SessionMessage( JSONRPCResponse( @@ -718,8 +689,7 @@ async def mock_server(): ) ) - # Wait for the tools/list request from the client - # The client requires this step to validate the tool output schema + # The client follows up with tools/list to validate the tool output schema. session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -757,7 +727,6 @@ async def mock_server(): @pytest.mark.anyio async def test_receive_loop_answers_malformed_inbound_request_with_invalid_params(): - """A request that fails ServerRequest validation gets an INVALID_PARAMS error response.""" async with raw_client_session() as (_session, to_client, from_client): await to_client.send( SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="sampling/createMessage", params={"broken": 1})) @@ -804,8 +773,7 @@ def _set_negotiated_version(session: ClientSession, version: str) -> None: @pytest.mark.anyio async def test_on_request_rejects_a_server_request_absent_at_the_negotiated_version(): - """`elicitation/create` does not exist at 2025-03-26: the version gate answers - METHOD_NOT_FOUND instead of reaching the elicitation callback.""" + """`elicitation/create` does not exist at 2025-03-26, so the version gate answers METHOD_NOT_FOUND.""" async with raw_client_session() as (session, to_client, from_client): _set_negotiated_version(session, "2025-03-26") await to_client.send( @@ -843,9 +811,7 @@ async def sampling( async def test_on_request_callback_returning_a_surface_invalid_result_is_internal_error( caplog: pytest.LogCaptureFixture, ): - """A callback result the surface schema rejects is answered with INTERNAL_ERROR. - `EmptyResult` is a `ClientResult` arm so the union accepts it, but `roots/list` - requires a `roots` array.""" + """`EmptyResult` is a `ClientResult` arm so the union accepts it, but `roots/list` requires a `roots` array.""" async def list_roots(ctx: ClientRequestContext) -> types.ListRootsResult | types.ErrorData: return cast("types.ListRootsResult", types.EmptyResult()) @@ -863,8 +829,7 @@ async def list_roots(ctx: ClientRequestContext) -> types.ListRootsResult | types async def test_on_notify_drops_a_server_notification_absent_at_the_negotiated_version( caplog: pytest.LogCaptureFixture, ): - """`notifications/elicitation/complete` does not exist at 2025-06-18: it is - debug-log-dropped without reaching `message_handler`.""" + """`notifications/elicitation/complete` does not exist at 2025-06-18: dropped before `message_handler`.""" seen: list[object] = [] delivered = anyio.Event() @@ -939,8 +904,7 @@ async def call() -> None: @pytest.mark.anyio async def test_send_request_skips_the_surface_gate_when_method_absent_at_version(): - """Surface row absent for the negotiated version: gate is bypassed and only - `result_type` validates.""" + """Surface row absent at the negotiated version: gate bypassed, only `result_type` validates.""" async with raw_client_session() as (session, to_client, from_client): _set_negotiated_version(session, "2026-07-28") async with anyio.create_task_group() as tg: @@ -1004,9 +968,8 @@ async def handler(msg: object) -> None: async def test_raising_message_handler_on_transport_exception_costs_the_delivery_not_the_connection( caplog: pytest.LogCaptureFixture, ): - """A `message_handler` that raises on a transport-level `Exception` item is contained: the - failure is logged and the receive loop keeps serving (SDK-defined). Raw streams because - only a transport can put an `Exception` item on the read stream.""" + """SDK-defined containment: the failure is logged and the receive loop keeps serving. Raw streams + because only a transport can put an `Exception` item on the read stream.""" seen: list[object] = [] delivered = anyio.Event() @@ -1030,8 +993,7 @@ async def handler(msg: object) -> None: @pytest.mark.anyio async def test_message_handler_awaiting_session_traffic_on_transport_exception_completes(): - """A `message_handler` that awaits session traffic on a transport `Exception` item completes: - fault deliveries are spawned into the task group, not run inline in the read loop (SDK-defined). + """SDK-defined: fault deliveries are spawned into the task group, not run inline in the read loop. Raw streams because only a transport can put an `Exception` item on the read stream.""" ponged = anyio.Event() @@ -1053,13 +1015,9 @@ async def handler(msg: object) -> None: @pytest.mark.anyio async def test_receive_loop_consumes_server_cancelled_without_reaching_message_handler(): - """A server-sent notifications/cancelled is swallowed, matching the pre-swap contract. - - The server dispatcher now emits this on sampling/elicitation timeout, but - ClientSession has no in-flight tracking to act on it, so surfacing it would - only break user handlers that exhaustively match ServerNotification. - Scripted peer: the typed server API cannot emit a bare `notifications/cancelled`. - """ + """The server dispatcher emits this on sampling/elicitation timeout, but ClientSession has no + in-flight tracking to act on it, and surfacing it would break handlers that exhaustively match + ServerNotification. Scripted peer: the typed server API cannot emit it bare.""" seen: list[object] = [] delivered = anyio.Event() @@ -1075,8 +1033,7 @@ async def handler(msg: object) -> None: ) ) ) - # Follow with a notification that does reach the handler so we can - # assert ordering deterministically. + # A follow-up that does reach the handler makes the ordering assertion deterministic. await to_client.send( SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/tools/list_changed")) ) @@ -1087,10 +1044,8 @@ async def handler(msg: object) -> None: @pytest.mark.anyio async def test_request_timeout_zero_overrides_session_timeout(): - """`request_read_timeout_seconds=0` is a real per-request timeout (fail at the - first checkpoint, `anyio.fail_after(0)` semantics), not a fall-through to the - session-level timeout. The request is never answered, so falling back to the - 30s session timeout would trip the harness's 5s guard instead.""" + """Zero means `anyio.fail_after(0)` semantics, not a fall-through to the session timeout — the + request is never answered, so falling back to the 30s session timeout would trip the harness's 5s guard.""" async with raw_client_session(read_timeout_seconds=30) as (session, _to_client, _from_client): with pytest.raises(MCPError) as exc_info: await session.send_request(types.PingRequest(), types.EmptyResult, request_read_timeout_seconds=0.0) @@ -1099,8 +1054,7 @@ async def test_request_timeout_zero_overrides_session_timeout(): @pytest.mark.anyio async def test_progress_notification_reaches_request_callback_and_message_handler(): - """A `notifications/progress` for an in-flight request reaches both the `progress_callback` and - `message_handler` (SDK-defined). Scripted peer: the progress token must echo the wire request id.""" + """Scripted peer: the progress token must echo the wire request id.""" updates: list[tuple[float, float | None, str | None]] = [] teed: list[types.ProgressNotification] = [] request_id: types.RequestId | None = None @@ -1178,9 +1132,7 @@ async def server_on_notify( @pytest.mark.anyio async def test_direct_dispatch_roots_list_reaches_callback_with_synthesized_request_id(): - """A server-initiated roots/list over dispatcher= reaches the registered callback and round-trips - the result; the callback context carries an int request_id (SDK-defined: DirectDispatcher - synthesizes ids).""" + """SDK-defined: DirectDispatcher synthesizes int request ids for server-initiated requests.""" client_side, server_side = create_direct_dispatcher_pair() contexts: list[ClientRequestContext] = [] @@ -1215,10 +1167,8 @@ async def server_on_notify( async def test_raising_notification_callbacks_over_direct_dispatch_cost_only_that_delivery( caplog: pytest.LogCaptureFixture, ): - """A raising `logging_callback` or `message_handler` is contained in the session, so the - in-process peer's notify() returns normally and the session keeps serving requests - (SDK-defined: DirectDispatcher awaits notification handlers inline in the peer's call). - A raising `logging_callback` skips the `message_handler` tee for that notification.""" + """DirectDispatcher awaits notification handlers inline in the peer's notify(), so containment is + what lets notify() return. A raising `logging_callback` skips the `message_handler` tee.""" client_side, server_side = create_direct_dispatcher_pair() teed: list[types.ServerNotification] = [] @@ -1252,7 +1202,6 @@ async def server_on_notify( await server_side.notify("notifications/message", {"level": "info", "data": "hello"}) # message_handler raises: notify() must return. await server_side.notify("notifications/tools/list_changed", None) - # The session still serves requests afterwards. assert await session.send_ping() == types.EmptyResult() server_side.close() assert [type(n) for n in teed] == [types.ToolListChangedNotification] @@ -1263,7 +1212,6 @@ async def server_on_notify( @pytest.mark.anyio async def test_dispatcher_keyword_send_request_before_enter_raises_runtimeerror(): - """The documented pre-enter RuntimeError holds for dispatcher= sessions too.""" client_side, _server_side = create_direct_dispatcher_pair() session = ClientSession(dispatcher=client_side) with anyio.fail_after(5), pytest.raises(RuntimeError) as exc: @@ -1273,7 +1221,6 @@ async def test_dispatcher_keyword_send_request_before_enter_raises_runtimeerror( @pytest.mark.anyio async def test_dispatcher_keyword_send_request_after_exit_raises_connection_closed(): - """After __aexit__ a dispatcher= session raises MCPError(CONNECTION_CLOSED), matching the JSONRPC path.""" client_side, server_side = create_direct_dispatcher_pair() async def server_on_request( @@ -1301,7 +1248,6 @@ async def server_on_notify( @pytest.mark.anyio async def test_dispatcher_keyword_request_timeout_bounds_wait_for_never_run_peer(): - """request_read_timeout_seconds fires even when the peer dispatcher never started running.""" client_side, _server_side = create_direct_dispatcher_pair() session = ClientSession(dispatcher=client_side) with anyio.fail_after(5): @@ -1312,8 +1258,7 @@ async def test_dispatcher_keyword_request_timeout_bounds_wait_for_never_run_peer def test_adopt_raises_when_no_mutual_modern_version_is_supported() -> None: - """SDK-defined: ``adopt(DiscoverResult)`` picks the newest version both sides support; an - empty intersection is unrecoverable and raises rather than installing a stamp.""" + """SDK-defined: `adopt` picks the newest mutual version; an empty intersection raises, no stamp installed.""" client_d, _ = create_direct_dispatcher_pair() session = ClientSession(dispatcher=client_d) with pytest.raises(RuntimeError, match="No mutually supported modern protocol version"): @@ -1336,8 +1281,6 @@ async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_lea cancelling it — and leaves the option unset for every other method.""" class RecordingDispatcher: - """Records `send_raw_request` opts and answers with canned results.""" - def __init__(self) -> None: self.calls: list[tuple[str, CallOptions]] = [] @@ -1397,8 +1340,7 @@ def test_constructor_requires_both_streams_without_dispatcher(): @pytest.mark.anyio async def test_aenter_cancelled_while_dispatcher_starts_unwinds_cleanly(): - """Cancellation while `__aenter__` waits for the dispatcher to start unwinds the half-entered - task group cleanly, not via anyio's "exited non-innermost cancel scope" RuntimeError (SDK-defined).""" + """Must not die via anyio's "exited non-innermost cancel scope" RuntimeError on the half-entered task group.""" class NeverStartsDispatcher: """`run()` parks without ever signalling `task_status.started()`.""" @@ -1448,13 +1390,8 @@ async def test_send_notification_after_close_is_dropped_silently(): s.close() -# --- discover() ladder --- - - class _ScriptedDispatcher: - """Records every `send_raw_request` and plays back scripted answers in order. - - A script entry that is an `Exception` is raised; a dict is returned.""" + """Plays back scripted answers in order: an `Exception` entry is raised, a dict entry is returned.""" def __init__(self, *script: dict[str, Any] | Exception) -> None: self.calls: list[tuple[str, Mapping[str, Any] | None]] = [] @@ -1494,8 +1431,6 @@ def _discover_result_dict() -> dict[str, Any]: @pytest.mark.anyio async def test_initialize_is_idempotent_and_returns_the_cached_result() -> None: - """A second `initialize()` returns the first call's result by identity and sends nothing - over the wire — the early-return guard short-circuits before the dispatcher is touched.""" init_result = InitializeResult( protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), @@ -1513,8 +1448,7 @@ async def test_initialize_is_idempotent_and_returns_the_cached_result() -> None: @pytest.mark.anyio async def test_discover_adopts_the_returned_result_and_installs_the_modern_stamp() -> None: - """SDK-defined: a successful `server/discover` is adopted and subsequent requests - carry the modern `_meta` envelope (protocol version + client info + capabilities).""" + """The "modern stamp" is the `_meta` envelope (protocol version + client info + capabilities).""" dispatcher = _ScriptedDispatcher(_discover_result_dict(), {}) with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -1549,9 +1483,8 @@ async def test_discover_retries_once_on_unsupported_version_then_adopts() -> Non @pytest.mark.anyio async def test_discover_raises_when_retry_intersection_is_empty() -> None: - """Spec SHOULD: a -32022 reply whose `supported` list shares nothing with the - client's modern versions is unrecoverable — the original error is re-raised - without a second probe.""" + """Spec SHOULD: a -32022 reply whose `supported` list shares nothing with the client's versions + re-raises the original error without a second probe.""" dispatcher = _ScriptedDispatcher( MCPError( UNSUPPORTED_PROTOCOL_VERSION, @@ -1570,9 +1503,8 @@ async def test_discover_raises_when_retry_intersection_is_empty() -> None: @pytest.mark.anyio @pytest.mark.parametrize("code", [METHOD_NOT_FOUND, REQUEST_TIMEOUT, INTERNAL_ERROR]) async def test_discover_reraises_non_retry_errors_without_falling_back(code: int) -> None: - """SDK-defined: any error outside the -32022 retry rung propagates verbatim - — `discover()` does not fall back to `initialize()` itself; that is the - caller's policy (`Client.__aenter__`).""" + """SDK-defined: `discover()` does not fall back to `initialize()` itself — that is the caller's + policy (`Client.__aenter__`).""" dispatcher = _ScriptedDispatcher(MCPError(code, "nope")) with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -1586,9 +1518,6 @@ async def test_discover_reraises_non_retry_errors_without_falling_back(code: int @pytest.mark.anyio async def test_discover_validates_the_response_shape_before_adopting() -> None: - """SDK-defined: the raw response is run through `DiscoverResult` validation - before any state is installed, so a malformed reply leaves the session - un-adopted rather than half-configured.""" dispatcher = _ScriptedDispatcher({"supportedVersions": ["2026-07-28"]}) session = ClientSession(dispatcher=dispatcher) with anyio.fail_after(5): @@ -1600,9 +1529,7 @@ async def test_discover_validates_the_response_shape_before_adopting() -> None: @pytest.mark.anyio async def test_discover_is_idempotent_and_returns_the_cached_result() -> None: - """SDK-defined: a second `discover()` returns the already-adopted result without - re-probing — the script holds exactly one entry, so a second wire call would - `IndexError` on the empty script.""" + """The script holds exactly one entry, so a second wire call would `IndexError`.""" dispatcher = _ScriptedDispatcher(_discover_result_dict()) with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -1614,7 +1541,6 @@ async def test_discover_is_idempotent_and_returns_the_cached_result() -> None: def test_era_neutral_properties_are_none_before_any_handshake() -> None: - """SDK-defined: the era-neutral accessors all read as None on a fresh session.""" client_d, _ = create_direct_dispatcher_pair() session = ClientSession(dispatcher=client_d) assert session.protocol_version is None @@ -1627,8 +1553,6 @@ def test_era_neutral_properties_are_none_before_any_handshake() -> None: @pytest.mark.anyio async def test_era_neutral_properties_after_discover() -> None: - """SDK-defined: after `discover()` the era-neutral accessors read from the - DiscoverResult; `initialize_result` stays None.""" raw = types.DiscoverResult( supported_versions=["2026-07-28"], capabilities=ServerCapabilities(tools=types.ToolsCapability(list_changed=True)), @@ -1649,9 +1573,6 @@ async def test_era_neutral_properties_after_discover() -> None: @pytest.mark.anyio async def test_discover_reraises_unsupported_version_with_malformed_error_data() -> None: - """SDK-defined: a -32022 reply whose `data` is not a valid - `UnsupportedProtocolVersionErrorData` payload is unrecoverable — the original - error is re-raised without a retry probe.""" dispatcher = _ScriptedDispatcher(MCPError(UNSUPPORTED_PROTOCOL_VERSION, "unsupported", data="not-an-object")) with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -1663,9 +1584,6 @@ async def test_discover_reraises_unsupported_version_with_malformed_error_data() @pytest.mark.anyio async def test_session_call_tool_returns_input_required_result_when_opted_in() -> None: - """`ClientSession.call_tool(..., allow_input_required=True)` surfaces the - raw `InputRequiredResult` so the caller can drive the loop manually.""" - # `on_call_tool` is still typed `-> CallToolResult` on this branch (#2967 widens it later); # `add_request_handler` is `HandlerResult`-typed and accepts `InputRequiredResult` cleanly. async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult: @@ -1707,9 +1625,7 @@ async def on_list_tools( @pytest.mark.anyio async def test_session_call_tool_raises_on_input_required_without_opt_in() -> None: - """SDK-defined: `ClientSession.call_tool` is mechanics-only; an - `InputRequiredResult` with the default `allow_input_required=False` raises - `RuntimeError` (the auto-loop policy lives on `Client`, not here).""" + """`ClientSession.call_tool` is mechanics-only: the auto-loop policy lives on `Client`, not here.""" async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult: return types.InputRequiredResult(request_state="s") @@ -1726,9 +1642,6 @@ async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams @pytest.mark.anyio async def test_session_get_prompt_returns_input_required_result_when_opted_in() -> None: - """`ClientSession.get_prompt` mirrors `call_tool`: opting in returns the - raw `InputRequiredResult`; the default raises `RuntimeError`.""" - async def handler(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.InputRequiredResult: return types.InputRequiredResult(request_state="prompt-state") @@ -1745,9 +1658,6 @@ async def handler(ctx: ServerRequestContext, params: types.GetPromptRequestParam @pytest.mark.anyio async def test_session_read_resource_returns_input_required_result_when_opted_in() -> None: - """`ClientSession.read_resource` mirrors `call_tool`: opting in returns the - raw `InputRequiredResult`; the default raises `RuntimeError`.""" - async def handler(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> types.InputRequiredResult: return types.InputRequiredResult(request_state="resource-state") diff --git a/tests/client/test_session_concurrency.py b/tests/client/test_session_concurrency.py index 0a0ae62dde..123ceb5e89 100644 --- a/tests/client/test_session_concurrency.py +++ b/tests/client/test_session_concurrency.py @@ -19,14 +19,8 @@ async def test_concurrent_tool_calls_resolve_out_of_order_to_their_own_callers() -> None: - """Three tool calls in flight at once on one session each receive their own result, even though - the responses come back in the reverse of the order the requests were sent. - - SDK-defined contract: pins the client request machinery's support for concurrent in-flight - calls with out-of-order response correlation. Each handler parks on its own release event - after signalling it started; a session that serialized requests would never start the later - handlers and the test would time out instead. - """ + """Pins concurrent in-flight calls with out-of-order response correlation; a serializing session + would never start the later handlers, so the test would time out instead.""" send_order = ["a", "b", "c"] started = {tag: anyio.Event() for tag in send_order} release = {tag: anyio.Event() for tag in send_order} @@ -51,8 +45,7 @@ async def call_and_record(tag: str) -> None: with anyio.fail_after(5): async with anyio.create_task_group() as task_group: # pragma: no branch - # Waiting for each handler to start before issuing the next call fixes the send - # order, and leaves all three parked in flight together once the loop finishes. + # Awaiting each handler's start fixes the send order and leaves all three parked in flight. for tag in send_order: task_group.start_soon(call_and_record, tag) await started[tag].wait() @@ -60,7 +53,7 @@ async def call_and_record(tag: str) -> None: # Nothing completed yet: all three calls are genuinely concurrent. assert completion_order == [] - # Release in reverse, awaiting each completion so the finish order is forced. + # Awaiting each completion forces the finish order. for tag in reversed(send_order): release[tag].set() await done[tag].wait() @@ -76,14 +69,9 @@ async def call_and_record(tag: str) -> None: async def test_overlapping_sampling_requests_are_serviced_concurrently_by_the_client() -> None: - """A server tool that fans out two sampling requests at once gets both echoes back: the client - runs overlapping inbound `create_message` requests concurrently instead of serializing them in - its receive loop. - - Regression pin for https://github.com/modelcontextprotocol/python-sdk/issues/2489 -- v1's - `BaseSession` awaited each inbound request handler inline, so the second sampling callback - could not start until the first returned; here both rendezvous before either is released. - """ + """Regression pin for https://github.com/modelcontextprotocol/python-sdk/issues/2489: v1's `BaseSession` + awaited each inbound request handler inline, so the second sampling callback could not + start until the first returned.""" sampling_started = {"x": anyio.Event(), "y": anyio.Event()} sampling_release = anyio.Event() tool_results: list[CallToolResult] = [] @@ -130,8 +118,7 @@ async def invoke_fan_out() -> None: task_group.start_soon(invoke_fan_out) - # Both sampling callbacks are mid-flight before either may answer -- a client that - # serialized inbound requests would never start the second one. + # Both callbacks are mid-flight before either answers; a serializing client never starts the second. await sampling_started["x"].wait() await sampling_started["y"].wait() sampling_release.set() diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index dae0766168..469061f1ba 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -18,10 +18,6 @@ @pytest.fixture def mock_exit_stack(): - """Fixture for a mocked AsyncExitStack.""" - # Use unittest.mock.Mock directly if needed, or just a plain object - # if only attribute access/existence is needed. - # For AsyncExitStack, Mock or MagicMock is usually fine. return mock.MagicMock(spec=contextlib.AsyncExitStack) @@ -34,18 +30,15 @@ def test_client_session_group_init(): def test_client_session_group_component_properties(): - # --- Mock Dependencies --- mock_prompt = mock.Mock() mock_resource = mock.Mock() mock_tool = mock.Mock() - # --- Prepare Session Group --- mcp_session_group = ClientSessionGroup() mcp_session_group._prompts = {"my_prompt": mock_prompt} mcp_session_group._resources = {"my_resource": mock_resource} mcp_session_group._tools = {"my_tool": mock_tool} - # --- Assertions --- assert mcp_session_group.prompts == {"my_prompt": mock_prompt} assert mcp_session_group.resources == {"my_resource": mock_resource} assert mcp_session_group.tools == {"my_tool": mock_tool} @@ -53,10 +46,8 @@ def test_client_session_group_component_properties(): @pytest.mark.anyio async def test_client_session_group_call_tool(): - # --- Mock Dependencies --- mock_session = mock.AsyncMock() - # --- Prepare Session Group --- def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cover return f"{(server_info.name)}-{name}" @@ -66,7 +57,6 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov text_content = types.TextContent(type="text", text="OK") mock_session.call_tool.return_value = types.CallToolResult(content=[text_content]) - # --- Test Execution --- result = await mcp_session_group.call_tool( name="server1-my_tool", arguments={ @@ -75,7 +65,6 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov }, ) - # --- Assertions --- assert result.content == [text_content] mock_session.call_tool.assert_called_once_with( "my_tool", @@ -105,8 +94,6 @@ async def test_client_session_group_call_tool_forwards_allow_input_required(): @pytest.mark.anyio async def test_client_session_group_connect_to_server(mock_exit_stack: contextlib.AsyncExitStack): - """Test connecting to a server and aggregating components.""" - # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) mock_server_info.name = "TestServer1" mock_session = mock.AsyncMock(spec=mcp.ClientSession) @@ -120,12 +107,10 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1]) mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) - # --- Test Execution --- group = ClientSessionGroup(exit_stack=mock_exit_stack) with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): await group.connect_to_server(StdioServerParameters(command="test")) - # --- Assertions --- assert mock_session in group._sessions assert len(group.tools) == 1 assert "tool_a" in group.tools @@ -144,8 +129,6 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): - """Test connecting with a component name hook.""" - # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) mock_server_info.name = "HookServer" mock_session = mock.AsyncMock(spec=mcp.ClientSession) @@ -155,16 +138,13 @@ async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_s mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) - # --- Test Setup --- def name_hook(name: str, server_info: types.Implementation) -> str: return f"{server_info.name}.{name}" - # --- Test Execution --- group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook) with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): await group.connect_to_server(StdioServerParameters(command="test")) - # --- Assertions --- assert mock_session in group._sessions assert len(group.tools) == 1 expected_tool_name = "HookServer.base_tool" @@ -175,12 +155,9 @@ def name_hook(name: str, server_info: types.Implementation) -> str: @pytest.mark.anyio async def test_client_session_group_disconnect_from_server(): - """Test disconnecting from a server.""" - # --- Test Setup --- group = ClientSessionGroup() server_name = "ServerToDisconnect" - # Manually populate state using standard mocks mock_session1 = mock.MagicMock(spec=mcp.ClientSession) mock_session2 = mock.MagicMock(spec=mcp.ClientSession) mock_tool1 = mock.Mock(spec=types.Tool) @@ -220,17 +197,14 @@ async def test_client_session_group_disconnect_from_server(): ) } - # --- Assertions --- assert mock_session in group._sessions assert "tool1" in group._tools assert "tool2" in group._tools assert "res1" in group._resources assert "prm1" in group._prompts - # --- Test Execution --- await group.disconnect_from_server(mock_session) - # --- Assertions --- assert mock_session not in group._sessions assert "tool1" not in group._tools assert "tool2" not in group._tools @@ -242,32 +216,24 @@ async def test_client_session_group_disconnect_from_server(): async def test_client_session_group_connect_to_server_duplicate_tool_raises_error( mock_exit_stack: contextlib.AsyncExitStack, ): - """Test MCPError raised when connecting a server with a dup name.""" - # --- Setup Pre-existing State --- group = ClientSessionGroup(exit_stack=mock_exit_stack) existing_tool_name = "shared_tool" - # Manually add a tool to simulate a previous connection group._tools[existing_tool_name] = mock.Mock(spec=types.Tool) group._tools[existing_tool_name].name = existing_tool_name - # Need a dummy session associated with the existing tool mock_session = mock.MagicMock(spec=mcp.ClientSession) group._tool_to_session[existing_tool_name] = mock_session group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack) - # --- Mock New Connection Attempt --- mock_server_info_new = mock.Mock(spec=types.Implementation) mock_server_info_new.name = "ServerWithDuplicate" mock_session_new = mock.AsyncMock(spec=mcp.ClientSession) - # Configure the new session to return a tool with the *same name* duplicate_tool = mock.Mock(spec=types.Tool) duplicate_tool.name = existing_tool_name mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool]) - # Keep other lists empty for simplicity mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) - # --- Test Execution and Assertion --- with pytest.raises(MCPError) as excinfo: with mock.patch.object( group, @@ -276,19 +242,17 @@ async def test_client_session_group_connect_to_server_duplicate_tool_raises_erro ): await group.connect_to_server(StdioServerParameters(command="test")) - # Assert details about the raised error assert excinfo.value.error.code == types.INVALID_PARAMS assert existing_tool_name in excinfo.value.error.message assert "already exist " in excinfo.value.error.message - # Verify the duplicate tool was *not* added again (state should be unchanged) - assert len(group._tools) == 1 # Should still only have the original - assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock + # Failed connect must leave the pre-existing tool state untouched + assert len(group._tools) == 1 + assert group._tools[existing_tool_name] is not duplicate_tool @pytest.mark.anyio async def test_client_session_group_disconnect_non_existent_server(): - """Test disconnecting a server that isn't connected.""" session = mock.Mock(spec=mcp.ClientSession) group = ClientSessionGroup() with pytest.raises(MCPError): @@ -309,17 +273,17 @@ async def test_client_session_group_disconnect_non_existent_server(): SseServerParameters(url="http://test.com/sse", timeout=10.0), "sse", "mcp.client.session_group.sse_client", - ), # url, headers, timeout, sse_read_timeout + ), ( StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False), "streamablehttp", "mcp.client.session_group.streamable_http_client", - ), # url, headers, timeout, sse_read_timeout, terminate_on_close + ), ], ) async def test_client_session_group_establish_session_parameterized( server_params_instance: StdioServerParameters | SseServerParameters | StreamableHttpParameters, - client_type_name: str, # Just for clarity or conditional logic if needed + client_type_name: str, patch_target_for_client_func: str, ): with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: @@ -328,14 +292,11 @@ async def test_client_session_group_establish_session_parameterized( mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") - # All client context managers return (read_stream, write_stream) mock_client_cm_instance.__aenter__.return_value = (mock_read_stream, mock_write_stream) mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None) mock_specific_client_func.return_value = mock_client_cm_instance - # --- Mock mcp.ClientSession (class) --- - # mock_ClientSession_class is already provided by the outer patch mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM") mock_ClientSession_class.return_value = mock_raw_session_cm @@ -343,12 +304,10 @@ async def test_client_session_group_establish_session_parameterized( mock_raw_session_cm.__aenter__.return_value = mock_entered_session mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None) - # Mock session.initialize() mock_initialize_result = mock.AsyncMock(name="InitializeResult") mock_initialize_result.server_info = types.Implementation(name="foo", version="1") mock_entered_session.initialize.return_value = mock_initialize_result - # --- Test Execution --- group = ClientSessionGroup() returned_server_info = None returned_session = None @@ -360,8 +319,6 @@ async def test_client_session_group_establish_session_parameterized( returned_session, ) = await group._establish_session(server_params_instance, ClientSessionParameters()) - # --- Assertions --- - # 1. Assert the correct specific client function was called if client_type_name == "stdio": assert isinstance(server_params_instance, StdioServerParameters) mock_specific_client_func.assert_called_once_with(server_params_instance) @@ -375,8 +332,7 @@ async def test_client_session_group_establish_session_parameterized( ) elif client_type_name == "streamablehttp": # pragma: no branch assert isinstance(server_params_instance, StreamableHttpParameters) - # Verify streamable_http_client was called with url, httpx_client, and terminate_on_close - # The http_client is created by the real create_mcp_http_client + # http_client is built internally by the real create_mcp_http_client, so only its type is checked call_args = mock_specific_client_func.call_args assert call_args.kwargs["url"] == server_params_instance.url assert call_args.kwargs["terminate_on_close"] == server_params_instance.terminate_on_close @@ -384,7 +340,6 @@ async def test_client_session_group_establish_session_parameterized( mock_client_cm_instance.__aenter__.assert_awaited_once() - # 2. Assert ClientSession was called correctly mock_ClientSession_class.assert_called_once_with( mock_read_stream, mock_write_stream, @@ -399,6 +354,5 @@ async def test_client_session_group_establish_session_parameterized( mock_raw_session_cm.__aenter__.assert_awaited_once() mock_entered_session.initialize.assert_awaited_once() - # 3. Assert returned values assert returned_server_info is mock_initialize_result.server_info assert returned_session is mock_entered_session diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 0b0695378b..d28868d9ee 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,10 +1,9 @@ """Tests for the stdio client transport. -Transport logic (framing, parse errors, shutdown escalation decisions) is tested in -process against a fake process injected through the spawn seam; only real OS behaviour -(process-group kill semantics, SIGKILL after an ignored SIGTERM, exec failure) uses -real subprocesses, synchronized only by kernel-level liveness sockets. The full -client<->server round trip is pinned by tests/interaction/transports/test_stdio.py. +Transport logic is tested in process against a fake process injected through the spawn +seam; only real OS behaviour (process-group kills, SIGKILL delivery, exec failure) uses +real subprocesses, synchronized by kernel-level liveness sockets. The full round trip is +pinned by tests/interaction/transports/test_stdio.py. """ import errno @@ -44,14 +43,6 @@ from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage -# --------------------------------------------------------------------------- -# In-process fake of the spawned server process -# --------------------------------------------------------------------------- -# -# Everything between the spawn and the OS kill is pure SDK logic, so it is tested -# against this fake by monkeypatching the spawn and terminate seams. The OS half -# is tested separately below with real processes. - class _FakeStdin: """The fake process's stdin: records what the client writes, signals closure.""" @@ -61,8 +52,7 @@ def __init__(self, process: "FakeProcess") -> None: async def send(self, data: bytes) -> None: if self._process.stdin_send_gate is not None: - # A full pipe whose reader is busy elsewhere: the write completes - # only once the test's gate opens. + # A full pipe whose reader is busy elsewhere: completes once the gate opens. await self._process.stdin_send_gate.wait() if self._process.stdin_send_blocks: # A pipe whose reader stopped reading: the write never completes. @@ -83,10 +73,8 @@ async def aclose(self) -> None: class _FakeStdout: - """The fake process's stdout: delegates to the in-memory stream. - - Optionally surfaces the abrupt-death or close-time errors a real pipe can. - """ + """The fake process's stdout: delegates to the in-memory stream, optionally + surfacing the abrupt-death or close-time errors a real pipe can.""" def __init__( self, @@ -185,9 +173,8 @@ def install_fake_process( ) -> list[FakeProcess]: """Route stdio_client's spawn and terminate seams to `process`. - Returns the list of processes the (fake) tree termination was invoked on. - `grace_period=None` keeps the production stdin-close grace (affordable only on a - virtual clock). + Returns the processes the fake tree termination was invoked on. `grace_period=None` + keeps the production stdin-close grace (affordable only on a virtual clock). """ terminated: list[FakeProcess] = [] @@ -227,11 +214,6 @@ async def _next_message(read_stream: ReadStream[SessionMessage | Exception]) -> @pytest.mark.anyio async def test_messages_split_and_packed_across_chunks_are_reframed(monkeypatch: pytest.MonkeyPatch) -> None: - """Framing survives arbitrary chunk boundaries. - - Split, packed, and CRLF-terminated messages are each delivered exactly once, and a - trailing line without a newline is not delivered. - """ ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) ping2 = JSONRPCRequest(jsonrpc="2.0", id=2, method="ping") @@ -253,9 +235,8 @@ async def test_messages_split_and_packed_across_chunks_are_reframed(monkeypatch: assert await _next_message(read_stream) == pong assert await _next_message(read_stream) == ping2 - # The partial trailing message is dropped at EOF, not delivered broken. - # (no branch: coverage mis-traces the exit arc of a `with` whose body - # raises inside a nested async context.) + # The partial trailing message is dropped at EOF, not delivered broken. (no + # branch: coverage mis-traces a `with` whose body raises in a nested async context.) with pytest.raises(anyio.EndOfStream): # pragma: no branch process.close_stdout() await read_stream.receive() @@ -263,11 +244,6 @@ async def test_messages_split_and_packed_across_chunks_are_reframed(monkeypatch: @pytest.mark.anyio async def test_each_outgoing_message_is_written_as_exactly_one_line(monkeypatch: pytest.MonkeyPatch) -> None: - """Client -> server framing writes one line per message. - - Every sent message reaches the server's stdin as exactly one newline-terminated - JSON document. - """ ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) process = FakeProcess(on_stdin_close=lambda: process.exit(0)) @@ -288,10 +264,6 @@ async def test_each_outgoing_message_is_written_as_exactly_one_line(monkeypatch: async def test_invalid_json_from_the_server_surfaces_as_an_in_stream_exception( monkeypatch: pytest.MonkeyPatch, ) -> None: - """A line failing JSON-RPC validation is delivered as an Exception on the read stream. - - The messages after it still come through. - """ ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") process = FakeProcess(on_stdin_close=lambda: process.exit(0)) @@ -302,7 +274,6 @@ async def test_invalid_json_from_the_server_surfaces_as_an_in_stream_exception( await process.feed(b"not json\n" + _line(ping)) error = await read_stream.receive() - # The transport surfaces parse failures as the underlying validation error. assert isinstance(error, ValueError) assert await _next_message(read_stream) == ping @@ -311,10 +282,6 @@ async def test_invalid_json_from_the_server_surfaces_as_an_in_stream_exception( async def test_a_server_that_dies_before_responding_fails_initialize_with_connection_closed( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Server death (stdout EOF) is reported to the session as a closed connection. - - The in-flight initialize fails instead of hanging. - """ process = FakeProcess(on_stdin_close=lambda: process.exit(0)) process.exit(1) @@ -334,12 +301,7 @@ async def test_a_server_that_dies_before_responding_fails_initialize_with_connec @pytest.mark.anyio async def test_a_server_that_exits_on_stdin_close_is_never_terminated(monkeypatch: pytest.MonkeyPatch) -> None: - """Closing stdin (shutdown's first step) suffices for a well-behaved server. - - The escalation is never invoked. The fake's stdin also raises on close, which the - shutdown must tolerate. - """ - + """Shutdown tolerates stdin raising on close; the escalation is never invoked.""" process = FakeProcess( on_stdin_close=lambda: process.exit(0), stdin_aclose_error=anyio.ClosedResourceError(), @@ -355,25 +317,19 @@ async def test_a_server_that_exits_on_stdin_close_is_never_terminated(monkeypatc def test_escalation_fires_once_and_only_after_the_grace_period(monkeypatch: pytest.MonkeyPatch) -> None: - """A server that ignores stdin closure is terminated at the grace deadline exactly. - - The kill lands no earlier than the production `PROCESS_TERMINATION_TIMEOUT` on the - runtime clock, and by the first `returncode` poll after it. + """The kill lands no earlier than `PROCESS_TERMINATION_TIMEOUT` on the runtime clock, + and by the first `returncode` poll after it. The suite's only direct trio use: anyio's pytest plugin cannot hand the backend a - clock, so the test calls `trio.run` itself with an autojumping `MockClock`. Every - time primitive rides that one virtual clock, so the production grace elapses - instantly and the bound can be two-sided (a wall-clock upper bound flakes under - load). That virtual seconds match wall seconds is the runtime clock's contract, - deliberately not re-tested here. + clock, so the test calls `trio.run` itself with an autojumping `MockClock`. Every time + primitive rides that one virtual clock, so the production grace elapses instantly and + the bound can be two-sided (a wall-clock upper bound flakes under load); that virtual + seconds match wall seconds is the runtime clock's contract, not re-tested here. """ class ClockedFakeProcess(FakeProcess): - """Records the virtual time of each death. - - Only the (fake) tree termination calls `exit` here, so these are the - escalation timestamps. - """ + """Records the virtual time of each death; only the fake tree termination calls + `exit`, so these are the escalation timestamps.""" def __init__(self) -> None: super().__init__() @@ -406,15 +362,11 @@ async def run_client() -> float: def test_a_server_dying_in_the_final_poll_interval_is_not_escalated(monkeypatch: pytest.MonkeyPatch) -> None: - """A server exiting in the poll interval the grace deadline cuts short is not escalated. - - Such a server is dead, not hung: the timed-out grace wait must re-check `returncode` - before deciding to escalate, so this server is never terminated. - - Runs on trio's MockClock (see the escalation-bound test above). The grace is - set to end mid-interval (0.105 with 0.01 polls) and the fake dies at 0.102 - after its stdin closes, strictly between the last in-window poll (0.10) and - the deadline (0.105), so no two timers collide. + """Dead, not hung: the timed-out grace wait must re-check `returncode` before + deciding to escalate. Runs on trio's MockClock (see the escalation-bound test + above); the grace ends mid-interval (0.105 with 0.01 polls) and the fake dies at + 0.102, strictly between the last in-window poll and the deadline, so no two timers + collide. """ process = FakeProcess() terminated = install_fake_process(monkeypatch, process, grace_period=0.105) @@ -429,8 +381,8 @@ async def die_late() -> None: # The grace wait starts when stdin closes; anchor the death there. process.on_stdin_close = lambda: tg.start_soon(die_late) - # no branch: the tracer drops this nested async-with's arcs under - # trio's MockClock even though the body runs. + # no branch: the tracer drops this nested async-with's arcs under MockClock + # though the body runs. async with stdio_client(FAKE_PARAMS): # pragma: no branch pass @@ -442,18 +394,14 @@ async def die_late() -> None: @pytest.mark.anyio async def test_cancelling_the_client_still_runs_the_full_shutdown(monkeypatch: pytest.MonkeyPatch) -> None: - """Cancellation (a client timeout, app shutdown) must not skip the shutdown sequence. - - Stdin is still closed and a server ignoring it is still terminated. Without the - shielded shutdown this leaks the process and can deadlock. - """ + """Without the shielded shutdown, cancellation (a client timeout, app shutdown) + skips cleanup, leaking the process and potentially deadlocking.""" process = FakeProcess() terminated = install_fake_process(monkeypatch, process, grace_period=0.05) entered = anyio.Event() # Cancel a scope owned by the client's task, not the test's task group: a host - # self-cancel is delivered by throwing through this test function's suspended - # frames, and Python 3.11's tracer loses coverage events after such a throw() - # traversal (python/cpython#106749). + # self-cancel throws through this test's suspended frames, and Python 3.11's tracer + # loses coverage events after such a throw() traversal (python/cpython#106749). cancel_scope = anyio.CancelScope() async def run_client_until_cancelled() -> None: @@ -474,11 +422,6 @@ async def run_client_until_cancelled() -> None: @pytest.mark.anyio async def test_writing_after_the_server_dies_reports_clean_closure(monkeypatch: pytest.MonkeyPatch) -> None: - """A send racing the server's death must not surface a raw backend exception. - - The exception (ConnectionResetError in an exception group) must not escape the - context manager; the transport still shuts down cleanly. - """ ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") process = FakeProcess(on_stdin_close=lambda: process.exit(0)) @@ -495,11 +438,6 @@ async def test_writing_after_the_server_dies_reports_clean_closure(monkeypatch: @pytest.mark.anyio async def test_exiting_with_an_unconsumed_server_message_does_not_raise(monkeypatch: pytest.MonkeyPatch) -> None: - """Exiting while a server message is still undelivered must be a clean exit. - - Shutdown closes the read stream under the blocked reader task, and that closure - must not escape the caller as a BrokenResourceError in an exception group. - """ ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") process = FakeProcess(on_stdin_close=lambda: process.exit(0)) @@ -507,21 +445,18 @@ async def test_exiting_with_an_unconsumed_server_message_does_not_raise(monkeypa with anyio.fail_after(5): async with stdio_client(FAKE_PARAMS): - # Feed a message and never receive it: the reader parses it and blocks - # delivering into the zero-buffer read stream until shutdown breaks the send. + # Feed a message and never receive it: the reader parks sending into the + # zero-buffer read stream; shutdown's closing of that stream must not escape + # as a BrokenResourceError in an exception group. await process.feed(_line(ping)) - # Wait until the reader task is genuinely parked on its blocked send - # before shutdown closes the stream out from under it. + # Ensure the reader is genuinely parked before shutdown closes the stream under it. await anyio.wait_all_tasks_blocked() @pytest.mark.anyio async def test_spawn_failure_propagates_the_error_and_leaks_no_streams(monkeypatch: pytest.MonkeyPatch) -> None: - """When the spawn itself fails, the OSError reaches the caller and no streams leak. - - The transport's internal streams are all closed; an unclosed stream would fail the - test through its GC-time ResourceWarning under filterwarnings=error. - """ + """An unclosed internal stream would fail the test through its GC-time + ResourceWarning under filterwarnings=error.""" async def failing_spawn( command: str, @@ -547,7 +482,6 @@ async def failing_spawn( @pytest.mark.anyio async def test_a_command_that_cannot_be_execed_raises_enoent() -> None: - """A command that cannot be exec'd raises OSError(ENOENT) out of stdio_client.""" server_params = StdioServerParameters( command="/path/to/nonexistent/command", args=["--help"], @@ -562,11 +496,8 @@ async def test_a_command_that_cannot_be_execed_raises_enoent() -> None: @pytest.mark.anyio async def test_cancellation_during_spawn_leaks_no_streams(monkeypatch: pytest.MonkeyPatch) -> None: - """Cancellation while the spawn is still in flight must not leak the internal streams. - - A caller timeout can fire mid-spawn (interpreter cold start); an unclosed stream - would fail the test through its GC-time ResourceWarning under filterwarnings=error. - """ + """A caller timeout can fire mid-spawn (interpreter cold start); the internal streams + must still all be closed (GC-time ResourceWarnings catch leaks).""" spawn_started = anyio.Event() async def hanging_spawn( @@ -582,10 +513,8 @@ async def hanging_spawn( monkeypatch.setattr(stdio, "_create_platform_compatible_process", hanging_spawn) - # Cancel a scope owned by the client's task, not the test's task group: a host - # self-cancel is delivered by throwing through this test function's suspended - # frames, and Python 3.11's tracer loses coverage events after such a throw() - # traversal (python/cpython#106749). + # Cancel a scope owned by the client's task, not the test's task group (see + # test_cancelling_the_client_still_runs_the_full_shutdown). cancel_scope = anyio.CancelScope() async def run_client() -> None: @@ -604,12 +533,8 @@ async def run_client() -> None: @pytest.mark.anyio async def test_a_non_oserror_spawn_failure_propagates_and_leaks_no_streams(monkeypatch: pytest.MonkeyPatch) -> None: - """A non-OSError spawn failure also propagates and leaks no streams. - - Spawning can fail with more than OSError (e.g. ValueError for a NUL byte in the - command); the error reaches the caller and the transport's internal streams are - still all closed (checked through GC-time ResourceWarnings, as above). - """ + """Spawning can fail with more than OSError (e.g. ValueError for a NUL byte in the + command); the error propagates and the internal streams are still all closed.""" async def failing_spawn( command: str, @@ -631,12 +556,9 @@ async def failing_spawn( @pytest.mark.anyio async def test_a_message_sent_just_before_exit_is_flushed_to_the_server(monkeypatch: pytest.MonkeyPatch) -> None: - """A message the transport accepted must reach the server even on immediate exit. - - The caller exits right after sending. Once the writer is parked waiting, a send is - a pure handoff that returns before the write lands, so the second message here is - the one shutdown must let the writer flush before closing the server's stdin. - """ + """Once the writer is parked, a send is a pure handoff returning before the write + lands; the second message is the one shutdown must let the writer flush before + closing the server's stdin.""" ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) process = FakeProcess(on_stdin_close=lambda: process.exit(0)) @@ -655,17 +577,10 @@ async def test_a_message_sent_just_before_exit_is_flushed_to_the_server(monkeypa async def test_a_failed_write_to_a_live_server_closes_the_read_stream_instead_of_hanging( monkeypatch: pytest.MonkeyPatch, ) -> None: - """A failed write to a live server ends the read stream instead of hanging the session. - - When a write fails but the server is still alive (stdout never EOFs), the transport - must end the read stream so a session maps the loss to CONNECTION_CLOSED instead of - waiting forever. EIO pins that plain OSError, not just ConnectionError, is handled. - - Steps: - 1. A send fails with EIO while the server is alive; the read stream ends. - 2. Output the server produces afterwards is still drained, so it cannot wedge - on a full pipe. - """ + """When a write fails but the server is alive (stdout never EOFs), the transport must + end the read stream so the session maps the loss to CONNECTION_CLOSED instead of + waiting forever. EIO pins that plain OSError, not just ConnectionError, is handled; + later server output is still drained so it cannot wedge on a full pipe.""" ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) process = FakeProcess( @@ -693,11 +608,8 @@ async def test_a_failed_write_to_a_live_server_closes_the_read_stream_instead_of async def test_exit_completes_when_a_write_is_wedged_in_a_pipe_no_one_reads( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Exiting stays bounded even when the writer is parked in a write that cannot complete. - - A kill-surviving descendant can hold the read end without reading; the flush window - expires and the post-shutdown cancellation unparks the writer. - """ + """A kill-surviving descendant can hold the read end without reading; the flush + window expires and the post-shutdown cancellation unparks the wedged writer.""" ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") process = FakeProcess(on_stdin_close=lambda: process.exit(0), stdin_send_blocks=True) terminated = install_fake_process(monkeypatch, process) @@ -718,13 +630,10 @@ async def test_exit_completes_when_a_write_is_wedged_in_a_pipe_no_one_reads( async def test_undelivered_server_output_is_drained_at_shutdown_so_the_server_can_exit( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Output the caller never received is consumed during the stdin-close grace period. - - A real server flushing its remaining output on the way out would otherwise block on - a full pipe, never reach its stdin read, and be killed despite being well-behaved. - The fake ignores stdin closure (so it is ultimately terminated); the pin is that its - backlog was drained during the grace window. - """ + """A real server flushing remaining output on the way out would block on a full pipe, + never reach its stdin read, and be killed despite being well-behaved. The fake + ignores stdin closure (so it is terminated); the pin is that its backlog was drained + during the grace window.""" ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) process = FakeProcess() @@ -748,12 +657,9 @@ async def test_undelivered_server_output_is_drained_at_shutdown_so_the_server_ca async def test_shutdown_drains_stdout_first_so_a_wedged_writers_flush_can_complete( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Shutdown unblocks the reader's drain before waiting out the writer flush. - - A server wedged writing its stdout cannot get to reading its stdin, so a client - write can sit in a full pipe; the drain is what unwedges the server and lets the - flush complete. - """ + """A server wedged writing its stdout cannot get to reading its stdin, so a client + write can sit in a full pipe; shutdown's drain must unwedge the server before the + writer flush can complete.""" ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) @@ -796,12 +702,8 @@ def unwedge_once_drained() -> None: async def test_cancellation_with_undelivered_backlog_still_drains_and_spares_the_server( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Cancellation must not skip the shutdown drain. - - A well-behaved server that can only exit once its remaining output is consumed (a - real one blocks on a full stdout pipe) still exits within the grace period and is - never terminated. - """ + """A server that can only exit once its output is consumed (a real one blocks on a + full stdout pipe) still exits within the grace period under cancellation.""" ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) process = FakeProcess() @@ -844,19 +746,15 @@ async def run_client_until_cancelled() -> None: async def test_invalid_utf8_flushed_by_a_dying_server_does_not_break_shutdown( monkeypatch: pytest.MonkeyPatch, ) -> None: - """The shutdown drain consumes raw bytes. - - A server flushing non-UTF-8 output (a crash dump, say) on its way out must not - abort the drain or surface a UnicodeDecodeError out of the context manager. - """ + """A server flushing non-UTF-8 output (a crash dump, say) on its way out must not + abort the raw-bytes drain or surface a UnicodeDecodeError out of the context manager.""" ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") process = FakeProcess(on_stdin_close=lambda: process.exit(0)) terminated = install_fake_process(monkeypatch, process) with anyio.fail_after(5): async with stdio_client(FAKE_PARAMS): - # Park the reader delivering a message nobody receives, then queue - # bytes that are not valid UTF-8 behind it. + # Park the reader on an undelivered message, then queue invalid UTF-8 behind it. await process.feed(_line(ping)) await anyio.wait_all_tasks_blocked() await process.feed(b"\xff\xfe not utf-8\n") @@ -870,11 +768,8 @@ async def test_a_kill_racing_a_pending_stdout_read_is_swallowed_during_shutdown( monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, ) -> None: - """A hard kill during a pending stdout read must not escape the context manager. - - The read surfaces ConnectionResetError on the proactor backend; being expected - teardown noise, it is not logged as an error either. - """ + """The read surfaces ConnectionResetError on the proactor backend; being expected + teardown noise, it must not escape the context manager or be logged as an error.""" process = FakeProcess(stdout_eof_error=ConnectionResetError("read torn down by kill")) terminated = install_fake_process(monkeypatch, process) @@ -891,11 +786,8 @@ async def test_a_mid_session_stdout_failure_is_logged_and_surfaces_as_clean_clos monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, ) -> None: - """A mid-session stdout read failure ends the read stream cleanly and is logged. - - A failure outside shutdown surfaces no raw exception out of the context manager and - leaves an error log identifying the failure, unlike the silent shutdown case. - """ + """Unlike the silent shutdown case, a failure outside shutdown surfaces no raw + exception out of the context manager and leaves an error log identifying it.""" process = FakeProcess( on_stdin_close=lambda: process.exit(0), stdout_eof_error=ConnectionResetError("pipe failed mid-session"), @@ -905,8 +797,7 @@ async def test_a_mid_session_stdout_failure_is_logged_and_surfaces_as_clean_clos with anyio.fail_after(5): async with stdio_client(FAKE_PARAMS) as (read_stream, _): process.exit(1) - # (no branch: coverage mis-traces the exit arc of a `with` whose body - # raises inside a nested async context.) + # (no branch: coverage mis-traces a `with` whose body raises in a nested async context.) with pytest.raises(anyio.EndOfStream): # pragma: no branch await read_stream.receive() @@ -915,12 +806,8 @@ async def test_a_mid_session_stdout_failure_is_logged_and_surfaces_as_clean_clos @pytest.mark.anyio async def test_a_failing_stdout_close_still_closes_the_transport_streams(monkeypatch: pytest.MonkeyPatch) -> None: - """A close-time error on the process's stdout must not abort the rest of the shutdown. - - Such an error (a contended pipe handle on the Windows fallback) still leaves the - context exiting cleanly and the internal streams all closed (checked via GC-time - ResourceWarnings). - """ + """A close-time stdout error (a contended pipe handle on the Windows fallback) must + not abort the rest of shutdown; the internal streams are still all closed.""" process = FakeProcess( on_stdin_close=lambda: process.exit(0), stdout_aclose_error=OSError(errno.EBADF, "Bad file descriptor"), @@ -940,12 +827,8 @@ async def test_a_process_surviving_the_kill_escalation_is_logged_and_abandoned( monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, ) -> None: - """A process surviving the whole kill escalation is logged and abandoned. - - If the process is still alive after the escalation (D-state, an unsignalable - survivor), shutdown still completes, bounded, and leaves a warning instead of - silently leaking a live process. - """ + """A survivor (D-state, unsignalable) must not hang shutdown: it still completes, + bounded, and leaves a warning instead of silently leaking a live process.""" process = FakeProcess() # ignores stdin closure and survives "termination" install_fake_process(monkeypatch, process, grace_period=0.05) @@ -969,20 +852,13 @@ async def stubborn_terminate(proc: FakeProcess) -> None: process.close_stdout() -# --------------------------------------------------------------------------- -# POSIX tree-termination policy, tested through the sanctioned killpg seam -# --------------------------------------------------------------------------- -# -# `mcp.os.posix.utilities` is coverage-omitted and the sanctioned place to monkeypatch -# OS calls. These pin the EPERM policy without a foreign-euid process: macOS killpg -# raises EPERM when *any* group member cannot be signalled, even if others were. +# `mcp.os.posix.utilities` is coverage-omitted and the sanctioned seam for monkeypatching +# OS calls. The next tests pin the EPERM policy without a foreign-euid process: macOS +# killpg raises EPERM when *any* group member cannot be signalled, even if others were. class _StubPosixProcess: - """The two attributes `terminate_posix_process_tree` touches. - - They are the pgid source and the reap-progress probe. - """ + """The two attributes `terminate_posix_process_tree` touches: pgid source and reap probe.""" pid = 54321 returncode: int | None = None @@ -994,11 +870,8 @@ class _StubPosixProcess: async def test_an_eperm_group_that_dies_during_the_grace_period_is_not_sigkilled( # pragma: lax no cover monkeypatch: pytest.MonkeyPatch, ) -> None: - """EPERM from the SIGTERM killpg no longer short-circuits termination. - - The grace wait still runs, and a group observed to be gone during it is never - SIGKILLed. - """ + """EPERM from the SIGTERM killpg no longer short-circuits: the grace wait still runs, + and a group observed to be gone during it is never SIGKILLed.""" calls: list[tuple[int, int]] = [] probes = 0 @@ -1029,12 +902,9 @@ def fake_killpg(pgid: int, sig: int) -> None: async def test_an_eperm_group_that_outlives_the_grace_period_is_still_sigkilled( # pragma: lax no cover monkeypatch: pytest.MonkeyPatch, ) -> None: - """Even when every probe reports EPERM, the SIGKILL escalation still fires. - - It fires after the grace period, and its own EPERM is tolerated. Pre-fix, EPERM at - SIGTERM abandoned the group escalation for a leader-only kill, leaking every other - group member. The tiny timeout is the time-based grace period under test. - """ + """The SIGKILL fires after the grace period and its own EPERM is tolerated. Pre-fix, + EPERM at SIGTERM abandoned the group escalation for a leader-only kill, leaking + every other group member. The tiny timeout is the grace period under test.""" calls: list[tuple[int, int]] = [] def fake_killpg(pgid: int, sig: int) -> None: @@ -1061,19 +931,16 @@ def fake_killpg(pgid: int, sig: int) -> None: async def test_the_grace_wait_reads_returncode_so_trio_can_reap_the_leaders_zombie( # pragma: lax no cover monkeypatch: pytest.MonkeyPatch, ) -> None: - """The wait between SIGTERM and SIGKILL reads `process.returncode` while it polls. - - On trio that property calls `Popen.poll()`, whose reap stops the leader's zombie - from keeping the group alive for the full timeout (see terminate_posix_process_tree). - Regression pin for the read itself, on both backends; the reaping side effect is - trio's documented behaviour, deliberately not re-tested here. - """ + """On trio, reading `returncode` calls `Popen.poll()`, whose reap stops the leader's + zombie keeping the group alive for the full timeout (see terminate_posix_process_tree). + Pins the read itself, on both backends; the reaping side effect is trio's documented + behaviour, deliberately not re-tested here.""" calls: list[tuple[int, int]] = [] def fake_killpg(pgid: int, sig: int) -> None: - # SIGTERM is accepted and every liveness probe reports survivors, so the - # grace wait runs to its (tiny) timeout and the SIGKILL escalation fires. + # SIGTERM is accepted and every probe reports survivors: the grace wait runs to + # its (tiny) timeout and SIGKILL fires. calls.append((pgid, sig)) class _ReadCountingProcess: @@ -1102,19 +969,14 @@ def returncode(self) -> int | None: assert stub.returncode_reads >= 1 -# --------------------------------------------------------------------------- -# Real-process tests: the OS facts no fake can certify -# --------------------------------------------------------------------------- -# -# These pin kernel behaviour (process-group kill semantics, SIGKILL delivery) via a -# socket liveness probe, no sleeps or polls: `accept()` blocks until the subprocess -# connects, proving it runs (and its pre-connect setup ran); after cleanup, `receive(1)` -# raises EndOfStream (FIN) or BrokenResourceError (RST, typical of SIGKILL and Windows -# job termination) because the kernel closes a dead process's file descriptors. +# The real-process tests pin kernel behaviour via a socket liveness probe, no sleeps or +# polls: `accept()` blocks until the subprocess connects, proving it runs; after cleanup, +# `receive(1)` raises EndOfStream (FIN) or BrokenResourceError (RST, typical of SIGKILL +# and Windows job termination) because the kernel closes a dead process's descriptors. def _connect_back_script(port: int) -> str: - """Return a ``python -c`` liveness-probe body: connect to `port`, send `b'alive'`, block forever.""" + """Liveness-probe body for `python -c`: connect to `port`, send `b'alive'`, block forever.""" return ( f"import socket, time\n" f"s = socket.create_connection(('127.0.0.1', {port}))\n" @@ -1124,7 +986,6 @@ def _connect_back_script(port: int) -> str: async def _open_liveness_listener() -> tuple[anyio.abc.SocketListener, int]: - """Open a TCP listener on localhost and return it along with its port.""" multi = await anyio.create_tcp_listener(local_host="127.0.0.1") sock = multi.listeners[0] assert isinstance(sock, anyio.abc.SocketListener) @@ -1135,11 +996,8 @@ async def _open_liveness_listener() -> tuple[anyio.abc.SocketListener, int]: async def _accept_alive(sock: anyio.abc.SocketListener) -> anyio.abc.SocketStream: - """Accept one connection and assert the peer sent ``b'alive'``. - - Blocks until a subprocess connects (the outer test bounds this with - ``anyio.fail_after``). - """ + """Accept one connection and assert the peer sent `b'alive'`; blocks until the + subprocess connects (the outer test bounds this with `anyio.fail_after`).""" stream = await sock.accept() msg = await stream.receive(5) assert msg == b"alive", f"expected b'alive', got {msg!r}" @@ -1152,25 +1010,18 @@ async def _assert_stream_closed(stream: anyio.abc.SocketStream) -> None: await stream.receive(1) -# lax no cover: only called by win32-skipped tests; Windows CI jobs enforce 100% -# coverage per job, where these helpers never execute. +# lax no cover: only called by win32-skipped tests; Windows CI enforces 100% per job. async def _wait_until_exited(proc: anyio.abc.Process) -> None: # pragma: lax no cover - """Poll `returncode` until the process itself dies. - - Not `proc.wait()`: on asyncio that also waits for the pipes to close, conflating - process death with pipe state. - """ + """Poll `returncode` until the process dies; `proc.wait()` on asyncio also waits for + the pipes to close, conflating process death with pipe state.""" while proc.returncode is None: await anyio.sleep(0.01) async def _reap(proc: anyio.abc.Process) -> None: # pragma: lax no cover - """Reap an already-killed process and release its pipe transports. - - Draining stdout to EOF lets the asyncio pipe transport observe the closure instead - of warning at GC. The bound swallows a hung cleanup on purpose; reaping is just a - safety net. - """ + """Reap a killed process: draining stdout to EOF lets the asyncio pipe transport + observe the closure instead of warning at GC. The bound deliberately swallows a + hung cleanup; reaping is just a safety net.""" with anyio.move_on_after(5.0): await proc.wait() assert proc.stdin is not None @@ -1182,11 +1033,7 @@ async def _reap(proc: anyio.abc.Process) -> None: # pragma: lax no cover def _record_spawned_processes(monkeypatch: pytest.MonkeyPatch) -> list[anyio.abc.Process | FallbackProcess]: - """Record every process `stdio_client` spawns (the real spawn still runs). - - A test can inspect each process afterwards and tear its process group down on - failure. - """ + """Record every process `stdio_client` spawns; the real spawn still runs.""" spawned: list[anyio.abc.Process | FallbackProcess] = [] async def recording_spawn( @@ -1204,15 +1051,11 @@ async def recording_spawn( return spawned -# lax no cover: registered on every platform but a no-op on Windows, whose runners -# enforce 100% coverage per job. +# lax no cover: registered on every platform but a no-op on Windows (100% per-job coverage). def _kill_spawn_groups(spawned: list[anyio.abc.Process | FallbackProcess]) -> None: # pragma: lax no cover - """Failure-path safety net: SIGKILL each spawn-time process group. - - This stops a test failing mid-body from orphaning its sleep-forever descendants. - A no-op when the test passed, and on Windows (no process group to signal; the Job - Object covers strays). - """ + """Failure-path safety net: SIGKILL each spawn-time process group so a test failing + mid-body cannot orphan its sleep-forever descendants. A no-op when the test passed, + and on Windows (no process group to signal; the Job Object covers strays).""" if sys.platform == "win32": return for process in spawned: @@ -1223,14 +1066,10 @@ def _kill_spawn_groups(spawned: list[anyio.abc.Process | FallbackProcess]) -> No @pytest.mark.anyio async def test_exiting_the_context_terminates_the_entire_process_tree(monkeypatch: pytest.MonkeyPatch) -> None: - """Exiting `stdio_client` kills the server's whole process tree. - - The tree is a parent that exits instantly on SIGTERM (so the group must outlive its - leader), a child, and a grandchild, each death observed through its liveness socket - closing. The escalation timing is pinned in process by - test_escalation_fires_once_and_only_after_the_grace_period; the production grace - constant's value is deliberately unpinned. - """ + """The tree is a parent that exits instantly on SIGTERM (so the group must outlive + its leader), a child, and a grandchild, each death observed through its liveness + socket closing. Escalation timing is pinned in process by the escalation-bound test; + the production grace constant's value is deliberately unpinned.""" monkeypatch.setattr(stdio, "PROCESS_TERMINATION_TIMEOUT", 0.2) spawned = _record_spawned_processes(monkeypatch) @@ -1253,8 +1092,7 @@ async def test_exiting_the_context_terminates_the_entire_process_tree(monkeypatc ) server_params = StdioServerParameters(command=sys.executable, args=["-c", parent]) - # The bound covers three Python interpreter cold starts on a loaded runner; - # a healthy run takes well under a second. + # Covers three interpreter cold starts on a loaded runner; healthy runs are <1s. with anyio.fail_after(15.0): async with stdio_client(server_params): streams = [await _accept_alive(sock) for _ in range(3)] @@ -1269,12 +1107,8 @@ async def test_exiting_the_context_terminates_the_entire_process_tree(monkeypatc @pytest.mark.skipif(sys.platform == "win32", reason="POSIX process-group semantics") # lax no cover: Windows CI jobs enforce 100% coverage per job and skip this test. async def test_tree_kill_reaches_children_after_the_leader_has_already_exited() -> None: # pragma: lax no cover - """Killing the tree of an already-exited process still reaches its surviving children. - - The process group outlives its leader, and the group ID is the leader's pid by - construction (start_new_session), not something to look up from the (reaped) - leader. - """ + """The process group outlives its leader, and the group ID is the leader's pid by + construction (start_new_session), not something looked up from the (reaped) leader.""" async with AsyncExitStack() as stack: sock, port = await _open_liveness_listener() stack.push_async_callback(sock.aclose) @@ -1305,15 +1139,12 @@ async def test_tree_kill_reaches_children_after_the_leader_has_already_exited() @pytest.mark.skipif(sys.platform == "win32", reason="POSIX process-group semantics") # lax no cover: same Windows-runner coverage reason as above. async def test_terminating_an_already_exited_process_is_a_no_op() -> None: # pragma: lax no cover - """Once the whole group is gone, tree termination returns without error. - - It does not fall back to signalling a reaped pid. - """ + """Once the whole group is gone, tree termination returns without error rather than + falling back to signalling a reaped pid.""" proc = await _create_platform_compatible_process(sys.executable, ["-c", "pass"]) assert isinstance(proc, anyio.abc.Process) - # The bound covers one interpreter cold start on a loaded runner; a healthy run - # takes well under a second. + # Covers one interpreter cold start on a loaded runner; healthy runs are <1s. with anyio.fail_after(10.0): await _wait_until_exited(proc) await _terminate_process_tree(proc) @@ -1326,13 +1157,10 @@ async def test_terminating_an_already_exited_process_is_a_no_op() -> None: # pr async def test_escalation_kills_a_process_that_ignores_sigterm( # pragma: lax no cover monkeypatch: pytest.MonkeyPatch, ) -> None: - """Cleanup escalates past SIGTERM and kills a process that ignores it. - - The child installs SIG_IGN *before* connecting to the liveness socket, so the - ignore is guaranteed in place; SIGKILL delivery is proven by the kernel closing - the socket. The only test of the SIGTERM-then-SIGKILL escalation itself; the - production constants' values are deliberately unpinned. - """ + """The child installs SIG_IGN *before* connecting to the liveness socket, so the + ignore is guaranteed in place; SIGKILL delivery is proven by the kernel closing the + socket. The only test of the SIGTERM-then-SIGKILL escalation itself; the production + constants' values are deliberately unpinned.""" monkeypatch.setattr(stdio, "PROCESS_TERMINATION_TIMEOUT", 0.2) monkeypatch.setattr(stdio, "FORCE_KILL_TIMEOUT", 0.2) spawned = _record_spawned_processes(monkeypatch) @@ -1345,8 +1173,7 @@ async def test_escalation_kills_a_process_that_ignores_sigterm( # pragma: lax n script = "import signal\nsignal.signal(signal.SIGTERM, signal.SIG_IGN)\n" + _connect_back_script(port) server_params = StdioServerParameters(command=sys.executable, args=["-c", script]) - # The bound covers an interpreter cold start on a loaded runner plus the two - # shortened escalation waits; a healthy run takes well under a second. + # Covers a cold start plus the two shortened escalation waits; healthy runs are <1s. with anyio.fail_after(15.0): async with stdio_client(server_params): stream = await _accept_alive(sock) @@ -1361,14 +1188,10 @@ async def test_escalation_kills_a_process_that_ignores_sigterm( # pragma: lax n async def test_a_graceful_exit_with_a_surviving_child_leaks_no_pipe_fds( # pragma: lax no cover monkeypatch: pytest.MonkeyPatch, ) -> None: - """A graceful exit with a surviving child must not leak the client's pipe fds. - - A server may exit cleanly on stdin closure while leaving a child holding the + """A server may exit cleanly on stdin closure while leaving a child holding the inherited pipe ends (the POSIX policy: survivors are the server's business). The - client must still release its own pipe fds and subprocess transport at shutdown - (on asyncio nothing else ever closes them while the orphan holds the pipe) instead - of leaking them for the orphan's lifetime. - """ + client must still release its own pipe fds and subprocess transport at shutdown; on + asyncio nothing else ever closes them while the orphan holds the pipe.""" spawned = _record_spawned_processes(monkeypatch) async with AsyncExitStack() as stack: diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 99ff6f03e5..54ee296172 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -1,9 +1,7 @@ """Unit tests for the streamable-HTTP client transport. -The full client<->server round trip is pinned by the interaction suite under -tests/interaction/transports/; these tests cover the transport's header encoding and the -per-message metadata-headers merge directly because the headers are an HTTP-seam observation -the public client never exposes. +Covers header encoding and the per-message metadata-headers merge — HTTP-seam observations the +public client never exposes. The full round trip is pinned by tests/interaction/transports/. """ import base64 @@ -36,13 +34,9 @@ def test_mcp_name_header_values_are_base64_wrapped_when_unsafe_for_an_http_field( raw: str, expected: str, wrapped: bool ) -> None: - """Printable-ASCII names pass verbatim; CR/LF, non-ASCII, edge-whitespace, and sentinel-shaped names are wrapped. - - The ``=?base64?...?=`` sentinel is the spec's RFC 7230 safety gate for the ``Mcp-Name`` header. - Wrapped values round-trip through base64 so the server can recover the original name. A leading - or trailing space is wrapped because RFC 7230 forbids it in field-values (h11 rejects on real - transports); an empty value is allowed and passes verbatim. - """ + """The `=?base64?...?=` sentinel is the spec's RFC 7230 safety gate for `Mcp-Name`: CR/LF, non-ASCII, + edge whitespace (forbidden in field-values; h11 rejects it), and sentinel-shaped names are wrapped so + the server can base64-decode the original; other printable ASCII (including empty) passes verbatim.""" encoded = encode_header_value(raw) assert encoded == expected if wrapped: @@ -54,8 +48,6 @@ def test_mcp_name_header_values_are_base64_wrapped_when_unsafe_for_an_http_field @pytest.mark.anyio async def test_post_request_merges_per_message_metadata_headers() -> None: - """`ClientMessageMetadata.headers` on a `SessionMessage` are merged into the outgoing POST headers - (SDK-defined: the headers sidecar is the path the session uses to reach the transport).""" recorded: list[httpx.Request] = [] def handler(request: httpx.Request) -> httpx.Response: @@ -82,11 +74,8 @@ def handler(request: httpx.Request) -> httpx.Response: @pytest.mark.anyio async def test_pre_session_bare_404_maps_to_method_not_found() -> None: - """A bare HTTP 404 (no JSON-RPC body) before any session-id is held maps to METHOD_NOT_FOUND. - - Gateways and legacy servers 404 at the HTTP layer for unknown methods; with no session yet, - "Session terminated" is meaningless, and the discover→initialize fallback ladder keys on -32601. - """ + """Gateways and legacy servers 404 at the HTTP layer for unknown methods; with no session-id held, + "Session terminated" is meaningless, and the discover→initialize fallback ladder keys on -32601.""" def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(404) @@ -105,17 +94,9 @@ def handler(request: httpx.Request) -> httpx.Response: @pytest.mark.anyio async def test_initialize_post_clears_cached_pv_header_and_unstamped_posts_read_it() -> None: - """``initialize`` discards the cached protocol-version header; every other POST reads it. - - Steps: - 1. A stamped probe POST caches its ``MCP-Protocol-Version`` header. - 2. An ``initialize`` POST clears that cache before building headers, so the fallback - handshake never carries a probe-stamped value. - 3. A subsequent stamped POST re-seeds the cache with the negotiated version. - 4. An unstamped POST (a JSON-RPC response written by the dispatcher, which never - passes through the session's stamp) then reads the cache and carries the - negotiated version — the spec MUST for all post-initialization HTTP requests. - """ + """`initialize` clears the cached header so the fallback handshake never carries a probe-stamped + value; stamped POSTs (re-)seed the cache; unstamped POSTs read it — the spec MUST for carrying + the negotiated version on every post-initialization HTTP request.""" recorded: list[httpx.Request] = [] def handler(request: httpx.Request) -> httpx.Response: @@ -145,8 +126,7 @@ def handler(request: httpx.Request) -> httpx.Response: metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: "2025-11-25"}), ) ) - # An unstamped JSON-RPC response — what the dispatcher writes when answering - # a server-initiated request (sampling/elicitation/roots). + # Unstamped JSON-RPC response — what the dispatcher writes when answering a server-initiated request. await write.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=99, result={}))) assert [r.method for r in recorded] == ["POST", "POST", "POST", "POST"] diff --git a/tests/client/test_transport_stream_cleanup.py b/tests/client/test_transport_stream_cleanup.py index 40d3b2439d..715d694be4 100644 --- a/tests/client/test_transport_stream_cleanup.py +++ b/tests/client/test_transport_stream_cleanup.py @@ -1,13 +1,9 @@ """Regression tests for memory stream leaks in client transports. -When a connection error occurs (404, 403, ConnectError), transport context -managers must close ALL 4 memory stream ends they created. anyio memory streams -are paired but independent — closing the writer does NOT close the reader. -Unclosed stream ends emit ResourceWarning on GC, which pytest promotes to a -test failure in whatever test happens to be running when GC triggers. - -These tests force GC after the transport context exits, so any leaked stream -triggers a ResourceWarning immediately and deterministically here, rather than +On connection errors (404, 403, ConnectError) transports must close all 4 memory stream +ends they created — anyio streams are paired but independent, so closing the writer does +NOT close the reader. Leaked ends emit ResourceWarning on GC (promoted to a test failure +by pytest); forcing gc.collect() here surfaces the leak deterministically instead of nondeterministically in an unrelated later test. """ @@ -27,22 +23,15 @@ def _assert_no_memory_stream_leak() -> Iterator[None]: """Fail if any anyio MemoryObject stream emits ResourceWarning during the block. - Uses a custom sys.unraisablehook to capture ONLY MemoryObject stream leaks, - ignoring unrelated resources (e.g. PipeHandle from flaky stdio tests on the - same xdist worker). gc.collect() is forced after the block to make leaks - deterministic. + Unrelated unraisables (e.g. PipeHandle from flaky stdio tests on the same xdist worker) + are deliberately ignored. """ leaked: list[str] = [] old_hook = sys.unraisablehook def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover - # Only executes if a leak occurs (i.e. the bug is present). - # args.object is the __del__ function (not the stream instance) when - # unraisablehook fires from a finalizer, so check exc_value — the - # actual ResourceWarning("Unclosed "). - # Non-MemoryObject unraisables (e.g. PipeHandle leaked by an earlier - # flaky test on the same xdist worker) are deliberately ignored — - # this test should not fail for another test's resource leak. + # Runs only when a leak occurs (hence the pragma). For finalizer unraisables, + # args.object is the __del__ function, not the stream — match on exc_value instead. if "MemoryObject" in str(args.exc_value): leaked.append(str(args.exc_value)) @@ -57,12 +46,8 @@ def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover @pytest.mark.anyio async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: - """sse_client creates streams only after the SSE connection succeeds, so a - ConnectError propagates directly with nothing to leak. - - Before the fix, streams were created before connecting and only 2 of 4 were - closed in the finally block. - """ + """Streams are created only after the SSE connection succeeds, so ConnectError leaks + nothing. Before the fix, streams were created pre-connect and only 2 of 4 were closed.""" with _assert_no_memory_stream_leak(): with pytest.raises(httpx.ConnectError): async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"): @@ -71,10 +56,8 @@ async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: @pytest.mark.anyio async def test_sse_client_closes_all_streams_on_http_error() -> None: - """sse_client creates streams only after raise_for_status() passes, so an - HTTPStatusError from a 4xx/5xx response propagates bare (not wrapped in an - ExceptionGroup) with nothing to leak — the task group is never entered. - """ + """Streams are created only after raise_for_status() passes, so HTTPStatusError + propagates bare (not wrapped in an ExceptionGroup) — the task group is never entered.""" def return_403(request: httpx.Request) -> httpx.Response: return httpx.Response(403) @@ -94,12 +77,8 @@ def mock_factory( @pytest.mark.anyio async def test_streamable_http_client_closes_all_streams_on_exit() -> None: - """streamable_http_client must close all 4 stream ends on exit. - - Before the fix, read_stream was never closed — not even on the happy path. - This test enters and exits the context without sending any messages, so no - network connection is ever attempted (streamable_http connects lazily). - """ + """Before the fix, read_stream was never closed — not even on the happy path. No messages + are sent, so no network connection is attempted (streamable_http connects lazily).""" with _assert_no_memory_stream_leak(): async with streamable_http_client("http://127.0.0.1:1/mcp"): pass diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 375fe972a6..7c1d20c82a 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -1,5 +1,3 @@ -"""Tests for InMemoryTransport.""" - from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any @@ -19,8 +17,6 @@ @pytest.fixture def simple_server() -> Server: - """Create a simple MCP server for testing.""" - async def handle_list_resources( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> ListResourcesResult: # pragma: no cover @@ -39,7 +35,6 @@ async def handle_list_resources( @pytest.fixture def mcpserver_server() -> MCPServer: - """Create an MCPServer server for testing.""" server = MCPServer("test") @server.tool() @@ -59,7 +54,6 @@ def test_resource() -> str: # pragma: no cover async def test_with_server(simple_server: Server): - """Test creating transport with a Server instance.""" transport = InMemoryTransport(simple_server) async with transport as (read_stream, write_stream): assert read_stream is not None @@ -67,7 +61,6 @@ async def test_with_server(simple_server: Server): async def test_with_mcpserver(mcpserver_server: MCPServer): - """Test creating transport with an MCPServer instance.""" transport = InMemoryTransport(mcpserver_server) async with transport as (read_stream, write_stream): assert read_stream is not None @@ -75,13 +68,11 @@ async def test_with_mcpserver(mcpserver_server: MCPServer): async def test_server_is_running(mcpserver_server: MCPServer): - """Test that the server is running and responding to requests.""" async with Client(mcpserver_server, mode="legacy") as client: assert client.server_capabilities.tools is not None async def test_list_tools(mcpserver_server: MCPServer): - """Test listing tools through the transport.""" async with Client(mcpserver_server, mode="legacy") as client: tools_result = await client.list_tools() assert len(tools_result.tools) > 0 @@ -90,7 +81,6 @@ async def test_list_tools(mcpserver_server: MCPServer): async def test_call_tool(mcpserver_server: MCPServer): - """Test calling a tool through the transport.""" async with Client(mcpserver_server, mode="legacy") as client: result = await client.call_tool("greet", {"name": "World"}) assert result is not None @@ -99,18 +89,13 @@ async def test_call_tool(mcpserver_server: MCPServer): async def test_raise_exceptions(mcpserver_server: MCPServer): - """Test that raise_exceptions parameter is passed through.""" transport = InMemoryTransport(mcpserver_server, raise_exceptions=True) async with transport as (read_stream, _write_stream): assert read_stream is not None async def test_aexit_with_well_behaved_lifespan_runs_teardown_without_cancel(): - """A lifespan that finishes promptly on EOF should run to completion. - - The transport closes the streams first and waits for the server to exit - naturally, so teardown observes no cancellation. - """ + """The transport closes the streams and waits for a natural server exit, so teardown sees no cancellation.""" teardown_ran = anyio.Event() @asynccontextmanager @@ -127,12 +112,7 @@ async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: async def test_aexit_with_blocking_lifespan_is_bounded(monkeypatch: pytest.MonkeyPatch): - """A lifespan that never returns must not hang `__aexit__` forever. - - After EOFing the server the transport waits `SERVER_SHUTDOWN_GRACE` for a - natural exit, then cancels the server task as a backstop so the - task-group join completes. - """ + """After EOF the transport waits `SERVER_SHUTDOWN_GRACE` for a natural exit, then cancels as a backstop.""" monkeypatch.setattr(_memory, "SERVER_SHUTDOWN_GRACE", 0.05) teardown_started = anyio.Event() diff --git a/tests/conftest.py b/tests/conftest.py index 2278c9939e..a6c4b3da5f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,13 +3,9 @@ import pytest -# OpenTelemetry's `set_tracer_provider` is set-once per process, so the suite -# uses a single span-capture mechanism: logfire's `capfire` fixture (its -# `configure()` swaps span processors on repeat calls rather than re-setting -# the provider). Logfire's default `distributed_tracing=None` emits a -# RuntimeWarning + diagnostic span when incoming W3C trace context is -# extracted; several tests exercise that propagation deliberately, so opt in -# suite-wide. Set before logfire is imported anywhere. +# OTel's `set_tracer_provider` is set-once per process, so all span capture goes through logfire's `capfire` +# fixture. Logfire's default `distributed_tracing=None` emits a RuntimeWarning when incoming W3C trace context +# is extracted; tests exercise that propagation deliberately, so opt in suite-wide before logfire is imported. os.environ.setdefault("LOGFIRE_DISTRIBUTED_TRACING", "true") import opentelemetry.trace # noqa: E402 (env var must be set before logfire import below) @@ -27,15 +23,10 @@ def anyio_backend(): def _capfire_isolated(capfire: CaptureLogfire) -> Iterator[CaptureLogfire]: """Override of logfire's `capfire` that scopes the MCP tracer to the test. - `capfire` installs a real tracer provider, and logfire's proxy machinery - mutates the cached `mcp.shared._otel._tracer` to delegate to it for the - rest of the process. Without isolation, every subsequent test in the same - worker would emit real spans, and `send_raw_request` would inject a real - `traceparent` into outbound `_meta`, breaking the interaction-suite - snapshots that pin `_meta={}` under a no-op tracer. - - Setup points `_tracer` at the now-live provider so MCP spans record; - teardown replaces it with a `NoOpTracer`. + Logfire's proxy machinery mutates the cached `mcp.shared._otel._tracer` to delegate to + `capfire`'s provider for the rest of the process. Without the `NoOpTracer` teardown, later + tests would emit real spans and `send_raw_request` would inject a real `traceparent` into + outbound `_meta`, breaking interaction-suite snapshots that pin `_meta={}`. """ mcp.shared._otel._tracer = opentelemetry.trace.get_tracer_provider().get_tracer("mcp-python-sdk") try: diff --git a/tests/docs_src/test_apps.py b/tests/docs_src/test_apps.py index 02375f97a3..ba92f36703 100644 --- a/tests/docs_src/test_apps.py +++ b/tests/docs_src/test_apps.py @@ -14,15 +14,12 @@ async def test_the_tool_carries_the_ui_resource_reference() -> None: - """tutorial001: `@apps.tool(resource_uri=...)` stamps `_meta.ui.resourceUri` on the tool.""" async with Client(tutorial001.mcp) as client: listed = await client.list_tools() assert listed.tools[0].meta == {"ui": {"resourceUri": "ui://clock/app.html"}} async def test_the_ui_resource_is_served_as_the_app_mime_type() -> None: - """tutorial001: `add_html_resource` serves the HTML at `text/html;profile=mcp-app`, - the MIME type that tells a host "this is an app, render it".""" async with Client(tutorial001.mcp) as client: result = await client.read_resource("ui://clock/app.html") contents = result.contents[0] @@ -32,8 +29,7 @@ async def test_the_ui_resource_is_served_as_the_app_mime_type() -> None: async def test_one_tool_two_answers() -> None: - """tutorial001: the canonical degradation pattern: raw data for a client that - negotiated Apps, a human sentence for one that did not.""" + """Degradation pattern: raw data for a client that negotiated Apps, a human sentence for one that did not.""" async with Client(tutorial001.mcp, extensions={EXTENSION_ID: {"mimeTypes": [APP_MIME_TYPE]}}) as ui_client: rich = await ui_client.call_tool("get_time", {}) async with Client(tutorial001.mcp) as text_client: @@ -43,21 +39,17 @@ async def test_one_tool_two_answers() -> None: async def test_the_clock_client_program_runs_as_shown(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial001: `main()` declares Apps support with the required `mimeTypes` and - receives the rich answer the page promises.""" await tutorial001.main() assert "2026-06-26T12:00:00Z" in capsys.readouterr().out async def test_capability_advertised_under_server_extensions() -> None: - """tutorial001: passing `extensions=[apps]` advertises `io.modelcontextprotocol/ui`.""" async with Client(tutorial001.mcp) as client: assert client.server_capabilities.extensions == {EXTENSION_ID: {}} async def test_csp_permissions_domain_and_border_ride_the_resource_meta() -> None: - """tutorial002: the iframe lockdown fields land under `_meta.ui` on both the list - entry and the read content item, with the spec's camelCase wire keys.""" + """The fields land under `_meta.ui` on both the list entry and the read content, with camelCase wire keys.""" expected: dict[str, Any] = { "ui": { "csp": {"connectDomains": ["https://api.example.com"]}, @@ -76,8 +68,7 @@ async def test_csp_permissions_domain_and_border_ride_the_resource_meta() -> Non async def test_an_app_only_tool_is_still_listed_and_callable() -> None: - """tutorial002: `visibility=["app"]` is metadata for the host; the server lists the - tool like any other and serves its calls. Filtering is the host's job.""" + """`visibility` is metadata for the host — filtering is the host's job, not the server's.""" async with Client(tutorial002.mcp) as client: listed = await client.list_tools() result = await client.call_tool("refresh_dashboard", {}) @@ -86,8 +77,6 @@ async def test_an_app_only_tool_is_still_listed_and_callable() -> None: async def test_a_file_resource_is_served_with_the_app_mime_type_filled_in() -> None: - """tutorial003: `add_resource` accepts a pre-built `FileResource` and fills in the - `text/html;profile=mcp-app` MIME type the resource didn't set explicitly.""" async with Client(tutorial003.mcp) as client: listed = await client.list_tools() called = await client.call_tool("refresh_report", {}) diff --git a/tests/docs_src/test_asgi.py b/tests/docs_src/test_asgi.py index 93aa502428..52c0aa94d0 100644 --- a/tests/docs_src/test_asgi.py +++ b/tests/docs_src/test_asgi.py @@ -20,14 +20,12 @@ async def test_streamable_http_app_is_a_starlette_app_with_one_route() -> None: - """tutorial001: the factory returns a Starlette application with a single route at `/mcp`.""" (route,) = tutorial001.app.routes assert isinstance(route, Route) assert route.path == "/mcp" async def test_the_server_behind_the_app_is_unchanged() -> None: - """tutorial001: wrapping the server in an ASGI app changes nothing about its tools.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("add_note", {"text": "milk"}) assert result.content == [TextContent(type="text", text="Saved: milk")] @@ -35,7 +33,7 @@ async def test_the_server_behind_the_app_is_unchanged() -> None: async def test_streamable_http_app_takes_runs_options_except_port() -> None: - """The tip: every `run("streamable-http", ...)` option is here except `port`. `host` is one of them.""" + """The page's tip: every `run("streamable-http", ...)` option is here except `port`.""" parameters = set(inspect.signature(MCPServer.streamable_http_app).parameters) - {"self"} assert parameters == { "streamable_http_path", @@ -57,7 +55,6 @@ async def test_a_request_before_the_session_manager_runs_is_rejected() -> None: async def test_mounting_at_the_root_keeps_the_default_path() -> None: - """tutorial002: `Mount("/")` plus the default `streamable_http_path` leaves the endpoint at `/mcp`.""" (mount,) = tutorial002.app.routes assert isinstance(mount, Mount) assert mount.path == "" @@ -86,7 +83,6 @@ async def about(request: Request) -> Response: async def test_the_host_lifespan_enters_the_session_manager() -> None: - """tutorial002: the host app's lifespan owns `session_manager.run()` and starts and stops cleanly.""" async with tutorial002.lifespan(tutorial002.app): async with Client(tutorial002.mcp) as client: result = await client.call_tool("add_note", {"text": "milk"}) @@ -94,7 +90,6 @@ async def test_the_host_lifespan_enters_the_session_manager() -> None: async def test_two_servers_get_two_mounts() -> None: - """tutorial003: each server is mounted under its own prefix, each still ending in `/mcp`.""" notes_mount, tasks_mount = tutorial003.app.routes assert isinstance(notes_mount, Mount) assert isinstance(tasks_mount, Mount) @@ -103,7 +98,6 @@ async def test_two_servers_get_two_mounts() -> None: async def test_one_lifespan_starts_both_session_managers() -> None: - """tutorial003: a single `AsyncExitStack` lifespan runs both managers; both servers answer.""" async with tutorial003.lifespan(tutorial003.app): async with Client(tutorial003.notes) as client: notes_result = await client.call_tool("add_note", {"text": "milk"}) @@ -114,7 +108,6 @@ async def test_one_lifespan_starts_both_session_managers() -> None: async def test_streamable_http_path_moves_the_endpoint_to_the_mount_prefix() -> None: - """tutorial004: `streamable_http_path="/"` makes the `Mount` prefix the whole public path.""" (mount,) = tutorial004.app.routes assert isinstance(mount, Mount) assert mount.path == "/notes" @@ -124,7 +117,6 @@ async def test_streamable_http_path_moves_the_endpoint_to_the_mount_prefix() -> async def test_cors_exposes_the_session_id_header() -> None: - """tutorial005: the browser origin gets the three MCP methods and can read `Mcp-Session-Id`.""" (middleware,) = tutorial005.app.user_middleware assert middleware.cls is CORSMiddleware transport = httpx.ASGITransport(app=tutorial005.app) @@ -142,7 +134,6 @@ async def test_cors_exposes_the_session_id_header() -> None: async def test_custom_route_lands_next_to_the_mcp_endpoint() -> None: - """tutorial006: `@mcp.custom_route()` adds a plain Starlette route to the returned app.""" mcp_route, health_route = tutorial006.app.routes assert isinstance(mcp_route, Route) assert isinstance(health_route, Route) @@ -151,7 +142,6 @@ async def test_custom_route_lands_next_to_the_mcp_endpoint() -> None: async def test_the_health_check_answers_outside_the_protocol() -> None: - """tutorial006: `GET /health` is ordinary HTTP, with no session manager and no MCP.""" transport = httpx.ASGITransport(app=tutorial006.app) async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1") as http: response = await http.get("/health") @@ -169,9 +159,7 @@ async def test_the_health_check_answers_outside_the_protocol() -> None: async def test_the_default_app_is_localhost_only() -> None: - """The "Localhost only" section: with no `transport_security=`, the app answers a real hostname - with the page's `421 Invalid Host header` and a foreign Origin with `403 Invalid Origin header`, - before any MCP code runs.""" + """The "Localhost only" section: with no `transport_security=`, host/origin rejections fire before any MCP code.""" bare = MCPServer("Notes") app = bare.streamable_http_app() transport = httpx.ASGITransport(app=app) @@ -187,8 +175,6 @@ async def test_the_default_app_is_localhost_only() -> None: async def test_the_documented_browser_origin_works_end_to_end() -> None: - """tutorial005: the page's scenario for real. The public hostname, the browser origin, a - realistic preflight naming the `Mcp-*` headers, then the actual request.""" transport = httpx.ASGITransport(app=tutorial005.app) async with tutorial005.lifespan(tutorial005.app): async with httpx.AsyncClient(transport=transport, base_url="https://mcp.example.com") as http: diff --git a/tests/docs_src/test_authorization.py b/tests/docs_src/test_authorization.py index 4c7554ed75..f093b570f8 100644 --- a/tests/docs_src/test_authorization.py +++ b/tests/docs_src/test_authorization.py @@ -16,7 +16,6 @@ async def test_the_in_memory_client_never_authenticates() -> None: - """tutorial001: `Client(mcp)` connects to the server object directly, so no token is ever checked.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("list_notes", {}) assert not result.is_error @@ -24,13 +23,12 @@ async def test_the_in_memory_client_never_authenticates() -> None: async def test_token_verifier_and_auth_settings_must_travel_together() -> None: - """tutorial001: passing `token_verifier=` without `auth=` is refused at construction time.""" with pytest.raises(ValueError, match="Cannot specify auth_server_provider or token_verifier without auth settings"): MCPServer("Notes", token_verifier=tutorial001.StaticTokenVerifier()) async def test_the_app_grows_a_protected_resource_metadata_route() -> None: - """tutorial001: the HTTP app has the `/mcp` endpoint plus the RFC 9728 well-known route.""" + """The HTTP app has the `/mcp` endpoint plus the RFC 9728 well-known route.""" mcp_route, metadata_route = tutorial001.mcp.streamable_http_app().routes assert isinstance(mcp_route, Route) assert isinstance(metadata_route, Route) @@ -39,7 +37,6 @@ async def test_the_app_grows_a_protected_resource_metadata_route() -> None: async def test_the_metadata_document_is_built_from_auth_settings() -> None: - """tutorial001: `GET` on the well-known route returns the Protected Resource Metadata the page shows.""" transport = httpx.ASGITransport(app=tutorial001.mcp.streamable_http_app()) async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:8000") as http_client: response = await http_client.get("/.well-known/oauth-protected-resource/mcp") @@ -68,7 +65,6 @@ async def test_a_request_without_a_token_never_reaches_the_protocol() -> None: async def test_a_token_the_verifier_rejects_gets_the_same_401() -> None: - """tutorial001: `verify_token` returning `None` and a missing header are indistinguishable to the caller.""" transport = httpx.ASGITransport(app=tutorial001.mcp.streamable_http_app()) async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:8000") as http_client: response = await http_client.post("/mcp", json={}, headers={"Authorization": "Bearer not-a-real-token"}) @@ -77,14 +73,12 @@ async def test_a_token_the_verifier_rejects_gets_the_same_401() -> None: async def test_get_access_token_is_none_outside_an_authenticated_request() -> None: - """tutorial002: in-memory there is no HTTP layer, so `get_access_token()` returns `None`.""" async with Client(tutorial002.mcp) as client: result = await client.call_tool("whoami", {}) assert result.structured_content == {"result": "anonymous"} async def test_get_access_token_is_the_callers_access_token() -> None: - """tutorial002: over Streamable HTTP a valid bearer token reaches the tool as an `AccessToken`.""" url = "http://127.0.0.1:8000/mcp" transport = httpx.ASGITransport(app=tutorial002.mcp.streamable_http_app()) headers = {"Authorization": "Bearer alice-token"} diff --git a/tests/docs_src/test_caching.py b/tests/docs_src/test_caching.py index bc2feb9ac0..4a3efd7a49 100644 --- a/tests/docs_src/test_caching.py +++ b/tests/docs_src/test_caching.py @@ -14,7 +14,6 @@ async def test_a_mapped_method_carries_the_configured_hint() -> None: - """tutorial001: `tools/list` is in the map, so clients see one minute, public.""" async with Client(tutorial001.mcp) as client: tools = await client.list_tools() assert tools.ttl_ms == 60_000 @@ -22,7 +21,6 @@ async def test_a_mapped_method_carries_the_configured_hint() -> None: async def test_a_hint_without_a_scope_stays_private() -> None: - """tutorial001: `resources/read` set only `ttl_ms`; scope keeps the conservative default.""" async with Client(tutorial001.mcp) as client: result = await client.read_resource("config://units") assert result.ttl_ms == 5_000 @@ -30,7 +28,6 @@ async def test_a_hint_without_a_scope_stays_private() -> None: async def test_an_unmapped_method_stays_immediately_stale_and_private() -> None: - """tutorial001: `resources/list` is not in the map - the defaults hold.""" async with Client(tutorial001.mcp) as client: resources = await client.list_resources() assert resources.ttl_ms == 0 @@ -38,7 +35,6 @@ async def test_an_unmapped_method_stays_immediately_stale_and_private() -> None: async def test_a_non_cacheable_method_is_rejected_at_construction() -> None: - """The page's claim: anything but the six cacheable methods raises at construction.""" with pytest.raises(ValueError) as exc: MCPServer("Weather", cache_hints=cast(Any, {"tools/call": CacheHint(ttl_ms=1_000)})) assert str(exc.value) == snapshot( @@ -47,8 +43,7 @@ async def test_a_non_cacheable_method_is_rejected_at_construction() -> None: async def test_the_handler_value_wins_over_the_map_per_field() -> None: - """tutorial002: the handler's `ttl_ms=1_000` beats the map's `60_000`; the scope - the handler left unset takes the map's `"public"`.""" + """tutorial002's map sets `ttl_ms=60_000, scope="public"`; the handler overrides only `ttl_ms`.""" async with Client(tutorial002.server) as client: tools = await client.list_tools() assert tools.ttl_ms == 1_000 @@ -56,15 +51,12 @@ async def test_the_handler_value_wins_over_the_map_per_field() -> None: async def test_the_client_program_on_the_page_reads_the_hints(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial003: `main()` is the literal client program on the page - the hints - arrive as parsed fields on the result.""" await tutorial003.main() assert capsys.readouterr().out == "1 tools, fresh for 60s, scope=public\n" async def test_the_wire_presence_check_the_page_recommends_works() -> None: - """The page's claim: `"ttl_ms" in result.model_fields_set` distinguishes a - server that sent the field from one that said nothing (model defaults).""" + """Presence in `model_fields_set` proves the server sent the field rather than the model defaulting it.""" async with Client(tutorial003.mcp) as client: tools = await client.list_tools() assert "ttl_ms" in tools.model_fields_set diff --git a/tests/docs_src/test_client.py b/tests/docs_src/test_client.py index 97cc327dcb..2c1720a4ac 100644 --- a/tests/docs_src/test_client.py +++ b/tests/docs_src/test_client.py @@ -13,7 +13,6 @@ async def test_every_client_program_on_the_page_runs(capsys: pytest.CaptureFixture[str]) -> None: - """Each `main()` is the literal client program shown on the page; all seven run clean in-memory.""" await tutorial001.main() await tutorial002.main() await tutorial003.main() @@ -25,7 +24,6 @@ async def test_every_client_program_on_the_page_runs(capsys: pytest.CaptureFixtu async def test_connected_properties_are_populated_inside_the_block() -> None: - """tutorial001: server_info, server_capabilities, protocol_version and instructions are just there.""" async with Client(tutorial001.mcp) as client: assert client.server_info.name == "Bookshop" assert client.protocol_version == "2026-07-28" @@ -35,7 +33,6 @@ async def test_connected_properties_are_populated_inside_the_block() -> None: async def test_a_client_is_not_reusable_after_the_block_ends() -> None: - """tutorial001: `async with` is the whole lifecycle. Construct a new Client per connection.""" client = Client(tutorial001.mcp) async with client: assert client.server_info.name == "Bookshop" @@ -44,7 +41,6 @@ async def test_a_client_is_not_reusable_after_the_block_ends() -> None: async def test_list_tools_returns_the_full_definition() -> None: - """tutorial002: each listed tool carries its name, title, description and the derived input schema.""" async with Client(tutorial002.mcp) as client: (tool,) = (await client.list_tools()).tools assert tool.name == "search_books" @@ -64,7 +60,7 @@ async def test_list_tools_returns_the_full_definition() -> None: def test_get_display_name_prefers_the_title() -> None: - """The `!!! tip`: get_display_name returns the title when there is one and the name when there isn't.""" + """Pins the page's `!!! tip` admonition.""" titled = Tool(name="search_books", title="Search the catalog", input_schema={"type": "object"}) untitled = Tool(name="search_books", input_schema={"type": "object"}) assert get_display_name(titled) == "Search the catalog" @@ -72,7 +68,6 @@ def test_get_display_name_prefers_the_title() -> None: async def test_call_tool_result_has_three_things_to_read() -> None: - """tutorial003: content for the model, structured_content for code, is_error for both.""" async with Client(tutorial003.mcp) as client: result = await client.call_tool("lookup_book", {"title": "Dune"}) assert not result.is_error @@ -83,7 +78,7 @@ async def test_call_tool_result_has_three_things_to_read() -> None: async def test_a_raising_tool_is_a_result_not_an_exception() -> None: - """tutorial003 `!!! check`: the exception's message comes back in content with is_error=True.""" + """Pins tutorial003's `!!! check` admonition.""" async with Client(tutorial003.mcp) as client: result = await client.call_tool("lookup_book", {"title": "Solaris"}) assert result.is_error @@ -94,7 +89,7 @@ async def test_a_raising_tool_is_a_result_not_an_exception() -> None: async def test_an_unknown_tool_name_is_a_result_not_an_exception() -> None: - """The `!!! warning`: a tool the server doesn't have comes back as is_error=True, not as MCPError.""" + """Pins the page's `!!! warning` admonition.""" async with Client(tutorial003.mcp) as client: result = await client.call_tool("does_not_exist", {}) assert result.is_error @@ -105,7 +100,6 @@ async def test_an_unknown_tool_name_is_a_result_not_an_exception() -> None: async def test_resources_and_templates_are_two_separate_lists() -> None: - """tutorial004: concrete resources and parameterised templates come back from different verbs.""" async with Client(tutorial004.mcp) as client: (resource,) = (await client.list_resources()).resources assert resource.uri == "catalog://genres" @@ -114,7 +108,6 @@ async def test_resources_and_templates_are_two_separate_lists() -> None: async def test_read_resource_fills_in_a_template() -> None: - """tutorial004: read_resource takes a plain str URI; narrow the contents with isinstance.""" async with Client(tutorial004.mcp) as client: (contents,) = (await client.read_resource("catalog://genres/poetry")).contents assert isinstance(contents, TextResourceContents) @@ -122,7 +115,6 @@ async def test_read_resource_fills_in_a_template() -> None: async def test_mcpserver_does_not_implement_resource_subscriptions() -> None: - """The Resources section: MCPServer advertises subscribe=False and rejects subscribe_resource with -32601.""" async with Client(tutorial004.mcp) as client: assert client.server_capabilities.resources is not None assert client.server_capabilities.resources.subscribe is False @@ -133,7 +125,6 @@ async def test_mcpserver_does_not_implement_resource_subscriptions() -> None: async def test_list_prompts_describes_the_arguments() -> None: - """tutorial005: a listed prompt carries its name, title and the arguments it needs.""" async with Client(tutorial005.mcp) as client: (prompt,) = (await client.list_prompts()).prompts assert prompt == snapshot( @@ -147,7 +138,6 @@ async def test_list_prompts_describes_the_arguments() -> None: async def test_get_prompt_renders_the_messages() -> None: - """tutorial005: get_prompt returns the rendered messages a host hands to the model.""" async with Client(tutorial005.mcp) as client: result = await client.get_prompt("recommend", {"genre": "poetry"}) (message,) = result.messages @@ -158,7 +148,6 @@ async def test_get_prompt_renders_the_messages() -> None: async def test_complete_suggests_values_for_an_argument() -> None: - """tutorial006: complete takes a ref and a name/value pair and returns the matching values.""" async with Client(tutorial006.mcp) as client: result = await client.complete( ref=PromptReference(type="ref/prompt", name="recommend"), @@ -168,7 +157,6 @@ async def test_complete_suggests_values_for_an_argument() -> None: async def test_a_single_page_server_ends_the_pagination_loop_immediately() -> None: - """tutorial007: every list_* takes cursor=; next_cursor is None when there is nothing left.""" async with Client(tutorial007.mcp) as client: page = await client.list_tools(cursor=None) assert page.next_cursor is None @@ -176,7 +164,7 @@ async def test_a_single_page_server_ends_the_pagination_loop_immediately() -> No async def test_raise_exceptions_is_a_constructor_flag() -> None: - """The `## In tests` section: `raise_exceptions=True` is accepted by the in-memory Client.""" + """Pins the page's `## In tests` section.""" async with Client(tutorial001.mcp, raise_exceptions=True) as client: result = await client.call_tool("search_books", {"query": "dune"}) assert result.structured_content == {"result": "Found 3 books matching 'dune'."} diff --git a/tests/docs_src/test_client_transports.py b/tests/docs_src/test_client_transports.py index 848eddd52e..fc7e3b8a59 100644 --- a/tests/docs_src/test_client_transports.py +++ b/tests/docs_src/test_client_transports.py @@ -14,13 +14,11 @@ async def test_the_in_memory_program_on_the_page_runs(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial001's `main()` is the literal client program on the page; it runs clean end to end.""" await tutorial001.main() assert "Found 3 books matching 'dune'." in capsys.readouterr().out async def test_in_memory_client_talks_to_the_server_object() -> None: - """tutorial001: passing the server object connects in-process. No subprocess, no port.""" async with Client(tutorial001.mcp) as client: assert client.server_info.name == "Bookshop" assert client.protocol_version == "2026-07-28" @@ -41,14 +39,12 @@ async def test_streamable_http_configuration_lives_on_the_httpx_client() -> None async def test_stdio_parameters_are_wrapped_by_stdio_client() -> None: - """tutorial004: `stdio_client(params)` is the transport, and `Client` takes it like any other.""" client = Client(stdio_client(tutorial004.server)) with pytest.raises(RuntimeError, match="Client must be used within an async context manager"): client.session async def test_the_child_environment_is_an_allowlist(monkeypatch: pytest.MonkeyPatch) -> None: - """tutorial004: a variable set in the parent process is not inherited; `env=` adds it back explicitly.""" monkeypatch.setenv("BOOKSHOP_API_KEY", "from-the-parent") inherited = get_default_environment() assert "PATH" in inherited diff --git a/tests/docs_src/test_completions.py b/tests/docs_src/test_completions.py index b1f5c18164..a372f2ed83 100644 --- a/tests/docs_src/test_completions.py +++ b/tests/docs_src/test_completions.py @@ -22,7 +22,6 @@ async def test_a_server_with_no_handler_has_no_completions_capability() -> None: - """tutorial001: there is something worth completing, but no handler and no advertised capability.""" async with Client(tutorial001.mcp) as client: (template,) = (await client.list_resource_templates()).resource_templates assert template.uri_template == "github://repos/{owner}/{repo}" @@ -32,7 +31,6 @@ async def test_a_server_with_no_handler_has_no_completions_capability() -> None: async def test_completing_without_a_handler_is_method_not_found() -> None: - """tutorial001: nothing handles `completion/complete`, so the request is a JSON-RPC error.""" async with Client(tutorial001.mcp) as client: with pytest.raises(MCPError) as excinfo: await client.complete(ref=PROMPT_REF, argument={"name": "language", "value": "py"}) @@ -40,27 +38,23 @@ async def test_completing_without_a_handler_is_method_not_found() -> None: async def test_registering_the_handler_advertises_the_capability() -> None: - """tutorial002: `@mcp.completion()` is the whole declaration; the capability is derived from it.""" async with Client(tutorial002.mcp) as client: assert client.server_capabilities.completions == CompletionsCapability() async def test_prompt_argument_completion_filters_on_the_typed_prefix() -> None: - """tutorial002: the handler returns the languages that start with `argument.value`.""" async with Client(tutorial002.mcp) as client: result = await client.complete(ref=PROMPT_REF, argument={"name": "language", "value": "py"}) assert result.completion == snapshot(Completion(values=["python"])) async def test_empty_value_returns_every_suggestion() -> None: - """tutorial002: an empty prefix matches everything, so the client gets the whole list.""" async with Client(tutorial002.mcp) as client: result = await client.complete(ref=PROMPT_REF, argument={"name": "language", "value": ""}) assert result.completion.values == ["go", "javascript", "python", "rust", "typescript"] async def test_returning_none_is_an_empty_list_not_an_error() -> None: - """tutorial002: an argument the handler does not recognise produces `values=[]`, never a failure.""" async with Client(tutorial002.mcp) as client: result = await client.complete(ref=PROMPT_REF, argument={"name": "code", "value": "x"}) assert result.completion == snapshot(Completion(values=[])) @@ -69,7 +63,6 @@ async def test_returning_none_is_an_empty_list_not_an_error() -> None: async def test_context_arguments_resolve_a_dependent_parameter() -> None: - """tutorial003: the already-resolved `owner` arrives in `context.arguments` and picks the repo list.""" async with Client(tutorial003.mcp) as client: result = await client.complete( ref=TEMPLATE_REF, @@ -80,7 +73,6 @@ async def test_context_arguments_resolve_a_dependent_parameter() -> None: async def test_the_typed_prefix_still_filters_a_dependent_parameter() -> None: - """tutorial003: `argument.value` narrows the owner's repos exactly as it narrows a prompt argument.""" async with Client(tutorial003.mcp) as client: result = await client.complete( ref=TEMPLATE_REF, @@ -97,7 +89,6 @@ def test_context_arguments_is_optional() -> None: async def test_no_context_means_no_suggestions() -> None: - """tutorial003: without a resolved `owner` (or with an unknown one) the handler has nothing to offer.""" async with Client(tutorial003.mcp) as client: result = await client.complete(ref=TEMPLATE_REF, argument={"name": "repo", "value": ""}) assert result.completion.values == [] @@ -110,7 +101,6 @@ async def test_no_context_means_no_suggestions() -> None: async def test_the_prompt_branch_is_untouched_by_the_new_one() -> None: - """tutorial003: adding the resource-template branch leaves prompt-argument completion as it was.""" async with Client(tutorial003.mcp) as client: result = await client.complete(ref=PROMPT_REF, argument={"name": "language", "value": "type"}) assert result.completion.values == ["typescript"] diff --git a/tests/docs_src/test_context.py b/tests/docs_src/test_context.py index 2948b10f57..4ce4de7e00 100644 --- a/tests/docs_src/test_context.py +++ b/tests/docs_src/test_context.py @@ -14,7 +14,6 @@ async def test_the_context_parameter_is_not_in_the_input_schema() -> None: - """tutorial001: the injected `Context` never appears in the schema the model sees.""" async with Client(tutorial001.mcp) as client: (tool,) = (await client.list_tools()).tools assert tool.input_schema == snapshot( @@ -28,7 +27,6 @@ async def test_the_context_parameter_is_not_in_the_input_schema() -> None: async def test_every_request_gets_its_own_context() -> None: - """tutorial001: `ctx.request_id` identifies the request being served, so it changes per call.""" async with Client(tutorial001.mcp) as client: first = await client.call_tool("search_books", {"query": "dune"}) second = await client.call_tool("search_books", {"query": "dune"}) @@ -52,7 +50,6 @@ async def test_a_tool_reads_the_servers_own_resource() -> None: async def test_a_context_only_tool_takes_no_arguments() -> None: - """tutorial002: a tool whose only parameter is the `Context` has an empty input schema.""" async with Client(tutorial002.mcp) as client: tools = {tool.name: tool for tool in (await client.list_tools()).tools} assert tools["describe_catalog"].input_schema == snapshot( @@ -61,7 +58,6 @@ async def test_a_context_only_tool_takes_no_arguments() -> None: async def test_register_a_tool_at_runtime_and_notify_the_client() -> None: - """tutorial003: `mcp.add_tool` takes effect immediately and `send_tool_list_changed` reaches the client.""" messages: list[object] = [] async def collect(message: object) -> None: diff --git a/tests/docs_src/test_dependencies.py b/tests/docs_src/test_dependencies.py index 06d8935853..e4685ed52d 100644 --- a/tests/docs_src/test_dependencies.py +++ b/tests/docs_src/test_dependencies.py @@ -14,7 +14,6 @@ async def test_the_resolver_fills_the_parameter_from_the_tools_own_argument() -> None: - """tutorial001: `check_stock` receives `title` by name and its return value becomes `stock`.""" async with Client(tutorial001.mcp) as client: in_stock = await client.call_tool("reserve_book", {"title": "Dune"}) sold_out = await client.call_tool("reserve_book", {"title": "Neuromancer"}) @@ -24,7 +23,7 @@ async def test_the_resolver_fills_the_parameter_from_the_tools_own_argument() -> async def test_the_resolved_parameter_is_invisible_to_the_model() -> None: - """tutorial001: the input schema shown on the page is exactly what `tools/list` reports.""" + """The snapshot is the exact input schema shown on the docs page.""" async with Client(tutorial001.mcp) as client: (tool,) = (await client.list_tools()).tools @@ -39,7 +38,6 @@ async def test_the_resolved_parameter_is_invisible_to_the_model() -> None: async def test_a_client_supplied_value_for_a_resolved_parameter_is_ignored() -> None: - """tutorial001: the resolver's value is the only one the tool can receive.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("reserve_book", {"title": "Dune", "stock": {"title": "Dune", "copies": 999}}) @@ -47,7 +45,6 @@ async def test_a_client_supplied_value_for_a_resolved_parameter_is_ignored() -> async def test_a_resolver_can_depend_on_another_resolver() -> None: - """tutorial002: `estimate_delivery` consumes `check_stock`'s result, and the tool gets both.""" async with Client(tutorial002.mcp) as client: in_stock = await client.call_tool("order_book", {"title": "Dune"}) backorder = await client.call_tool("order_book", {"title": "Neuromancer"}) @@ -59,8 +56,6 @@ async def test_a_resolver_can_depend_on_another_resolver() -> None: async def test_a_shared_dependency_runs_once_per_call(monkeypatch: pytest.MonkeyPatch) -> None: - """tutorial002: `stock` and `delivery` both need `check_stock`; one call, one inventory lookup.""" - class CountingInventory: def __init__(self, data: dict[str, int]) -> None: self.data = data @@ -81,14 +76,11 @@ def get(self, key: str, default: int) -> int: assert inventory.lookups == ["Dune", "Dune"] -# The `!!! info` claims the tutorial003 behaviour is transport-independent, so each claim is -# proved on both: mode="legacy" elicits synchronously mid-call (2025-11-25 and earlier), while -# mode="auto" negotiates 2026-07-28, where the question rides a multi-round-trip `tools/call` -# and `Client` drives the retries. +# The docs' `!!! info` claims tutorial003 is transport-independent, so each claim is proved in both +# modes: "legacy" elicits synchronously mid-call (2025-11-25 and earlier); "auto" negotiates +# 2026-07-28, where the question rides a multi-round-trip `tools/call` and `Client` drives retries. @pytest.mark.parametrize("mode", ["legacy", "auto"]) async def test_an_in_stock_order_asks_no_question(mode: Literal["legacy", "auto"]) -> None: - """tutorial003: `confirm_backorder` returns directly when stock exists - no round-trip.""" - async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover raise AssertionError("an in-stock order must not elicit") @@ -109,7 +101,6 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E async def test_an_out_of_stock_order_asks_and_honours_the_answer( mode: Literal["legacy", "auto"], confirm: bool, expected: str ) -> None: - """tutorial003: the resolver elicits, the SDK validates the answer, the tool reads it.""" asked: list[str] = [] async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: @@ -125,7 +116,7 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) @pytest.mark.parametrize("mode", ["legacy", "auto"]) async def test_declining_an_unwrapped_dependency_aborts_the_call(mode: Literal["legacy", "auto"]) -> None: - """tutorial003: no answer, no order - the error text on the page is the real one.""" + """The asserted error text is the one shown on the docs page.""" async def decline(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") diff --git a/tests/docs_src/test_deprecated.py b/tests/docs_src/test_deprecated.py index 892a8f3627..d96f0decf6 100644 --- a/tests/docs_src/test_deprecated.py +++ b/tests/docs_src/test_deprecated.py @@ -1,11 +1,8 @@ """`docs/advanced/deprecated.md`: the page's behavioural claims, executed against the live SDK. -This chapter has no `docs_src/` example by design: it is the one page allowed to name -the deprecated methods, and a runnable example would teach exactly what the page tells -the reader not to build. So instead of importing an example, each test here runs a -claim the page states in prose (the warning category and text, the warn-*then*-raise -order on a modern connection, the `ping` removal, and both `filterwarnings` recipes) -so the prose cannot drift away from what the SDK does. +The chapter has no `docs_src/` example by design — a runnable one would teach exactly what +the page tells readers not to build — so each test runs a claim the page states in prose, +keeping the prose from drifting away from what the SDK does. """ import warnings @@ -42,12 +39,7 @@ async def old_log(ctx: Context) -> str: async def test_create_message_warns_and_then_raises_on_a_modern_connection() -> None: - """The `!!! warning`: on a modern connection sampling warns AND THEN the send raises. - - The two signals are independent: `@deprecated` fires the moment the method is - called, and only afterwards does the channel refuse the send. The page reports - both, in that order. - """ + """`@deprecated` warns the moment the method is called; only afterwards does the channel refuse the send.""" async with Client(mcp) as client: with ( pytest.warns( @@ -64,12 +56,7 @@ async def test_create_message_warns_and_then_raises_on_a_modern_connection() -> async def test_a_deprecated_feature_still_works_on_a_legacy_session() -> None: - """The page's headline: the deprecation is advisory. - - On a classic-handshake session, the same `ask_model` tool that fails on a modern - connection runs to completion: sampling round-trips through the client's callback - and the result comes back. The only difference is the visible warning. - """ + """The deprecation is advisory: under mode='legacy' the same tool completes, with only the warning.""" async def canned_sampling(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: return CreateMessageResult( @@ -89,12 +76,7 @@ async def canned_sampling(context: ClientRequestContext, params: CreateMessageRe async def test_send_ping_still_carries_the_deprecation_warning() -> None: - """The opening sentence: every retired method carries an `MCPDeprecationWarning`. - - `ping` is removed from the 2026-07-28 protocol rather than put in a deprecation - window, but the SDK method is still decorated (its message says *removed*) and - a modern connection answers the actual request with "Method not found". - """ + """`ping` is removed outright in 2026-07-28 (no deprecation window), yet the method is still decorated.""" async with Client(mcp) as client: with ( pytest.warns( @@ -107,22 +89,15 @@ async def test_send_ping_still_carries_the_deprecation_warning() -> None: def test_mcp_deprecation_warning_is_a_user_warning() -> None: - """The "Deprecated is advisory" section: the category subclasses `UserWarning`. - - Python's default filter hides `DeprecationWarning` outside `__main__`; deriving - from `UserWarning` is what makes the warning visible with no `-W` flag. - """ + """Deriving from `UserWarning` keeps the warning visible with no `-W` flag; + Python's default filter hides `DeprecationWarning` outside `__main__`.""" assert issubclass(MCPDeprecationWarning, UserWarning) assert not issubclass(MCPDeprecationWarning, DeprecationWarning) @pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning") async def test_error_filter_turns_the_deprecated_call_into_the_documented_tool_error() -> None: - """The `!!! check`: `"error::mcp.MCPDeprecationWarning"` makes `old_log` fail. - - Under the error filter the warning becomes the raised exception, the tool manager - wraps it, and the result is exactly the tool error the page quotes. - """ + """Under the error filter the warning is raised, wrapped, and surfaces as the tool error the page quotes.""" async with Client(mcp) as client: result = await client.call_tool("old_log", {}) assert result.is_error @@ -134,7 +109,6 @@ async def test_error_filter_turns_the_deprecated_call_into_the_documented_tool_e async def test_filterwarnings_ignore_silences_the_whole_category() -> None: - """The "Silencing the warning" snippet: one `filterwarnings` line quiets the category.""" async with Client(mcp) as client: with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") diff --git a/tests/docs_src/test_elicitation.py b/tests/docs_src/test_elicitation.py index a28f1087fc..44f0ed1999 100644 --- a/tests/docs_src/test_elicitation.py +++ b/tests/docs_src/test_elicitation.py @@ -25,8 +25,6 @@ async def test_an_accepted_answer_resumes_the_tool() -> None: - """tutorial001: the user's answer comes back into the same call as a validated model.""" - async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="accept", content={"accept_alternative": True, "date": "2025-12-26"}) @@ -55,7 +53,6 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) async def test_the_client_receives_the_message_and_the_generated_schema() -> None: - """tutorial001: form mode sends your message plus a JSON Schema built from the Pydantic model.""" received: list[ElicitRequestParams] = [] async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: @@ -90,8 +87,6 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) async def test_decline_and_cancel_are_ordinary_return_values() -> None: - """tutorial001: a refusal is not an error; the tool sees the action and answers the model normally.""" - async def on_decline(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") @@ -108,7 +103,6 @@ async def on_cancel(context: ClientRequestContext, params: ElicitRequestParams) async def test_a_tool_that_does_not_ask_needs_nothing_from_the_client() -> None: - """tutorial001: the elicitation only happens on the path that needs it.""" async with Client(tutorial001.mcp, mode="legacy") as client: result = await client.call_tool("book_table", {"date": "2025-12-30", "party_size": 4}) assert result.content == [TextContent(type="text", text="Booked a table for 4 on 2025-12-30.")] @@ -189,7 +183,6 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) async def test_url_mode_sends_a_url_and_gets_consent_back_not_data() -> None: - """tutorial002: the client receives the URL and the elicitation id; only the action comes back.""" received: list[ElicitRequestParams] = [] async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: @@ -206,8 +199,6 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) async def test_a_declined_url_elicitation_is_an_ordinary_return_value() -> None: - """tutorial002: the tool decides what a refusal means.""" - async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") @@ -217,7 +208,6 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) async def test_send_elicit_complete_notifies_the_client_with_the_same_id() -> None: - """tutorial002: `send_elicit_complete` emits `notifications/elicitation/complete`.""" notifications: list[object] = [] async def on_message(message: object) -> None: @@ -232,7 +222,6 @@ async def on_message(message: object) -> None: async def test_the_docs_client_callback_handles_both_modes() -> None: - """tutorial003: one `elicitation_callback` answers the form and the URL consent.""" async with Client(tutorial001.mcp, mode="legacy", elicitation_callback=tutorial003.handle_elicitation) as client: booked = await client.call_tool("book_table", {"date": "2025-12-25", "party_size": 2}) async with Client(tutorial002.mcp, mode="legacy", elicitation_callback=tutorial003.handle_elicitation) as client: @@ -249,7 +238,6 @@ async def test_a_client_without_the_callback_cannot_be_asked() -> None: async def test_resolver_asks_only_when_the_folder_is_not_empty() -> None: - """tutorial004: `confirm_delete` resolves an empty folder directly and elicits otherwise.""" tutorial004._FOLDERS.update({"/tmp/empty": [], "/tmp/project": ["main.py", "README.md"]}) asked: list[str] = [] @@ -264,11 +252,10 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) assert empty.content == [TextContent(type="text", text="deleted /tmp/empty")] assert non_empty.content == [TextContent(type="text", text="deleted /tmp/project")] - assert asked == ["/tmp/project has 2 file(s). Delete anyway?"] # the empty folder was not queried + assert asked == ["/tmp/project has 2 file(s). Delete anyway?"] async def test_the_resolved_parameter_is_hidden_from_the_tool_schema() -> None: - """tutorial004: the `Resolve`-filled parameter never appears in the client-facing input schema.""" async with Client(tutorial004.mcp, mode="legacy") as client: (tool,) = (await client.list_tools()).tools assert tool.name == "delete_folder" @@ -288,7 +275,6 @@ async def test_the_tool_branches_on_every_elicitation_outcome( content: dict[str, str | int | float | bool | list[str] | None] | None, expected: str, ) -> None: - """tutorial004: annotating the result union lets the tool handle accept/decline/cancel.""" tutorial004._FOLDERS["/tmp/project"] = ["main.py", "README.md"] async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: diff --git a/tests/docs_src/test_extensions.py b/tests/docs_src/test_extensions.py index ebe00e5a88..0288d2278f 100644 --- a/tests/docs_src/test_extensions.py +++ b/tests/docs_src/test_extensions.py @@ -17,15 +17,11 @@ async def test_using_an_extension_advertises_its_capability() -> None: - """tutorial001: `extensions=[Apps()]` is all it takes for the server to advertise - the extension under `capabilities.extensions`.""" async with Client(tutorial001.mcp) as client: assert client.server_capabilities.extensions == {"io.modelcontextprotocol/ui": {}} def test_a_prefixless_identifier_fails_at_class_definition() -> None: - """tutorial002 + the page's TypeError block: the identifier is validated when the - subclass is defined, with the exact message the page shows.""" assert tutorial002.Stamps.identifier == "com.example/stamps" with pytest.raises(TypeError) as exc_info: type("Stamps", (Extension,), {"identifier": "stamps"}) @@ -35,13 +31,11 @@ def test_a_prefixless_identifier_fails_at_class_definition() -> None: async def test_extension_settings_advertised_under_capabilities() -> None: - """tutorial003: `settings()` becomes the entry at `capabilities.extensions[identifier]`.""" async with Client(tutorial003.mcp) as client: assert client.server_capabilities.extensions == {"com.example/stamps": {"sealed": True}} async def test_contributed_tool_is_listed_and_callable() -> None: - """tutorial003: a `ToolBinding` registers like any `add_tool` call: listed and callable.""" async with Client(tutorial003.mcp) as client: listed = await client.list_tools() assert [tool.name for tool in listed.tools] == ["stamp"] @@ -50,8 +44,6 @@ async def test_contributed_tool_is_listed_and_callable() -> None: async def test_the_stamps_client_program_runs_as_shown(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial003: `main()` is the literal client program on the page; both printed - lines match the page's comments.""" await tutorial003.main() out = capsys.readouterr().out assert "{'com.example/stamps': {'sealed': True}}" in out @@ -59,14 +51,11 @@ async def test_the_stamps_client_program_runs_as_shown(capsys: pytest.CaptureFix async def test_the_search_client_program_runs_as_shown(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial004: `main()` declares the extension and gets the vendor method's result.""" await tutorial004.main() assert "['mcp-0', 'mcp-1', 'mcp-2']" in capsys.readouterr().out async def test_vendor_method_rejects_a_non_declaring_client_with_32021() -> None: - """tutorial004: `require_client_extension` answers a non-declaring client with `-32021` - and the machine-readable `requiredCapabilities` payload.""" async with Client(tutorial004.mcp) as client: request = tutorial004.SearchRequest(params=tutorial004.SearchParams(query="mcp")) with pytest.raises(MCPError) as exc_info: @@ -76,8 +65,7 @@ async def test_vendor_method_rejects_a_non_declaring_client_with_32021() -> None async def test_version_pinned_method_is_not_found_on_a_legacy_connection() -> None: - """tutorial004: `protocol_versions={"2026-07-28"}` makes the method METHOD_NOT_FOUND - at any other wire version; for a legacy client it doesn't exist.""" + """tutorial004 pins the method to `protocol_versions={"2026-07-28"}`; on a legacy connection it doesn't exist.""" async with Client(tutorial004.mcp, mode="legacy", extensions={tutorial004.EXTENSION_ID: {}}) as client: request = tutorial004.SearchRequest(params=tutorial004.SearchParams(query="mcp")) with pytest.raises(MCPError) as exc_info: @@ -88,7 +76,6 @@ async def test_version_pinned_method_is_not_found_on_a_legacy_connection() -> No async def test_interceptor_observes_the_call_and_passes_the_result_through( caplog: pytest.LogCaptureFixture, ) -> None: - """tutorial005: the interceptor logs the tool name and returns `call_next`'s result unchanged.""" with caplog.at_level(logging.INFO, logger=tutorial005.logger.name): async with Client(tutorial005.mcp) as client: result = await client.call_tool("add", {"a": 2, "b": 3}) diff --git a/tests/docs_src/test_first_steps.py b/tests/docs_src/test_first_steps.py index 15d6708ee2..d3d6e65efc 100644 --- a/tests/docs_src/test_first_steps.py +++ b/tests/docs_src/test_first_steps.py @@ -47,7 +47,7 @@ async def test_each_decorator_registers_one_primitive() -> None: async def test_call_the_tool() -> None: - """tutorial001: the Inspector walkthrough. `add` with 1 and 2 answers 3.""" + """tutorial001: the Inspector walkthrough.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("add", {"a": 1, "b": 2}) assert not result.is_error @@ -79,14 +79,10 @@ async def test_get_the_prompt() -> None: async def test_the_three_primitive_capabilities_are_always_declared() -> None: - """tutorial001: `MCPServer` always declares tools/resources/prompts; only `completions` follows your code. - - An `MCPServer` with nothing registered declares the same three, which is why the - page ties registration to the *optional* capabilities only. - """ + """tutorial001: `MCPServer` always declares tools/resources/prompts; only `completions` follows your code.""" async with Client(tutorial001.mcp) as client: declared = client.server_capabilities - # The exact dictionary the page prints from `model_dump(exclude_none=True)`. + # The exact dictionary the page prints. assert declared.model_dump(exclude_none=True) == snapshot( { "prompts": {"list_changed": False}, diff --git a/tests/docs_src/test_handling_errors.py b/tests/docs_src/test_handling_errors.py index 1a76a7bb77..faba80b543 100644 --- a/tests/docs_src/test_handling_errors.py +++ b/tests/docs_src/test_handling_errors.py @@ -11,7 +11,6 @@ async def test_a_plain_exception_becomes_a_tool_error_the_model_reads() -> None: - """tutorial001: any non-`MCPError` exception comes back as `is_error=True` with the message in `content`.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("get_author", {"title": "Nothing"}) assert result.is_error @@ -22,7 +21,6 @@ async def test_a_plain_exception_becomes_a_tool_error_the_model_reads() -> None: async def test_a_title_the_catalog_knows_is_an_ordinary_result() -> None: - """tutorial001: the non-raising path is a plain `is_error=False` result.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("get_author", {"title": "Dune"}) assert not result.is_error @@ -30,7 +28,6 @@ async def test_a_title_the_catalog_knows_is_an_ordinary_result() -> None: async def test_a_bad_argument_never_reaches_the_function() -> None: - """tutorial001: schema validation rejects the call before `get_author` runs, as the same kind of tool error.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("get_author", {"title": 42}) assert result.is_error @@ -39,7 +36,6 @@ async def test_a_bad_argument_never_reaches_the_function() -> None: async def test_mcp_error_makes_the_call_itself_fail() -> None: - """tutorial002: `MCPError` is not caught. It surfaces as a JSON-RPC error, with `code` and `message` intact.""" async with Client(tutorial002.mcp) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("get_author", {"title": "Nothing"}) @@ -48,7 +44,6 @@ async def test_mcp_error_makes_the_call_itself_fail() -> None: async def test_mcp_error_only_fires_on_the_raising_path() -> None: - """tutorial002: a title the catalog knows still returns a normal result.""" async with Client(tutorial002.mcp) as client: result = await client.call_tool("get_author", {"title": "Dune"}) assert not result.is_error @@ -56,7 +51,6 @@ async def test_mcp_error_only_fires_on_the_raising_path() -> None: async def test_resource_not_found_error_maps_to_invalid_params() -> None: - """tutorial003: `ResourceNotFoundError` from a template handler is `-32602` with the URI in `data`.""" async with Client(tutorial003.mcp) as client: with pytest.raises(MCPError) as exc_info: await client.read_resource("books://Nothing") @@ -78,7 +72,6 @@ async def test_raise_exceptions_does_not_turn_a_tool_error_into_a_traceback() -> async def test_a_title_the_template_knows_reads_normally() -> None: - """tutorial003: the non-raising path resolves the template and returns text contents.""" async with Client(tutorial003.mcp) as client: result = await client.read_resource("books://Dune") (contents,) = result.contents diff --git a/tests/docs_src/test_identity_assertion.py b/tests/docs_src/test_identity_assertion.py index afcfd83290..30c22850c4 100644 --- a/tests/docs_src/test_identity_assertion.py +++ b/tests/docs_src/test_identity_assertion.py @@ -28,8 +28,6 @@ class RecordingASGITransport(httpx.ASGITransport): - """An `httpx.ASGITransport` that appends every (method, path, body) it carries to a shared log.""" - def __init__(self, app: Starlette, log: list[tuple[str, str, bytes]]) -> None: super().__init__(app=app) self.log = log @@ -40,18 +38,15 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: async def test_the_provider_is_an_httpx_auth_but_not_an_oauth_client_provider() -> None: - """tutorial001: same `auth=` slot as the rest of OAuth clients, but nothing is discovered or registered.""" assert isinstance(tutorial001.oauth, httpx.Auth) assert not isinstance(tutorial001.oauth, OAuthClientProvider) async def test_main_is_the_main_from_the_oauth_clients_page() -> None: - """The page says `main()` is unchanged to the character from the OAuth clients page.""" assert inspect.getsource(tutorial001.main) == inspect.getsource(oauth_clients_tutorial001.main) async def test_a_client_secret_is_required() -> None: - """tutorial001: the provider refuses to be constructed as a public client.""" with pytest.raises(ValueError, match="client_secret is required"): IdentityAssertionOAuthProvider( server_url=MCP_SERVER_URL, @@ -64,7 +59,7 @@ async def test_a_client_secret_is_required() -> None: async def test_an_issuer_is_required() -> None: - """tutorial001: the authorization server is configuration, not discovery.""" + """The authorization server is configuration, not discovery — nothing fills in a missing issuer.""" with pytest.raises(ValueError, match="issuer is required"): IdentityAssertionOAuthProvider( server_url=MCP_SERVER_URL, @@ -77,7 +72,6 @@ async def test_an_issuer_is_required() -> None: async def test_the_id_jag_is_a_typed_jwt_carrying_the_claims_the_page_lists() -> None: - """tutorial001: the stand-in IdP signs a real ID-JAG; its header `typ` and claim set are the extension's.""" assertion = tutorial001.idp_issue_id_jag("alice@example.com", tutorial002.ISSUER, MCP_SERVER_URL) assert jwt.get_unverified_header(assertion)["typ"] == "oauth-id-jag+jwt" claims = jwt.decode(assertion, tutorial001.IDP_SIGNING_KEY, algorithms=["HS256"], audience=tutorial002.ISSUER) @@ -87,7 +81,6 @@ async def test_the_id_jag_is_a_typed_jwt_carrying_the_claims_the_page_lists() -> async def test_a_forged_assertion_is_rejected() -> None: - """tutorial002: the signature check fails closed with `invalid_grant`.""" client = tutorial002.REGISTERED_CLIENTS["finance-agent"] with pytest.raises(TokenError) as exc_info: await tutorial002.provider.exchange_identity_assertion( @@ -98,7 +91,6 @@ async def test_a_forged_assertion_is_rejected() -> None: async def test_an_assertion_for_another_audience_is_rejected() -> None: - """tutorial002: an ID-JAG whose `aud` is not this authorization server is `invalid_grant`.""" client = tutorial002.REGISTERED_CLIENTS["finance-agent"] assertion = tutorial001.idp_issue_id_jag("alice@example.com", "https://other.example.com/", MCP_SERVER_URL) with pytest.raises(TokenError) as exc_info: @@ -108,7 +100,6 @@ async def test_an_assertion_for_another_audience_is_rejected() -> None: async def test_an_assertion_for_an_unknown_resource_is_rejected() -> None: - """tutorial002: an ID-JAG naming a resource this server does not serve is `invalid_target`.""" client = tutorial002.REGISTERED_CLIENTS["finance-agent"] assertion = tutorial001.idp_issue_id_jag("alice@example.com", tutorial002.ISSUER, "https://other.example.com/mcp") with pytest.raises(TokenError) as exc_info: @@ -118,7 +109,6 @@ async def test_an_assertion_for_an_unknown_resource_is_rejected() -> None: async def test_a_replayed_assertion_is_rejected() -> None: - """tutorial002: `jti` is tracked, so presenting the same ID-JAG twice fails the second time.""" client = tutorial002.REGISTERED_CLIENTS["finance-agent"] assertion = tutorial001.idp_issue_id_jag("alice@example.com", tutorial002.ISSUER, MCP_SERVER_URL) params = IdentityAssertionParams(assertion=assertion) @@ -131,7 +121,6 @@ async def test_a_replayed_assertion_is_rejected() -> None: async def test_the_metadata_advertises_the_grant_type_and_the_id_jag_profile() -> None: - """tutorial002: the flag turns on both the `jwt-bearer` grant type and the grant-profile advertisement.""" transport = httpx.ASGITransport(app=tutorial002.auth_app) async with httpx.AsyncClient(transport=transport, base_url="https://auth.example.com") as http_client: response = await http_client.get("/.well-known/oauth-authorization-server") diff --git a/tests/docs_src/test_index.py b/tests/docs_src/test_index.py index 3012ae1a08..0639f69973 100644 --- a/tests/docs_src/test_index.py +++ b/tests/docs_src/test_index.py @@ -7,11 +7,9 @@ from docs_src.index.tutorial001 import mcp from mcp import Client -# `pyproject.toml` globally downgrades `mcp.MCPDeprecationWarning` to *ignore* because the -# SDK still calls those methods internally. A documentation example must never lean on -# that allowance, so every test that runs one re-arms the warning as an error. This is a -# per-module mark, not a conftest hook, because `pytest_collection_modifyitems` receives -# every item in the session. A hook here would break unrelated tests across the repo. +# `pyproject.toml` globally ignores `mcp.MCPDeprecationWarning` (the SDK still calls those methods +# internally), but doc examples must never lean on that, so each module re-arms it as an error. +# Per-module mark, not a conftest hook: a collection hook would affect every test in the session. pytestmark = [pytest.mark.anyio, pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning")] diff --git a/tests/docs_src/test_lifespan.py b/tests/docs_src/test_lifespan.py index d78764fd64..242045a213 100644 --- a/tests/docs_src/test_lifespan.py +++ b/tests/docs_src/test_lifespan.py @@ -23,7 +23,6 @@ async def test_lifespan_object_reaches_the_tool() -> None: async def test_context_parameter_never_reaches_the_input_schema() -> None: - """tutorial001: `ctx` is injected by the SDK, so `genre` is the only argument the model sees.""" async with Client(tutorial001.mcp) as client: (tool,) = (await client.list_tools()).tools assert tool.input_schema == snapshot( @@ -37,7 +36,6 @@ async def test_context_parameter_never_reaches_the_input_schema() -> None: async def test_startup_runs_before_the_first_request_and_shutdown_after_the_last() -> None: - """tutorial002: `connect()` runs at startup, the `finally` runs `disconnect()` at shutdown.""" assert not tutorial002.database.connected async with Client(tutorial002.mcp) as client: assert tutorial002.database.connected @@ -47,7 +45,6 @@ async def test_startup_runs_before_the_first_request_and_shutdown_after_the_last async def test_bare_context_reaches_the_lifespan_object_in_resources_and_prompts() -> None: - """A resource or prompt declaring a bare `ctx: Context` gets the same lifespan object a tool gets.""" mcp = MCPServer("Bookshop", lifespan=tutorial001.app_lifespan) @mcp.resource("books://{genre}/count") @@ -100,7 +97,6 @@ def stock_report(ctx: Context[tutorial001.AppContext]) -> str: async def test_default_lifespan_yields_an_empty_dict() -> None: - """No `lifespan=`: the SDK's default yields `{}`, so `lifespan_context` is never `None`.""" bare = MCPServer("Bare") @bare.tool() diff --git a/tests/docs_src/test_logging.py b/tests/docs_src/test_logging.py index fa4b995c6e..7978d21e17 100644 --- a/tests/docs_src/test_logging.py +++ b/tests/docs_src/test_logging.py @@ -15,7 +15,6 @@ async def test_the_tool_logs_through_the_standard_library(caplog: pytest.LogCaptureFixture) -> None: - """tutorial001: `logger.info(...)` inside a tool emits an ordinary stdlib record named after the module.""" caplog.set_level(logging.INFO) async with Client(tutorial001.mcp) as client: await client.call_tool("search_books", {"query": "dune"}) @@ -25,7 +24,6 @@ async def test_the_tool_logs_through_the_standard_library(caplog: pytest.LogCapt async def test_the_log_line_never_reaches_the_client() -> None: - """tutorial001: the result is only the return value. Log output is invisible to the model.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("search_books", {"query": "dune"}) assert result == snapshot( diff --git a/tests/docs_src/test_lowlevel.py b/tests/docs_src/test_lowlevel.py index 34746dd0b3..5fbb90e730 100644 --- a/tests/docs_src/test_lowlevel.py +++ b/tests/docs_src/test_lowlevel.py @@ -13,7 +13,6 @@ async def test_the_input_schema_on_the_wire_is_the_dict_you_wrote() -> None: - """tutorial001: nothing is derived. `tools/list` returns the literal `input_schema` dict.""" async with Client(tutorial001.server) as client: (tool,) = (await client.list_tools()).tools assert tool.name == "search_books" @@ -29,7 +28,6 @@ async def test_the_input_schema_on_the_wire_is_the_dict_you_wrote() -> None: async def test_the_client_does_not_care_which_server_class_it_connects_to() -> None: - """tutorial001: `Client(server)` accepts a low-level `Server` and the call answers like **Tools**.""" async with Client(tutorial001.server) as client: result = await client.call_tool("search_books", {"query": "dune", "limit": 5}) assert not result.is_error @@ -38,13 +36,11 @@ async def test_the_client_does_not_care_which_server_class_it_connects_to() -> N async def test_only_the_handlers_you_passed_become_capabilities() -> None: - """tutorial001: two tool handlers advertise `tools` and nothing else.""" async with Client(tutorial001.server) as client: assert client.server_capabilities.model_dump(exclude_none=True) == snapshot({"tools": {"list_changed": False}}) async def test_arguments_are_not_validated_against_your_schema() -> None: - """tutorial001: a call missing a `required` argument still reaches the handler and blows up there.""" async with Client(tutorial001.server) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("search_books", {"query": "dune"}) @@ -52,7 +48,6 @@ async def test_arguments_are_not_validated_against_your_schema() -> None: async def test_one_handler_routes_every_tool() -> None: - """tutorial002: `on_call_tool` is the single entry point; it dispatches on `params.name`.""" async with Client(tutorial002.server) as client: assert [tool.name for tool in (await client.list_tools()).tools] == ["search_books", "add_book"] result = await client.call_tool("add_book", {"title": "Dune", "author": "Frank Herbert", "year": 1965}) @@ -60,7 +55,6 @@ async def test_one_handler_routes_every_tool() -> None: async def test_an_unknown_tool_name_becomes_a_protocol_error_not_a_tool_error() -> None: - """tutorial002: raising from a handler is a `-32603` JSON-RPC error, never an `is_error` result.""" async with Client(tutorial002.server) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("does_not_exist", {}) @@ -68,7 +62,6 @@ async def test_an_unknown_tool_name_becomes_a_protocol_error_not_a_tool_error() async def test_output_schema_and_structured_content_are_both_yours_to_build() -> None: - """tutorial003: you declare the schema on the `Tool` and you build the matching payload.""" async with Client(tutorial003.server) as client: (tool,) = (await client.list_tools()).tools assert tool.output_schema == snapshot( @@ -84,8 +77,6 @@ async def test_output_schema_and_structured_content_are_both_yours_to_build() -> async def test_the_client_checks_the_schema_you_promised() -> None: - """The page's warning: a `structured_content` that violates your `output_schema` fails in `call_tool`.""" - async def promise_breaker(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: return CallToolResult(content=[TextContent(type="text", text="oops")], structured_content={"matches": "three"}) @@ -96,7 +87,6 @@ async def promise_breaker(ctx: ServerRequestContext, params: CallToolRequestPara async def test_meta_reaches_the_client_application() -> None: - """tutorial004: `_meta=` on the result comes back as `result.meta` and serialises under `_meta`.""" async with Client(tutorial004.server) as client: result = await client.call_tool("search_books", {"query": "dune", "limit": 5}) assert result.meta == {"bookshop/record_ids": ["bk_17", "bk_42", "bk_99"]} @@ -112,14 +102,12 @@ async def test_meta_reaches_the_client_application() -> None: async def test_the_lifespan_object_reaches_every_handler_with_its_type() -> None: - """tutorial005: what the lifespan yields is `ctx.lifespan_context`, typed by `Server[Catalog]`.""" async with Client(tutorial005.server) as client: result = await client.call_tool("search_books", {"query": "dune"}) assert result.content == [TextContent(type="text", text="Found 3 books: Dune, Dune Messiah, Children of Dune.")] async def test_add_request_handler_registers_a_method_the_constructor_does_not_know() -> None: - """tutorial006: the registry holds the handler and the params model it validates against.""" entry = tutorial006.server.get_request_handler("bookshop/reindex") assert entry is not None assert entry.params_type is tutorial006.ReindexParams @@ -127,13 +115,11 @@ async def test_add_request_handler_registers_a_method_the_constructor_does_not_k async def test_a_custom_method_never_changes_the_advertised_capabilities() -> None: - """tutorial006: only the spec's method families map to capabilities. `bookshop/reindex` is invisible.""" async with Client(tutorial006.server) as client: assert client.server_capabilities.model_dump(exclude_none=True) == snapshot({"tools": {"list_changed": False}}) def test_initialize_is_reserved() -> None: - """The page's `ValueError`: the handshake belongs to the runner, not to `add_request_handler`.""" server = Server("Bookshop") async def grab_the_handshake(ctx: ServerRequestContext, params: RequestParams) -> None: diff --git a/tests/docs_src/test_media.py b/tests/docs_src/test_media.py index 96ea42a0b1..177cf529ac 100644 --- a/tests/docs_src/test_media.py +++ b/tests/docs_src/test_media.py @@ -14,7 +14,6 @@ async def test_image_return_becomes_an_image_content_block() -> None: - """tutorial001: `-> Image` reaches the client as a base64 `ImageContent` block, not text.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("logo", {}) assert not result.is_error @@ -24,7 +23,6 @@ async def test_image_return_becomes_an_image_content_block() -> None: async def test_image_result_has_no_structured_content_and_no_output_schema() -> None: - """tutorial001: media is content for the model, not data for the application.""" async with Client(tutorial001.mcp) as client: (tool,) = (await client.list_tools()).tools assert tool.output_schema is None @@ -33,7 +31,6 @@ async def test_image_result_has_no_structured_content_and_no_output_schema() -> async def test_audio_return_becomes_an_audio_content_block() -> None: - """tutorial002: `Audio` is the same shape as `Image`.""" async with Client(tutorial002.mcp) as client: result = await client.call_tool("chime", {}) assert not result.is_error @@ -51,7 +48,6 @@ def test_raw_data_without_a_format_falls_back_to_a_default_mime_type() -> None: async def test_icons_are_visible_where_they_were_declared() -> None: - """tutorial003: server icons land on `server_info`, tool icons on the `Tool`, resource icons on the `Resource`.""" async with Client(tutorial003.mcp) as client: assert client.server_info.icons == [ Icon(src="https://example.com/brand-kit.png", mime_type="image/png", sizes=["48x48"]) diff --git a/tests/docs_src/test_middleware.py b/tests/docs_src/test_middleware.py index 97d9e96086..54ff6520cf 100644 --- a/tests/docs_src/test_middleware.py +++ b/tests/docs_src/test_middleware.py @@ -23,19 +23,16 @@ def _is_timing_record(record: logging.LogRecord) -> bool: - """A record emitted by tutorial001's `log_timing` middleware (and nothing else caplog caught).""" return record.name == tutorial001.logger.name def test_timing_record_predicate() -> None: - """The caplog filter keeps the middleware's own records and drops everyone else's.""" args = (logging.INFO, __file__, 1, "msg", None, None) assert _is_timing_record(logging.LogRecord(tutorial001.logger.name, *args)) assert not _is_timing_record(logging.LogRecord("somebody.elses.logger", *args)) async def test_middleware_observes_every_inbound_message(caplog: pytest.LogCaptureFixture) -> None: - """tutorial001: two client calls produce three timed lines. `server/discover` is wrapped too.""" with caplog.at_level(logging.INFO, logger=tutorial001.logger.name): async with Client(tutorial001.server) as client: await client.list_tools() @@ -46,7 +43,6 @@ async def test_middleware_observes_every_inbound_message(caplog: pytest.LogCaptu async def test_the_result_passes_through_unchanged() -> None: - """tutorial001: `log_timing` returns what `call_next` returned, so the client sees the real result.""" async with Client(tutorial001.server) as client: result = await client.call_tool("search_books", {"query": "dune"}) assert not result.is_error @@ -54,7 +50,6 @@ async def test_the_result_passes_through_unchanged() -> None: async def test_a_notification_has_no_request_id() -> None: - """`ctx.request_id is None` is how middleware tells a notification from a request.""" seen: list[tuple[str, RequestId | None]] = [] async def spy(ctx: ServerRequestContext, call_next: CallNext) -> HandlerResult: @@ -69,8 +64,6 @@ async def spy(ctx: ServerRequestContext, call_next: CallNext) -> HandlerResult: async def test_raising_before_call_next_refuses_the_message() -> None: - """A middleware that raises instead of calling `call_next` answers with a JSON-RPC error.""" - async def gate(ctx: ServerRequestContext, call_next: CallNext) -> HandlerResult: if ctx.method == "tools/call": raise MCPError(code=INVALID_REQUEST, message="No calls on Sundays.") @@ -87,7 +80,6 @@ async def gate(ctx: ServerRequestContext, call_next: CallNext) -> HandlerResult: async def test_an_unhandled_method_raises_through_the_middleware() -> None: - """A method without a handler raises `METHOD_NOT_FOUND` out of `call_next`, through the middleware.""" seen: list[tuple[str, int]] = [] async def spy(ctx: ServerRequestContext, call_next: CallNext) -> HandlerResult: @@ -107,7 +99,6 @@ async def spy(ctx: ServerRequestContext, call_next: CallNext) -> HandlerResult: async def test_initialize_cannot_be_replaced_only_wrapped() -> None: - """`add_request_handler("initialize", ...)` is rejected: middleware is the sanctioned hook.""" expected = ( "'initialize' is handled by the server runner and cannot be overridden; " "use Server.middleware to observe or wrap initialization" diff --git a/tests/docs_src/test_mrtr.py b/tests/docs_src/test_mrtr.py index 4be449edc0..62f0b90f1b 100644 --- a/tests/docs_src/test_mrtr.py +++ b/tests/docs_src/test_mrtr.py @@ -23,7 +23,6 @@ async def test_first_call_returns_an_input_required_result() -> None: - """tutorial001: a tool that is missing input returns `InputRequiredResult` instead of calling back.""" async with Client(tutorial001.server) as client: result = await client.session.call_tool("provision", {"name": "orders"}, allow_input_required=True) assert result == snapshot( @@ -49,7 +48,6 @@ async def test_first_call_returns_an_input_required_result() -> None: async def test_the_auto_loop_drives_the_call_to_completion() -> None: - """tutorial003: register `elicitation_callback`, call the tool, get a plain `CallToolResult` back.""" async with Client(tutorial001.server, elicitation_callback=tutorial003.handle_elicitation) as client: result = await client.call_tool("provision", {"name": "orders"}) assert result == snapshot( @@ -81,7 +79,6 @@ async def test_retry_with_input_responses_and_request_state_completes_the_call() async def test_the_manual_loop_drives_the_call_to_completion() -> None: - """tutorial002: `client.session.call_tool(..., allow_input_required=True)` for callers who own the loop.""" async with Client(tutorial001.server) as client: result = await tutorial002.provision(client, "billing") assert result == snapshot( @@ -105,7 +102,6 @@ async def test_a_pre_2026_session_has_nowhere_to_put_the_result() -> None: def test_fulfil_refuses_a_request_it_cannot_answer() -> None: - """tutorial002: `fulfil` is the dispatch point. This client only knows how to answer an `ElicitRequest`.""" request = CreateMessageRequest(params=CreateMessageRequestParams(messages=[], max_tokens=64)) with pytest.raises(NotImplementedError, match="sampling/createMessage"): tutorial002.fulfil(request) diff --git a/tests/docs_src/test_opentelemetry.py b/tests/docs_src/test_opentelemetry.py index 00f3af8aac..d5762a7952 100644 --- a/tests/docs_src/test_opentelemetry.py +++ b/tests/docs_src/test_opentelemetry.py @@ -11,7 +11,6 @@ async def test_a_plain_server_is_traced_with_no_extra_code(capfire: CaptureLogfire) -> None: - """tutorial001: calling a tool emits a `tools/call` SERVER span, though the example adds no middleware.""" async with Client(tutorial001.mcp) as client: await client.call_tool("search_books", {"query": "dune"}) diff --git a/tests/docs_src/test_pagination.py b/tests/docs_src/test_pagination.py index ab5949df96..6ce1269d1c 100644 --- a/tests/docs_src/test_pagination.py +++ b/tests/docs_src/test_pagination.py @@ -17,7 +17,6 @@ async def test_mcpserver_never_pages() -> None: - """The page's framing: `MCPServer` answers `resources/list` in one page with `next_cursor=None`.""" async with Client(mcp) as client: result = await client.list_resources() assert len(result.resources) == 100 @@ -25,7 +24,6 @@ async def test_mcpserver_never_pages() -> None: async def test_first_page_has_ten_resources_and_a_cursor() -> None: - """tutorial001: no cursor means page one: ten resources and a `next_cursor` the client may ignore.""" async with Client(tutorial001.server) as client: page = await client.list_resources() assert [resource.name for resource in page.resources] == [f"book-{n}" for n in range(1, 11)] @@ -33,7 +31,6 @@ async def test_first_page_has_ten_resources_and_a_cursor() -> None: async def test_the_cursor_resumes_where_the_last_page_stopped() -> None: - """tutorial001: handing `next_cursor` straight back yields the next page, no overlap.""" async with Client(tutorial001.server) as client: page = await client.list_resources(cursor="10") assert page.resources[0].name == "book-11" @@ -41,7 +38,6 @@ async def test_the_cursor_resumes_where_the_last_page_stopped() -> None: async def test_the_last_page_carries_no_cursor() -> None: - """tutorial001: `next_cursor=None` is the only end-of-list signal.""" async with Client(tutorial001.server) as client: page = await client.list_resources(cursor="90") assert len(page.resources) == 10 @@ -49,7 +45,6 @@ async def test_the_last_page_carries_no_cursor() -> None: async def test_the_loop_collects_all_one_hundred() -> None: - """tutorial001: the `cursor=` loop visits ten pages and reassembles the whole catalog.""" async with Client(tutorial001.server) as client: resources: list[Resource] = [] cursor: str | None = None @@ -66,7 +61,6 @@ async def test_the_loop_collects_all_one_hundred() -> None: async def test_the_client_program_on_the_page_runs(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial002: `main()` is the literal client program on the page and prints the stitched total.""" await tutorial002.main() assert capsys.readouterr().out == "100 resources\n" diff --git a/tests/docs_src/test_progress.py b/tests/docs_src/test_progress.py index 45cc4df8eb..08a8229310 100644 --- a/tests/docs_src/test_progress.py +++ b/tests/docs_src/test_progress.py @@ -43,12 +43,10 @@ async def show(progress: float, total: float | None, message: str | None) -> Non async def test_over_a_wire_dispatcher_callbacks_race_the_result() -> None: - """The `!!! info`: only the in-memory connection runs the callback inline. + """The `!!! info`: on a wire dispatcher, progress callbacks can outlive `call_tool`. - On a wire dispatcher (`mode="legacy"` here) each progress notification starts its own task, so - `call_tool` can return while a slow callback is still running. The callbacks below block on an - event that is only set *after* `call_tool` has returned: exactly the situation the page tells - you not to rule out. + Only the in-memory connection runs callbacks inline; under `mode="legacy"` each one runs in + its own task, so the gated callbacks here finish only after `call_tool` has already returned. """ release = anyio.Event() done = anyio.Event() diff --git a/tests/docs_src/test_prompts.py b/tests/docs_src/test_prompts.py index 1cbab3af0a..5d55fda8ce 100644 --- a/tests/docs_src/test_prompts.py +++ b/tests/docs_src/test_prompts.py @@ -14,7 +14,6 @@ async def test_function_becomes_the_prompt() -> None: - """tutorial001: the name, the docstring and the parameters are the whole `prompts/list` entry.""" async with Client(tutorial001.mcp) as client: (prompt,) = (await client.list_prompts()).prompts assert prompt.model_dump(mode="json", by_alias=True, exclude_none=True) == snapshot( @@ -27,7 +26,6 @@ async def test_function_becomes_the_prompt() -> None: async def test_returned_string_becomes_one_user_message() -> None: - """tutorial001: a `str` return value is rendered as a single `user` message.""" async with Client(tutorial001.mcp) as client: result = await client.get_prompt("review_code", {"code": "def add(a, b): return a + b"}) assert result.model_dump(mode="json", by_alias=True, exclude_none=True) == snapshot( @@ -48,7 +46,6 @@ async def test_returned_string_becomes_one_user_message() -> None: async def test_missing_required_argument_is_a_protocol_error() -> None: - """tutorial001: omitting a required argument fails the request itself. There is no error result.""" async with Client(tutorial001.mcp) as client: with pytest.raises(MCPError) as exc_info: await client.get_prompt("review_code") @@ -61,7 +58,6 @@ async def test_missing_required_argument_is_a_protocol_error() -> None: async def test_message_list_becomes_a_multi_turn_template() -> None: - """tutorial002: a list of `UserMessage` / `AssistantMessage` renders in order, roles intact.""" async with Client(tutorial002.mcp) as client: assert [p.name for p in (await client.list_prompts()).prompts] == ["review_code", "debug_error"] result = await client.get_prompt("debug_error", {"error": "TypeError: 'int' object is not iterable"}) @@ -79,7 +75,6 @@ async def test_message_list_becomes_a_multi_turn_template() -> None: async def test_title_and_argument_descriptions() -> None: - """tutorial003: `title=` and `Field(description=...)` land in the `prompts/list` entry.""" async with Client(tutorial003.mcp) as client: (prompt,) = (await client.list_prompts()).prompts assert prompt.title == "Code review" @@ -90,7 +85,6 @@ async def test_title_and_argument_descriptions() -> None: async def test_default_value_makes_the_argument_optional() -> None: - """tutorial003: a parameter with a default can be omitted and the default is used in the render.""" async with Client(tutorial003.mcp) as client: result = await client.get_prompt("review_code", {"code": "x = 1"}) assert result.messages == [ diff --git a/tests/docs_src/test_protocol_versions.py b/tests/docs_src/test_protocol_versions.py index f8e5b19f16..92592d6f2f 100644 --- a/tests/docs_src/test_protocol_versions.py +++ b/tests/docs_src/test_protocol_versions.py @@ -13,7 +13,6 @@ async def test_auto_lands_on_the_modern_version() -> None: - """tutorial001: the default `mode="auto"` probes `server/discover` and adopts the result.""" async with Client(tutorial001.mcp) as client: assert client.protocol_version == "2026-07-28" assert client.server_info.name == "Bookshop" @@ -22,7 +21,6 @@ async def test_auto_lands_on_the_modern_version() -> None: async def test_legacy_forces_the_initialize_handshake() -> None: - """tutorial002: `mode="legacy"` runs `initialize` against the very same server.""" async with Client(tutorial002.mcp, mode="legacy") as client: assert client.protocol_version == "2025-11-25" assert client.server_info.name == "Bookshop" @@ -31,7 +29,6 @@ async def test_legacy_forces_the_initialize_handshake() -> None: async def test_version_pin_sends_nothing_and_knows_nothing() -> None: - """tutorial003: a pin adopts the version locally; `server_info` and capabilities are blank.""" async with Client(tutorial003.mcp, mode="2026-07-28") as client: assert client.protocol_version == "2026-07-28" assert client.server_info == Implementation(name="", version="") @@ -43,7 +40,7 @@ async def test_version_pin_sends_nothing_and_knows_nothing() -> None: def test_handshake_era_version_is_not_a_valid_pin() -> None: - """A pre-2026 version string is rejected at construction with the exact error the page shows.""" + # The match is the exact error text the page shows. with pytest.raises( ValueError, match=re.escape( @@ -55,7 +52,6 @@ def test_handshake_era_version_is_not_a_valid_pin() -> None: async def test_prior_discover_round_trips() -> None: - """tutorial004: save `discover_result`, reconnect with it, and the identity comes back.""" async with Client(tutorial004.mcp) as client: saved = client.session.discover_result assert saved is not None @@ -68,7 +64,6 @@ async def test_prior_discover_round_trips() -> None: async def test_discover_result_survives_json() -> None: - """`DiscoverResult` is a Pydantic model: dump it to JSON, validate it back, reconnect with it.""" async with Client(tutorial004.mcp) as client: saved = client.session.discover_result assert saved is not None diff --git a/tests/docs_src/test_resources.py b/tests/docs_src/test_resources.py index 85e827833d..432fa3e9b0 100644 --- a/tests/docs_src/test_resources.py +++ b/tests/docs_src/test_resources.py @@ -15,7 +15,6 @@ async def test_function_becomes_a_listed_resource() -> None: - """tutorial001: the URI, the function name and the docstring are the whole listing entry.""" async with Client(tutorial001.mcp) as client: (resource,) = (await client.list_resources()).resources assert resource == snapshot( @@ -29,7 +28,6 @@ async def test_function_becomes_a_listed_resource() -> None: async def test_read_returns_the_return_value_as_text() -> None: - """tutorial001: reading the URI runs the function and wraps the `str` in `TextResourceContents`.""" async with Client(tutorial001.mcp) as client: result = await client.read_resource("config://app") assert result.contents == [ @@ -38,7 +36,6 @@ async def test_read_returns_the_return_value_as_text() -> None: async def test_template_is_listed_separately_from_resources() -> None: - """tutorial002: a `{placeholder}` moves the entry from `resources/list` to `resources/templates/list`.""" async with Client(tutorial002.mcp) as client: assert [r.uri for r in (await client.list_resources()).resources] == ["config://app"] (template,) = (await client.list_resource_templates()).resource_templates @@ -53,7 +50,6 @@ async def test_template_is_listed_separately_from_resources() -> None: async def test_reading_a_template_fills_the_placeholder() -> None: - """tutorial002: the client reads a concrete URI; the matched value arrives as the function argument.""" async with Client(tutorial002.mcp) as client: result = await client.read_resource("users://42/profile") assert result.contents == [ @@ -78,7 +74,6 @@ def get_user_profile(user: str) -> None: async def test_mime_type_is_what_you_declare() -> None: - """tutorial003: `mime_type=` lands in the listing verbatim; the SDK never guesses it from the value.""" async with Client(tutorial003.mcp) as client: resources = (await client.list_resources()).resources assert {r.uri: r.mime_type for r in resources} == snapshot( @@ -91,7 +86,6 @@ async def test_mime_type_is_what_you_declare() -> None: async def test_str_return_is_sent_as_is() -> None: - """tutorial003: a `str` return value is the text content, untouched.""" async with Client(tutorial003.mcp) as client: (content,) = (await client.read_resource("docs://readme")).contents assert isinstance(content, TextResourceContents) @@ -99,7 +93,7 @@ async def test_str_return_is_sent_as_is() -> None: async def test_dict_return_becomes_json_text() -> None: - """tutorial003: a non-`str`, non-`bytes` return value is serialised to JSON text.""" + """Any non-`str`, non-`bytes` return value is serialised to JSON text.""" async with Client(tutorial003.mcp) as client: (content,) = (await client.read_resource("stats://catalog")).contents assert isinstance(content, TextResourceContents) @@ -107,7 +101,6 @@ async def test_dict_return_becomes_json_text() -> None: async def test_bytes_return_becomes_a_blob() -> None: - """tutorial003: a `bytes` return value arrives as `BlobResourceContents`, base64-encoded in `blob`.""" async with Client(tutorial003.mcp) as client: (content,) = (await client.read_resource("covers://placeholder")).contents assert isinstance(content, BlobResourceContents) diff --git a/tests/docs_src/test_run.py b/tests/docs_src/test_run.py index 4b9a8926ad..e0f5b28881 100644 --- a/tests/docs_src/test_run.py +++ b/tests/docs_src/test_run.py @@ -15,7 +15,6 @@ async def test_the_run_call_is_guarded_so_importing_does_not_start_a_server() -> None: - """tutorial001: `run()` sits under `__main__`, so the module imports cleanly and serves in-memory.""" async with Client(tutorial001.mcp) as client: result = await client.call_tool("search_books", {"query": "dune"}) assert result == snapshot( @@ -27,7 +26,6 @@ async def test_the_run_call_is_guarded_so_importing_does_not_start_a_server() -> async def test_the_transport_never_changes_what_the_server_is() -> None: - """tutorial001/002/003 differ only in how they run: every client sees the identical tool.""" async with ( Client(tutorial001.mcp) as stdio_client, Client(tutorial002.mcp) as http_client, @@ -39,14 +37,12 @@ async def test_the_transport_never_changes_what_the_server_is() -> None: def test_transport_options_are_not_constructor_options() -> None: - """The page's warning: `port=` belongs to `run()`; the constructor rejects it.""" options: dict[str, Any] = {"port": 3001} with pytest.raises(TypeError, match="unexpected keyword argument 'port'"): MCPServer("Bookshop", **options) def test_settings_are_constructor_arguments_and_land_on_settings() -> None: - """tutorial003: `log_level=` ends up on `mcp.settings`; the defaults are INFO and not-debug.""" assert tutorial001.mcp.settings.log_level == "INFO" assert tutorial001.mcp.settings.debug is False assert tutorial003.mcp.settings.log_level == "DEBUG" diff --git a/tests/docs_src/test_session_groups.py b/tests/docs_src/test_session_groups.py index e6fee8ce92..4b6b019dbe 100644 --- a/tests/docs_src/test_session_groups.py +++ b/tests/docs_src/test_session_groups.py @@ -1,7 +1,7 @@ -"""`docs/advanced/session-groups.md`: every claim the page makes, proved against the real SDK. +"""Prove every claim in `docs/advanced/session-groups.md` against the real SDK. -`connect_to_server` opens a real transport (a subprocess or a socket), so these tests drive the -exact same aggregation path through `connect_with_session` with in-memory sessions instead. +`connect_to_server` opens a real transport, so tests drive the same aggregation path +through `connect_with_session` with in-memory sessions instead. """ import traceback @@ -17,7 +17,6 @@ async def test_both_servers_call_their_tool_search() -> None: - """tutorial001 + tutorial002: two unrelated servers, one colliding tool name.""" async with Client(tutorial001.mcp) as library, Client(tutorial002.mcp) as web: (library_tool,) = (await library.list_tools()).tools (web_tool,) = (await web.list_tools()).tools @@ -53,7 +52,6 @@ async def test_colliding_names_are_rejected() -> None: async def test_component_name_hook_prefixes_every_name() -> None: - """tutorial004: the hook rewrites every registered name, so both servers coexist.""" async with Client(tutorial001.mcp) as library, Client(tutorial002.mcp) as web: group = ClientSessionGroup(component_name_hook=tutorial004.by_server) await group.connect_with_session(library.server_info, library.session) @@ -63,12 +61,10 @@ async def test_component_name_hook_prefixes_every_name() -> None: def test_the_hook_is_a_plain_function_of_name_and_server_info() -> None: - """tutorial004: `by_server` builds the key from `server_info.name`.""" assert tutorial004.by_server("search", Implementation(name="Web", version="1.0.0")) == "Web.search" async def test_the_key_is_prefixed_but_the_wire_name_is_not() -> None: - """tutorial004: the dict key is yours; the `Tool` inside keeps the name the server declared.""" async with Client(tutorial002.mcp) as web: group = ClientSessionGroup(component_name_hook=tutorial004.by_server) await group.connect_with_session(web.server_info, web.session) @@ -76,7 +72,6 @@ async def test_the_key_is_prefixed_but_the_wire_name_is_not() -> None: async def test_call_tool_routes_to_the_owning_server() -> None: - """tutorial004: `group.call_tool` resolves the prefixed name to the session that owns it.""" async with Client(tutorial001.mcp) as library, Client(tutorial002.mcp) as web: group = ClientSessionGroup(component_name_hook=tutorial004.by_server) await group.connect_with_session(library.server_info, library.session) @@ -88,7 +83,6 @@ async def test_call_tool_routes_to_the_owning_server() -> None: async def test_disconnect_removes_every_component_of_that_server() -> None: - """tutorial004: `disconnect_from_server` takes the session back out of all three dicts.""" async with Client(tutorial001.mcp) as library, Client(tutorial002.mcp) as web: group = ClientSessionGroup(component_name_hook=tutorial004.by_server) await group.connect_with_session(library.server_info, library.session) diff --git a/tests/docs_src/test_shape.py b/tests/docs_src/test_shape.py index 1636bd825e..22e4915c39 100644 --- a/tests/docs_src/test_shape.py +++ b/tests/docs_src/test_shape.py @@ -1,9 +1,8 @@ """Structural invariants every `docs_src/` example must satisfy. -These are deliberately string/regex checks, not an AST analyzer: each predicate -is branch-free at the call site so the suite stays compatible with the repo's -100% branch-coverage gate, and a contributor whose doc PR goes red gets a -one-line reason, not a parser traceback. +Deliberately string/regex checks, not an AST analyzer: branch-free predicates keep the +repo's 100% branch-coverage gate happy, and a failing doc PR gets a one-line reason, +not a parser traceback. """ import importlib @@ -20,7 +19,6 @@ DOCS_SRC = REPO_ROOT / "docs_src" EXAMPLE_FILES = sorted(p for p in DOCS_SRC.rglob("*.py") if p.name != "__init__.py") -"""Every example module under `docs_src/` (the `__init__.py` scaffolding is not an example).""" _PRIVATE_MCP_IMPORT = re.compile(r"^\s*(?:from|import)\s+(mcp(?:\.\w+)*\._\w+)", re.MULTILINE) """A `_`-private segment inside the imported MODULE path: `from mcp.client._memory import X`.""" @@ -29,11 +27,9 @@ """A `_`-private NAME imported from a public `mcp` module: `from mcp.client import _memory`.""" RETIRED_NAMES = ("UrlElicitationRequiredError",) -"""Public SDK names built on protocol surfaces retired by the 2026-07-28 spec. +"""Still-exported SDK names built on surfaces the 2026-07-28 spec retired. -`UrlElicitationRequiredError` is the `-32042` flow; the spec lists that code as -reserved-never-reused, so no documentation example may teach it even while the -symbol is still exported. +`-32042` is reserved-never-reused, so no example may teach the `UrlElicitationRequiredError` flow. """ _INCLUDE_DIRECTIVE = re.compile(r"(?:--8<--\s*\"|