diff --git a/docs/advanced/multi-round-trip.md b/docs/advanced/multi-round-trip.md index de11a8db8..665808a5d 100644 --- a/docs/advanced/multi-round-trip.md +++ b/docs/advanced/multi-round-trip.md @@ -19,7 +19,7 @@ That's the whole protocol. Every leg is an ordinary request from the client to t ## The server side -The high-level `@mcp.tool()` decorator has no sugar for this yet. Today you write it on the **low-level** `Server`, whose `on_call_tool` handler is allowed to return either result type: +On `@mcp.tool()` you rarely build this by hand: declare a dependency that asks the user and the SDK returns the `InputRequiredResult` for you - that form is the **[Dependencies](../tutorial/dependencies.md)** tutorial. The manual form is the **low-level** `Server`, whose `on_call_tool` handler is allowed to return either result type: ```python title="server.py" hl_lines="44-47" --8<-- "docs_src/mrtr/tutorial001.py" @@ -93,6 +93,6 @@ Drop to the underlying session, where `allow_input_required=True` hands you the * `input_requests` is what it needs. `request_state` is an opaque resume token only the server reads. * `Client` runs the retry loop for you: register `elicitation_callback` / `sampling_callback` / `list_roots_callback` and `call_tool` returns a plain `CallToolResult`. `input_required_max_rounds` (default 10) bounds it. * To inspect or persist rounds, use `client.session.call_tool(..., allow_input_required=True)` and own the `while isinstance(result, InputRequiredResult)` loop yourself. -* The server side is the **low-level** `Server` only; `@mcp.tool()` has no sugar for this yet. +* On `@mcp.tool()`, a dependency that asks the user produces this result for you (**[Dependencies](../tutorial/dependencies.md)**); the **low-level** `Server` is the manual form. This is the mechanism that replaces server-initiated sampling and the rest of the push-style back-channel; see **[Deprecated features](deprecated.md)**. diff --git a/docs/migration.md b/docs/migration.md index 79c15d91f..fd76d8a4f 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -786,6 +786,8 @@ Positional calls (`await ctx.info("hello")`) are unaffected. `Context.elicit()` (and `elicit_with_validation()`) now render the schema first and validate each property against the spec's `PrimitiveSchemaDefinition`, raising `TypeError` at the call site for anything outside it. `Optional[T]` fields render as `{"type": ...}` with the field omitted from `required` (previously the non-spec `anyOf` shape). A bare `list[str]` field is rejected because it renders without the required enum items; use `list[Literal[...]]` or `list[str]` with `json_schema_extra` supplying the items. Unions of multiple primitives (e.g. `int | str`) and nested models are rejected. +A schema-mismatched *accepted* answer also fails differently: the call now raises `ValueError` with a stable message ("Received an accepted elicitation whose content does not match the requested schema") instead of letting pydantic's `ValidationError` escape with its internals. Code that caught `ValidationError` around `ctx.elicit()` should catch `ValueError` (or rely on the tool's error result). + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: diff --git a/docs/tutorial/dependencies.md b/docs/tutorial/dependencies.md index e9b4c789b..4e91515ef 100644 --- a/docs/tutorial/dependencies.md +++ b/docs/tutorial/dependencies.md @@ -116,11 +116,26 @@ And if the user won't answer at all - declines the question, or cancels it? That's the right default for a precondition: no answer, no order. When declining is an outcome your tool wants to handle - skip the backorder but still suggest another title - annotate `ElicitationResult[Backorder]` instead and the tool receives the full accept/decline/cancel outcome to branch on. **[Elicitation](elicitation.md)** shows that form, and everything else about asking: the schema rules, the three answers, the client's side of the conversation. +!!! info + The framework picks the question's transport from the negotiated protocol version; the code + above is identical on both. On **2026-07-28** and later the question rides inside a + multi-round-trip `tools/call` - the server returns it, the client's `elicitation_callback` + answers it, and the `Client` retries the call for you (**[Multi-round-trip requests](../advanced/multi-round-trip.md)**). On + **2025-11-25** and earlier it is a synchronous elicitation request mid-call. Each question is + asked exactly once per call - a guarantee about the question, not the resolver. In the + multi-round-trip form an eliciting resolver runs again to consume its answer, so code before + its `return Elicit(...)` runs on the asking round and again on the answering one; a resolver + that answered *without* asking, like `check_stock`, may run again whenever the call resumes + after a question. When it resumes, each answer is matched back to its question, so an + eliciting resolver must derive its question deterministically from the tool's arguments and + earlier answers - a per-call generated value (a `default_factory` id, a timestamp) is + re-derived on each round and must not appear in a question the answer is meant to bind to. + ## Recap * `Annotated[T, Resolve(fn)]` on a tool parameter: the SDK runs `fn` and injects its return value. * A resolved parameter is invisible to the model and cannot be supplied by a client. Values the model must not invent - prices, identities, permissions - belong here. -* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once per call. +* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once per round, however many consumers it has; each question is asked exactly once, an eliciting resolver runs again to consume its answer, and a resolver that never asked may run again when a call resumes. * Bad graphs fail at registration with `InvalidSignature`, not mid-call. * Return `Elicit(message, Model)` to ask the user, only when you have to. Unwrapped annotations abort on decline; `ElicitationResult[T]` lets the tool branch. diff --git a/docs/tutorial/elicitation.md b/docs/tutorial/elicitation.md index aa4f16820..7bd27a78a 100644 --- a/docs/tutorial/elicitation.md +++ b/docs/tutorial/elicitation.md @@ -76,8 +76,8 @@ A refusal is not an error. The tool decides what declining means (here, no booki !!! tip The answer is validated against your model before your code sees it. A client that sends - `"maybe"` for a `bool` doesn't corrupt your booking: the call fails with the - `ValidationError`, your `if` never runs. + `"maybe"` for a `bool` doesn't corrupt your booking: the call fails with a + schema-mismatch error, your `if` never runs. ## Ask before the tool runs diff --git a/examples/stories/legacy_elicitation/README.md b/examples/stories/legacy_elicitation/README.md index 1a9d48e60..e9812aced 100644 --- a/examples/stories/legacy_elicitation/README.md +++ b/examples/stories/legacy_elicitation/README.md @@ -68,6 +68,6 @@ uv run python -m stories.legacy_elicitation.client --http --legacy --server serv ## See also `sampling/` (same push-request shape, deprecated per SEP-2577), `mrtr/` -(planned — the 2026-era carrier), `error_handling/` +(the 2026-era carrier), `error_handling/` (`UrlElicitationRequiredError`), `refund_desk/` (resolver DI rides this push -mechanism today). +mechanism on handshake-era connections). diff --git a/examples/stories/manifest.toml b/examples/stories/manifest.toml index 57ec0e8a4..1ba2fe862 100644 --- a/examples/stories/manifest.toml +++ b/examples/stories/manifest.toml @@ -40,9 +40,8 @@ era = "legacy" status = "legacy" [story.refund_desk] -# Resolver DI rides push elicitation (ctx.elicit) today; era flips to "dual" once -# the SDK carries resolver elicitation over the 2026 input_required round-trip. -era = "legacy" +# Resolver elicitation picks its transport per era: input_required round-trips on +# the modern leg, push elicitation (ctx.elicit) on the legacy one. lowlevel = false [story.sampling] diff --git a/examples/stories/mrtr/README.md b/examples/stories/mrtr/README.md index de214988d..aaad86ca9 100644 --- a/examples/stories/mrtr/README.md +++ b/examples/stories/mrtr/README.md @@ -46,7 +46,7 @@ uv run python -m stories.mrtr.client --http --server server_lowlevel ## Spec -[Multi-round results — server features](https://modelcontextprotocol.io/specification/draft/server/tools#multi-round-results) +[Input required tool results — server features](https://modelcontextprotocol.io/specification/draft/server/tools#input-required-tool-results) ## See also diff --git a/examples/stories/refund_desk/README.md b/examples/stories/refund_desk/README.md index 0a77dd580..153504041 100644 --- a/examples/stories/refund_desk/README.md +++ b/examples/stories/refund_desk/README.md @@ -7,9 +7,10 @@ reason)` refunds what the order record says — `cents` is resolver-computed and does not appear in the input schema at all, so the model cannot supply or inflate the amount. Resolvers form a DAG (`load_order` → `refund_scope` → `refund_amount` / `ask_restock`), may return `Elicit[...]` to ask the human, -and run at most once per call. A resolver's own plain parameters are filled -from the tool's arguments by name — `load_order(order_id)` receives the -`order_id` the model passed to `refund_order`. +and ask each question at most once per call. A resolver's own plain +parameters are filled from the tool's arguments by name — +`load_order(order_id)` receives the `order_id` the model passed to +`refund_order`. ## Run it @@ -18,9 +19,9 @@ from the tool's arguments by name — `load_order(order_id)` receives the uv run python -m stories.refund_desk.client # HTTP — the client self-hosts the server on a free port, runs, then tears it -# down (--legacy: resolver elicitation rides the push request today; the -# manifest pins this era, so bare --http runs the same leg) -uv run python -m stories.refund_desk.client --http --legacy +# down (2026 protocol: the questions ride embedded input_required round-trips; +# add --legacy to ride synchronous push elicitation instead) +uv run python -m stories.refund_desk.client --http ``` ## What to look at @@ -47,21 +48,38 @@ uv run python -m stories.refund_desk.client --http --legacy ## Caveats +- **Transport per era.** The framework picks the elicitation transport from + the negotiated protocol: at >= 2026-07-28 the questions ride embedded + `input_required` round-trips (a resolver that depends on another's answer is + asked in a later round); at <= 2025-11-25 each is a synchronous + `elicitation/create` push request mid-call. Author code is identical on + both — this client runs unchanged on either era. - **Decline order.** A declined unwrapped dependency aborts resolution in tool-signature order — `cents` resolves before `restock`, so `ask_restock` never runs. Don't rely on a later resolver's side effects after an earlier consumer can abort. -- **Memoization scope.** Each resolver runs at most once per `tools/call`, - keyed by function identity; nothing is cached across calls or connections. +- **Memoization scope.** Each question is asked at most once per call, and + within a round each resolver runs at most once, keyed by function identity. + Across 2026 rounds only *elicited* outcomes persist (in `requestState`); a + resolver that resolves without eliciting is pure and may re-run each round. + An eliciting resolver's body runs again too — once to ask, once more to + consume its answer. + An answer is matched back to its question when the call resumes, so an + eliciting resolver must derive its question deterministically from the + tool's arguments and earlier answers; a per-call generated value (a + `default_factory` id, a timestamp) is re-derived each round and must not + appear in a question the answer is meant to bind to. Nothing is cached + across calls or connections. - **Validate elicited values.** Elicited answers are human-typed; check them against your records (as `_scoped` does) before acting on them. ## Spec -[Elicitation — client features](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation) +[Elicitation — client features](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation), +[Input required tool results — server features](https://modelcontextprotocol.io/specification/draft/server/tools#input-required-tool-results) ## See also -`legacy_elicitation/` (the push mechanism resolver elicitation rides on today), -`mrtr/` (the 2026 `input_required` carrier; resolver DI will ride it once the -SDK wires them together). +`mrtr/` (the 2026 `input_required` carrier these questions ride at +>= 2026-07-28), `legacy_elicitation/` (the push mechanism they ride on +handshake-era connections). diff --git a/examples/stories/refund_desk/client.py b/examples/stories/refund_desk/client.py index ee86d94b4..0ff8d28fc 100644 --- a/examples/stories/refund_desk/client.py +++ b/examples/stories/refund_desk/client.py @@ -41,7 +41,9 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa assert counts == {"scope": 0, "restock": 0}, counts # Full refund of a three-line order. The scope question fires exactly ONCE even though - # both refund_amount and ask_restock consume it — memoized within the call. + # both refund_amount and ask_restock consume it — asked at most once per call on either + # era. ask_restock needs the scope ANSWER, so at 2026 the two questions land in + # successive rounds, never one concurrent batch: counts and order are era-independent. receipt = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "arrived broken"}) assert receipt.structured_content == { "order_id": "ORD-7002", @@ -53,7 +55,7 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa # Declining restock still refunds: the tool keeps the ElicitationResult union for # `restock`, sees the decline, and just skips the restock. The scope counter moves - # again — the memo cache is per tools/call, not per connection. + # again — questions are deduped per call, not per connection. declines.add("restock") answers["scope"] = {"full": False, "sku": "canvas-tote"} receipt = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "wrong colour"}) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index c6faf0065..5a4acdd6c 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -87,6 +87,18 @@ def _validate_rendered_properties(json_schema: dict[str, Any]) -> None: ) from None +def render_elicitation_schema(schema: type[BaseModel]) -> dict[str, Any]: + """Render a model as the spec-valid `requested_schema` for an elicitation. + + Raises: + TypeError: If a field renders as something the spec's + `PrimitiveSchemaDefinition` does not accept. + """ + json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema) + _validate_rendered_properties(json_schema) + return json_schema + + async def elicit_with_validation( session: ServerSession, message: str, @@ -102,9 +114,12 @@ async def elicit_with_validation( the user or automatically generating a response. For sensitive data like credentials or OAuth flows, use elicit_url() instead. + + Raises: + ValueError: If the client accepted the elicitation without supplying + content, or with content that does not match the requested schema. """ - json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema) - _validate_rendered_properties(json_schema) + json_schema = render_elicitation_schema(schema) result = await session.elicit_form( message=message, @@ -112,17 +127,19 @@ async def elicit_with_validation( related_request_id=related_request_id, ) - if result.action == "accept" and result.content is not None: - # Validate and parse the content using the schema - validated_data = schema.model_validate(result.content) + if result.action == "accept": + if result.content is None: + raise ValueError("Received an accepted elicitation with no content") + try: + validated_data = schema.model_validate(result.content) + except ValidationError as e: + raise ValueError( + "Received an accepted elicitation whose content does not match the requested schema" + ) from e return AcceptedElicitation(data=validated_data) - elif result.action == "decline": + if result.action == "decline": return DeclinedElicitation() - elif result.action == "cancel": - return CancelledElicitation() - else: # pragma: no cover - # This should never happen, but handle it just in case - raise ValueError(f"Unexpected elicitation action: {result.action}") + return CancelledElicitation() async def elicit_url( diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 4d494db6e..82a6fa2b6 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -232,6 +232,11 @@ def request_id(self) -> str: """Get the unique ID for this request.""" return str(self.request_context.request_id) + @property + def protocol_version(self) -> str | None: + """The negotiated protocol version, or `None` outside of an active request.""" + return self._request_context.protocol_version if self._request_context is not None else None + @property def input_responses(self) -> InputResponses | None: """Client responses to a prior `InputRequiredResult.input_requests`. diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 89843a716..323ce5cdd 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -7,25 +7,47 @@ `Elicit[T]` to ask the client; the framework runs the elicitation and injects the answer. +The framework picks the elicitation transport from the negotiated protocol. At +>= 2026-07-28 it returns an `InputRequiredResult` carrying the batched questions +and resumes when the client retries with `input_responses`/`request_state` +(independent resolvers are asked in one round; a resolver depending on another's +answer is asked in a later round). At <= 2025-11-25 it issues a synchronous +`elicitation/create` request mid-call. Only *elicited* outcomes are carried in +`request_state` across rounds (so the user is asked each question once); a +resolver that returns a value without eliciting is pure and may re-run each round. + Whether the consumer receives the unwrapped model or the full `ElicitationResult` union is decided by the consumer's annotation: - `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call. - `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the full outcome; the consumer branches on accept/decline/cancel. - -Each resolver runs at most once per `tools/call` (memoized by function identity). """ from __future__ import annotations import inspect +import types import typing from collections.abc import Callable, Hashable, Mapping -from typing import Annotated, Any, Generic, cast, get_args, get_origin +from typing import Annotated, Any, Generic, Literal, TypeGuard, get_args, get_origin import anyio.to_thread -from pydantic import BaseModel +from mcp_types import ( + MISSING_REQUIRED_CLIENT_CAPABILITY, + ClientCapabilities, + ElicitationCapability, + ElicitRequest, + ElicitRequestFormParams, + ElicitResult, + FormElicitationCapability, + InputRequests, + InputRequiredResult, + InputResponses, + MissingRequiredClientCapabilityErrorData, +) +from mcp_types.version import is_version_at_least +from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar from mcp.server.elicitation import ( @@ -33,16 +55,24 @@ CancelledElicitation, DeclinedElicitation, ElicitationResult, + render_elicitation_schema, ) from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError from mcp.shared._callable_inspection import is_async_callable +from mcp.shared.exceptions import MCPError T = TypeVar("T", bound=BaseModel) # The union members the framework injects when a consumer opts into the outcome. _ELICITATION_RESULT_MEMBERS = (AcceptedElicitation, DeclinedElicitation, CancelledElicitation) +# First protocol revision whose `tools/call` carries elicitation inside +# `InputRequiredResult` rather than as a standalone server-to-client request. +# Pinned (not `LATEST_MODERN_VERSION`, which moves when newer revisions are added). +_INPUT_REQUIRED_VERSION = "2026-07-28" +_STATE_VERSION = 1 + class Resolve: """Marker for `Annotated[T, Resolve(fn)]`: fill the parameter by running `fn`.""" @@ -79,10 +109,24 @@ def __init__(self, kind: str, resolve: Resolve | None = None, wants_union: bool class _ResolverPlan: """A resolver's parameters and whether it is async, analyzed once.""" - def __init__(self, fn: Callable[..., Any], params: dict[str, _ParamPlan], is_async: bool) -> None: + def __init__( + self, + fn: Callable[..., Any], + params: dict[str, _ParamPlan], + is_async: bool, + elicit_schema: type[BaseModel] | None, + wire_key: str, + ) -> None: self.fn = fn self.params = params self.is_async = is_async + # The `T` from the resolver's `Elicit[T]` return arm, if annotated. Used to + # re-validate an outcome restored from `request_state` into a model. + self.elicit_schema = elicit_schema + # Deterministic, collision-free key for this resolver's elicitation on the + # wire (`input_requests`/`request_state`). Assigned at registration so it is + # stable across rounds even when `module:qualname` collides (closures). + self.wire_key = wire_key def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: @@ -139,6 +183,36 @@ def _contains_resolve(annotation: Any) -> bool: return any(_contains_resolve(arg) for arg in get_args(annotation)) +def _elicit_return_schema(return_annotation: Any, name: str) -> type[BaseModel] | None: + """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. + + Handles a bare `-> Elicit[T]` and a `-> T | Elicit[T]` union. Lets an elicited + outcome restored from `request_state` (a plain dict) be re-validated into its + model so dependent resolvers and tools receive a typed value. + + Raises: + InvalidSignature: If the annotation has more than one `Elicit[...]` arm; + the runtime can honor only one static question schema per resolver. + """ + # A bare `Elicit[T]` is itself a candidate; a union contributes its members. + candidates = get_args(return_annotation) if _is_union(return_annotation) else (return_annotation,) + # Typing dedupes equal union members, so two arms here are genuinely distinct. + arms = [c for c in candidates if get_origin(c) is Elicit] + if len(arms) > 1: + raise InvalidSignature( + f"Resolver {name!r} return annotation has multiple Elicit arms; " + "a resolver asks one question - split it into separate resolvers" + ) + if not arms: + return None + schema = get_args(arms[0])[0] + return schema if isinstance(schema, type) and issubclass(schema, BaseModel) else None + + +def _is_union(annotation: Any) -> bool: + return get_origin(annotation) in (typing.Union, types.UnionType) + + def _wants_union(type_arg: Any) -> bool: """True when `type_arg` is an `ElicitationResult` member (or a union of them). @@ -187,6 +261,9 @@ def build_resolver_plans( or a tool argument by name). """ plans: dict[Hashable, _ResolverPlan] = {} + # Count how many distinct resolvers share each `module:qualname` base so closures + # from one factory get distinct, deterministic wire keys (`base`, `base#1`, ...). + base_counts: dict[str, int] = {} def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: key = _resolver_key(fn) @@ -195,6 +272,11 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: if key in plans: return + base = _state_key(fn) + seen = base_counts.get(base, 0) + base_counts[base] = seen + 1 + wire_key = base if seen == 0 else f"{base}#{seen}" + hints = _type_hints(fn) sig = inspect.signature(fn) params: dict[str, _ParamPlan] = {} @@ -217,7 +299,8 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" ) - plans[key] = _ResolverPlan(fn, params, is_async_callable(fn)) + elicit_schema = _elicit_return_schema(hints.get("return"), _resolver_name(fn)) + plans[key] = _ResolverPlan(fn, params, is_async_callable(fn), elicit_schema, wire_key) for dep in nested: analyze(dep, stack + (key,)) @@ -241,76 +324,337 @@ def _is_context_annotation(annotation: Any) -> bool: return any(isinstance(c, type) and issubclass(c, Context) for c in candidates) +class _Pending(Exception): + """Internal: a resolver needs client input not yet available this round.""" + + +class _Resolution: + """Per-`tools/call` resolution state, shared across the DAG walk. + + `input_required` selects the transport: at >= 2026-07-28 elicitations are + batched into `pending` and surfaced as an `InputRequiredResult`; at older + revisions each `Elicit` is answered synchronously via `ctx.elicit`. + """ + + def __init__( + self, + plans: Mapping[Hashable, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], + input_required: bool, + ) -> None: + self.plans = plans + self.tool_args = tool_args + self.context = context + self.input_required = input_required + self.answers: InputResponses = context.input_responses or {} if input_required else {} + self.state = _decode_state(context.request_state) if input_required else {} + # In-call dedup keyed by resolver identity (distinguishes two instances of + # the same bound method); `persist` holds the wire-shaped record of each + # elicited outcome, keyed by its wire key - exactly what the next round's + # `request_state` carries. Entries are the client's own (validated) wire + # data, never re-derived from a model, so encode-restore is the identity. + # Pure resolvers are cheap to re-run each round and are not persisted. + self.cache: dict[Hashable, ElicitationResult[Any]] = {} + self.persist: dict[str, _StateEntry] = {} + self.pending: InputRequests = {} + + +def _state_key(fn: Callable[..., Any]) -> str: + """Worker-stable base wire key for a resolver, derived only from registration data. + + `input_requests`/`request_state` must round-trip through the client and resume on + any worker (stateless HTTP), so the key carries no `id(...)`: it is the resolver's + `module:qualname` (a callable object uses its type's). Distinct resolvers that + share this base - two instances of one method, two closures from one factory - are + disambiguated deterministically by `build_resolver_plans` (`base`, `base#1`, ...). + """ + qualname = getattr(fn, "__qualname__", None) or type(fn).__qualname__ + module = getattr(fn, "__module__", None) or type(fn).__module__ + return f"{module}:{qualname}" + + async def resolve_arguments( resolved_params: Mapping[str, tuple[Resolve, bool]], plans: Mapping[Hashable, _ResolverPlan], tool_args: Mapping[str, Any], context: Context[Any, Any], -) -> dict[str, Any]: +) -> dict[str, Any] | InputRequiredResult: """Resolve every `Resolve`-marked tool parameter into a concrete value. - Each resolver runs at most once (memoized by function identity). Returns a - mapping of tool parameter name to the value to inject. + Returns the mapping of tool parameter name to injected value when every + resolver is satisfied. When a resolver still needs client input (and the + negotiated protocol is >= 2026-07-28), returns an `InputRequiredResult` + carrying the batched questions instead; the tool body is not run. + + An eliciting resolver asks its question once - its answer is carried in + `request_state` across rounds - while a resolver that resolves without + eliciting is pure and may re-run on each round. Raises: ToolError: If an elicited value is declined or cancelled and the consumer asked for the unwrapped model (rather than the result union). """ - cache: dict[Hashable, ElicitationResult[Any]] = {} + # `ctx.protocol_version` is `None` outside an active request: `MCPServer.call_tool()` + # called directly builds such a `Context`, and a tool whose resolvers never elicit + # must still work there. A missing version means the synchronous (non-input_required) + # transport, which never reaches a server-to-client request anyway. + res = _Resolution(plans, tool_args, context, _uses_input_required(context.protocol_version)) injected: dict[str, Any] = {} for name, (marker, wants_union) in resolved_params.items(): - outcome = await _resolve(marker.fn, plans, tool_args, context, cache) + try: + outcome = await _resolve(marker.fn, res) + except _Pending: + continue injected[name] = outcome if wants_union else _unwrap(outcome, name) + + if res.pending: + return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.persist)) return injected -async def _resolve( - fn: Callable[..., Any], - plans: Mapping[Hashable, _ResolverPlan], - tool_args: Mapping[str, Any], - context: Context[Any, Any], - cache: dict[Hashable, ElicitationResult[Any]], -) -> ElicitationResult[Any]: - key = _resolver_key(fn) - if key in cache: - return cache[key] +async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResult[Any]: + """Resolve one resolver, deduped within the call by its resolver identity. + + Raises `_Pending` when the resolver (or one of its dependencies) needs client + input that has not arrived yet. + """ + cache_key = _resolver_key(fn) + if cache_key in res.cache: + return res.cache[cache_key] + + plan = res.plans[cache_key] + wire_key = plan.wire_key + if wire_key in res.pending: + # Already asked this round by another consumer; don't run the resolver again. + raise _Pending + # Restore a prior round's outcome directly only when its model is known from the + # `Elicit[T]` return arm. Without that (a resolver that elicits but isn't annotated + # `-> ... Elicit[T]`), fall through and re-run the resolver so `_elicit` can + # re-validate the stored answer against the live `Elicit.schema`. + if wire_key in res.state and (plan.elicit_schema is not None or res.state[wire_key].action != "accept"): + outcome = _restore_outcome(res, wire_key, plan.elicit_schema) + if outcome is not None: + res.cache[cache_key] = outcome + return outcome - plan = plans[key] kwargs: dict[str, Any] = {} + dep_pending = False for param_name, param_plan in plan.params.items(): if param_plan.kind == "context": - kwargs[param_name] = context + kwargs[param_name] = res.context elif param_plan.kind == "by_name": - kwargs[param_name] = tool_args[param_name] + kwargs[param_name] = res.tool_args[param_name] else: assert param_plan.resolve is not None - dep_outcome = await _resolve(param_plan.resolve.fn, plans, tool_args, context, cache) + try: + # Visit every dependency so independent ones that need input are all + # collected into `res.pending` and batched into a single round. + dep_outcome = await _resolve(param_plan.resolve.fn, res) + except _Pending: + dep_pending = True + continue kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) + if dep_pending: + raise _Pending + result: Any if plan.is_async: result = await fn(**kwargs) else: result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) - outcome: ElicitationResult[Any] - if isinstance(result, Elicit): - elicit = cast("Elicit[BaseModel]", result) - outcome = await context.elicit(elicit.message, elicit.schema) + if _is_elicit(result): + outcome = await _elicit(result, wire_key, res) else: - # A resolver may return any type (not just `BaseModel`); `model_construct` - # wraps it as an accepted result without validating against the schema bound. - outcome = cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=result)) + # A resolver may return any type (not just `BaseModel`), so accept it as the + # outcome without validating against the schema bound. Plain outcomes are not + # persisted in `request_state`; the resolver re-runs next round instead. + outcome = _accepted(result) - cache[key] = outcome + res.cache[cache_key] = outcome return outcome +async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> ElicitationResult[Any]: + """Turn a resolver's `Elicit` into an outcome via the negotiated transport.""" + if not res.input_required: + return await res.context.elicit(elicit.message, elicit.schema) + + # Answered in a prior round (restored without a known schema, e.g. an unannotated + # resolver): re-validate the stored entry against the live `Elicit.schema`. A + # recorded outcome wins over a re-sent answer; an invalid entry self-deletes and + # falls through to the fresh answer (or to re-asking). + outcome = _restore_outcome(res, key, elicit.schema) + if outcome is not None: + return outcome + + answer = res.answers.get(key) + if answer is None: + _require_form_elicitation(res.context, key) + res.pending[key] = _elicit_request(elicit) + raise _Pending + if not isinstance(answer, ElicitResult): + raise ToolError(f"Resolver {key!r} received a non-elicitation response") + if answer.action == "accept": + if answer.content is None: + raise ToolError(f"Resolver {key!r} received an accepted elicitation with no content") + try: + data = elicit.schema.model_validate(answer.content) + except ValidationError as e: + raise ToolError( + f"Resolver {key!r} received an accepted elicitation whose content does not match the requested schema" + ) from e + # Persist the exact wire content that just passed validation - never the + # model - so restoring next round revalidates the same bytes the client sent. + res.persist[key] = _StateEntry(action="accept", data=answer.content) + return AcceptedElicitation(data=data) + if answer.action == "decline": + res.persist[key] = _StateEntry(action="decline") + return DeclinedElicitation() + res.persist[key] = _StateEntry(action="cancel") + return CancelledElicitation() + + def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: if isinstance(outcome, AcceptedElicitation): return outcome.data raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") +def _is_elicit(value: Any) -> TypeGuard[Elicit[Any]]: + """Runtime narrow of a resolver's return value to a (parameter-erased) `Elicit`.""" + return isinstance(value, Elicit) + + +def _accepted(data: Any) -> AcceptedElicitation[Any]: + """Wrap a resolved value as an accepted outcome without schema validation. + + A resolver may return any type (the schema bound only constrains `Elicit[T]`), + and a value restored from `request_state` is already validated. + """ + return AcceptedElicitation[Any].model_construct(data=data) + + +def _uses_input_required(protocol_version: str | None) -> bool: + """True when this request must elicit via `InputRequiredResult` (>= 2026-07-28). + + Older revisions still carry a standalone `elicitation/create` server-to-client + request, so the framework keeps the synchronous `ctx.elicit()` path for them. + """ + return protocol_version is not None and is_version_at_least(protocol_version, _INPUT_REQUIRED_VERSION) + + +def _require_form_elicitation(context: Context[Any, Any], key: str) -> None: + """Assert the client declared form elicitation before queueing a question for it. + + The spec forbids sending an `input_requests` entry the client has not declared a + capability for. A bare `elicitation: {}` declaration (the only shape before modes + existed) counts as form support; an explicit url-only declaration does not. + + Raises: + MCPError: With code `MISSING_REQUIRED_CLIENT_CAPABILITY` and a + `requiredCapabilities` payload when form elicitation is not declared. + """ + capabilities = context.client_capabilities + elicitation = capabilities.elicitation if capabilities is not None else None + if elicitation is not None and (elicitation.form is not None or elicitation.url is None): + return + data = MissingRequiredClientCapabilityErrorData( + required_capabilities=ClientCapabilities(elicitation=ElicitationCapability(form=FormElicitationCapability())) + ) + raise MCPError( + code=MISSING_REQUIRED_CLIENT_CAPABILITY, + message=f"Client did not declare the form elicitation capability required by resolver {key!r}", + data=data.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + + +def _elicit_request(elicit: Elicit[Any]) -> ElicitRequest: + """Render an `Elicit[T]` as the embedded `elicitation/create` request for `input_requests`.""" + json_schema = render_elicitation_schema(elicit.schema) + return ElicitRequest(params=ElicitRequestFormParams(message=elicit.message, requested_schema=json_schema)) + + +class _StateEntry(BaseModel): + """One resolver's recorded outcome inside `request_state`.""" + + action: Literal["accept", "decline", "cancel"] + data: Any = None + + +class _State(BaseModel): + """The decoded `request_state`: resolver outcomes from earlier rounds.""" + + v: int + outcomes: dict[str, _StateEntry] = {} + + +def _decode_state(request_state: str | None) -> dict[str, _StateEntry]: + """Decode the per-call resolution progress from `request_state`. + + `request_state` is client-trusted (integrity sealing is a follow-up); validate + it through `_State` and treat anything malformed as "no progress yet". + """ + if not request_state: + return {} + try: + state = _State.model_validate_json(request_state) + except ValidationError: + return {} + return state.outcomes if state.v == _STATE_VERSION else {} + + +def _encode_state(outcomes: Mapping[str, _StateEntry]) -> str: + """Encode recorded elicitation outcomes (keyed by wire key) for the next round. + + Entries already hold the client's wire-shaped data exactly as it was sent (and + validated), so encoding is pure wrapping: encode-restore is the identity. + """ + return _State(v=_STATE_VERSION, outcomes=dict(outcomes)).model_dump_json() + + +def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel] | None) -> ElicitationResult[Any]: + """Rebuild an `ElicitationResult` from a decoded `request_state` entry. + + Raises: + ValidationError: If `schema` is known and the entry's data does not + validate against it. + """ + if entry.action == "decline": + return DeclinedElicitation() + if entry.action == "cancel": + return CancelledElicitation() + data = entry.data + if schema is not None: + data = schema.model_validate(data) + return _accepted(data) + + +def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel] | None) -> ElicitationResult[Any] | None: + """Restore `key`'s recorded outcome from a prior round, or `None` when absent. + + `request_state` is client-trusted, so an entry whose data fails validation gets + the `_decode_state` treatment - dropped as if no progress was recorded, so the + question is asked again - rather than surfacing a validation error. + + Carries the original decoded entry forward unchanged in `res.persist`: if a + later resolver is still pending, the next round's `request_state` is built from + `res.persist`, so an earlier answer must stay there - byte-identical, never + re-derived - or it would be dropped and re-asked. + """ + entry = res.state.get(key) + if entry is None: + return None + try: + outcome = _outcome_from_state(entry, schema) + except ValidationError: + del res.state[key] + return None + res.persist[key] = entry + return outcome + + __all__ = [ "Resolve", "Elicit", diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 6aab3c777..50d28f574 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -4,7 +4,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from mcp_types import Icon, ToolAnnotations +from mcp_types import Icon, InputRequiredResult, ToolAnnotations from pydantic import BaseModel, Field from mcp.server.mcpserver.exceptions import ToolError @@ -135,9 +135,12 @@ async def run( pre_validated: dict[str, Any] | None = None if self.resolved_params: pre_validated = self.fn_metadata.validate_arguments(arguments) - pass_directly |= await resolve_arguments( - self.resolved_params, self.resolver_plans, pre_validated, context - ) + resolved = await resolve_arguments(self.resolved_params, self.resolver_plans, pre_validated, context) + if isinstance(resolved, InputRequiredResult): + # A resolver still needs client input (>= 2026-07-28): surface the + # batched questions instead of running the tool body this round. + return self.fn_metadata.convert_result(resolved) if convert_result else resolved + pass_directly |= resolved result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, diff --git a/tests/docs_src/test_dependencies.py b/tests/docs_src/test_dependencies.py index 73355a892..06d893585 100644 --- a/tests/docs_src/test_dependencies.py +++ b/tests/docs_src/test_dependencies.py @@ -1,5 +1,7 @@ """`docs/tutorial/dependencies.md`: every claim the page makes, proved against the real SDK.""" +from typing import Literal + import pytest from inline_snapshot import snapshot from mcp_types import ElicitRequestParams, ElicitResult, TextContent @@ -79,18 +81,24 @@ def get(self, key: str, default: int) -> int: assert inventory.lookups == ["Dune", "Dune"] -async def test_an_in_stock_order_asks_no_question() -> None: +# The `!!! info` claims the tutorial003 behaviour is transport-independent, so each claim is +# proved on both: mode="legacy" elicits synchronously mid-call (2025-11-25 and earlier), while +# mode="auto" negotiates 2026-07-28, where the question rides a multi-round-trip `tools/call` +# and `Client` drives the retries. +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_an_in_stock_order_asks_no_question(mode: Literal["legacy", "auto"]) -> None: """tutorial003: `confirm_backorder` returns directly when stock exists - no round-trip.""" async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover raise AssertionError("an in-stock order must not elicit") - async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=never) as client: + async with Client(tutorial003.mcp, mode=mode, elicitation_callback=never) as client: result = await client.call_tool("order_book", {"title": "Dune"}) assert result.content == [TextContent(type="text", text="Ordered 'Dune'.")] +@pytest.mark.parametrize("mode", ["legacy", "auto"]) @pytest.mark.parametrize( ("confirm", "expected"), [ @@ -98,7 +106,9 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E (False, "No order placed."), ], ) -async def test_an_out_of_stock_order_asks_and_honours_the_answer(confirm: bool, expected: str) -> None: +async def test_an_out_of_stock_order_asks_and_honours_the_answer( + mode: Literal["legacy", "auto"], confirm: bool, expected: str +) -> None: """tutorial003: the resolver elicits, the SDK validates the answer, the tool reads it.""" asked: list[str] = [] @@ -106,20 +116,21 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) asked.append(params.message) return ElicitResult(action="accept", content={"confirm": confirm}) - async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=on_elicit) as client: + async with Client(tutorial003.mcp, mode=mode, elicitation_callback=on_elicit) as client: result = await client.call_tool("order_book", {"title": "Neuromancer"}) assert result.content == [TextContent(type="text", text=expected)] assert asked == ["'Neuromancer' is out of stock (2-3 weeks). Order anyway?"] -async def test_declining_an_unwrapped_dependency_aborts_the_call() -> None: +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_declining_an_unwrapped_dependency_aborts_the_call(mode: Literal["legacy", "auto"]) -> None: """tutorial003: no answer, no order - the error text on the page is the real one.""" async def decline(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") - async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=decline) as client: + async with Client(tutorial003.mcp, mode=mode, elicitation_callback=decline) as client: result = await client.call_tool("order_book", {"title": "Neuromancer"}) assert result.is_error diff --git a/tests/docs_src/test_elicitation.py b/tests/docs_src/test_elicitation.py index 4c9bb4036..a28f1087f 100644 --- a/tests/docs_src/test_elicitation.py +++ b/tests/docs_src/test_elicitation.py @@ -124,7 +124,7 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) result = await client.call_tool("book_table", {"date": "2025-12-25", "party_size": 2}) assert result.is_error assert isinstance(result.content[0], TextContent) - assert "Input should be a valid boolean" in result.content[0].text + assert "does not match the requested schema" in result.content[0].text class Address(BaseModel): diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 1f4f72408..7e92f1c4e 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -1,12 +1,26 @@ """Tests for resolver dependency injection (MRTR) on MCPServer tools.""" +import json +from collections.abc import Callable +from datetime import datetime from typing import Annotated, Any, Literal +import anyio import pytest -from mcp_types import ElicitRequestParams, ElicitResult, TextContent +from mcp_types import ( + MISSING_REQUIRED_CLIENT_CAPABILITY, + CallToolResult, + CreateMessageResult, + ElicitRequestFormParams, + ElicitRequestParams, + ElicitResult, + InputRequiredResult, + InputResponses, + TextContent, +) from pydantic import BaseModel, Field -from mcp import Client +from mcp import Client, InputRequiredRoundsExceededError from mcp.client import ClientRequestContext from mcp.server.mcpserver import ( AcceptedElicitation, @@ -19,8 +33,19 @@ Resolve, ) from mcp.server.mcpserver.exceptions import InvalidSignature -from mcp.server.mcpserver.resolve import _resolver_key, find_resolved_parameters +from mcp.server.mcpserver.resolve import ( + _decode_state, + _elicit_return_schema, + _encode_state, + _outcome_from_state, + _resolver_key, + _state_key, + _StateEntry, + _uses_input_required, + find_resolved_parameters, +) from mcp.server.mcpserver.tools.base import Tool +from mcp.shared.exceptions import MCPError class Login(BaseModel): @@ -31,6 +56,14 @@ class Confirm(BaseModel): ok: bool +class Handle(BaseModel): + user_name: str = Field(alias="userName") + + +class Account(BaseModel): + user_name: str = Field(validation_alias="vUser", serialization_alias="sUser") + + async def _alias_login(ctx: Context) -> Login: return Login(username="x") # pragma: no cover - only the signature is inspected @@ -46,6 +79,12 @@ async def _decline(context: ClientRequestContext, params: ElicitRequestParams) - return ElicitResult(action="decline") +async def _never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + # Declares the form elicitation capability for clients that drive the + # input_required loop manually; the auto-driver never invokes it. + raise AssertionError("should not be called") + + async def _text(client: Client, tool: str, args: dict[str, object]) -> str: result = await client.call_tool(tool, args) assert len(result.content) == 1 @@ -53,6 +92,18 @@ async def _text(client: Client, tool: str, args: dict[str, object]) -> str: return result.content[0].text +def _answer_round( + result: InputRequiredResult, answer: Callable[[str, ElicitRequestFormParams], ElicitResult] +) -> InputResponses: + """Fulfil every question in one `InputRequiredResult` round via `answer(key, request_params)`.""" + assert result.input_requests is not None + responses: InputResponses = {} + for key, req in result.input_requests.items(): + assert isinstance(req.params, ElicitRequestFormParams) + responses[key] = answer(key, req.params) + return responses + + @pytest.mark.anyio async def test_resolver_returns_value_directly_without_eliciting(): mcp = MCPServer(name="Direct") @@ -291,6 +342,20 @@ async def tool(login: Annotated[Login, Resolve(login)]) -> str: Tool.from_function(tool) +def test_multiple_elicit_arms_raise_at_registration(): + # The runtime can honor only one static question schema per resolver, so an + # ambiguous `-> Elicit[A] | Elicit[B]` must not register (the second arm used + # to be silently ignored). + async def ambiguous(ctx: Context) -> Elicit[Login] | Elicit[Confirm]: + raise NotImplementedError # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(ambiguous)]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="multiple Elicit arms"): + Tool.from_function(tool) + + def test_resolve_marker_inside_a_union_raises_at_registration(): async def login(ctx: Context) -> Login: return Login(username="x") # pragma: no cover @@ -569,3 +634,960 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: assert await _text(client, "delete_folder", {"path": "/docs"}) == expected assert ("/docs" in fs) is (expected != "deleted /docs") + + +@pytest.mark.anyio +async def test_input_required_first_round_returns_the_question(): + mcp, fs = _delete_folder_server() + fs["/docs"] = ["a.txt", "b.txt"] + + async with Client(mcp, elicitation_callback=_never) as client: # mode="auto" negotiates 2026-07-28 + assert client.session.protocol_version == "2026-07-28" + result = await client.session.call_tool("delete_folder", {"path": "/docs"}, allow_input_required=True) + assert isinstance(result, InputRequiredResult) + assert result.input_requests is not None + (request,) = result.input_requests.values() + assert request.method == "elicitation/create" + assert "/docs has 2 file(s)" in request.params.message + assert result.request_state is not None + assert "/docs" in fs # nothing deleted before the answer arrives + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("action", "content", "expected"), + [ + ("accept", {"ok": True}, "deleted /docs"), + ("accept", {"ok": False}, "kept the folder"), + ("decline", None, "declined: folder not deleted"), + ("cancel", None, "cancelled: folder not deleted"), + ], +) +async def test_input_required_loop_handles_every_outcome( + action: Literal["accept", "decline", "cancel"], + content: dict[str, str | int | float | bool | list[str] | None] | None, + expected: str, +): + # End-to-end at 2026-07-28: the client's auto-driver answers the embedded + # elicitation through the ordinary `elicitation_callback` and retries. + mcp, fs = _delete_folder_server() + fs["/docs"] = ["a.txt", "b.txt"] + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + assert "/docs has 2 file(s)" in params.message + return ElicitResult(action=action, content=content) + + async with Client(mcp, elicitation_callback=callback) as client: # mode="auto" negotiates 2026-07-28 + result = await client.call_tool("delete_folder", {"path": "/docs"}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == expected + assert ("/docs" in fs) is (expected != "deleted /docs") + + +@pytest.mark.anyio +async def test_input_required_empty_folder_completes_without_eliciting(): + mcp, fs = _delete_folder_server() + fs["/empty"] = [] + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit for an empty folder") + + async with Client(mcp, elicitation_callback=never) as client: + result = await client.call_tool("delete_folder", {"path": "/empty"}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "deleted /empty" + assert "/empty" not in fs + + +@pytest.mark.anyio +async def test_input_required_resolver_asks_and_consumes_then_never_reruns(): + mcp = MCPServer(name="ExactlyOnceMRTR") + counts = {"login": 0, "confirm": 0} + + async def login(ctx: Context) -> Login | Elicit[Login]: + counts["login"] += 1 + return Elicit("Username?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + counts["confirm"] += 1 + return Elicit(f"As {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + asked: list[str] = [] + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params.message) + if "Username" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("act", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + # `confirm` can only form its question from `login`'s answer, so the auto-driver + # sees the questions in two successive rounds and answers each exactly once. + assert asked == ["Username?", "As octocat?"] + # An eliciting resolver runs twice - once to ask, once to consume the answer - + # then its outcome is carried in `request_state` and it never runs again. `login` + # asks in round 1 and is consumed in round 2; `confirm` (which depends on + # `login`) only forms its question once `login` is known, so it asks in round 2 + # and is consumed in round 3. Neither re-runs beyond consuming its own answer. + assert counts == {"login": 2, "confirm": 2} + + +@pytest.mark.anyio +async def test_input_required_batches_independent_elicits_in_one_round(): + mcp = MCPServer(name="BatchedMRTR") + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + async def ask_confirm(ctx: Context) -> Elicit[Confirm]: + return Elicit("Confirm?", Confirm) + + @mcp.tool() + async def both( + name: Annotated[Login, Resolve(ask_name)], + confirm: Annotated[Confirm, Resolve(ask_confirm)], + ) -> str: + return f"{name.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "Name" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=_never) as client: + # Both independent resolvers are asked together in the first round. + first = await client.session.call_tool("both", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 + + # Answering both and echoing `request_state` completes in a single retry. + final = await client.session.call_tool( + "both", + {}, + input_responses=_answer_round(first, answer), + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_auto_driver_answers_independent_questions_in_a_single_round(): + # The pure `count_round` resolver is never persisted in `request_state`, so it + # re-runs on every round: its run count is the number of rounds the call took. + mcp = MCPServer(name="AutoBatch") + rounds = 0 + + async def count_round(ctx: Context) -> int: + nonlocal rounds + rounds += 1 + return rounds + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + async def ask_confirm(ctx: Context) -> Elicit[Confirm]: + return Elicit("Confirm?", Confirm) + + @mcp.tool() + async def both( + round_no: Annotated[int, Resolve(count_round)], + name: Annotated[Login, Resolve(ask_name)], + confirm: Annotated[Confirm, Resolve(ask_confirm)], + ) -> str: + return f"{name.username}:{confirm.ok}" + + asked: list[str] = [] + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params.message) + if "Name" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("both", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + # The driver dispatches batched questions concurrently, so order is unspecified. + assert sorted(asked) == ["Confirm?", "Name?"] # both questions, each exactly once + assert rounds == 2 # one question round, then the completing round + + +def test_uses_input_required_version_gate(): + assert _uses_input_required("2026-07-28") is True + assert _uses_input_required("2025-11-25") is False + assert _uses_input_required(None) is False + + +@pytest.mark.parametrize( + "request_state", + [ + None, + "", + "not json", + '{"v": 99, "outcomes": {}}', # wrong version + '{"v": 1}', # missing outcomes + '{"v": 1, "outcomes": []}', # outcomes not a dict + "[1, 2, 3]", # not an object + ], +) +def test_decode_state_tolerates_malformed_request_state(request_state: str | None): + assert _decode_state(request_state) == {} + + +def test_state_round_trips_accept_decline_cancel(): + entries = { + "a": _StateEntry(action="accept", data={"username": "octocat"}), + "b": _StateEntry(action="decline"), + "c": _StateEntry(action="cancel"), + "d": _StateEntry(action="accept", data="raw-token"), # non-dict wire value + } + decoded = _decode_state(_encode_state(entries)) + assert decoded == entries # encode-restore is the identity on the stored entries + + accepted = _outcome_from_state(decoded["a"], Login) + assert isinstance(accepted, AcceptedElicitation) and accepted.data == Login(username="octocat") + assert isinstance(_outcome_from_state(decoded["b"], None), DeclinedElicitation) + assert isinstance(_outcome_from_state(decoded["c"], None), CancelledElicitation) + raw = _outcome_from_state(decoded["d"], None) + assert isinstance(raw, AcceptedElicitation) and raw.data == "raw-token" + + +def test_elicit_return_schema_extraction(): + assert _elicit_return_schema(Elicit[Login], "r") is Login # bare Elicit[T] + assert _elicit_return_schema(Login | Elicit[Login], "r") is Login # union arm + assert _elicit_return_schema(Login, "r") is None # no Elicit arm + assert _elicit_return_schema(None, "r") is None + # The bound on `Elicit`'s parameter is unenforced at runtime, so a non-model + # subscription is constructible and must yield no schema rather than crash. + unbounded_elicit: Any = Elicit + assert _elicit_return_schema(unbounded_elicit[int], "r") is None + # Two distinct Elicit arms are ambiguous: the runtime can honor only one schema. + with pytest.raises(InvalidSignature, match="'r' return annotation has multiple Elicit arms"): + _elicit_return_schema(Elicit[Login] | Elicit[Confirm], "r") + + +@pytest.mark.anyio +async def test_non_elicitation_response_raises(): + mcp = MCPServer(name="WrongResponse") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + @mcp.tool() + async def tool(name: Annotated[Login, Resolve(ask)]) -> str: + return name.username # pragma: no cover + + async with Client(mcp, elicitation_callback=_never) as client: + r1 = await client.session.call_tool("tool", {}, allow_input_required=True) + assert isinstance(r1, InputRequiredResult) + assert r1.input_requests is not None + (key,) = r1.input_requests + # Answer with a sampling result instead of an elicitation result. + r2 = await client.session.call_tool( + "tool", + {}, + input_responses={ + key: CreateMessageResult(role="assistant", content=TextContent(type="text", text="x"), model="m") + }, + request_state=r1.request_state, + allow_input_required=True, + ) + assert isinstance(r2, CallToolResult) + assert r2.is_error + assert isinstance(r2.content[0], TextContent) + assert "non-elicitation response" in r2.content[0].text + + +@pytest.mark.anyio +async def test_direct_call_tool_with_non_eliciting_resolver(): + # `MCPServer.call_tool()` called directly builds a Context with no request, so + # `ctx.protocol_version` is None. A tool whose resolvers never elicit must still + # work there (regression: it used to raise "Context is not available"). + mcp = MCPServer(name="Direct") + + async def whoami(ctx: Context) -> Login: + return Login(username="direct") + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(whoami)]) -> str: + return login.username + + result = await mcp.call_tool("tool", {}, Context(mcp_server=mcp)) + assert isinstance(result, CallToolResult) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "direct" + + +@pytest.mark.anyio +async def test_two_instances_of_one_method_do_not_collide(): + mcp = MCPServer(name="Instances") + + class Service: + def __init__(self, name: str) -> None: + self.name = name + + async def who(self, ctx: Context) -> Login: + return Login(username=self.name) + + alice, bob = Service("alice"), Service("bob") + + @mcp.tool() + async def both( + a: Annotated[Login, Resolve(alice.who)], + b: Annotated[Login, Resolve(bob.who)], + ) -> str: + return f"{a.username},{b.username}" + + result = await mcp.call_tool("both", {}, Context(mcp_server=mcp)) + assert isinstance(result, CallToolResult) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "alice,bob" + + +@pytest.mark.anyio +async def test_non_serializable_sibling_resolver_does_not_break_rounds(): + mcp = MCPServer(name="NonSerializable") + + async def clock(ctx: Context) -> datetime: + return datetime(2026, 1, 1) + + async def ask(ctx: Context) -> Elicit[Confirm]: + return Elicit("ok?", Confirm) + + @mcp.tool() + async def act( + when: Annotated[datetime, Resolve(clock)], + confirm: Annotated[Confirm, Resolve(ask)], + ) -> str: + return f"{when.year}:{confirm.ok}" + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("act", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "2026:True" + + +@pytest.mark.anyio +async def test_bare_elicit_dependency_restored_as_model(): + # A `-> Elicit[Login]` (bare, no union) resolver feeds a dependent resolver. After + # the round-trip the dependency must come back as a Login model, not a raw dict. + mcp = MCPServer(name="BareElicitDep") + + async def login(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"as {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + if "user" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "as octocat?" in params.message # proves login was a real model + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("act", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + +@pytest.mark.anyio +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_accept_with_no_content_is_an_error_not_a_cancel(mode: Literal["legacy", "auto"]): + # Both transports must agree: mode="legacy" elicits synchronously mid-call, + # mode="auto" rides the 2026-07-28 input_required loop. + mcp = MCPServer(name="AcceptNoContent") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username # pragma: no cover + + async def empty_accept(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content=None) + + async with Client(mcp, mode=mode, elicitation_callback=empty_accept) as client: + result = await client.call_tool("tool", {}) + assert result.is_error + assert isinstance(result.content[0], TextContent) + assert "no content" in result.content[0].text + + +@pytest.mark.anyio +async def test_eliciting_tool_without_client_capability_is_a_protocol_error(): + # The server must not send an `input_requests` entry the client has not declared + # capability for: with no `elicitation` declared (no callback), the call fails as + # a -32021 protocol error, not a CallToolResult execution failure. + mcp = MCPServer(name="NoElicitationCapability") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username # pragma: no cover + + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.session.call_tool("tool", {}, allow_input_required=True) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY + assert exc_info.value.error.data is not None + assert "elicitation" in exc_info.value.error.data["requiredCapabilities"] + + +@pytest.mark.anyio +async def test_independent_nested_deps_batch_into_one_round(): + mcp = MCPServer(name="NestedBatch") + + async def ask_a(ctx: Context) -> Elicit[Login]: + return Elicit("A name?", Login) + + async def ask_b(ctx: Context) -> Elicit[Confirm]: + return Elicit("B confirm?", Confirm) + + # `combine` depends on two independent eliciting resolvers; both must be asked + # in the same round, not serialized across two InputRequiredResult rounds. + async def combine( + a: Annotated[Login, Resolve(ask_a)], + b: Annotated[Confirm, Resolve(ask_b)], + ) -> Login: + return Login(username=f"{a.username}:{b.ok}") + + @mcp.tool() + async def tool(combined: Annotated[Login, Resolve(combine)]) -> str: + return combined.username + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "name" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("tool", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 # batched, not serialized + + final = await client.session.call_tool( + "tool", + {}, + input_responses=_answer_round(first, answer), + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_deep_chain_keeps_early_answers_across_rounds(): + # A 4-round dependency chain where an early answer (A) must survive in + # request_state while later resolvers are asked. It must be asked exactly once. + mcp = MCPServer(name="DeepChain") + + async def ra(ctx: Context) -> Elicit[Login]: + return Elicit("A name?", Login) + + async def rb(a: Annotated[Login, Resolve(ra)]) -> Elicit[Confirm]: + return Elicit("B?", Confirm) + + async def rc(b: Annotated[Confirm, Resolve(rb)]) -> Elicit[Confirm]: + return Elicit("C?", Confirm) + + async def rd(c: Annotated[Confirm, Resolve(rc)]) -> Elicit[Confirm]: + return Elicit("D?", Confirm) + + # Depends on `ra` directly AND on `rd` (which transitively needs ra->rb->rc). + @mcp.tool() + async def tool( + a: Annotated[Login, Resolve(ra)], + d: Annotated[Confirm, Resolve(rd)], + ) -> str: + return f"{a.username}:{d.ok}" + + a_asks = 0 + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + nonlocal a_asks + if "name" in params.message: + a_asks += 1 + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("tool", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + assert a_asks == 1 # ra's answer survived in request_state; never re-asked + + +@pytest.mark.anyio +async def test_factory_closures_get_distinct_wire_keys(): + # Two resolvers from one factory share module:qualname; they must still get + # distinct questions and their own values (regression: they collided on the wire). + mcp = MCPServer(name="FactoryClosures") + + def make(label: str): + async def resolver(ctx: Context) -> Elicit[Login]: + return Elicit(f"{label}?", Login) + + return resolver + + ask_a, ask_b = make("A"), make("B") + + @mcp.tool() + async def tool( + a: Annotated[Login, Resolve(ask_a)], + b: Annotated[Login, Resolve(ask_b)], + ) -> str: + return f"{a.username},{b.username}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + return ElicitResult(action="accept", content={"username": params.message[0]}) + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("tool", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 # distinct keys, not collapsed to one + + final = await client.session.call_tool( + "tool", + {}, + input_responses=_answer_round(first, answer), + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "A,B" + + +@pytest.mark.anyio +async def test_eliciting_resolver_without_elicit_arm_restores_a_typed_model(): + # A resolver annotated `-> Login` that actually returns `Elicit(...)` has no + # `Elicit[T]` return arm, so `elicit_schema` is None. Its answer, restored from + # request_state in a 3+ round flow, must still come back as a Login model (not a + # raw dict) so a dependent resolver/tool can use its attributes. + mcp = MCPServer(name="LyingAnnotation") + + # Annotated without an `Elicit[T]` return arm, so `elicit_schema` is None. + async def login(ctx: Context) -> object: + return Elicit("user?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"as {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + if "user" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "as octocat?" in params.message # login restored as a real model + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("act", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + +def test_wire_key_is_worker_stable_for_methods_and_callable_objects(): + class Service: + async def token(self, ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + class CallableResolver: + async def __call__(self, ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + a, b = Service(), Service() + # No id(...) in the key: two instances of one method get the same base (they are + # disambiguated at registration, not here), and the key carries no memory address. + assert _state_key(a.token) == _state_key(b.token) + assert "#" not in _state_key(a.token) + assert _state_key(a.token).endswith("Service.token") + # Callable objects key by their type's qualname (they have no `__qualname__`). + assert _state_key(CallableResolver()).endswith("CallableResolver") + + +@pytest.mark.anyio +async def test_declined_outcome_persists_in_request_state_and_is_not_reasked(): + # A decline is recorded in `request_state` just like an accept: RB elicits only + # after seeing RA's decline, so RA's outcome must survive into the round that + # answers RB without RA being asked again. + mcp = MCPServer(name="DeclinePersists") + + async def ra(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + async def rb(a: Annotated[ElicitationResult[Login], Resolve(ra)]) -> Elicit[Confirm]: + assert isinstance(a, DeclinedElicitation) + return Elicit("proceed anonymously?", Confirm) + + @mcp.tool() + async def act( + a: Annotated[ElicitationResult[Login], Resolve(ra)], + c: Annotated[Confirm, Resolve(rb)], + ) -> str: + assert isinstance(a, DeclinedElicitation) + return f"anonymous:{c.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (ra_key,) = first.input_requests + + second = await client.session.call_tool( + "act", + {}, + input_responses={ra_key: ElicitResult(action="decline")}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + (rb_key,) = second.input_requests # only RB's question; RA is not re-asked + assert rb_key != ra_key + assert _decode_state(second.request_state)[ra_key].action == "decline" + + final = await client.session.call_tool( + "act", + {}, + input_responses={rb_key: ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "anonymous:True" + + +@pytest.mark.anyio +async def test_unknown_response_keys_and_ghost_state_entries_are_ignored(): + # `input_responses` keys the server never asked for and `request_state` outcome + # entries matching no resolver are tolerated (both are client-supplied), and the + # ghost state entry is not echoed into any later round's `request_state`. + mcp = MCPServer(name="GhostKeys") + + async def ra(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + async def rb(a: Annotated[Login, Resolve(ra)]) -> Elicit[Confirm]: + return Elicit(f"as {a.username}?", Confirm) + + @mcp.tool() + async def act( + a: Annotated[Login, Resolve(ra)], + c: Annotated[Confirm, Resolve(rb)], + ) -> str: + return f"{a.username}:{c.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert first.request_state is not None + (ra_key,) = first.input_requests + + spliced = json.loads(first.request_state) + spliced["outcomes"]["ghost"] = {"action": "accept", "data": {"username": "spooky"}} + second = await client.session.call_tool( + "act", + {}, + input_responses={ + ra_key: ElicitResult(action="accept", content={"username": "octocat"}), + "ghost": ElicitResult(action="accept", content={"username": "spooky"}), + }, + request_state=json.dumps(spliced), + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + (rb_key,) = second.input_requests + outcomes = _decode_state(second.request_state) + assert ra_key in outcomes + assert "ghost" not in outcomes # the spliced entry is dropped, not carried onward + + final = await client.session.call_tool( + "act", + {}, + input_responses={rb_key: ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "forged_data", + [ + pytest.param("not-a-dict", id="non-dict-data"), + pytest.param({"hacked": True}, id="dict-failing-schema"), + ], +) +async def test_forged_state_entry_failing_the_schema_is_reasked_not_an_error(forged_data: str | dict[str, bool]): + # `request_state` is client-trusted JSON: an accept entry whose data does not + # validate against the resolver's schema reads as no recorded progress, so the + # question is asked again (not an error) and a proper answer completes the call. + mcp = MCPServer(name="ForgedState") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("whoami", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert first.request_state is not None + (key,) = first.input_requests + + forged = json.loads(first.request_state) + forged["outcomes"][key] = {"action": "accept", "data": forged_data} + second = await client.session.call_tool( + "whoami", {}, request_state=json.dumps(forged), allow_input_required=True + ) + assert isinstance(second, InputRequiredResult) # re-asked, not an error + assert second.input_requests is not None + assert set(second.input_requests) == {key} + assert _decode_state(second.request_state) == {} # the forged entry is dropped + + final = await client.session.call_tool( + "whoami", + {}, + input_responses={key: ElicitResult(action="accept", content={"username": "octocat"})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat" + + +@pytest.mark.anyio +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_schema_mismatched_fresh_answer_fails_the_call_without_pydantic_leakage(mode: Literal["legacy", "auto"]): + # An accepted answer whose content fails the requested schema fails the call + # with the framework's own message on both transports; pydantic's error text + # (which carries an "errors.pydantic.dev" link) must not leak to the client. + mcp = MCPServer(name="MismatchedAnswer") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(ask)]) -> str: + raise NotImplementedError # pragma: no cover - the mismatched answer never reaches the body + + async with Client(mcp, mode=mode, elicitation_callback=_accept({"nope": "x"})) as client: + result = await client.call_tool("whoami", {}) + assert result.is_error + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + assert "does not match the requested schema" in text + assert "errors.pydantic.dev" not in text + if mode == "auto": + assert "Resolver" in text # the input_required transport names the offending resolver key + else: + assert "Received an accepted elicitation" in text # the legacy path has no wire key to name + + +@pytest.mark.anyio +async def test_auto_driver_gives_up_when_the_chain_outlasts_its_round_budget(): + # A dependency chain of 11 eliciting resolvers needs 11 retry rounds, one more + # than the default `input_required_max_rounds`, so `client.call_tool` must raise + # rather than loop on. The pure `count_leg` resolver is never persisted, so it + # re-runs on every server leg: its final value is the exact number of legs. + mcp = MCPServer(name="TooDeep") + legs = 0 + + async def count_leg(ctx: Context) -> int: + nonlocal legs + legs += 1 + return legs + + async def root(ctx: Context) -> Elicit[Confirm]: + return Elicit("Q1?", Confirm) + + def extend(dep: Callable[..., Any], n: int) -> Callable[..., Any]: + async def link(prev: Annotated[Confirm, Resolve(dep)]) -> Elicit[Confirm]: + return Elicit(f"Q{n}?", Confirm) + + return link + + chain: Callable[..., Any] = root + for n in range(2, 12): # 11 eliciting resolvers in total + chain = extend(chain, n) + + @mcp.tool() + async def long_haul( + leg: Annotated[int, Resolve(count_leg)], + last: Annotated[Confirm, Resolve(chain)], + ) -> str: + raise NotImplementedError # pragma: no cover - the driver gives up first + + answered = 0 + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + nonlocal answered + answered += 1 + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + with anyio.fail_after(5): # the loop must end by raising, not spin on retries + with pytest.raises(InputRequiredRoundsExceededError) as exc_info: + await client.call_tool("long_haul", {}) + assert exc_info.value.max_rounds == client.input_required_max_rounds + assert answered == client.input_required_max_rounds # one question answered per retry round + assert legs == client.input_required_max_rounds + 1 # the initial call plus one leg per retry + + +@pytest.mark.anyio +async def test_aliased_elicitation_model_round_trips_through_request_state(): + # The stored entry is the client's raw wire content, so it restores through + # the same validation the answer originally passed - aliases and all. A + # re-derived (field-name) shape would fail validation on the round after + # next, drop the stored answer, and re-ask the user forever. + mcp = MCPServer(name="AliasState") + + async def who(ctx: Context) -> Elicit[Handle]: + return Elicit("handle?", Handle) + + async def confirm(h: Annotated[Handle, Resolve(who)]) -> Elicit[Confirm]: + return Elicit(f"go as {h.user_name}?", Confirm) + + @mcp.tool() + async def act( + h: Annotated[Handle, Resolve(who)], + c: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{h.user_name}:{c.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (who_key,) = first.input_requests + + second = await client.session.call_tool( + "act", + {}, + input_responses={who_key: ElicitResult(action="accept", content={"userName": "octocat"})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + (confirm_key,) = second.input_requests # only the dependent question; the stored answer holds + assert confirm_key != who_key + + final = await client.session.call_tool( + "act", + {}, + input_responses={confirm_key: ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_divergent_validation_and_serialization_aliases_round_trip(): + # `request_state` must carry the client's answer exactly as it was sent: the + # rendered question is validation-aliased, so re-deriving the stored shape from + # the validated model (which serializes under the *serialization* alias) would + # produce data the schema's own validation rejects, dropping the stored answer + # on the round after next and re-asking the user. + mcp = MCPServer(name="DivergentAliases") + + async def who(ctx: Context) -> Elicit[Account]: + return Elicit("account?", Account) + + async def confirm(a: Annotated[Account, Resolve(who)]) -> Elicit[Confirm]: + return Elicit(f"go as {a.user_name}?", Confirm) + + @mcp.tool() + async def act( + a: Annotated[Account, Resolve(who)], + c: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{a.user_name}:{c.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (who_key,) = first.input_requests + question = first.input_requests[who_key].params + assert isinstance(question, ElicitRequestFormParams) + assert "vUser" in question.requested_schema["properties"] # the client answers validation-aliased + + second = await client.session.call_tool( + "act", + {}, + input_responses={who_key: ElicitResult(action="accept", content={"vUser": "octocat"})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + (go_key,) = second.input_requests # only the dependent question; the stored answer holds + assert go_key != who_key + # The stored entry is the client's wire content, not a re-serialization of it. + assert _decode_state(second.request_state)[who_key].data == {"vUser": "octocat"} + + final = await client.session.call_tool( + "act", + {}, + input_responses={go_key: ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True"