diff --git a/docs/migration.md b/docs/migration.md index d94db1f60..8c1378d11 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1624,6 +1624,20 @@ app = server.streamable_http_app( The lowlevel `Server` also now exposes a `session_manager` property to access the `StreamableHTTPSessionManager` after calling `streamable_http_app()`. +### `ElicitationResult` is now a subscriptable generic alias + +`ElicitationResult` is now a `TypeAliasType` instead of a plain union, so `ElicitationResult[Confirm]` works as an annotation (resolver dependency injection consumes it that way - see [Dependencies](tutorial/dependencies.md)). The members are unchanged: `AcceptedElicitation[T] | DeclinedElicitation | CancelledElicitation`. + +The one behavioral change: a runtime `isinstance(result, ElicitationResult)` now raises `TypeError`. Check against the member classes directly instead: + +```python +result = await ctx.elicit("Proceed?", Confirm) +if isinstance(result, AcceptedElicitation): + ... # result.data is a Confirm +``` + +Narrowing on `result.action` (`"accept"` / `"decline"` / `"cancel"`) is unaffected. + ## Need Help? If you encounter issues during migration: diff --git a/docs/tutorial/context.md b/docs/tutorial/context.md index 3a15e8fc8..17af592fb 100644 --- a/docs/tutorial/context.md +++ b/docs/tutorial/context.md @@ -63,6 +63,7 @@ The injected object is small. Besides `request_id`: * `await ctx.report_progress(progress, total, message)`: stream progress back to the caller during a long call. The whole story is in **Progress**. * `await ctx.elicit(message, schema)` and `await ctx.elicit_url(...)`: pause the tool and ask the user a question. That's **Elicitation**. * `ctx.session`: the server's side of the conversation with this client. Notifications you send to the client live here; the last section uses it. +* `ctx.headers`: the request headers the transport carried, or `None` on stdio. Read a custom header with `(ctx.headers or {}).get("x-...")`. Headers are client-supplied input - fine for a locale or a feature flag, never an identity. * `ctx.request_context`: the raw per-request record. The field you'll reach for is `lifespan_context`, the object your startup code yielded (see **Lifespan**). Logging is deliberately not on that list. A server logs with Python's `logging` module, like any other Python program. **Logging** is the short chapter on why. @@ -123,4 +124,4 @@ The siblings are `send_resource_list_changed()`, `send_prompt_list_changed()`, a * `ctx.session` is the channel back to the client: `send_tool_list_changed()` and its siblings tell it to re-fetch a list you changed. * Progress reporting and elicitation also start at `Context`; each has its own chapter. -Next: what happens when your tool fails, and how to choose who finds out, in **Handling errors**. +Next: parameters the model never sees, filled by your own functions, in **Dependencies**. diff --git a/docs/tutorial/dependencies.md b/docs/tutorial/dependencies.md new file mode 100644 index 000000000..0631ccd8f --- /dev/null +++ b/docs/tutorial/dependencies.md @@ -0,0 +1,127 @@ +# Dependencies + +A tool's arguments come from the model. Some values never should: a price looked up from your records, a confirmation only a person can give, anything the model could get wrong by inventing it. + +**Dependencies** are parameters filled by your own functions. You annotate the parameter, name the function, and the SDK calls it before your tool runs. + +## Declare one + +Wrap the parameter's type in `Annotated[...]` and add `Resolve(fn)`: + +```python title="server.py" hl_lines="18-19 23" +--8<-- "docs_src/dependencies/tutorial001.py" +``` + +* `check_stock` is a **resolver**: a plain function the SDK runs before `reserve_book`, whose return value becomes the `stock` argument. +* Its `title` parameter is the tool's own `title` argument, matched **by name**. The resolver sees exactly the validated value the tool body will see. +* The tool body starts from a `Stock` that already exists. No lookup code in the tool, no "what if it's missing" preamble. + +!!! info + If you've used FastAPI, this is `Depends`. Same move, same reason: the function declares what + it needs, the framework supplies it, and the wiring lives in the type annotation. + +### Invisible to the model + +Here is the input schema `tools/list` reports for `reserve_book`: + +```json +{ + "type": "object", + "properties": { + "title": {"title": "Title", "type": "string"} + }, + "required": ["title"], + "title": "reserve_bookArguments" +} +``` + +One property. Like the `Context` in **The Context**, a resolved parameter is a contract between you and the SDK: `stock` is not in the schema, the model is never told about it, and a client that sends a `stock` value anyway is ignored. The resolver's value is the only one your tool can receive. + +That last part is the point. A parameter the model cannot supply is a parameter the model cannot get wrong. + +### Try it + +Run the server with the MCP Inspector: + +```console +uv run mcp dev server.py +``` + +The form for `reserve_book` has a single `title` field. `stock` is nowhere on it. Call it with `Dune`: + +```text +Reserved 'Dune' (6 copies left). +``` + +The tool body never looked anything up: `check_stock` ran first, and the `Stock` it returned arrived as an argument. Try `Neuromancer` and the same resolver hands the tool a zero. + +!!! tip + You could just call `check_stock(title)` in the tool body. Declare it as a dependency when the + value deserves more than a helper call: every tool that needs stock declares the same parameter, + and the SDK runs the resolver at most once per call, no matter how many declare it. The next + sections add the rest: resolvers that depend on each other, and resolvers that ask the user. + +## Dependencies of dependencies + +A resolver can declare its own dependencies, with the same annotation: + +```python title="server.py" hl_lines="22 29-30" +--8<-- "docs_src/dependencies/tutorial002.py" +``` + +* `estimate_delivery` depends on `check_stock`. The SDK runs the graph in order: stock first, then the estimate, then the tool. +* Both `stock` and `delivery` ultimately need `check_stock`, but it runs **once per call**. One inventory lookup, two consumers. +* There is nothing to register. The graph *is* the annotations. + +!!! check + Don't take once-per-call on faith. Put a `print` in `check_stock` and call `order_book` from the + Inspector: one line per call. Two consumers, one lookup. + +The SDK analyses the graph when the tool is registered, not when it is called. A parameter it can't classify - not a `Context`, not a `Resolve(...)`, not a tool argument's name - and a cycle of resolvers both raise `InvalidSignature` at startup. Your server fails before a client ever connects, with the offending parameter or resolver named in the error. + +A resolver's parameters resolve exactly like a tool's: another `Resolve(...)`, the tool's own arguments by name, or the `Context` - `ctx.headers`, the lifespan object, all of it. + +!!! warning + On HTTP transports the `Context` includes `ctx.headers`. Headers are **client-supplied input**, + like any tool argument: fine for a locale or a feature flag, never an identity. Who the caller + is comes from your authorization layer (**Authorization**), not from a header anyone can set. + +!!! tip + *Once per call* means exactly that: the next `tools/call` runs `check_stock` again. A resource + that should outlive a request - a database pool, an HTTP client - belongs in **Lifespan**, and + a resolver can reach it through `ctx.request_context.lifespan_context`. + +## Ask when you must + +A resolver doesn't have to know the answer. It can return `Elicit(message, Model)` and the SDK asks the user - the **Elicitation** machinery, run for you: + +```python title="server.py" hl_lines="26-32 39" +--8<-- "docs_src/dependencies/tutorial003.py" +``` + +* In stock: `confirm_backorder` returns a `Backorder` directly. **No question, no round-trip.** The user is only interrupted when their answer matters. +* Out of stock: the SDK sends the elicitation, validates the answer against `Backorder`, and injects it. Your resolver never touches the protocol. +* The tool reads `backorder.confirm` like any other argument. Answering **no** is still an answer: the elicitation is accepted with `confirm=False`, the tool runs, and no order is placed. Asking became a precondition, not plumbing in the tool body. + +And if the user won't answer at all - declines the question, or cancels it? + +!!! check + Run `order_book` for `Neuromancer` and decline the question. With the annotation written as + `Annotated[Backorder, Resolve(...)]` the tool body never runs; the call fails with an error + result the model can read: + + ```text + Error executing tool order_book: Resolver for parameter 'backorder' could not resolve: elicitation was decline + ``` + +That's the right default for a precondition: no answer, no order. When declining is an outcome your tool wants to handle - skip the backorder but still suggest another title - annotate `ElicitationResult[Backorder]` instead and the tool receives the full accept/decline/cancel outcome to branch on. **Elicitation** shows that form, and everything else about asking: the schema rules, the three answers, the client's side of the conversation. + +## Recap + +* `Annotated[T, Resolve(fn)]` on a tool parameter: the SDK runs `fn` and injects its return value. +* A resolved parameter is invisible to the model and cannot be supplied by a client. Values the model must not invent - prices, identities, permissions - belong here. +* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once per call. +* Bad graphs fail at registration with `InvalidSignature`, not mid-call. +* Return `Elicit(message, Model)` to ask the user, only when you have to. Unwrapped annotations abort on decline; `ElicitationResult[T]` lets the tool branch. + +Next: what happens when your tool fails, and how to choose who finds out, in **Handling errors**. diff --git a/docs/tutorial/elicitation.md b/docs/tutorial/elicitation.md index df7ae477f..8a7b4c335 100644 --- a/docs/tutorial/elicitation.md +++ b/docs/tutorial/elicitation.md @@ -79,6 +79,24 @@ A refusal is not an error. The tool decides what declining means (here, no booki `"maybe"` for a `bool` doesn't corrupt your booking: the call fails with the `ValidationError`, your `if` never runs. +## Ask before the tool runs + +The booking tool above weaves the question into its own body. When the question is really a *precondition* - confirm before deleting, authenticate before acting - you can lift it out of the tool into a **resolver** and let the framework ask for you. + +A parameter annotated `Annotated[T, Resolve(fn)]` is filled by running `fn` before the tool body. The resolver returns the value directly when it already knows it, or returns `Elicit(...)` to have the framework ask: + +```python title="server.py" hl_lines="24-30 35-36" +--8<-- "docs_src/elicitation/tutorial004.py" +``` + +* `confirm_delete` reads the tool's own `path` argument by name, lists the folder, and **only elicits when it must** - an empty folder resolves to `Confirm(ok=True)` with no round-trip to the client. +* `delete_folder` annotates `ElicitationResult[Confirm]`, so the framework injects the whole outcome and the tool `match`es every case: accept-and-confirm, accept-but-keep (`ok=False`), decline, cancel. +* The `confirm` parameter never appears in the tool's input schema - the client supplies `path`, the resolver supplies `confirm`. + +Annotate the unwrapped model (`Annotated[Confirm, Resolve(confirm_delete)]`) instead when the tool doesn't need to branch: it receives the model on accept and the call aborts with an error on decline or cancel. + +Asking is only one thing a resolver can do. The general mechanism - dependencies that compute without asking, dependencies of dependencies, what the model can and cannot supply - is the **Dependencies** chapter. + ## Send the user to a URL Some things must not go through the model or the client: credentials, card numbers, OAuth consent. For those you don't ask for data; you ask the user to go somewhere: diff --git a/docs_src/dependencies/__init__.py b/docs_src/dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs_src/dependencies/tutorial001.py b/docs_src/dependencies/tutorial001.py new file mode 100644 index 000000000..182b54414 --- /dev/null +++ b/docs_src/dependencies/tutorial001.py @@ -0,0 +1,27 @@ +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server import MCPServer +from mcp.server.mcpserver import Resolve + +mcp = MCPServer("Bookshop") + +INVENTORY = {"Dune": 7, "Neuromancer": 0} + + +class Stock(BaseModel): + title: str + copies: int + + +async def check_stock(title: str) -> Stock: + return Stock(title=title, copies=INVENTORY.get(title, 0)) + + +@mcp.tool() +async def reserve_book(title: str, stock: Annotated[Stock, Resolve(check_stock)]) -> str: + """Reserve a copy of a book.""" + if stock.copies == 0: + return f"{title!r} is out of stock." + return f"Reserved {title!r} ({stock.copies - 1} copies left)." diff --git a/docs_src/dependencies/tutorial002.py b/docs_src/dependencies/tutorial002.py new file mode 100644 index 000000000..3f24e2ceb --- /dev/null +++ b/docs_src/dependencies/tutorial002.py @@ -0,0 +1,35 @@ +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server import MCPServer +from mcp.server.mcpserver import Resolve + +mcp = MCPServer("Bookshop") + +INVENTORY = {"Dune": 7, "Neuromancer": 0} + + +class Stock(BaseModel): + title: str + copies: int + + +async def check_stock(title: str) -> Stock: + return Stock(title=title, copies=INVENTORY.get(title, 0)) + + +async def estimate_delivery(stock: Annotated[Stock, Resolve(check_stock)]) -> str: + return "tomorrow" if stock.copies > 0 else "in 2-3 weeks" + + +@mcp.tool() +async def order_book( + title: str, + stock: Annotated[Stock, Resolve(check_stock)], + delivery: Annotated[str, Resolve(estimate_delivery)], +) -> str: + """Order a book from the shop.""" + if stock.copies == 0: + return f"{title!r} is on backorder; it would arrive {delivery}." + return f"Ordered {title!r}; it arrives {delivery}." diff --git a/docs_src/dependencies/tutorial003.py b/docs_src/dependencies/tutorial003.py new file mode 100644 index 000000000..51252668e --- /dev/null +++ b/docs_src/dependencies/tutorial003.py @@ -0,0 +1,46 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + +from mcp.server import MCPServer +from mcp.server.mcpserver import Elicit, Resolve + +mcp = MCPServer("Bookshop") + +INVENTORY = {"Dune": 7, "Neuromancer": 0} + + +class Stock(BaseModel): + title: str + copies: int + + +class Backorder(BaseModel): + confirm: bool = Field(description="Order anyway and wait?") + + +async def check_stock(title: str) -> Stock: + return Stock(title=title, copies=INVENTORY.get(title, 0)) + + +async def confirm_backorder( + title: str, + stock: Annotated[Stock, Resolve(check_stock)], +) -> Backorder | Elicit[Backorder]: + if stock.copies > 0: + return Backorder(confirm=True) # in stock: nothing to ask + return Elicit(f"{title!r} is out of stock (2-3 weeks). Order anyway?", Backorder) + + +@mcp.tool() +async def order_book( + title: str, + stock: Annotated[Stock, Resolve(check_stock)], + backorder: Annotated[Backorder, Resolve(confirm_backorder)], +) -> str: + """Order a book from the shop.""" + if not backorder.confirm: + return "No order placed." + if stock.copies == 0: + return f"Backordered {title!r}; it ships in 2-3 weeks." + return f"Ordered {title!r}." diff --git a/docs_src/elicitation/tutorial004.py b/docs_src/elicitation/tutorial004.py new file mode 100644 index 000000000..1edec06cf --- /dev/null +++ b/docs_src/elicitation/tutorial004.py @@ -0,0 +1,47 @@ +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server import MCPServer +from mcp.server.mcpserver import ( + AcceptedElicitation, + CancelledElicitation, + DeclinedElicitation, + Elicit, + ElicitationResult, + Resolve, +) + +mcp = MCPServer("Files") + +_FOLDERS: dict[str, list[str]] = {"/tmp/empty": [], "/tmp/project": ["main.py", "README.md"]} + + +class Confirm(BaseModel): + ok: bool + + +async def confirm_delete(path: str) -> Confirm | Elicit[Confirm]: + """Resolver: ask for confirmation only when the folder is not empty.""" + file_count = len(_FOLDERS.get(path, [])) + if file_count == 0: + return Confirm(ok=True) # nothing to confirm, no round-trip to the client + return Elicit(f"{path} has {file_count} file(s). Delete anyway?", Confirm) + + +@mcp.tool() +async def delete_folder( + path: str, + confirm: Annotated[ElicitationResult[Confirm], Resolve(confirm_delete)], +) -> str: + """Delete a folder, asking for confirmation when it is not empty.""" + match confirm: + case AcceptedElicitation(data=Confirm(ok=True)): + _FOLDERS.pop(path, None) + return f"deleted {path}" + case AcceptedElicitation(): + return "kept the folder" + case DeclinedElicitation(): + return "declined: folder not deleted" + case CancelledElicitation(): + return "cancelled: folder not deleted" diff --git a/examples/stories/README.md b/examples/stories/README.md index 8b267f392..8c1cceb5b 100644 --- a/examples/stories/README.md +++ b/examples/stories/README.md @@ -130,6 +130,7 @@ opens with a banner saying what replaces it. | [`streaming`](streaming/) | progress notifications, in-flight logging, cancellation | current | | [`mrtr`](mrtr/) | `InputRequiredResult` round-trip: the `Client` auto-loop and a manual session-level loop | current | | [`legacy_elicitation`](legacy_elicitation/) | server pauses a tool to ask the user (form + url) via a push request | legacy | +| [`refund_desk`](refund_desk/) | resolver DI: `Annotated[T, Resolve(fn)]` params filled server-side, hidden from the input schema | current | | [`sampling`](sampling/) | server asks the client's LLM mid-tool (push request) | deprecated | | [`stickynotes`](stickynotes/) | capstone: tools mutate state → resources + `list_changed` + elicit guard | current | | [`custom_methods`](custom_methods/) | vendor-prefixed JSON-RPC via `add_request_handler` / `send_request` | current | diff --git a/examples/stories/legacy_elicitation/README.md b/examples/stories/legacy_elicitation/README.md index 62f4379c3..1a9d48e60 100644 --- a/examples/stories/legacy_elicitation/README.md +++ b/examples/stories/legacy_elicitation/README.md @@ -69,4 +69,5 @@ uv run python -m stories.legacy_elicitation.client --http --legacy --server serv `sampling/` (same push-request shape, deprecated per SEP-2577), `mrtr/` (planned — the 2026-era carrier), `error_handling/` -(`UrlElicitationRequiredError`). +(`UrlElicitationRequiredError`), `refund_desk/` (resolver DI rides this push +mechanism today). diff --git a/examples/stories/manifest.toml b/examples/stories/manifest.toml index 0fb25a0f0..57ec0e8a4 100644 --- a/examples/stories/manifest.toml +++ b/examples/stories/manifest.toml @@ -39,6 +39,12 @@ era = "modern" era = "legacy" status = "legacy" +[story.refund_desk] +# Resolver DI rides push elicitation (ctx.elicit) today; era flips to "dual" once +# the SDK carries resolver elicitation over the 2026 input_required round-trip. +era = "legacy" +lowlevel = false + [story.sampling] era = "legacy" status = "deprecated" diff --git a/examples/stories/mrtr/README.md b/examples/stories/mrtr/README.md index d801b8ff0..de214988d 100644 --- a/examples/stories/mrtr/README.md +++ b/examples/stories/mrtr/README.md @@ -51,4 +51,5 @@ uv run python -m stories.mrtr.client --http --server server_lowlevel ## See also `legacy_elicitation/` and `sampling/` — the handshake-era push equivalents this -mechanism replaces on the 2026 protocol. +mechanism replaces on the 2026 protocol. `refund_desk/` — resolver DI at the +MCPServer tier: the questions a tool can declare instead of pushing by hand. diff --git a/examples/stories/refund_desk/README.md b/examples/stories/refund_desk/README.md new file mode 100644 index 000000000..0a77dd580 --- /dev/null +++ b/examples/stories/refund_desk/README.md @@ -0,0 +1,67 @@ +# refund-desk + +Resolver dependency injection: a tool parameter annotated `Annotated[T, +Resolve(fn)]` is filled by running the resolver `fn` before the tool body, +instead of from the LLM-supplied arguments. Here `refund_order(order_id, +reason)` refunds what the order record says — `cents` is resolver-computed and +does not appear in the input schema at all, so the model cannot supply or +inflate the amount. Resolvers form a DAG (`load_order` → `refund_scope` → +`refund_amount` / `ask_restock`), may return `Elicit[...]` to ask the human, +and run at most once per call. A resolver's own plain parameters are filled +from the tool's arguments by name — `load_order(order_id)` receives the +`order_id` the model passed to `refund_order`. + +## Run it + +```bash +# stdio (default — the client spawns the server as a subprocess) +uv run python -m stories.refund_desk.client + +# HTTP — the client self-hosts the server on a free port, runs, then tears it +# down (--legacy: resolver elicitation rides the push request today; the +# manifest pins this era, so bare --http runs the same leg) +uv run python -m stories.refund_desk.client --http --legacy +``` + +## What to look at + +- `server.py` `refund_order` — the signature is the whole story: `order_id` and + `reason` are model-facing; `cents` and `restock` carry `Resolve(...)` markers + and never reach the input schema. `client.py` asserts `properties` and + `required` are exactly `{order_id, reason}`. +- `server.py` `refund_scope` — the no-round-trip fast path: a one-line order + returns `Scope(full=True)` directly; only a multi-line order returns + `Elicit(...)`. The ORD-7001 call completes with zero elicitations. +- `server.py` `_scoped` — the elicited SKU is human-typed free text; it is + validated against the order (`ToolError` on a miss) before any amount is + computed. +- The decline contrast: `refund_amount` takes `scope` **unwrapped**, so + declining the scope question aborts the whole `cents` chain with an error + containing the framework's + `Resolver for parameter 'scope' could not resolve: elicitation was decline` + (the client sees it behind the usual `Error executing tool refund_order:` + prefix); `restock` keeps the `ElicitationResult` union, so declining restock + still refunds — just with `restocked: false`. +- `client.py` — the scope counter proves memoization from outside: one call + consumes `refund_scope` from two resolvers but the question fires once. + +## Caveats + +- **Decline order.** A declined unwrapped dependency aborts resolution in + tool-signature order — `cents` resolves before `restock`, so `ask_restock` + never runs. Don't rely on a later resolver's side effects after an earlier + consumer can abort. +- **Memoization scope.** Each resolver runs at most once per `tools/call`, + keyed by function identity; nothing is cached across calls or connections. +- **Validate elicited values.** Elicited answers are human-typed; check them + against your records (as `_scoped` does) before acting on them. + +## Spec + +[Elicitation — client features](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation) + +## See also + +`legacy_elicitation/` (the push mechanism resolver elicitation rides on today), +`mrtr/` (the 2026 `input_required` carrier; resolver DI will ride it once the +SDK wires them together). diff --git a/examples/stories/refund_desk/__init__.py b/examples/stories/refund_desk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/stories/refund_desk/client.py b/examples/stories/refund_desk/client.py new file mode 100644 index 000000000..ee86d94b4 --- /dev/null +++ b/examples/stories/refund_desk/client.py @@ -0,0 +1,103 @@ +"""Prove the refund amount is schema-hidden, resolvers memoize per call, and decline semantics differ per consumer.""" + +import mcp_types as types + +from mcp.client import Client, ClientRequestContext +from stories._harness import Target, run_client + + +async def main(target: Target, *, mode: str = "auto") -> None: + # Scripted answers + per-topic counters; topics in `declines` are refused. + counts = {"scope": 0, "restock": 0} + answers: dict[str, dict[str, str | int | float | bool | list[str] | None]] = { + "scope": {"full": True}, + "restock": {"restock": True}, + } + declines: set[str] = set() + + async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + assert isinstance(params, types.ElicitRequestFormParams) + topic = "scope" if "full" in params.requested_schema["properties"] else "restock" + counts[topic] += 1 + if topic in declines: + return types.ElicitResult(action="decline") + return types.ElicitResult(action="accept", content=answers[topic]) + + async with Client(target, mode=mode, elicitation_callback=on_elicit) as client: + # The model-facing contract is order_id + reason only; cents and restock are resolver-filled. + listed = await client.list_tools() + (tool,) = listed.tools + assert set(tool.input_schema["properties"]) == {"order_id", "reason"}, tool.input_schema + assert set(tool.input_schema.get("required", ())) == {"order_id", "reason"}, tool.input_schema + + # One digital line: scope auto-fills (full), restock auto-fills (False) — zero round-trips. + receipt = await client.call_tool("refund_order", {"order_id": "ORD-7001", "reason": "download corrupted"}) + assert receipt.structured_content == { + "order_id": "ORD-7001", + "refunded_cents": 1500, + "restocked": False, + "reason": "download corrupted", + }, receipt.structured_content + assert counts == {"scope": 0, "restock": 0}, counts + + # Full refund of a three-line order. The scope question fires exactly ONCE even though + # both refund_amount and ask_restock consume it — memoized within the call. + receipt = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "arrived broken"}) + assert receipt.structured_content == { + "order_id": "ORD-7002", + "refunded_cents": 4800, + "restocked": True, + "reason": "arrived broken", + }, receipt.structured_content + assert counts == {"scope": 1, "restock": 1}, counts + + # Declining restock still refunds: the tool keeps the ElicitationResult union for + # `restock`, sees the decline, and just skips the restock. The scope counter moves + # again — the memo cache is per tools/call, not per connection. + declines.add("restock") + answers["scope"] = {"full": False, "sku": "canvas-tote"} + receipt = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "wrong colour"}) + assert receipt.structured_content == { + "order_id": "ORD-7002", + "refunded_cents": 2400, + "restocked": False, + "reason": "wrong colour", + }, receipt.structured_content + assert counts == {"scope": 2, "restock": 2}, counts + declines.clear() + + # An elicited SKU is human-typed: the server validates it against the order before + # any money is computed. + answers["scope"] = {"full": False, "sku": "mystery-hat"} + result = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "lost parcel"}) + assert result.is_error, result + assert isinstance(result.content[0], types.TextContent) + assert "order has no item 'mystery-hat'" in result.content[0].text, result.content[0].text + + # Declining scope aborts the whole call: refund_amount and ask_restock both consume scope + # unwrapped, so whichever resolves first (`cents`, in signature order) aborts, and + # ask_restock never runs under any order. + declines.add("scope") + restock_before = counts["restock"] + result = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "changed mind"}) + assert result.is_error, result + assert isinstance(result.content[0], types.TextContent) + assert "Resolver for parameter 'scope' could not resolve: elicitation was decline" in result.content[0].text, ( + result.content[0].text + ) + assert counts["restock"] == restock_before, counts + declines.clear() + + # A ToolError raised inside a resolver surfaces exactly like one from the tool body. + result = await client.call_tool("refund_order", {"order_id": "ORD-9999", "reason": "typo"}) + assert result.is_error, result + assert isinstance(result.content[0], types.TextContent) + assert "unknown order 'ORD-9999'" in result.content[0].text, result.content[0].text + + # Full elicitation trajectory: scope fired in legs 2-5 (memoized within each call), + # restock only in the two calls that reached it. + assert counts == {"scope": 4, "restock": 2}, counts + + +if __name__ == "__main__": + run_client(main) diff --git a/examples/stories/refund_desk/server.py b/examples/stories/refund_desk/server.py new file mode 100644 index 000000000..f29a266f0 --- /dev/null +++ b/examples/stories/refund_desk/server.py @@ -0,0 +1,125 @@ +"""Resolver DI: the refund amount is computed by resolvers from the order record — `cents` never appears in the +tool's input schema, so the model cannot supply or inflate it.""" + +from dataclasses import dataclass +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server.mcpserver import ( + AcceptedElicitation, + Elicit, + ElicitationResult, + MCPServer, + Resolve, +) +from mcp.server.mcpserver.exceptions import ToolError +from stories._hosting import run_server_from_args + + +@dataclass(frozen=True) +class Line: + sku: str + cents: int + physical: bool + + +@dataclass(frozen=True) +class Order: + order_id: str + lines: tuple[Line, ...] + + +ORDERS: dict[str, Order] = { + "ORD-7001": Order("ORD-7001", (Line("ebook-fieldnotes", 1500, physical=False),)), + "ORD-7002": Order( + "ORD-7002", + ( + Line("enamel-mug", 1800, physical=True), + Line("canvas-tote", 2400, physical=True), + Line("sticker-pack", 600, physical=False), + ), + ), +} + + +class Scope(BaseModel): + """Which items to refund: the whole order, or a single SKU.""" + + full: bool + sku: str = "" + + +class RestockChoice(BaseModel): + restock: bool + + +class Receipt(BaseModel): + order_id: str + refunded_cents: int + restocked: bool + reason: str + + +def load_order(order_id: str) -> Order: + order = ORDERS.get(order_id) + if order is None: + raise ToolError(f"unknown order {order_id!r}") + return order + + +def refund_scope(order_id: str, order: Annotated[Order, Resolve(load_order)]) -> Scope | Elicit[Scope]: + if len(order.lines) == 1: + return Scope(full=True) + skus = ", ".join(line.sku for line in order.lines) + return Elicit(f"{order_id} has several items ({skus}). Refund the whole order, or one SKU?", Scope) + + +def _scoped(order: Order, scope: Scope) -> tuple[Line, ...]: + """The lines a scope covers. The SKU was typed by a human — validate it against the order.""" + if scope.full: + return order.lines + lines = tuple(line for line in order.lines if line.sku == scope.sku) + if not lines: + raise ToolError(f"order has no item {scope.sku!r}") + return lines + + +def refund_amount( + order: Annotated[Order, Resolve(load_order)], + scope: Annotated[Scope, Resolve(refund_scope)], +) -> int: + return sum(line.cents for line in _scoped(order, scope)) + + +def ask_restock( + order: Annotated[Order, Resolve(load_order)], + scope: Annotated[Scope, Resolve(refund_scope)], +) -> RestockChoice | Elicit[RestockChoice]: + physical = [line.sku for line in _scoped(order, scope) if line.physical] + if not physical: + return RestockChoice(restock=False) + return Elicit(f"The refund includes physical items ({', '.join(physical)}). Return them to stock?", RestockChoice) + + +def build_server() -> MCPServer: + mcp = MCPServer("refund-desk") + + @mcp.tool(description="Refund an order. The amount comes from the order record, not from the caller.") + def refund_order( + order_id: str, + reason: str, + cents: Annotated[int, Resolve(refund_amount)], + restock: Annotated[ElicitationResult[RestockChoice], Resolve(ask_restock)], + ) -> Receipt: + # `restock` keeps the full elicitation outcome: a declined restock still refunds. A plain + # (non-Elicit) resolver return arrives wrapped as an accepted outcome, so the fast path + # lands in the same `AcceptedElicitation` branch. + restocked = isinstance(restock, AcceptedElicitation) and restock.data.restock + return Receipt(order_id=order_id, refunded_cents=cents, restocked=restocked, reason=reason) + + return mcp + + +if __name__ == "__main__": + run_server_from_args(build_server) diff --git a/mkdocs.yml b/mkdocs.yml index 3e671da8c..7acee7d5d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -21,6 +21,7 @@ nav: - Resources: tutorial/resources.md - Prompts: tutorial/prompts.md - The Context: tutorial/context.md + - Dependencies: tutorial/dependencies.md - Handling errors: tutorial/handling-errors.md - Lifespan: tutorial/lifespan.md - Media: tutorial/media.md diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index dc0e669c8..c6faf0065 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import core_schema +from typing_extensions import TypeAliasType from mcp.server.session import ServerSession @@ -36,7 +37,11 @@ class CancelledElicitation(BaseModel): action: Literal["cancel"] = "cancel" -ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation +ElicitationResult = TypeAliasType( + "ElicitationResult", + AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation, + type_params=(ElicitSchemaModelT,), +) class AcceptedUrlElicitation(BaseModel): diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index 7a8da42fe..8ee6c4e4e 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -5,6 +5,14 @@ from mcp.server.extension import Extension, MethodBinding, ResourceBinding, ToolBinding from .context import Context +from .resolve import ( + AcceptedElicitation, + CancelledElicitation, + DeclinedElicitation, + Elicit, + ElicitationResult, + Resolve, +) from .resources import DEFAULT_RESOURCE_SECURITY, ResourceSecurity from .server import MCPServer, require_client_extension from .utilities.types import Audio, Image @@ -15,6 +23,12 @@ "Image", "Audio", "Icon", + "Resolve", + "Elicit", + "ElicitationResult", + "AcceptedElicitation", + "DeclinedElicitation", + "CancelledElicitation", "Extension", "ToolBinding", "ResourceBinding", diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 15b6fd4ad..4d494db6e 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING, Any, Generic, cast from mcp_types import ClientCapabilities, InputResponseRequestParams, InputResponses, LoggingLevel from pydantic import AnyUrl, BaseModel @@ -217,6 +217,16 @@ def client_id(self) -> str | None: """ return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + Populated by HTTP-based transports; `None` on stdio or when the + transport's request object carries no headers. Headers are + client-supplied input - never treat one as an identity assertion. + """ + return cast("Mapping[str, str] | None", getattr(self.request_context.request, "headers", None)) + @property def request_id(self) -> str: """Get the unique ID for this request.""" diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py new file mode 100644 index 000000000..89843a716 --- /dev/null +++ b/src/mcp/server/mcpserver/resolve.py @@ -0,0 +1,324 @@ +"""Resolver dependency injection for MCPServer tools. + +A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the +resolver `fn` before the tool body, instead of from the LLM-supplied arguments. +Resolvers form a DAG: a resolver may declare its own `Resolve(...)` dependencies, +take tool arguments by name, and take the `Context`. A resolver may return +`Elicit[T]` to ask the client; the framework runs the elicitation and injects the +answer. + +Whether the consumer receives the unwrapped model or the full +`ElicitationResult` union is decided by the consumer's annotation: + +- `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call. +- `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the + full outcome; the consumer branches on accept/decline/cancel. + +Each resolver runs at most once per `tools/call` (memoized by function identity). +""" + +from __future__ import annotations + +import inspect +import typing +from collections.abc import Callable, Hashable, Mapping +from typing import Annotated, Any, Generic, cast, get_args, get_origin + +import anyio.to_thread +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.elicitation import ( + AcceptedElicitation, + CancelledElicitation, + DeclinedElicitation, + ElicitationResult, +) +from mcp.server.mcpserver.context import Context +from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError +from mcp.shared._callable_inspection import is_async_callable + +T = TypeVar("T", bound=BaseModel) + +# The union members the framework injects when a consumer opts into the outcome. +_ELICITATION_RESULT_MEMBERS = (AcceptedElicitation, DeclinedElicitation, CancelledElicitation) + + +class Resolve: + """Marker for `Annotated[T, Resolve(fn)]`: fill the parameter by running `fn`.""" + + def __init__(self, fn: Callable[..., Any]) -> None: + self.fn = fn + + +class Elicit(Generic[T]): + """A resolver's request to ask the client. + + Returned from a resolver to signal that the value must be elicited. The + framework runs `ctx.elicit(message, schema)` and injects the outcome. + """ + + def __init__(self, message: str, schema: type[T]) -> None: + self.message = message + self.schema = schema + + +class _ParamPlan: + """How to fill one resolver parameter, decided once at registration.""" + + kind: str # "context" | "resolve" | "by_name" + resolve: Resolve | None + wants_union: bool + + def __init__(self, kind: str, resolve: Resolve | None = None, wants_union: bool = False) -> None: + self.kind = kind + self.resolve = resolve + self.wants_union = wants_union + + +class _ResolverPlan: + """A resolver's parameters and whether it is async, analyzed once.""" + + def __init__(self, fn: Callable[..., Any], params: dict[str, _ParamPlan], is_async: bool) -> None: + self.fn = fn + self.params = params + self.is_async = is_async + + +def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: + """Resolve type hints for a function or a callable object. + + `typing.get_type_hints` raises on a callable *instance*; fall back to its + `__call__`. Returns an empty mapping when hints cannot be resolved, matching + `find_context_parameter`'s tolerance so callables without annotations (or with + unresolvable ones) simply have no resolved parameters. + """ + target = fn if inspect.isroutine(fn) else getattr(type(fn), "__call__", fn) + try: + return typing.get_type_hints(target, include_extras=True) + except Exception: + return {} + + +def _resolver_name(fn: Callable[..., Any]) -> str: + """Best-effort display name for error messages (callable objects lack `__name__`).""" + return getattr(fn, "__name__", None) or type(fn).__name__ + + +def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, bool]]: + """Find parameters of `fn` annotated `Annotated[_, Resolve(...)]`. + + Returns a mapping of parameter name to `(Resolve, wants_union)`, where + `wants_union` is True when the annotated type is an `ElicitationResult` member + (the consumer wants the full outcome rather than the unwrapped model). + """ + hints = _type_hints(fn) + resolved: dict[str, tuple[Resolve, bool]] = {} + for name in inspect.signature(fn).parameters: + annotation = hints.get(name) + if get_origin(annotation) is not Annotated: + # A `Resolve` marker is only honored at the top level; flag (rather than + # silently drop) one buried in a union, e.g. `Annotated[T, Resolve(f)] | None`. + if _contains_resolve(annotation): + raise InvalidSignature( + f"Parameter {name!r} of {_resolver_name(fn)!r} wraps `Resolve(...)` in a " + "union; annotate the parameter directly as `Annotated[T, Resolve(...)]`" + ) + continue + type_arg, *metadata = get_args(annotation) + marker = next((m for m in metadata if isinstance(m, Resolve)), None) + if marker is not None: + resolved[name] = (marker, _wants_union(type_arg)) + return resolved + + +def _contains_resolve(annotation: Any) -> bool: + """True when a `Resolve` marker is nested inside `annotation` (e.g. a union member).""" + if get_origin(annotation) is Annotated: + return any(isinstance(m, Resolve) for m in get_args(annotation)[1:]) + return any(_contains_resolve(arg) for arg in get_args(annotation)) + + +def _wants_union(type_arg: Any) -> bool: + """True when `type_arg` is an `ElicitationResult` member (or a union of them). + + Handles the subscripted `ElicitationResult[T]` alias (a `TypeAliasType` whose + union is on the origin's `__value__`), the bare `ElicitationResult` alias (the + `__value__` is on `type_arg` itself), an explicit `AcceptedElicitation[T] | ...` + union, and a single member. + """ + # Unwrap the `ElicitationResult` alias whether it is bare or subscripted. + value = getattr(type_arg, "__value__", None) or getattr(get_origin(type_arg), "__value__", None) + if value is not None: + type_arg = value + members = get_args(type_arg) if get_origin(type_arg) is not None else (type_arg,) + return any(isinstance(m, type) and issubclass(m, _ELICITATION_RESULT_MEMBERS) for m in members) + + +def _resolver_key(fn: Callable[..., Any]) -> Hashable: + """Identity key for memoizing a resolver. + + A bound method - pure-python (`inspect.ismethod`) or built-in (e.g. `obj.meth` + on a C-extension type) - is recreated on each attribute access, so `id(fn)` + differs every time. Key it by its underlying function (or name) plus its + `__self__` identity so `auth.login` referenced in two places memoizes to one + call. Everything else keys by `id`, so two distinct callables never collide + even if they compare equal. + """ + bound_self = getattr(fn, "__self__", None) + if bound_self is not None: + # `__func__` (pure-python) has a stable identity; built-ins expose only a + # stable `__name__`. Use the function's id or the name's value accordingly. + func = getattr(fn, "__func__", None) + underlying: Hashable = id(func) if func is not None else getattr(fn, "__name__", id(fn)) + return (underlying, id(bound_self)) + return id(fn) + + +def build_resolver_plans( + resolved_params: Mapping[str, tuple[Resolve, bool]], + tool_arg_names: set[str], +) -> dict[Hashable, _ResolverPlan]: + """Statically analyze the resolver DAG rooted at a tool's resolved parameters. + + Raises: + InvalidSignature: If a resolver has a cyclic dependency, or a resolver + parameter cannot be classified (not a `Context`, a nested `Resolve`, + or a tool argument by name). + """ + plans: dict[Hashable, _ResolverPlan] = {} + + def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: + key = _resolver_key(fn) + if key in stack: + raise InvalidSignature(f"Resolver {_resolver_name(fn)!r} has a cyclic dependency") + if key in plans: + return + + hints = _type_hints(fn) + sig = inspect.signature(fn) + params: dict[str, _ParamPlan] = {} + nested: list[Callable[..., Any]] = [] + for param_name in sig.parameters: + annotation = hints.get(param_name) + if annotation is not None and _is_context_annotation(annotation): + params[param_name] = _ParamPlan("context") + continue + marker, wants_union = _resolve_marker(annotation) + if marker is not None: + params[param_name] = _ParamPlan("resolve", marker, wants_union) + nested.append(marker.fn) + continue + if param_name in tool_arg_names: + params[param_name] = _ParamPlan("by_name") + continue + raise InvalidSignature( + f"Resolver {_resolver_name(fn)!r} parameter {param_name!r} cannot be resolved: " + "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" + ) + + plans[key] = _ResolverPlan(fn, params, is_async_callable(fn)) + for dep in nested: + analyze(dep, stack + (key,)) + + for marker, _ in resolved_params.values(): + analyze(marker.fn, ()) + return plans + + +def _resolve_marker(annotation: Any) -> tuple[Resolve | None, bool]: + if get_origin(annotation) is not Annotated: + return None, False + type_arg, *metadata = get_args(annotation) + marker = next((m for m in metadata if isinstance(m, Resolve)), None) + return marker, (_wants_union(type_arg) if marker is not None else False) + + +def _is_context_annotation(annotation: Any) -> bool: + if get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + candidates = get_args(annotation) if get_origin(annotation) is not None else (annotation,) + return any(isinstance(c, type) and issubclass(c, Context) for c in candidates) + + +async def resolve_arguments( + resolved_params: Mapping[str, tuple[Resolve, bool]], + plans: Mapping[Hashable, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], +) -> dict[str, Any]: + """Resolve every `Resolve`-marked tool parameter into a concrete value. + + Each resolver runs at most once (memoized by function identity). Returns a + mapping of tool parameter name to the value to inject. + + Raises: + ToolError: If an elicited value is declined or cancelled and the consumer + asked for the unwrapped model (rather than the result union). + """ + cache: dict[Hashable, ElicitationResult[Any]] = {} + injected: dict[str, Any] = {} + for name, (marker, wants_union) in resolved_params.items(): + outcome = await _resolve(marker.fn, plans, tool_args, context, cache) + injected[name] = outcome if wants_union else _unwrap(outcome, name) + return injected + + +async def _resolve( + fn: Callable[..., Any], + plans: Mapping[Hashable, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], + cache: dict[Hashable, ElicitationResult[Any]], +) -> ElicitationResult[Any]: + key = _resolver_key(fn) + if key in cache: + return cache[key] + + plan = plans[key] + kwargs: dict[str, Any] = {} + for param_name, param_plan in plan.params.items(): + if param_plan.kind == "context": + kwargs[param_name] = context + elif param_plan.kind == "by_name": + kwargs[param_name] = tool_args[param_name] + else: + assert param_plan.resolve is not None + dep_outcome = await _resolve(param_plan.resolve.fn, plans, tool_args, context, cache) + kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) + + if plan.is_async: + result = await fn(**kwargs) + else: + result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) + + outcome: ElicitationResult[Any] + if isinstance(result, Elicit): + elicit = cast("Elicit[BaseModel]", result) + outcome = await context.elicit(elicit.message, elicit.schema) + else: + # A resolver may return any type (not just `BaseModel`); `model_construct` + # wraps it as an accepted result without validating against the schema bound. + outcome = cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=result)) + + cache[key] = outcome + return outcome + + +def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: + if isinstance(outcome, AcceptedElicitation): + return outcome.data + raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") + + +__all__ = [ + "Resolve", + "Elicit", + "ElicitationResult", + "AcceptedElicitation", + "DeclinedElicitation", + "CancelledElicitation", + "find_resolved_parameters", + "build_resolver_plans", + "resolve_arguments", +] diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 7eb87eed0..6aab3c777 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Hashable from functools import cached_property from typing import TYPE_CHECKING, Any @@ -8,6 +8,11 @@ from pydantic import BaseModel, Field from mcp.server.mcpserver.exceptions import ToolError +from mcp.server.mcpserver.resolve import ( + build_resolver_plans, + find_resolved_parameters, + resolve_arguments, +) from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata from mcp.shared._callable_inspection import is_async_callable @@ -32,6 +37,14 @@ class Tool(BaseModel): ) is_async: bool = Field(description="Whether the tool is async") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + resolved_params: dict[str, Any] = Field( + default_factory=lambda: {}, + exclude=True, + description="Parameters filled by resolvers, mapped to (Resolve, wants_union)", + ) + resolver_plans: dict[Hashable, Any] = Field( + default_factory=lambda: {}, exclude=True, description="Static per-resolver parameter plans" + ) annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool") meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool") @@ -67,13 +80,23 @@ def from_function( if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) + resolved_params = find_resolved_parameters(fn) + + skip_names = [context_kwarg] if context_kwarg is not None else [] + skip_names.extend(resolved_params) + func_arg_metadata = func_metadata( fn, - skip_names=[context_kwarg] if context_kwarg is not None else [], + skip_names=skip_names, structured_output=structured_output, ) parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) + # Match `model_dump_one_level`'s kwarg keys (alias when present, else field name) + # so a by-name resolver param resolves to a key that exists at call time. + tool_arg_names = {field.alias or name for name, field in func_arg_metadata.arg_model.model_fields.items()} + resolver_plans = build_resolver_plans(resolved_params, tool_arg_names) + return cls( fn=fn, name=func_name, @@ -83,6 +106,8 @@ def from_function( fn_metadata=func_arg_metadata, is_async=is_async, context_kwarg=context_kwarg, + resolved_params=dict(resolved_params), + resolver_plans=resolver_plans, annotations=annotations, icons=icons, meta=meta, @@ -100,11 +125,26 @@ async def run( ToolError: If the tool function raises during execution. """ try: + pass_directly: dict[str, Any] = {} + if self.context_kwarg is not None: + pass_directly[self.context_kwarg] = context + + # Resolvers see the same validated arguments the tool body receives: + # validate once and reuse it, so a `default_factory`/stateful validator + # can't hand a by-name resolver a different value than the body. + pre_validated: dict[str, Any] | None = None + if self.resolved_params: + pre_validated = self.fn_metadata.validate_arguments(arguments) + pass_directly |= await resolve_arguments( + self.resolved_params, self.resolver_plans, pre_validated, context + ) + result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, self.is_async, arguments, - {self.context_kwarg: context} if self.context_kwarg is not None else None, + pass_directly or None, + pre_validated=pre_validated, ) if convert_result: diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 97eb3909e..be4afb4e9 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -69,21 +69,36 @@ class FuncMetadata(BaseModel): output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None wrap_output: bool = False + def validate_arguments(self, arguments_to_validate: dict[str, Any]) -> dict[str, Any]: + """Validate raw arguments into a one-level kwargs dict (no function call). + + Used to feed resolver dependency injection the validated tool arguments + before the tool function itself runs. + """ + arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) + arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) + return arguments_parsed_model.model_dump_one_level() + async def call_fn_with_arg_validation( self, fn: Callable[..., Any | Awaitable[Any]], fn_is_async: bool, arguments_to_validate: dict[str, Any], arguments_to_pass_directly: dict[str, Any] | None, + pre_validated: dict[str, Any] | None = None, ) -> Any: """Call the given function with arguments validated and injected. Arguments are first attempted to be parsed from JSON, then validated against - the argument model, before being passed to the function. + the argument model, before being passed to the function. Pass `pre_validated` + (the output of `validate_arguments`) to reuse an earlier validation pass - + validating twice can re-run `default_factory`/stateful validators and hand the + function different values than a caller already observed. """ - arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) - arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) - arguments_parsed_dict = arguments_parsed_model.model_dump_one_level() + # Copy so a caller-provided `pre_validated` dict is never mutated in place. + arguments_parsed_dict = dict( + pre_validated if pre_validated is not None else self.validate_arguments(arguments_to_validate) + ) arguments_parsed_dict |= arguments_to_pass_directly or {} @@ -150,7 +165,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: key_to_field_info[field_info.alias] = field_info for data_key, data_value in data.items(): - if data_key not in key_to_field_info: # pragma: no cover + if data_key not in key_to_field_info: continue field_info = key_to_field_info[data_key] diff --git a/tests/docs_src/test_dependencies.py b/tests/docs_src/test_dependencies.py new file mode 100644 index 000000000..73355a892 --- /dev/null +++ b/tests/docs_src/test_dependencies.py @@ -0,0 +1,129 @@ +"""`docs/tutorial/dependencies.md`: every claim the page makes, proved against the real SDK.""" + +import pytest +from inline_snapshot import snapshot +from mcp_types import ElicitRequestParams, ElicitResult, TextContent + +from docs_src.dependencies import tutorial001, tutorial002, tutorial003 +from mcp import Client +from mcp.client import ClientRequestContext + +pytestmark = [pytest.mark.anyio, pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning")] + + +async def test_the_resolver_fills_the_parameter_from_the_tools_own_argument() -> None: + """tutorial001: `check_stock` receives `title` by name and its return value becomes `stock`.""" + async with Client(tutorial001.mcp) as client: + in_stock = await client.call_tool("reserve_book", {"title": "Dune"}) + sold_out = await client.call_tool("reserve_book", {"title": "Neuromancer"}) + + assert in_stock.content == [TextContent(type="text", text="Reserved 'Dune' (6 copies left).")] + assert sold_out.content == [TextContent(type="text", text="'Neuromancer' is out of stock.")] + + +async def test_the_resolved_parameter_is_invisible_to_the_model() -> None: + """tutorial001: the input schema shown on the page is exactly what `tools/list` reports.""" + async with Client(tutorial001.mcp) as client: + (tool,) = (await client.list_tools()).tools + + assert tool.input_schema == snapshot( + { + "type": "object", + "properties": {"title": {"title": "Title", "type": "string"}}, + "required": ["title"], + "title": "reserve_bookArguments", + } + ) + + +async def test_a_client_supplied_value_for_a_resolved_parameter_is_ignored() -> None: + """tutorial001: the resolver's value is the only one the tool can receive.""" + async with Client(tutorial001.mcp) as client: + result = await client.call_tool("reserve_book", {"title": "Dune", "stock": {"title": "Dune", "copies": 999}}) + + assert result.content == [TextContent(type="text", text="Reserved 'Dune' (6 copies left).")] + + +async def test_a_resolver_can_depend_on_another_resolver() -> None: + """tutorial002: `estimate_delivery` consumes `check_stock`'s result, and the tool gets both.""" + async with Client(tutorial002.mcp) as client: + in_stock = await client.call_tool("order_book", {"title": "Dune"}) + backorder = await client.call_tool("order_book", {"title": "Neuromancer"}) + + assert in_stock.content == [TextContent(type="text", text="Ordered 'Dune'; it arrives tomorrow.")] + assert backorder.content == [ + TextContent(type="text", text="'Neuromancer' is on backorder; it would arrive in 2-3 weeks.") + ] + + +async def test_a_shared_dependency_runs_once_per_call(monkeypatch: pytest.MonkeyPatch) -> None: + """tutorial002: `stock` and `delivery` both need `check_stock`; one call, one inventory lookup.""" + + class CountingInventory: + def __init__(self, data: dict[str, int]) -> None: + self.data = data + self.lookups: list[str] = [] + + def get(self, key: str, default: int) -> int: + self.lookups.append(key) + return self.data.get(key, default) + + inventory = CountingInventory(dict(tutorial002.INVENTORY)) + monkeypatch.setattr(tutorial002, "INVENTORY", inventory) + + async with Client(tutorial002.mcp) as client: + await client.call_tool("order_book", {"title": "Dune"}) + assert inventory.lookups == ["Dune"] + # Memoization is per call, not per server: the next call looks the title up again. + await client.call_tool("order_book", {"title": "Dune"}) + assert inventory.lookups == ["Dune", "Dune"] + + +async def test_an_in_stock_order_asks_no_question() -> None: + """tutorial003: `confirm_backorder` returns directly when stock exists - no round-trip.""" + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("an in-stock order must not elicit") + + async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=never) as client: + result = await client.call_tool("order_book", {"title": "Dune"}) + + assert result.content == [TextContent(type="text", text="Ordered 'Dune'.")] + + +@pytest.mark.parametrize( + ("confirm", "expected"), + [ + (True, "Backordered 'Neuromancer'; it ships in 2-3 weeks."), + (False, "No order placed."), + ], +) +async def test_an_out_of_stock_order_asks_and_honours_the_answer(confirm: bool, expected: str) -> None: + """tutorial003: the resolver elicits, the SDK validates the answer, the tool reads it.""" + asked: list[str] = [] + + async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params.message) + return ElicitResult(action="accept", content={"confirm": confirm}) + + async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=on_elicit) as client: + result = await client.call_tool("order_book", {"title": "Neuromancer"}) + + assert result.content == [TextContent(type="text", text=expected)] + assert asked == ["'Neuromancer' is out of stock (2-3 weeks). Order anyway?"] + + +async def test_declining_an_unwrapped_dependency_aborts_the_call() -> None: + """tutorial003: no answer, no order - the error text on the page is the real one.""" + + async def decline(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=decline) as client: + result = await client.call_tool("order_book", {"title": "Neuromancer"}) + + assert result.is_error + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == ( + "Error executing tool order_book: Resolver for parameter 'backorder' could not resolve: elicitation was decline" + ) diff --git a/tests/docs_src/test_elicitation.py b/tests/docs_src/test_elicitation.py index 44523a141..4c9bb4036 100644 --- a/tests/docs_src/test_elicitation.py +++ b/tests/docs_src/test_elicitation.py @@ -14,7 +14,7 @@ ) from pydantic import BaseModel -from docs_src.elicitation import tutorial001, tutorial002, tutorial003 +from docs_src.elicitation import tutorial001, tutorial002, tutorial003, tutorial004 from mcp import Client, MCPError from mcp.client import ClientRequestContext from mcp.server import MCPServer @@ -246,3 +246,54 @@ async def test_a_client_without_the_callback_cannot_be_asked() -> None: async with Client(tutorial001.mcp, mode="legacy") as client: with pytest.raises(MCPError, match="Elicitation not supported"): await client.call_tool("book_table", {"date": "2025-12-25", "party_size": 2}) + + +async def test_resolver_asks_only_when_the_folder_is_not_empty() -> None: + """tutorial004: `confirm_delete` resolves an empty folder directly and elicits otherwise.""" + tutorial004._FOLDERS.update({"/tmp/empty": [], "/tmp/project": ["main.py", "README.md"]}) + asked: list[str] = [] + + async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestFormParams) + asked.append(params.message) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(tutorial004.mcp, mode="legacy", elicitation_callback=on_elicit) as client: + empty = await client.call_tool("delete_folder", {"path": "/tmp/empty"}) + non_empty = await client.call_tool("delete_folder", {"path": "/tmp/project"}) + + assert empty.content == [TextContent(type="text", text="deleted /tmp/empty")] + assert non_empty.content == [TextContent(type="text", text="deleted /tmp/project")] + assert asked == ["/tmp/project has 2 file(s). Delete anyway?"] # the empty folder was not queried + + +async def test_the_resolved_parameter_is_hidden_from_the_tool_schema() -> None: + """tutorial004: the `Resolve`-filled parameter never appears in the client-facing input schema.""" + async with Client(tutorial004.mcp, mode="legacy") as client: + (tool,) = (await client.list_tools()).tools + assert tool.name == "delete_folder" + assert set(tool.input_schema["properties"]) == {"path"} + + +@pytest.mark.parametrize( + ("action", "content", "expected"), + [ + ("accept", {"ok": False}, "kept the folder"), + ("decline", None, "declined: folder not deleted"), + ("cancel", None, "cancelled: folder not deleted"), + ], +) +async def test_the_tool_branches_on_every_elicitation_outcome( + action: Literal["accept", "decline", "cancel"], + content: dict[str, str | int | float | bool | list[str] | None] | None, + expected: str, +) -> None: + """tutorial004: annotating the result union lets the tool handle accept/decline/cancel.""" + tutorial004._FOLDERS["/tmp/project"] = ["main.py", "README.md"] + + async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action=action, content=content) + + async with Client(tutorial004.mcp, mode="legacy", elicitation_callback=on_elicit) as client: + result = await client.call_tool("delete_folder", {"path": "/tmp/project"}) + assert result.content == [TextContent(type="text", text=expected)] diff --git a/tests/server/mcpserver/test_func_metadata.py b/tests/server/mcpserver/test_func_metadata.py index edc3decbd..62a9612b9 100644 --- a/tests/server/mcpserver/test_func_metadata.py +++ b/tests/server/mcpserver/test_func_metadata.py @@ -155,6 +155,28 @@ async def test_complex_function_runtime_arg_validation_with_json(): assert result == "ok!" +@pytest.mark.anyio +async def test_call_fn_does_not_mutate_pre_validated(): + """A caller-provided `pre_validated` dict must not be mutated by the call.""" + + def fn(x: int, ctx: str) -> str: + return f"{x}:{ctx}" + + meta = func_metadata(fn, skip_names=["ctx"]) + pre_validated = meta.validate_arguments({"x": 1}) + snapshot = dict(pre_validated) + + result = await meta.call_fn_with_arg_validation( + fn, + fn_is_async=False, + arguments_to_validate={"x": 1}, + arguments_to_pass_directly={"ctx": "injected"}, + pre_validated=pre_validated, + ) + assert result == "1:injected" + assert pre_validated == snapshot # `ctx` was not leaked into the caller's dict + + def test_str_vs_list_str(): """Test handling of string vs list[str] type annotations. diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py new file mode 100644 index 000000000..1f4f72408 --- /dev/null +++ b/tests/server/mcpserver/test_resolve.py @@ -0,0 +1,571 @@ +"""Tests for resolver dependency injection (MRTR) on MCPServer tools.""" + +from typing import Annotated, Any, Literal + +import pytest +from mcp_types import ElicitRequestParams, ElicitResult, TextContent +from pydantic import BaseModel, Field + +from mcp import Client +from mcp.client import ClientRequestContext +from mcp.server.mcpserver import ( + AcceptedElicitation, + CancelledElicitation, + Context, + DeclinedElicitation, + Elicit, + ElicitationResult, + MCPServer, + Resolve, +) +from mcp.server.mcpserver.exceptions import InvalidSignature +from mcp.server.mcpserver.resolve import _resolver_key, find_resolved_parameters +from mcp.server.mcpserver.tools.base import Tool + + +class Login(BaseModel): + username: str + + +class Confirm(BaseModel): + ok: bool + + +async def _alias_login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover - only the signature is inspected + + +def _accept(content: dict[str, str | int | float | bool | list[str] | None]): + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content=content) + + return callback + + +async def _decline(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + +async def _text(client: Client, tool: str, args: dict[str, object]) -> str: + result = await client.call_tool(tool, args) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + return result.content[0].text + + +@pytest.mark.anyio +async def test_resolver_returns_value_directly_without_eliciting(): + mcp = MCPServer(name="Direct") + + async def login(ctx: Context) -> Login | Elicit[Login]: + username = (ctx.headers or {}).get("x-github-user") + if username: # pragma: no cover - no headers on in-memory transport + return Login(username=username) + return Login(username="from-resolver") + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "whoami", {}) == "from-resolver" + + +@pytest.mark.anyio +async def test_resolver_elicits_and_injects_unwrapped_model_on_accept(): + mcp = MCPServer(name="Accept") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async with Client(mcp, mode="legacy", elicitation_callback=_accept({"username": "octocat"})) as client: + assert await _text(client, "whoami", {}) == "octocat" + + +@pytest.mark.anyio +async def test_consumer_receives_result_union_and_branches(): + mcp = MCPServer(name="Union") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami(login: Annotated[ElicitationResult[Login], Resolve(login)]) -> str: + match login: + case AcceptedElicitation(data=data): + return f"hi {data.username}" + case _: # pragma: no cover - accepted in this test + return "no username" + + async with Client(mcp, mode="legacy", elicitation_callback=_accept({"username": "octocat"})) as client: + assert await _text(client, "whoami", {}) == "hi octocat" + + +@pytest.mark.anyio +async def test_decline_reaches_union_consumer_without_aborting(): + mcp = MCPServer(name="UnionDecline") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami( + login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)], + ) -> str: + if isinstance(login, DeclinedElicitation): + return "declined gracefully" + raise NotImplementedError + + async with Client(mcp, mode="legacy", elicitation_callback=_decline) as client: + assert await _text(client, "whoami", {}) == "declined gracefully" + + +@pytest.mark.anyio +async def test_decline_aborts_when_consumer_wants_unwrapped(): + mcp = MCPServer(name="UnwrappedDecline") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + raise NotImplementedError # pragma: no cover - never reached + + async with Client(mcp, mode="legacy", elicitation_callback=_decline) as client: + result = await client.call_tool("whoami", {}) + assert result.is_error + assert isinstance(result.content[0], TextContent) + assert "decline" in result.content[0].text + + +@pytest.mark.anyio +async def test_nested_resolver_sees_dependency_and_tool_args(): + mcp = MCPServer(name="Nested") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + async def confirm(repo: str, login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"Star {repo} as {login.username}?", Confirm) + + @mcp.tool() + async def star_repo( + repo: str, + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + if confirm.ok: + return f"starred {repo} as {login.username}" + raise NotImplementedError + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + if "username" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "Star modelcontextprotocol/python-sdk as octocat?" in params.message + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: + text = await _text(client, "star_repo", {"repo": "modelcontextprotocol/python-sdk"}) + assert text == "starred modelcontextprotocol/python-sdk as octocat" + + +@pytest.mark.anyio +async def test_resolver_runs_once_for_two_consumers(): + mcp = MCPServer(name="ExactlyOnce") + elicit_count = 0 + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"As {login.username}?", Confirm) + + @mcp.tool() + async def star_repo( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + nonlocal elicit_count + if "username" in params.message: + elicit_count += 1 + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: + assert await _text(client, "star_repo", {}) == "octocat:True" + assert elicit_count == 1 + + +@pytest.mark.anyio +async def test_sync_resolver(): + mcp = MCPServer(name="Sync") + + def login(ctx: Context) -> Login: + return Login(username="sync-user") + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "whoami", {}) == "sync-user" + + +def test_resolved_params_absent_from_input_schema(): + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover - only the schema is inspected + + async def tool( + repo: Annotated[str, Field(description="repo name")], + login: Annotated[Login, Resolve(login)], + ) -> str: + return repo # pragma: no cover - only the schema is inspected + + built = Tool.from_function(tool) + properties = built.parameters["properties"] + assert "repo" in properties + assert "login" not in properties + + +def test_cycle_detection_raises_at_registration(): + async def a(dep: Login) -> Login: + return dep # pragma: no cover + + async def b(dep: Login) -> Login: + return dep # pragma: no cover + + # Close the loop after both exist: a depends on b, b depends on a. + a.__annotations__["dep"] = Annotated[Login, Resolve(b)] + b.__annotations__["dep"] = Annotated[Login, Resolve(a)] + + async def tool(value: Annotated[Login, Resolve(a)]) -> str: + return value.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cyclic"): + Tool.from_function(tool) + + +def test_find_resolved_parameters_tolerates_unresolvable_hints(): + def fn(x: int) -> int: + return x # pragma: no cover + + fn.__annotations__["x"] = "DoesNotExist" + assert find_resolved_parameters(fn) == {} + + +def test_elicitation_result_alias_resolves_under_postponed_annotations(): + # Reproduces the case where `from __future__ import annotations` stringifies + # `Annotated[ElicitationResult[Login], Resolve(_alias_login)]`: the alias must be + # subscriptable so the resolver is detected (not silently dropped) and the + # consumer is recognized as wanting the result union. + def tool(login: str) -> str: + return login # pragma: no cover + + tool.__annotations__["login"] = "Annotated[ElicitationResult[Login], Resolve(_alias_login)]" + resolved = find_resolved_parameters(tool) + assert "login" in resolved + assert resolved["login"][1] is True # wants_union + + +def test_unresolvable_resolver_param_raises_at_registration(): + async def login(mystery: int) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(login)]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cannot be resolved"): + Tool.from_function(tool) + + +def test_resolve_marker_inside_a_union_raises_at_registration(): + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(login)] | None = None) -> str: + return login.username if login else "" # pragma: no cover + + with pytest.raises(InvalidSignature, match="wraps `Resolve"): + Tool.from_function(tool) + + +def test_bare_elicitation_result_alias_wants_the_outcome_union(): + # The bare `ElicitationResult` alias (no `[T]` subscription) must still opt into + # the result union, not be treated as wanting the unwrapped model. + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: object) -> str: + return "x" # pragma: no cover + + bare_alias: Any = ElicitationResult + tool.__annotations__["login"] = Annotated[bare_alias, Resolve(login)] + (_, wants_union) = find_resolved_parameters(tool)["login"] + assert wants_union is True + + +def test_resolve_marker_on_return_annotation_is_ignored(): + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(repo: str) -> Annotated[str, Resolve(login)]: + return repo # pragma: no cover + + assert find_resolved_parameters(tool) == {} + + +def test_callable_object_resolver_error_uses_type_name(): + class BadResolver: + async def __call__(self, mystery: int) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(BadResolver())]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="'BadResolver'"): + Tool.from_function(tool) + + +@pytest.mark.anyio +async def test_by_name_resolver_param_uses_aliased_tool_arg(): + mcp = MCPServer(name="Aliased") + + # `schema` collides with a BaseModel attribute, so func_metadata aliases the field; + # the runtime kwarg key is the alias, which is what a by-name resolver must match. + async def upper(schema: str) -> Login: + return Login(username=schema.upper()) + + @mcp.tool() + async def run(schema: str, shouted: Annotated[Login, Resolve(upper)]) -> str: + return shouted.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {"schema": "gpt"}) == "GPT" + + +@pytest.mark.anyio +async def test_resolver_may_return_non_basemodel_value(): + mcp = MCPServer(name="NonModel") + + async def get_token(ctx: Context) -> str: + return "secret-token" + + @mcp.tool() + async def use_token(token: Annotated[str, Resolve(get_token)]) -> str: + return token + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "use_token", {}) == "secret-token" + + +@pytest.mark.anyio +async def test_resolver_accepts_optional_context_annotation(): + mcp = MCPServer(name="OptionalContext") + + async def whoami(ctx: Context | None) -> str: + assert ctx is not None + return "has-context" + + @mcp.tool() + async def run(who: Annotated[str, Resolve(whoami)]) -> str: + return who + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "has-context" + + +@pytest.mark.anyio +async def test_bound_method_resolver_runs_once_across_references(): + mcp = MCPServer(name="BoundMethod") + calls = 0 + + class Service: + async def token(self, ctx: Context) -> str: + nonlocal calls + calls += 1 + return "tok" + + service = Service() + + # Each `service.token` access is a fresh bound-method object; keying by the + # callable (not id) keeps the resolver memoized to a single call. + async def downstream(token: Annotated[str, Resolve(service.token)]) -> str: + return token.upper() + + @mcp.tool() + async def run( + token: Annotated[str, Resolve(service.token)], + shouted: Annotated[str, Resolve(downstream)], + ) -> str: + return f"{token}:{shouted}" + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "tok:TOK" + assert calls == 1 + + +def test_bound_method_cycle_is_detected(): + class Service: + async def a(self, dep: Login) -> Login: + return dep # pragma: no cover + + async def b(self, dep: Login) -> Login: + return dep # pragma: no cover + + service = Service() + service.a.__func__.__annotations__["dep"] = Annotated[Login, Resolve(service.b)] + service.b.__func__.__annotations__["dep"] = Annotated[Login, Resolve(service.a)] + + async def tool(value: Annotated[Login, Resolve(service.a)]) -> str: + return value.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cyclic"): + Tool.from_function(tool) + + +@pytest.mark.anyio +async def test_resolver_and_body_see_the_same_validated_default(): + mcp = MCPServer(name="DefaultFactory") + counter = {"n": 0} + + def next_id() -> int: + counter["n"] += 1 + return counter["n"] + + # A by-name resolver and the tool body must observe one validation pass, so the + # `default_factory` runs once and both see the same generated value. + async def echo_id(request_id: int) -> int: + return request_id + + @mcp.tool() + async def run( + request_id: Annotated[int, Field(default_factory=next_id)], + resolved_id: Annotated[int, Resolve(echo_id)], + ) -> str: + return f"{request_id}:{resolved_id}" + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "1:1" + assert counter["n"] == 1 + + +def test_resolver_key_is_stable_for_methods_and_distinct_callables(): + class Service: + def handler(self) -> None: ... # pragma: no cover + + a, b = Service(), Service() + + # Pure-python bound methods: stable across accesses, distinct per instance. + assert _resolver_key(a.handler) == _resolver_key(a.handler) + assert _resolver_key(a.handler) != _resolver_key(b.handler) + + # Built-in bound methods (no `__func__`): fresh object each access, but the key + # is stable and keyed to `__self__`. + items: list[int] = [] + others: list[int] = [] + assert _resolver_key(items.append) == _resolver_key(items.append) + assert _resolver_key(items.append) != _resolver_key(others.append) + assert _resolver_key(items.append) != _resolver_key(items.pop) + + # Plain functions key by identity. + def fn() -> None: ... # pragma: no cover + + assert _resolver_key(fn) == _resolver_key(fn) + + +def _delete_folder_server() -> tuple[MCPServer, dict[str, list[str]]]: + """The `delete_folder` example from docs/migration.md, wired to an in-memory fs.""" + mcp = MCPServer(name="files") + fs: dict[str, list[str]] = {} + + async def confirm_delete(path: str) -> Confirm | Elicit[Confirm]: + file_count = len(fs.get(path, [])) + if file_count == 0: + return Confirm(ok=True) + return Elicit(f"{path} has {file_count} file(s). Delete anyway?", Confirm) + + @mcp.tool() + async def delete_folder( + path: str, + confirm: Annotated[ElicitationResult[Confirm], Resolve(confirm_delete)], + ) -> str: + match confirm: + case AcceptedElicitation(data=Confirm(ok=True)): + fs.pop(path, None) + return f"deleted {path}" + case AcceptedElicitation(): + return "kept the folder" + case DeclinedElicitation(): + return "declined: folder not deleted" + case CancelledElicitation(): # pragma: no branch + return "cancelled: folder not deleted" + + return mcp, fs + + +@pytest.mark.anyio +async def test_delete_empty_folder_does_not_elicit(): + mcp, fs = _delete_folder_server() + fs["/empty"] = [] + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit for an empty folder") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "delete_folder", {"path": "/empty"}) == "deleted /empty" + assert "/empty" not in fs + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("action", "content", "expected"), + [ + ("accept", {"ok": True}, "deleted /docs"), + ("accept", {"ok": False}, "kept the folder"), + ("decline", None, "declined: folder not deleted"), + ("cancel", None, "cancelled: folder not deleted"), + ], +) +async def test_delete_non_empty_folder_handles_every_outcome( + action: Literal["accept", "decline", "cancel"], + content: dict[str, str | int | float | bool | list[str] | None] | None, + expected: str, +): + mcp, fs = _delete_folder_server() + fs["/docs"] = ["a.txt", "b.txt"] + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + assert "/docs has 2 file(s)" in params.message + return ElicitResult(action=action, content=content) + + async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: + assert await _text(client, "delete_folder", {"path": "/docs"}) == expected + assert ("/docs" in fs) is (expected != "deleted /docs") diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 70855f44b..d92ed5eaa 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,5 +1,6 @@ import base64 from pathlib import Path +from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -1801,6 +1802,33 @@ async def test_report_progress_delegates_to_session_report_progress(): mock_session.report_progress.assert_awaited_once_with(50, 100, "halfway") +def _request_context(request: object | None) -> ServerRequestContext[None, object]: + return ServerRequestContext( + session=AsyncMock(), + method="tools/call", + lifespan_context=None, + protocol_version="2025-11-25", + request=request, + ) + + +def test_context_headers_returns_request_headers(): + request = SimpleNamespace(headers={"x-github-user": "octocat"}) + ctx = Context(request_context=_request_context(request), mcp_server=MagicMock()) + assert ctx.headers == {"x-github-user": "octocat"} + + +def test_context_headers_is_none_without_request(): + ctx = Context(request_context=_request_context(None), mcp_server=MagicMock()) + assert ctx.headers is None + + +def test_context_headers_is_none_when_request_carries_no_headers(): + """A transport may attach a custom request object that has no headers attribute.""" + ctx = Context(request_context=_request_context(object()), mcp_server=MagicMock()) + assert ctx.headers is None + + async def test_read_resource_template_error(): """Template-creation failure must surface as INTERNAL_ERROR, not INVALID_PARAMS (not-found).""" mcp = MCPServer()