diff --git a/.github/workflows/integrity-gate.yml b/.github/workflows/integrity-gate.yml index aab2417..0184106 100644 --- a/.github/workflows/integrity-gate.yml +++ b/.github/workflows/integrity-gate.yml @@ -168,7 +168,7 @@ jobs: --lock e2e.warden.lock \ --verify \ --certificate-identity \ - "https://github.com/ernestprovo23/mcp-warden/.github/workflows/integrity-gate.yml@${GITHUB_REF}" \ + "https://github.com/DataScience-EngineeringExperts/mcp-warden/.github/workflows/integrity-gate.yml@${GITHUB_REF}" \ --certificate-oidc-issuer "https://token.actions.githubusercontent.com" # NEGATIVE PROOF (real crypto): tamper the signed lock's overall_digest and @@ -185,7 +185,7 @@ jobs: --lock e2e.warden.lock \ --verify \ --certificate-identity \ - "https://github.com/ernestprovo23/mcp-warden/.github/workflows/integrity-gate.yml@${GITHUB_REF}" \ + "https://github.com/DataScience-EngineeringExperts/mcp-warden/.github/workflows/integrity-gate.yml@${GITHUB_REF}" \ --certificate-oidc-issuer "https://token.actions.githubusercontent.com" ; then echo "ERROR: real sigstore verified a TAMPERED overall_digest — fail closed is broken" exit 1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e934cb..7148ddd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,15 +22,25 @@ CI. The v0.3 `guard` proxy adds deterministic runtime *result* inspection **Explicitly out of scope in v1 (documented post-1.0 roadmap):** - **HTTP/SSE transport** — v1 is stdio-only; HTTP/SSE is the headline v1.x item (#9). -- **DNS-name resolution** of exfil-domain matches (raw-IP-literal handling is the D6 - work item) and **prompt-injection default-block** (stays opt-in / MONITOR until - field false-positive data justifies blocking by default). +- **Prompt-injection default-block** — stays opt-in / MONITOR until field + false-positive data justifies blocking by default. - Behavioral-attack defense (`T-BEHAVE`), full agent-firewall mediation, and any compliance/regulatory claim. See `docs/THREAT_MODEL.md` for the limits. ## [Unreleased] -_No unreleased changes yet._ +### Added + +- **Runtime DNS resolution SSRF bypass detection (`WRD-RES-EXFIL-DNS-SSRF`)** (#11): + the `guard` proxy now resolves URL hostnames from `tools/call` results at runtime + and blocks (error-replace) when any resolved IP falls in a deny range + (`SSRF_NETWORKS` — link-local, loopback, RFC1918, IPv6 ULA/loopback/link-local). + This closes the bypass where `WRD-RES-EXFIL-IP-LITERAL` could not fire because the + result contained a DNS hostname (e.g. `169.254.169.254.nip.io`) rather than a raw + IP literal. Resolution is bounded by 1 s across all hostnames per result frame, + fail-open (any DNS error = no hit), and opt-out via `--no-block-exfil-dns-ssrf` + (or `--no-block-deterministic`). Raw IP literals and the offline `inspect` command + are unchanged. New module `res_dns.py`; 23 new tests. ## [1.0.1] — 2026-06-13 diff --git a/src/mcp_warden/capture.py b/src/mcp_warden/capture.py index 3416144..ebdfe4d 100644 --- a/src/mcp_warden/capture.py +++ b/src/mcp_warden/capture.py @@ -1,8 +1,9 @@ -"""MCP stdio capture client. +"""MCP capture client — stdio and HTTP/SSE transports. Spawns the target MCP server **over stdio as an argv array, never via a shell** -(WARDEN_LOCK_SCHEMA.md §10.4), runs ``initialize`` + ``tools/list`` + -``resources/list`` + ``prompts/list``, and captures the declared surface. +(WARDEN_LOCK_SCHEMA.md §10.4), *or* connects to an already-running server over +HTTP/SSE (Streamable HTTP), then runs ``initialize`` + ``tools/list`` + +``resources/list`` + ``prompts/list`` and captures the declared surface. A server that hangs, crashes, or exits nonzero must produce a clear ``CaptureError``, not a traceback. @@ -16,6 +17,7 @@ import anyio from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamable_http_client from .models import ( CapturedPrompt, @@ -195,3 +197,79 @@ def capture_surface_sync( CaptureError: On any capture failure (see :func:`capture_surface`). """ return anyio.run(capture_surface, command, args, timeout_s) + + +async def _capture_http_async(url: str, timeout_s: float) -> CapturedSurface: + """Inner async HTTP/SSE capture; wrapped with a timeout by :func:`capture_surface_http`.""" + async with streamable_http_client(url) as (read_stream, write_stream, _get_session_id): + async with ClientSession(read_stream, write_stream) as session: + init_result = await session.initialize() + protocol_version = str(getattr(init_result, "protocolVersion", "") or "") + + tools = await _list_tools(session) + resources = await _list_resources(session) + prompts = await _list_prompts(session) + + return CapturedSurface( + url=url, + protocol_version=protocol_version, + tools=tools, + resources=resources, + prompts=prompts, + ) + + +async def capture_surface_http( + url: str, + timeout_s: float = DEFAULT_TIMEOUT_S, +) -> CapturedSurface: + """Connect to a running MCP server over HTTP/SSE and capture its declared surface. + + Connects to ``url`` using the Streamable HTTP transport (MCP SDK + ``streamable_http_client``). The server must already be running and + reachable; no process is spawned. + + Args: + url: HTTP/HTTPS endpoint of the MCP server (e.g. ``https://example.com/mcp``). + timeout_s: Wall-clock timeout for the whole handshake. + + Returns: + The :class:`CapturedSurface` with ``url`` set and ``command``/``args`` empty. + + Raises: + CaptureError: On timeout, connection error, or MCP handshake failure. + """ + logger.debug("connecting to MCP server over HTTP/SSE: url=%r", url) + try: + with anyio.fail_after(timeout_s): + return await _capture_http_async(url, timeout_s) + except TimeoutError as exc: + raise CaptureError( + f"MCP server at '{url}' did not complete the handshake within {timeout_s:.0f}s " + f"(it may be unreachable or hung)." + ) from exc + except CaptureError: + raise + except Exception as exc: + raise CaptureError( + f"Failed to capture MCP server at '{url}': {type(exc).__name__}: {exc}" + ) from exc + + +def capture_surface_http_sync( + url: str, + timeout_s: float = DEFAULT_TIMEOUT_S, +) -> CapturedSurface: + """Synchronous wrapper around :func:`capture_surface_http` for the CLI. + + Args: + url: HTTP/HTTPS endpoint URL. + timeout_s: Wall-clock timeout. + + Returns: + The captured surface. + + Raises: + CaptureError: On any capture failure. + """ + return anyio.run(capture_surface_http, url, timeout_s) diff --git a/src/mcp_warden/check_core.py b/src/mcp_warden/check_core.py index 24487f1..95725b0 100644 --- a/src/mcp_warden/check_core.py +++ b/src/mcp_warden/check_core.py @@ -5,9 +5,13 @@ (issue: "a hook that disagrees with CI is worse than no hook"). The sequence here mirrors what ``check`` has always done: -``read_lock`` -> ``capture_surface_sync`` -> ``run_checks`` -> ``build_lock`` +``read_lock`` -> capture -> ``run_checks`` -> ``build_lock`` (an in-memory CURRENT lock, never persisted) -> ``compute_drift``. +Capture routing: +- stdio (command + args): :func:`~mcp_warden.capture.capture_surface_sync` +- HTTP/SSE (url): :func:`~mcp_warden.capture.capture_surface_http_sync` + # INTERNAL STABILITY NOTE: the pre-commit wrapper (precommit.py) depends on this # function's signature and exception contract (CaptureError for spawn/timeout # failures; FileNotFoundError / ValueError for a missing/invalid lock). Do not @@ -26,7 +30,7 @@ from dataclasses import dataclass from pathlib import Path -from .capture import capture_surface_sync +from .capture import capture_surface_http_sync, capture_surface_sync from .checks import run_checks from .drift import DriftItem, compute_drift from .lockfile import build_lock, read_lock @@ -50,6 +54,8 @@ def run_check_full( args: list[str], lock_path: Path, timeout_s: float, + *, + url: str | None = None, ) -> CheckResult: """Run the full check verdict path: read lock -> capture -> checks -> drift. @@ -58,10 +64,12 @@ def run_check_full( calls the thinner :func:`run_check` which discards ``findings``. Args: - command: The MCP server launch command (argv[0]). - args: The remaining server launch argv. + command: The MCP server launch command (argv[0]). Ignored when ``url`` is set. + args: The remaining server launch argv. Ignored when ``url`` is set. lock_path: Path to the baseline ``warden.lock``. timeout_s: Capture timeout in seconds. + url: When set, connect to this HTTP/SSE endpoint instead of spawning a + subprocess. Mutually exclusive with meaningful ``command``/``args``. Returns: A :class:`CheckResult` (``drift`` empty == clean). @@ -69,10 +77,13 @@ def run_check_full( Raises: FileNotFoundError: The lock file does not exist. ValueError: The lock file is invalid JSON or fails schema validation. - CaptureError: The server could not be spawned or did not respond in time. + CaptureError: The server could not be spawned/reached or did not respond in time. """ baseline = read_lock(lock_path) - surface = capture_surface_sync(command, args, timeout_s=timeout_s) + if url: + surface = capture_surface_http_sync(url, timeout_s=timeout_s) + else: + surface = capture_surface_sync(command, args, timeout_s=timeout_s) findings = run_checks(surface) # build_lock constructs an IN-MEMORY current lock for diffing only; it is # never written to disk on the check path. diff --git a/src/mcp_warden/cli.py b/src/mcp_warden/cli.py index 2b05e66..9b1a78e 100644 --- a/src/mcp_warden/cli.py +++ b/src/mcp_warden/cli.py @@ -25,7 +25,7 @@ from rich.table import Table from . import __version__ -from .capture import CaptureError, capture_surface_sync +from .capture import CaptureError, capture_surface_http_sync, capture_surface_sync from .check_core import run_check_full from .checks import run_checks from .cli_diff import register as register_diff_command @@ -94,7 +94,9 @@ def _split_server_cmd(server_cmd: list[str]) -> tuple[str, list[str]]: @app.command() def pin( - server_cmd: list[str] = typer.Argument(..., help="MCP server launch argv (e.g. node ./server.js)"), + server_cmd: Optional[list[str]] = typer.Argument( + None, help="MCP server launch argv (e.g. node ./server.js); omit when using --url" + ), lock: Path = typer.Option(Path(DEFAULT_LOCK_NAME), "--lock", help="Output lock path"), approve: bool = typer.Option(False, "--approve", help="Record a human approval attestation"), approver: Optional[str] = typer.Option(None, "--approver", help="Approver identity (or WARDEN_APPROVER env)"), @@ -105,16 +107,32 @@ def pin( identity_token: Optional[str] = typer.Option( None, "--identity-token", help="Explicit OIDC token for signing (default: ambient/CI OIDC)" ), + url: Optional[str] = typer.Option( + None, "--url", help="HTTP/SSE endpoint of a running MCP server (mutually exclusive with server-cmd)" + ), ) -> None: - """Pin an MCP server's declared surface into ``warden.lock`` (TOFU baseline).""" - command, args = _split_server_cmd(server_cmd) + """Pin an MCP server's declared surface into ``warden.lock`` (TOFU baseline). + + Pass either a server launch command (stdio) or ``--url`` (HTTP/SSE). + """ + if url and server_cmd: + err_console.print("[red]error:[/red] --url and server-cmd are mutually exclusive") + raise typer.Exit(code=2) + if not url and not server_cmd: + err_console.print("[red]error:[/red] provide either a server command or --url") + raise typer.Exit(code=2) + approver_id = approver or os.environ.get("WARDEN_APPROVER") if approve and not approver_id: err_console.print("[red]error:[/red] --approve requires --approver or WARDEN_APPROVER env") raise typer.Exit(code=2) try: - surface = capture_surface_sync(command, args, timeout_s=timeout) + if url: + surface = capture_surface_http_sync(url, timeout_s=timeout) + else: + command, args = _split_server_cmd(server_cmd) + surface = capture_surface_sync(command, args, timeout_s=timeout) except CaptureError as exc: err_console.print(f"[red]capture failed:[/red] {exc}") raise typer.Exit(code=2) from exc @@ -146,7 +164,7 @@ def pin( @app.command() def check( server_cmd: Optional[list[str]] = typer.Argument( - None, help="MCP server launch argv (must match the pinned launch); omit with --verify" + None, help="MCP server launch argv (must match the pinned launch); omit with --url or --verify" ), lock: Path = typer.Option(Path(DEFAULT_LOCK_NAME), "--lock", help="Baseline lock path"), json_out: bool = typer.Option(False, "--json", help="Emit findings+drift as JSONL to stdout"), @@ -164,9 +182,13 @@ def check( offline_bundle: Optional[Path] = typer.Option( None, "--offline-bundle", help="Explicit bundle path (default: .sigstore next to the lock)" ), + url: Optional[str] = typer.Option( + None, "--url", help="HTTP/SSE endpoint of a running MCP server (mutually exclusive with server-cmd)" + ), ) -> None: """Re-capture and verify a server against ``warden.lock``; fail on drift. + Pass either a server launch command (stdio) or ``--url`` (HTTP/SSE). With ``--verify`` the command ALSO/INSTEAD verifies the lock's Sigstore signature (a no-server-spawn cryptographic check). ``--verify`` requires ``--certificate-identity`` and ``--certificate-oidc-issuer`` and exits 0 only @@ -179,12 +201,22 @@ def check( ) return - command, args = _split_server_cmd(server_cmd) + if url and server_cmd: + err_console.print("[red]error:[/red] --url and server-cmd are mutually exclusive") + raise typer.Exit(code=2) + if not url and not server_cmd: + err_console.print("[red]error:[/red] provide either a server command or --url") + raise typer.Exit(code=2) + + if url: + command, args = "", [] + else: + command, args = _split_server_cmd(server_cmd) try: # Single source of truth shared with the pre-commit wrapper (precommit.py) # so a local hook and CI can never disagree on a drift verdict. - result = run_check_full(command, args, lock, timeout_s=timeout) + result = run_check_full(command, args, lock, timeout_s=timeout, url=url) except (FileNotFoundError, ValueError) as exc: err_console.print(f"[red]error:[/red] {exc}") raise typer.Exit(code=2) from exc diff --git a/src/mcp_warden/cli_guard.py b/src/mcp_warden/cli_guard.py index 4ac2d21..f5c2faf 100644 --- a/src/mcp_warden/cli_guard.py +++ b/src/mcp_warden/cli_guard.py @@ -85,6 +85,7 @@ def guard( no_block_exfil_domain: bool = typer.Option(False, "--no-block-exfil-domain", help="Demote WRD-RES-EXFIL-DOMAIN to shadow"), allow_exfil_domain: bool = typer.Option(False, "--allow-exfil-domain", help="Alias of --no-block-exfil-domain"), no_block_exfil_ip_literal: bool = typer.Option(False, "--no-block-exfil-ip-literal", help="Demote WRD-RES-EXFIL-IP-LITERAL to shadow"), + no_block_exfil_dns_ssrf: bool = typer.Option(False, "--no-block-exfil-dns-ssrf", help="Demote WRD-RES-EXFIL-DNS-SSRF to shadow (disables runtime DNS resolution)"), no_block_list_changed: bool = typer.Option(False, "--no-block-list-changed", help="Demote tools/list_changed gate to shadow"), no_block_policy: bool = typer.Option(False, "--no-block-policy", help="Demote argument-policy deny to shadow"), no_block_deterministic: bool = typer.Option(False, "--no-block-deterministic", help="Demote the WHOLE deterministic tier + both gates"), @@ -177,6 +178,7 @@ def guard( no_block_secret_echo=no_block_secret_echo or no_block_deterministic, no_block_exfil_domain=no_block_exfil_domain or allow_exfil_domain or no_block_deterministic, no_block_exfil_ip_literal=no_block_exfil_ip_literal or no_block_deterministic, + no_block_exfil_dns_ssrf=no_block_exfil_dns_ssrf or no_block_deterministic, no_block_list_changed=no_block_list_changed or no_block_deterministic, no_block_policy=no_block_policy or no_block_deterministic, block_inject_phrase=block_inject_phrase, diff --git a/src/mcp_warden/guard_loop.py b/src/mcp_warden/guard_loop.py index f706201..e578e75 100644 --- a/src/mcp_warden/guard_loop.py +++ b/src/mcp_warden/guard_loop.py @@ -86,6 +86,7 @@ class GuardConfig: no_block_secret_echo: bool = False no_block_exfil_domain: bool = False no_block_exfil_ip_literal: bool = False + no_block_exfil_dns_ssrf: bool = False no_block_list_changed: bool = False no_block_policy: bool = False block_inject_phrase: bool = False @@ -116,6 +117,7 @@ class GuardConfig: "WRD-RES-SECRET-ECHO": "no_block_secret_echo", "WRD-RES-EXFIL-DOMAIN": "no_block_exfil_domain", "WRD-RES-EXFIL-IP-LITERAL": "no_block_exfil_ip_literal", + "WRD-RES-EXFIL-DNS-SSRF": "no_block_exfil_dns_ssrf", } def category_enabled(self, rule_id: str) -> bool: diff --git a/src/mcp_warden/guard_result.py b/src/mcp_warden/guard_result.py index 2223d10..24c0c17 100644 --- a/src/mcp_warden/guard_result.py +++ b/src/mcp_warden/guard_result.py @@ -98,6 +98,17 @@ def handle_s2c(state, frame: Frame, mode: str) -> bytes: ) from None state.emit(_frame_error_note("s2c", rpc_id, f"inspect error: {exc}")) return frame.raw + + # WRD-RES-EXFIL-DNS-SSRF: resolve URL hostnames to catch SSRF bypasses + # (fail-open — any DNS error produces no hits, never aborts). + if state.config.category_enabled("WRD-RES-EXFIL-DNS-SSRF"): + try: + dns_findings = _dns_ssrf_findings(result, tool) + if dns_findings: + findings = list(findings) + dns_findings + except Exception as exc: # pragma: no cover + state.emit(_frame_error_note("s2c", rpc_id, f"dns-ssrf error: {exc}")) + return _apply_result_findings(state, frame, mode, rpc_id, tool, result, findings, pol) @@ -179,6 +190,47 @@ def _tool_name_from_result(obj: dict[str, Any]) -> str: return "" +def _dns_ssrf_findings(result: dict, tool: str) -> list[ResultFinding]: + """Resolve URL hostnames in ``result`` blocks; return WRD-RES-EXFIL-DNS-SSRF findings. + + Fail-open: DNS errors within :func:`~mcp_warden.res_dns.resolve_ssrf_hits` + are already swallowed there (return ``[]``). This function raises only if + the catalog/extract plumbing itself fails — caught by the caller. + + Args: + result: The ``tools/call`` result dict. + tool: The tool name (for finding messages). + + Returns: + Per-block ``WRD-RES-EXFIL-DNS-SSRF`` findings (empty if no SSRF hits). + """ + from . import res_catalog, res_dns + + blocks, _ = res_catalog.extract_blocks(result) + if not blocks: + return [] + + # Collect unique candidates across all blocks (resolve once, match per block). + all_candidates: set[str] = set() + for _idx, text in blocks: + all_candidates.update(res_dns.extract_dns_candidates(text)) + if not all_candidates: + return [] + + hits = res_dns.resolve_ssrf_hits(sorted(all_candidates)) + if not hits: + return [] + + hit_map: dict[str, tuple[str, str]] = {h: (ip, lbl) for h, ip, lbl in hits} + findings: list[ResultFinding] = [] + for idx, text in blocks: + block_candidates = res_dns.extract_dns_candidates(text) + block_hits = [(h, hit_map[h][0], hit_map[h][1]) for h in block_candidates if h in hit_map] + if block_hits: + findings.extend(res_catalog.inspect_exfil_dns_ssrf(block_hits, tool, idx)) + return findings + + def _apply_result_findings( state, frame: Frame, @@ -197,7 +249,7 @@ def _apply_result_findings( if state.config.block_inject_phrase and not state.config.audit_only: block_findings += [f for f in findings if f.tier == TIER_MONITOR] - error_rules = {"WRD-RES-EXFIL-DOMAIN", "WRD-RES-INJECT-PHRASE"} + error_rules = {"WRD-RES-EXFIL-DOMAIN", "WRD-RES-INJECT-PHRASE", "WRD-RES-EXFIL-DNS-SSRF"} if not state.config.redact_secret_echo: error_rules.add("WRD-RES-SECRET-ECHO") redact_rules = {"WRD-RES-ANSI"} diff --git a/src/mcp_warden/lockfile.py b/src/mcp_warden/lockfile.py index dab3c3d..34c7495 100644 --- a/src/mcp_warden/lockfile.py +++ b/src/mcp_warden/lockfile.py @@ -47,7 +47,15 @@ def _now_rfc3339() -> str: def _server_identity(surface: CapturedSurface) -> ServerIdentity: - """Build the server identity block + ``command_digest`` (§4.1).""" + """Build the server identity block + ``command_digest`` (§4.1). + + For HTTP/SSE captures (``surface.url`` is set) the digest covers the URL; + command and args remain empty strings so existing overall-digest logic is + unchanged in the stdio path. + """ + if surface.url: + command_digest = hash_value({"url": surface.url}) + return ServerIdentity(url=surface.url, command_digest=command_digest) command_digest = hash_value({"command": surface.command, "args": surface.args}) return ServerIdentity(command=surface.command, args=list(surface.args), command_digest=command_digest) diff --git a/src/mcp_warden/models.py b/src/mcp_warden/models.py index a734219..dd940d4 100644 --- a/src/mcp_warden/models.py +++ b/src/mcp_warden/models.py @@ -44,10 +44,11 @@ class CapturedPrompt(BaseModel): class CapturedSurface(BaseModel): - """The full captured declared surface of an MCP server over stdio.""" + """The full captured declared surface of an MCP server (stdio or HTTP/SSE).""" - command: str + command: str = "" args: list[str] = Field(default_factory=list) + url: str | None = None # set for HTTP/SSE captures; mutually exclusive with command/args protocol_version: str tools: list[CapturedTool] = Field(default_factory=list) resources: list[CapturedResource] = Field(default_factory=list) @@ -113,8 +114,9 @@ class Finding(BaseModel): class ServerIdentity(BaseModel): """Server identity block (WARDEN_LOCK_SCHEMA.md §4).""" - command: str - args: list[str] + command: str = "" + args: list[str] = Field(default_factory=list) + url: str | None = None # set for HTTP/SSE-transport pins command_digest: str diff --git a/src/mcp_warden/res_catalog.py b/src/mcp_warden/res_catalog.py index 4c67ad5..2f6ebc7 100644 --- a/src/mcp_warden/res_catalog.py +++ b/src/mcp_warden/res_catalog.py @@ -187,6 +187,42 @@ def inspect_inject(text: str, tool: str, idx: int, phrases: tuple[str, ...] | li ] +def inspect_exfil_dns_ssrf( + dns_hits: list[tuple[str, str, str]], + tool: str, + idx: int, +) -> list["ResultFinding"]: + """WRD-RES-EXFIL-DNS-SSRF: a URL hostname resolved to a private/SSRF IP at runtime. + + Closes the bypass where ``WRD-RES-EXFIL-IP-LITERAL`` would miss a destination + whose hostname (not a raw IP literal) resolves to a deny-range IP. + + Args: + dns_hits: Pre-resolved ``(host, resolved_ip, range_label)`` tuples from + :func:`mcp_warden.res_dns.resolve_ssrf_hits`. + tool: The tool name (for the finding message). + idx: The content-block index. + + Returns: + A single block-tier finding when ``dns_hits`` is non-empty, else ``[]``. + """ + if not dns_hits: + return [] + detail = ", ".join(f"{h} -> {ip} ({label})" for h, ip, label in dns_hits) + return [ + _RF( + rule_id="WRD-RES-EXFIL-DNS-SSRF", + severity="high", + tier=TIER_BLOCK, + message=( + f"tools/{tool}: result URL hostname(s) resolve to " + f"private/SSRF IP(s): {detail}" + ), + block_index=idx, + ) + ] + + def uninspectable_note(tool: str, idx: int) -> "ResultFinding": """WRD-RES-UNINSPECTABLE (§5.2): a content block could not be inspected.""" return _RF( diff --git a/src/mcp_warden/res_dns.py b/src/mcp_warden/res_dns.py new file mode 100644 index 0000000..8799cd4 --- /dev/null +++ b/src/mcp_warden/res_dns.py @@ -0,0 +1,108 @@ +"""DNS-resolution SSRF bypass detection for runtime result inspection. + +Resolves extracted hostnames at runtime to catch the class of SSRF bypass +where a non-literal host resolves to a private/loopback/metadata IP — a gap +``WRD-RES-EXFIL-IP-LITERAL`` cannot close (that rule only matches raw IP +literals already present in result text). + +Design constraints (POLICY_MODEL.md §5): +- All DNS IO is isolated here; callers that are pure (``inspect_result``) + receive pre-resolved hits, not raw hostnames to resolve themselves. +- Fail-open: any ``OSError``, timeout, or unexpected error returns no hits for + the affected host — the guard continues normally. +- Raw IP literals are skipped (already handled by ``WRD-RES-EXFIL-IP-LITERAL``). +- Resolution is bounded by ``timeout`` seconds across ALL hosts combined. +""" + +from __future__ import annotations + +import concurrent.futures +import socket +from typing import Sequence + +from .net_rules import SSRF_NETWORKS, parse_ip +from .res_net import extract_urls + + +def _resolve_ips(host: str) -> list[str]: + """Return string-form IPs from ``getaddrinfo`` for ``host``, or ``[]``. + + Any ``OSError`` (NXDOMAIN, refused, unreachable) silently returns empty. + + Args: + host: A DNS hostname (never a raw IP literal). + + Returns: + IP address strings for all addresses ``getaddrinfo`` returns. + """ + try: + return [info[4][0] for info in socket.getaddrinfo(host, None)] + except OSError: + return [] + + +def extract_dns_candidates(text: str) -> list[str]: + """Return unique non-IP-literal hostnames from ``scheme://`` URLs in ``text``. + + Only scheme-qualified URLs (``https://``, ``http://``, etc.) are + considered — bare hostname tokens are not resolved (too noisy). Raw IP + literals are filtered out (already handled by ``WRD-RES-EXFIL-IP-LITERAL``). + + Args: + text: Raw result text. + + Returns: + Sorted, de-duplicated list of DNS name candidates. + """ + seen: set[str] = set() + for host, _path, _full in extract_urls(text): + if parse_ip(host) is None and host not in seen: + seen.add(host) + return sorted(seen) + + +def resolve_ssrf_hits( + hosts: Sequence[str], + *, + timeout: float = 1.0, +) -> list[tuple[str, str, str]]: + """Resolve ``hosts`` and return those whose IPs fall in ``SSRF_NETWORKS``. + + DNS lookups run in a :class:`concurrent.futures.ThreadPoolExecutor` so the + total wall-clock is bounded by ``timeout`` seconds (``wait`` semantics: any + host not resolved within the budget is silently skipped — fail-open). + + Args: + hosts: DNS name candidates (no raw IP literals; use + :func:`extract_dns_candidates` to derive these from result text). + timeout: Max seconds to wait for ALL resolutions combined (default 1.0). + + Returns: + Sorted ``(host, resolved_ip_str, range_label)`` for every host that + resolved to at least one IP inside a deny range. Empty on timeout or + error. + """ + if not hosts: + return [] + workers = min(len(hosts), 8) + hits: list[tuple[str, str, str]] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as exe: + futs: dict[concurrent.futures.Future[list[str]], str] = { + exe.submit(_resolve_ips, h): h for h in hosts + } + done, _ = concurrent.futures.wait(futs, timeout=timeout) + for fut in done: + host = futs[fut] + try: + ips = fut.result() + except Exception: # pragma: no cover — _resolve_ips already swallows + continue + for ip_str in ips: + ip = parse_ip(ip_str) + if ip is None: + continue + for net, label in SSRF_NETWORKS: + if ip in net: + hits.append((host, ip_str, label)) + break # first matching range wins per resolved IP + return sorted(hits) diff --git a/src/mcp_warden/result_inspection.py b/src/mcp_warden/result_inspection.py index 65e1631..ca7f6d2 100644 --- a/src/mcp_warden/result_inspection.py +++ b/src/mcp_warden/result_inspection.py @@ -32,7 +32,13 @@ TIER_NOTE = "note" BLOCK_RULES = frozenset( - {"WRD-RES-ANSI", "WRD-RES-SECRET-ECHO", "WRD-RES-EXFIL-DOMAIN", "WRD-RES-EXFIL-IP-LITERAL"} + { + "WRD-RES-ANSI", + "WRD-RES-SECRET-ECHO", + "WRD-RES-EXFIL-DOMAIN", + "WRD-RES-EXFIL-IP-LITERAL", + "WRD-RES-EXFIL-DNS-SSRF", + } ) #: Severity -> SARIF level (CHECKS.md §2), mirrored for WRD-RES-*. diff --git a/tests/test_capture_http.py b/tests/test_capture_http.py new file mode 100644 index 0000000..59a3bd2 --- /dev/null +++ b/tests/test_capture_http.py @@ -0,0 +1,186 @@ +"""Tests for HTTP/SSE capture path (DSE-57). + +Uses unittest.mock to avoid real network calls; verifies that +capture_surface_http routes through streamable_http_client and correctly +constructs CapturedSurface with url set. +""" + +from __future__ import annotations + +import contextlib +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from mcp_warden.capture import ( + CaptureError, + capture_surface_http, + capture_surface_http_sync, +) +from mcp_warden.models import CapturedSurface + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_tool(name: str = "echo", desc: str = "echo tool") -> MagicMock: + t = MagicMock() + t.model_dump.return_value = {"name": name, "description": desc, "inputSchema": None} + return t + + +def _make_resource(uri: str = "res://x") -> MagicMock: + r = MagicMock() + r.model_dump.return_value = {"uri": uri, "name": "x", "description": None, "mimeType": None} + return r + + +def _make_prompt(name: str = "greet") -> MagicMock: + p = MagicMock() + p.model_dump.return_value = {"name": name, "description": None, "arguments": None} + return p + + +def _session_mock(tools=None, resources=None, prompts=None, protocol_version="2024-11-05"): + """Return a mock ClientSession whose list_* methods return given items.""" + session = AsyncMock() + init_result = MagicMock() + init_result.protocolVersion = protocol_version + session.initialize.return_value = init_result + + tools_result = MagicMock() + tools_result.tools = tools or [] + session.list_tools.return_value = tools_result + + resources_result = MagicMock() + resources_result.resources = resources or [] + session.list_resources.return_value = resources_result + + prompts_result = MagicMock() + prompts_result.prompts = prompts or [] + session.list_prompts.return_value = prompts_result + + return session + + +@contextlib.asynccontextmanager +async def _fake_streamable_http(url, **kwargs): + """Async context manager factory that yields the mocked streams.""" + read_stream = AsyncMock() + write_stream = AsyncMock() + get_session_id = MagicMock(return_value=None) + yield read_stream, write_stream, get_session_id + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestCaptureHttpAsync: + @pytest.mark.anyio + async def test_returns_surface_with_url_set(self): + session = _session_mock(tools=[_make_tool()]) + with ( + patch("mcp_warden.capture.streamable_http_client", _fake_streamable_http), + patch("mcp_warden.capture.ClientSession") as MockSession, + ): + MockSession.return_value.__aenter__ = AsyncMock(return_value=session) + MockSession.return_value.__aexit__ = AsyncMock(return_value=False) + + surface = await capture_surface_http("http://localhost:8080/mcp", timeout_s=5.0) + + assert isinstance(surface, CapturedSurface) + assert surface.url == "http://localhost:8080/mcp" + assert surface.command == "" + assert surface.args == [] + assert len(surface.tools) == 1 + assert surface.tools[0].name == "echo" + + @pytest.mark.anyio + async def test_protocol_version_captured(self): + session = _session_mock(protocol_version="2025-03-26") + with ( + patch("mcp_warden.capture.streamable_http_client", _fake_streamable_http), + patch("mcp_warden.capture.ClientSession") as MockSession, + ): + MockSession.return_value.__aenter__ = AsyncMock(return_value=session) + MockSession.return_value.__aexit__ = AsyncMock(return_value=False) + + surface = await capture_surface_http("http://example.com/mcp", timeout_s=5.0) + + assert surface.protocol_version == "2025-03-26" + + @pytest.mark.anyio + async def test_tools_resources_prompts_all_captured(self): + session = _session_mock( + tools=[_make_tool("t1"), _make_tool("t2")], + resources=[_make_resource("res://a")], + prompts=[_make_prompt("p1")], + ) + with ( + patch("mcp_warden.capture.streamable_http_client", _fake_streamable_http), + patch("mcp_warden.capture.ClientSession") as MockSession, + ): + MockSession.return_value.__aenter__ = AsyncMock(return_value=session) + MockSession.return_value.__aexit__ = AsyncMock(return_value=False) + + surface = await capture_surface_http("http://example.com/mcp", timeout_s=5.0) + + assert len(surface.tools) == 2 + assert len(surface.resources) == 1 + assert len(surface.prompts) == 1 + + @pytest.mark.anyio + async def test_connection_error_raises_capture_error(self): + @contextlib.asynccontextmanager + async def _failing_client(url, **kwargs): + raise RuntimeError("connection refused") + yield # pragma: no cover — unreachable; satisfies async-generator protocol + + with ( + patch("mcp_warden.capture.streamable_http_client", _failing_client), + ): + with pytest.raises(CaptureError, match="connection refused"): + await capture_surface_http("http://bad-host/mcp", timeout_s=5.0) + + @pytest.mark.anyio + async def test_timeout_raises_capture_error(self): + import anyio + + @contextlib.asynccontextmanager + async def _slow_client(url, **kwargs): + read_stream = AsyncMock() + write_stream = AsyncMock() + get_session_id = MagicMock(return_value=None) + yield read_stream, write_stream, get_session_id + + async def _slow_session_init(self_arg=None): + await anyio.sleep(999) + + with patch("mcp_warden.capture.streamable_http_client", _slow_client): + with patch("mcp_warden.capture.ClientSession") as MockSession: + slow_session = AsyncMock() + slow_session.initialize = _slow_session_init + MockSession.return_value.__aenter__ = AsyncMock(return_value=slow_session) + MockSession.return_value.__aexit__ = AsyncMock(return_value=False) + + with pytest.raises(CaptureError, match="did not complete the handshake"): + await capture_surface_http("http://slow/mcp", timeout_s=0.05) + + +class TestCaptureHttpSync: + def test_sync_wrapper_returns_surface(self): + session = _session_mock() + with ( + patch("mcp_warden.capture.streamable_http_client", _fake_streamable_http), + patch("mcp_warden.capture.ClientSession") as MockSession, + ): + MockSession.return_value.__aenter__ = AsyncMock(return_value=session) + MockSession.return_value.__aexit__ = AsyncMock(return_value=False) + + surface = capture_surface_http_sync("http://localhost:9000/mcp", timeout_s=5.0) + + assert surface.url == "http://localhost:9000/mcp" + assert isinstance(surface, CapturedSurface) diff --git a/tests/test_guard_posture.py b/tests/test_guard_posture.py index f41fabb..45338aa 100644 --- a/tests/test_guard_posture.py +++ b/tests/test_guard_posture.py @@ -144,6 +144,25 @@ def test_exfil_ip_literal_no_block_deterministic_fold_demotes(): assert GuardConfig(no_block_exfil_ip_literal=True).category_enabled("WRD-RES-EXFIL-IP-LITERAL") is False +# --- WRD-RES-EXFIL-DNS-SSRF category posture (DSE-58) ------------------------- + + +def test_exfil_dns_ssrf_blocks_by_default(): + assert GuardConfig().category_enabled("WRD-RES-EXFIL-DNS-SSRF") is True + + +def test_exfil_dns_ssrf_opt_out_demotes(): + assert GuardConfig(no_block_exfil_dns_ssrf=True).category_enabled("WRD-RES-EXFIL-DNS-SSRF") is False + + +def test_exfil_dns_ssrf_audit_only_demotes(): + assert GuardConfig(audit_only=True).category_enabled("WRD-RES-EXFIL-DNS-SSRF") is False + + +def test_exfil_dns_ssrf_no_block_deterministic_demotes(): + assert GuardConfig(no_block_exfil_dns_ssrf=True).category_enabled("WRD-RES-EXFIL-DNS-SSRF") is False + + # --- framing: both modes round-trip ------------------------------------------ diff --git a/tests/test_res_dns.py b/tests/test_res_dns.py new file mode 100644 index 0000000..2381c36 --- /dev/null +++ b/tests/test_res_dns.py @@ -0,0 +1,141 @@ +"""Tests for res_dns — DNS-resolution SSRF bypass detection (DSE-58).""" + +from __future__ import annotations + +import socket +from unittest.mock import patch + +from mcp_warden.res_dns import extract_dns_candidates, resolve_ssrf_hits + + +# --------------------------------------------------------------------------- +# extract_dns_candidates +# --------------------------------------------------------------------------- + + +def test_extract_candidates_from_https_url(): + hosts = extract_dns_candidates("fetch https://evil.callback.io/data") + assert "evil.callback.io" in hosts + + +def test_extract_candidates_from_http_url(): + hosts = extract_dns_candidates("see http://internal.corp.example.com/meta") + assert "internal.corp.example.com" in hosts + + +def test_extract_candidates_skips_raw_ipv4_literal(): + # Raw IP literals are already handled by WRD-RES-EXFIL-IP-LITERAL. + hosts = extract_dns_candidates("https://169.254.169.254/latest/meta-data") + assert hosts == [] + + +def test_extract_candidates_skips_raw_ipv6_literal(): + hosts = extract_dns_candidates("https://[::1]/admin") + assert hosts == [] + + +def test_extract_candidates_deduplicates(): + text = "https://foo.example.com/a and https://foo.example.com/b" + hosts = extract_dns_candidates(text) + assert hosts.count("foo.example.com") == 1 + + +def test_extract_candidates_multiple_distinct_hosts(): + text = "https://a.example.com/x https://b.example.com/y" + hosts = extract_dns_candidates(text) + assert "a.example.com" in hosts and "b.example.com" in hosts + + +def test_extract_candidates_no_urls_returns_empty(): + assert extract_dns_candidates("no URLs here at all") == [] + + +def test_extract_candidates_returns_sorted(): + text = "https://z.example.com/1 https://a.example.com/2" + hosts = extract_dns_candidates(text) + assert hosts == sorted(hosts) + + +# --------------------------------------------------------------------------- +# resolve_ssrf_hits — mocked socket.getaddrinfo +# --------------------------------------------------------------------------- + + +def _addrinfo(ip: str, family: int = socket.AF_INET) -> list: + """Minimal getaddrinfo result for a single IP.""" + return [(family, socket.SOCK_STREAM, 0, "", (ip, 0))] + + +def test_resolve_detects_link_local_metadata_ip(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("169.254.169.254")) + hits = resolve_ssrf_hits(["bypass.nip.io"]) + assert hits == [("bypass.nip.io", "169.254.169.254", "link-local")] + + +def test_resolve_detects_loopback(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("127.0.0.1")) + hits = resolve_ssrf_hits(["loopback-bypass.example.com"]) + assert hits and hits[0][2] == "loopback" + + +def test_resolve_detects_rfc1918_10_block(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("10.1.2.3")) + hits = resolve_ssrf_hits(["internal.corp"]) + assert hits and "RFC1918" in hits[0][2] + + +def test_resolve_detects_rfc1918_172_block(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("172.16.0.1")) + hits = resolve_ssrf_hits(["vpn-internal.example.com"]) + assert hits and "RFC1918" in hits[0][2] + + +def test_resolve_detects_rfc1918_192_168(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("192.168.1.100")) + hits = resolve_ssrf_hits(["router.local"]) + assert hits and "RFC1918" in hits[0][2] + + +def test_resolve_detects_ipv6_loopback(monkeypatch): + monkeypatch.setattr( + socket, "getaddrinfo", lambda h, p: [(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("::1", 0, 0, 0))] + ) + hits = resolve_ssrf_hits(["loopback6.example.com"]) + assert hits and "loopback" in hits[0][2] + + +def test_resolve_public_ip_produces_no_hits(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("8.8.8.8")) + hits = resolve_ssrf_hits(["public.example.com"]) + assert hits == [] + + +def test_resolve_fail_open_on_nxdomain(monkeypatch): + def _raise(host, port): + raise OSError("NXDOMAIN") + + monkeypatch.setattr(socket, "getaddrinfo", _raise) + hits = resolve_ssrf_hits(["nonexistent.invalid"]) + assert hits == [] + + +def test_resolve_empty_input_returns_empty(): + assert resolve_ssrf_hits([]) == [] + + +def test_resolve_result_is_sorted(monkeypatch): + # Return same IP for any host so all hits land in the result. + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("10.0.0.1")) + hosts = ["z.example.com", "a.example.com", "m.example.com"] + hits = resolve_ssrf_hits(hosts) + assert hits == sorted(hits) + + +def test_resolve_multiple_hosts_all_private(monkeypatch): + def _multi(host, port): + ip = "10.0.0.1" if "a" in host else "127.0.0.1" + return _addrinfo(ip) + + monkeypatch.setattr(socket, "getaddrinfo", _multi) + hits = resolve_ssrf_hits(["alpha.example.com", "beta.example.com"]) + assert len(hits) == 2