diff --git a/Framework/Core/scripts/dpl-mcp-server/dpl_mcp_server.py b/Framework/Core/scripts/dpl-mcp-server/dpl_mcp_server.py index 3900a646632a1..dca5058b01dcd 100644 --- a/Framework/Core/scripts/dpl-mcp-server/dpl_mcp_server.py +++ b/Framework/Core/scripts/dpl-mcp-server/dpl_mcp_server.py @@ -14,19 +14,21 @@ Bridges the DPL driver /status WebSocket endpoint to MCP tools so that an AI assistant (e.g. Claude) can inspect and monitor a running DPL workflow. +Supports multiple concurrent workflows. Use the ``connect`` tool to attach +to a running topology by port or PID, then pass the returned workflow name +to every other tool. + Usage ----- - python3 dpl_mcp_server.py --port 8080 - python3 dpl_mcp_server.py --pid 12345 # port derived as 8080 + pid % 30000 - DPL_STATUS_PORT=8080 python3 dpl_mcp_server.py + python3 dpl_mcp_server.py -Wire protocol (client → driver) +Wire protocol (client -> driver) -------------------------------- {"cmd":"list_metrics","device":""} {"cmd":"subscribe","device":"","metrics":["m1","m2"]} {"cmd":"unsubscribe","device":"","metrics":["m1"]} -Wire protocol (driver → client) +Wire protocol (driver -> client) -------------------------------- {"type":"snapshot","devices":[{"name","pid","active","streamingState","deviceState"},...]} {"type":"update","device":,"name":"","metrics":{}} @@ -35,80 +37,115 @@ from __future__ import annotations -import argparse import asyncio import json -import os -import sys from typing import Any import websockets from mcp.server.fastmcp import FastMCP + # --------------------------------------------------------------------------- -# Global connection state (all access from the single asyncio event loop) +# Per-workflow connection state # --------------------------------------------------------------------------- -_port: int = 8080 -_ws: Any = None -_reader_task: asyncio.Task | None = None -_snapshot: dict = {} -_updates: list[dict] = [] -_logs: list[dict] = [] -_metrics_lists: dict[str, list[str]] = {} - - -async def _ensure_connected() -> None: - """Connect (or reconnect) to the driver's /status WebSocket.""" - global _ws, _reader_task - - # Check liveness of existing connection. - if _ws is not None: +class WorkflowConnection: + """Holds WebSocket connection and buffered state for one DPL workflow.""" + + def __init__(self, port: int, name: str): + self.port = port + self.name = name + self.ws: Any = None + self.reader_task: asyncio.Task | None = None + self.snapshot: dict = {} + self.updates: list[dict] = [] + self.logs: list[dict] = [] + self.metrics_lists: dict[str, list[str]] = {} + + async def ensure_connected(self) -> None: + """Connect (or reconnect) to the driver's /status WebSocket.""" + if self.ws is not None: + try: + pong = await asyncio.wait_for(self.ws.ping(), timeout=2.0) + await pong + return + except Exception: + old_ws = self.ws + self.ws = None + if self.reader_task is not None and not self.reader_task.done(): + self.reader_task.cancel() + try: + await self.reader_task + except (asyncio.CancelledError, Exception): + pass + self.reader_task = None + try: + await old_ws.close() + except Exception: + pass + + url = f"ws://localhost:{self.port}/status" + self.ws = await websockets.connect(url, subprotocols=["dpl"]) + if self.reader_task is None or self.reader_task.done(): + self.reader_task = asyncio.create_task(self._reader()) + + async def _reader(self) -> None: + """Background task: read frames from the driver and buffer them.""" try: - pong = await asyncio.wait_for(_ws.ping(), timeout=2.0) - await pong - return + async for raw in self.ws: + try: + msg = json.loads(raw) + except json.JSONDecodeError: + continue + t = msg.get("type") + if t == "snapshot": + self.snapshot = msg + self.metrics_lists.clear() + elif t == "update": + self.updates.append(msg) + elif t == "log": + self.logs.append(msg) + elif t == "metrics_list": + device = msg.get("device", "") + self.metrics_lists[device] = msg.get("metrics", []) except Exception: - _ws = None - if _reader_task is not None and not _reader_task.done(): - _reader_task.cancel() - _reader_task = None - - url = f"ws://localhost:{_port}/status" - _ws = await websockets.connect(url, subprotocols=["dpl"]) - if _reader_task is None or _reader_task.done(): - _reader_task = asyncio.create_task(_reader()) - - -async def _reader() -> None: - """Background task: read frames from the driver and buffer them.""" - global _ws, _snapshot, _updates, _logs, _metrics_lists - try: - async for raw in _ws: + pass + finally: + self.ws = None + + async def send(self, obj: dict) -> None: + await self.ensure_connected() + await self.ws.send(json.dumps(obj, separators=(",", ":"))) + + async def close(self) -> None: + ws = self.ws + self.ws = None + if self.reader_task is not None and not self.reader_task.done(): + self.reader_task.cancel() try: - msg = json.loads(raw) - except json.JSONDecodeError: - continue - t = msg.get("type") - if t == "snapshot": - _snapshot = msg - # Clear stale metric lists from a previous driver instance. - _metrics_lists.clear() - elif t == "update": - _updates.append(msg) - elif t == "log": - _logs.append(msg) - elif t == "metrics_list": - device = msg.get("device", "") - _metrics_lists[device] = msg.get("metrics", []) - except Exception: - pass - finally: - _ws = None - - -async def _send(obj: dict) -> None: - await _ensure_connected() - await _ws.send(json.dumps(obj, separators=(",", ":"))) + await self.reader_task + except (asyncio.CancelledError, Exception): + pass + self.reader_task = None + if ws is not None: + await ws.close() + + +# --------------------------------------------------------------------------- +# Workflow registry +# --------------------------------------------------------------------------- +_workflows: dict[str, WorkflowConnection] = {} + + +def _get(workflow: str) -> WorkflowConnection: + """Look up a workflow by name, raising a clear error if not found.""" + conn = _workflows.get(workflow) + if conn is None: + available = ", ".join(_workflows.keys()) if _workflows else "(none)" + raise ValueError( + f"No workflow named '{workflow}'. Connected workflows: {available}. " + f"Use the connect tool first." + ) + return conn # --------------------------------------------------------------------------- @@ -118,16 +155,81 @@ async def _send(obj: dict) -> None: @mcp.tool() -async def list_devices() -> str: +async def connect(port: int = 0, pid: int = 0, name: str = "") -> str: + """Connect to a running DPL workflow. + + Provide either ``port`` (the driver's WebSocket port) or ``pid`` (the + driver PID, port derived as 8080 + pid % 30000). An optional ``name`` + gives the workflow a human-friendly label; if omitted the port number is + used. + + Args: + port: TCP port of the DPL driver status WebSocket. + pid: PID of the DPL driver process (alternative to port). + name: Optional human-friendly name for this workflow. + """ + if pid: + port = 8080 + pid % 30000 + if not port: + return "Provide either port or pid." + + wf_name = name or str(port) + if wf_name in _workflows: + old = _workflows[wf_name] + await old.close() + + conn = WorkflowConnection(port, wf_name) + await conn.ensure_connected() + _workflows[wf_name] = conn + + devices = conn.snapshot.get("devices", []) + return ( + f"Connected to workflow '{wf_name}' on port {port} " + f"({len(devices)} device(s))." + ) + + +@mcp.tool() +async def disconnect(workflow: str) -> str: + """Disconnect from a DPL workflow and release its resources. + + Args: + workflow: Workflow name as returned by connect. + """ + conn = _get(workflow) + await conn.close() + del _workflows[workflow] + return f"Disconnected from workflow '{workflow}'." + + +@mcp.tool() +async def list_workflows() -> str: + """List all currently connected DPL workflows.""" + if not _workflows: + return "No workflows connected. Use the connect tool first." + lines = [] + for wf_name, conn in _workflows.items(): + n = len(conn.snapshot.get("devices", [])) + status = "connected" if conn.ws is not None else "disconnected" + lines.append(f"{wf_name}: port={conn.port} devices={n} status={status}") + return "\n".join(lines) + + +@mcp.tool() +async def list_devices(workflow: str) -> str: """List all DPL devices with their current status. Returns each device's name, PID, active flag, streaming state, and device state as reported by the driver snapshot. + + Args: + workflow: Workflow name as returned by connect. """ - await _ensure_connected() - if not _snapshot: - return "No snapshot received yet — the driver may still be starting." - devices = _snapshot.get("devices", []) + conn = _get(workflow) + await conn.ensure_connected() + if not conn.snapshot: + return "No snapshot received yet -- the driver may still be starting." + devices = conn.snapshot.get("devices", []) if not devices: return "No devices in snapshot." lines = [] @@ -140,7 +242,7 @@ async def list_devices() -> str: @mcp.tool() -async def list_metrics(device: str) -> str: +async def list_metrics(workflow: str, device: str) -> str: """List the available numeric metrics for a DPL device. Sends a list_metrics command to the driver and waits up to 3 seconds for @@ -148,15 +250,16 @@ async def list_metrics(device: str) -> str: and enum metrics are excluded. Args: + workflow: Workflow name as returned by connect. device: Device name exactly as shown by list_devices. """ - # Remove any stale cached result so we can detect the fresh reply. - _metrics_lists.pop(device, None) - await _send({"cmd": "list_metrics", "device": device}) + conn = _get(workflow) + conn.metrics_lists.pop(device, None) + await conn.send({"cmd": "list_metrics", "device": device}) for _ in range(60): # up to 3 s await asyncio.sleep(0.05) - if device in _metrics_lists: - names = _metrics_lists[device] + if device in conn.metrics_lists: + names = conn.metrics_lists[device] if not names: return f"Device '{device}' has no numeric metrics yet." return f"{len(names)} metric(s): " + ", ".join(names) @@ -164,7 +267,7 @@ async def list_metrics(device: str) -> str: @mcp.tool() -async def subscribe(device: str, metrics: list[str]) -> str: +async def subscribe(workflow: str, device: str, metrics: list[str]) -> str: """Subscribe to one or more metrics for a DPL device. After subscribing, the driver will push update frames for the device @@ -172,60 +275,70 @@ async def subscribe(device: str, metrics: list[str]) -> str: the buffer. Args: + workflow: Workflow name as returned by connect. device: Device name exactly as shown by list_devices. metrics: List of metric names to subscribe to (from list_metrics). """ - await _send({"cmd": "subscribe", "device": device, "metrics": metrics}) + conn = _get(workflow) + await conn.send({"cmd": "subscribe", "device": device, "metrics": metrics}) return f"Subscribed to {len(metrics)} metric(s) for '{device}': {', '.join(metrics)}" @mcp.tool() -async def unsubscribe(device: str, metrics: list[str]) -> str: +async def unsubscribe(workflow: str, device: str, metrics: list[str]) -> str: """Stop receiving updates for specific metrics of a DPL device. Args: + workflow: Workflow name as returned by connect. device: Device name exactly as shown by list_devices. metrics: List of metric names to unsubscribe from. """ - await _send({"cmd": "unsubscribe", "device": device, "metrics": metrics}) + conn = _get(workflow) + await conn.send({"cmd": "unsubscribe", "device": device, "metrics": metrics}) return f"Unsubscribed from {len(metrics)} metric(s) for '{device}'." @mcp.tool() -async def subscribe_logs(device: str) -> str: +async def subscribe_logs(workflow: str, device: str) -> str: """Subscribe to log output for a DPL device. After subscribing, new log lines from the device will be buffered and can be retrieved with get_logs(). Args: + workflow: Workflow name as returned by connect. device: Device name exactly as shown by list_devices. """ - await _send({"cmd": "subscribe_logs", "device": device}) + conn = _get(workflow) + await conn.send({"cmd": "subscribe_logs", "device": device}) return f"Subscribed to logs for '{device}'." @mcp.tool() -async def unsubscribe_logs(device: str) -> str: +async def unsubscribe_logs(workflow: str, device: str) -> str: """Stop receiving log output for a DPL device. Args: + workflow: Workflow name as returned by connect. device: Device name exactly as shown by list_devices. """ - await _send({"cmd": "unsubscribe_logs", "device": device}) + conn = _get(workflow) + await conn.send({"cmd": "unsubscribe_logs", "device": device}) return f"Unsubscribed from logs for '{device}'." @mcp.tool() -async def get_logs(max_lines: int = 100) -> str: +async def get_logs(workflow: str, max_lines: int = 100) -> str: """Drain and return buffered log lines received since the last call. Args: + workflow: Workflow name as returned by connect. max_lines: Maximum number of log lines to return (default 100). """ - await _ensure_connected() - batch = _logs[:max_lines] - del _logs[:max_lines] + conn = _get(workflow) + await conn.ensure_connected() + batch = conn.logs[:max_lines] + del conn.logs[:max_lines] if not batch: return "No buffered log lines." lines = [] @@ -238,17 +351,21 @@ async def get_logs(max_lines: int = 100) -> str: @mcp.tool() -async def start_devices() -> str: +async def start_devices(workflow: str) -> str: """Resume all stopped DPL devices (send SIGCONT). Use this when the workflow was started with -s (all devices paused). + + Args: + workflow: Workflow name as returned by connect. """ - await _send({"cmd": "start_devices"}) + conn = _get(workflow) + await conn.send({"cmd": "start_devices"}) return "Sent SIGCONT to all active devices." @mcp.tool() -async def enable_signpost(device: str, streams: list[str]) -> str: +async def enable_signpost(workflow: str, device: str, streams: list[str]) -> str: """Enable one or more signpost log streams for a DPL device. Signpost streams produce detailed trace output visible in the device logs. @@ -259,27 +376,31 @@ async def enable_signpost(device: str, streams: list[str]) -> str: ch.cern.aliceo2.data_processor_context, ch.cern.aliceo2.stream_context. Args: + workflow: Workflow name as returned by connect. device: Device name as shown by list_devices, or "" for the driver. streams: List of full signpost log names to enable. """ - await _send({"cmd": "enable_signpost", "device": device, "streams": streams}) + conn = _get(workflow) + await conn.send({"cmd": "enable_signpost", "device": device, "streams": streams}) return f"Enabled {len(streams)} signpost stream(s) for '{device or 'driver'}': {', '.join(streams)}" @mcp.tool() -async def disable_signpost(device: str, streams: list[str]) -> str: +async def disable_signpost(workflow: str, device: str, streams: list[str]) -> str: """Disable one or more signpost log streams for a DPL device. Args: + workflow: Workflow name as returned by connect. device: Device name as shown by list_devices, or "" for the driver. streams: List of full signpost log names to disable. """ - await _send({"cmd": "disable_signpost", "device": device, "streams": streams}) + conn = _get(workflow) + await conn.send({"cmd": "disable_signpost", "device": device, "streams": streams}) return f"Disabled {len(streams)} signpost stream(s) for '{device or 'driver'}': {', '.join(streams)}" @mcp.tool() -async def get_updates(max_updates: int = 50) -> str: +async def get_updates(workflow: str, max_updates: int = 50) -> str: """Drain and return buffered metric update frames received since the last call. Each frame contains the latest values of all subscribed metrics that @@ -287,11 +408,13 @@ async def get_updates(max_updates: int = 50) -> str: time-ordered view of metric evolution. Args: + workflow: Workflow name as returned by connect. max_updates: Maximum number of update frames to return (default 50). """ - await _ensure_connected() - batch = _updates[:max_updates] - del _updates[:max_updates] + conn = _get(workflow) + await conn.ensure_connected() + batch = conn.updates[:max_updates] + del conn.updates[:max_updates] if not batch: return "No buffered updates." lines = [] @@ -310,34 +433,6 @@ async def get_updates(max_updates: int = 50) -> str: # Entry point # --------------------------------------------------------------------------- def main() -> None: - global _port - - parser = argparse.ArgumentParser( - description="DPL status MCP server — expose DPL driver metrics via MCP tools" - ) - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--port", - type=int, - default=None, - help="TCP port of the DPL driver status WebSocket (default: 8080 or DPL_STATUS_PORT env var)", - ) - group.add_argument( - "--pid", - type=int, - default=None, - help="PID of the DPL driver process; port is derived as 8080 + pid %% 30000", - ) - args = parser.parse_args() - - if args.pid is not None: - _port = 8080 + args.pid % 30000 - elif args.port is not None: - _port = args.port - elif "DPL_STATUS_PORT" in os.environ: - _port = int(os.environ["DPL_STATUS_PORT"]) - # else leave _port at the default 8080 - mcp.run()