diff --git a/.gitignore b/.gitignore index 74906da..f91d6c8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,7 @@ cython_debug/ ## custom commands.md compose.yml -config.json \ No newline at end of file +config.json +.qwen/ +tools.json +response.json diff --git a/compose.yml b/compose.yml index b48a1d4..df4beba 100644 --- a/compose.yml +++ b/compose.yml @@ -7,21 +7,20 @@ services: - path: mcp_bridge action: rebuild container_name: mcp-bridge - ports: - - "8000:8000" + network_mode: host environment: - MCP_BRIDGE__CONFIG__FILE=config.json # mount the config file for this to work # - MCP_BRIDGE__CONFIG__HTTP_URL=http://10.88.100.170:8888/config.json # - MCP_BRIDGE__CONFIG__JSON= - # volumes: - # - ./config.json:/mcp_bridge/config.json + volumes: + - ./config.json:/mcp_bridge/config.json restart: unless-stopped jaeger: image: jaegertracing/jaeger:latest ports: - "16686:16686" # Web UI - # - "4317:4317" # OTLP gRPC + - "4317:4317" # OTLP gRPC - "4318:4318" # OTLP HTTP # - "5778:5778" # Config server # - "9411:9411" # Zipkin compatible diff --git a/mcp_bridge/mcp_clients/session.py b/mcp_bridge/mcp_clients/session.py index 56d1f94..061f038 100644 --- a/mcp_bridge/mcp_clients/session.py +++ b/mcp_bridge/mcp_clients/session.py @@ -1,281 +1,256 @@ -from datetime import timedelta -from typing import Awaitable, Callable - -from loguru import logger -import mcp.types as types -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.shared.session import BaseSession, RequestResponder -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from pydantic import AnyUrl - -from mcp_bridge import __version__ as version -from mcp_bridge.sampling.sampler import handle_sampling_message - -sampling_function_signature = Callable[ - [types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult] -] - - -class McpClientSession( - BaseSession[ - types.ClientRequest, - types.ClientNotification, - types.ClientResult, - types.ServerRequest, - types.ServerNotification, - ] -): - - def __init__( - self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], - read_timeout_seconds: timedelta | None = None, - ) -> None: - super().__init__( - read_stream, - write_stream, - types.ServerRequest, - types.ServerNotification, - read_timeout_seconds=read_timeout_seconds, - ) - - async def __aenter__(self): - session = await super().__aenter__() - self._task_group.start_soon(self._consume_messages) - return session - - async def _consume_messages(self): - try: - async for message in self.incoming_messages: - try: - if isinstance(message, Exception): - logger.error(f"Received exception in message stream: {message}") - elif isinstance(message, RequestResponder): - logger.debug(f"Received request: {message.request}") - elif isinstance(message, types.ServerNotification): - if isinstance(message.root, types.LoggingMessageNotification): - logger.debug(f"Received notification from server: {message.root.params}") - else: - logger.debug(f"Received notification from server: {message}") - else: - logger.debug(f"Received notification: {message}") - except Exception as e: - logger.exception(f"Error processing message: {e}") - except Exception as e: - logger.exception(f"Message consumer task failed: {e}") - - async def initialize(self) -> types.InitializeResult: - result = await self.send_request( - types.ClientRequest( - types.InitializeRequest( - method="initialize", - params=types.InitializeRequestParams( - protocolVersion=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities( - sampling=types.SamplingCapability(), - experimental=None, - roots=types.RootsCapability( - listChanged=True - ), - ), - clientInfo=types.Implementation(name="MCP-Bridge", version=version), - ), - ) - ), - types.InitializeResult, - ) - - if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: - raise RuntimeError( - "Unsupported protocol version from the server: " - f"{result.protocolVersion}" - ) - - await self.send_notification( - types.ClientNotification( - types.InitializedNotification(method="notifications/initialized") - ) - ) - - return result - - async def send_ping(self) -> types.EmptyResult: - """Send a ping request.""" - return await self.send_request( - types.ClientRequest( - types.PingRequest( - method="ping", - ) - ), - types.EmptyResult, - ) - - async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None - ) -> None: - """Send a progress notification.""" - await self.send_notification( - types.ClientNotification( - types.ProgressNotification( - method="notifications/progress", - params=types.ProgressNotificationParams( - progressToken=progress_token, - progress=progress, - total=total, - ), - ), - ) - ) - - async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: - """Send a logging/setLevel request.""" - return await self.send_request( - types.ClientRequest( - types.SetLevelRequest( - method="logging/setLevel", - params=types.SetLevelRequestParams(level=level), - ) - ), - types.EmptyResult, - ) - - async def list_resources(self) -> types.ListResourcesResult: - """Send a resources/list request.""" - return await self.send_request( - types.ClientRequest( - types.ListResourcesRequest( - method="resources/list", - ) - ), - types.ListResourcesResult, - ) - - async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: - """Send a resources/read request.""" - return await self.send_request( - types.ClientRequest( - types.ReadResourceRequest( - method="resources/read", - params=types.ReadResourceRequestParams(uri=uri), - ) - ), - types.ReadResourceResult, - ) - - async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: - """Send a resources/subscribe request.""" - return await self.send_request( - types.ClientRequest( - types.SubscribeRequest( - method="resources/subscribe", - params=types.SubscribeRequestParams(uri=uri), - ) - ), - types.EmptyResult, - ) - - async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: - """Send a resources/unsubscribe request.""" - return await self.send_request( - types.ClientRequest( - types.UnsubscribeRequest( - method="resources/unsubscribe", - params=types.UnsubscribeRequestParams(uri=uri), - ) - ), - types.EmptyResult, - ) - - async def call_tool( - self, name: str, arguments: dict | None = None - ) -> types.CallToolResult: - """Send a tools/call request.""" - return await self.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams(name=name, arguments=arguments), - ) - ), - types.CallToolResult, - ) - - async def list_prompts(self) -> types.ListPromptsResult: - """Send a prompts/list request.""" - return await self.send_request( - types.ClientRequest( - types.ListPromptsRequest( - method="prompts/list", - ) - ), - types.ListPromptsResult, - ) - - async def get_prompt( - self, name: str, arguments: dict[str, str] | None = None - ) -> types.GetPromptResult: - """Send a prompts/get request.""" - return await self.send_request( - types.ClientRequest( - types.GetPromptRequest( - method="prompts/get", - params=types.GetPromptRequestParams(name=name, arguments=arguments), - ) - ), - types.GetPromptResult, - ) - - async def complete( - self, ref: types.ResourceReference | types.PromptReference, argument: dict - ) -> types.CompleteResult: - """Send a completion/complete request.""" - return await self.send_request( - types.ClientRequest( - types.CompleteRequest( - method="completion/complete", - params=types.CompleteRequestParams( - ref=ref, - argument=types.CompletionArgument(**argument), - ), - ) - ), - types.CompleteResult, - ) - - async def list_tools(self) -> types.ListToolsResult: - """Send a tools/list request.""" - return await self.send_request( - types.ClientRequest( - types.ListToolsRequest( - method="tools/list", - ) - ), - types.ListToolsResult, - ) - - async def send_roots_list_changed(self) -> None: - """Send a roots/list_changed notification.""" - await self.send_notification( - types.ClientNotification( - types.RootsListChangedNotification( - method="notifications/roots/list_changed", - ) - ) - ) - - async def _received_request( - self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] - ) -> None: - if isinstance(responder.request.root, types.CreateMessageRequest): - # handle create message request (sampling) - response = await self.sample(responder.request.root.params) - client_response = types.ClientResult(**response.model_dump()) - await responder.respond(client_response) - - async def sample(self, params: types.CreateMessageRequestParams) -> types.CreateMessageResult: - logger.info("got sampling request from mcp server") - resp = await handle_sampling_message(params) - logger.info("finished sampling request from mcp server") - return resp +from datetime import timedelta +from typing import Awaitable, Callable + +from loguru import logger +import mcp.types as types +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.shared.session import BaseSession, RequestResponder +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from pydantic import AnyUrl + +from mcp_bridge import __version__ as version +from mcp_bridge.sampling.sampler import handle_sampling_message + +sampling_function_signature = Callable[ + [types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult] +] + + +class McpClientSession( + BaseSession[ + types.ClientRequest, + types.ClientNotification, + types.ClientResult, + types.ServerRequest, + types.ServerNotification, + ] +): + + def __init__( + self, + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_timeout_seconds: timedelta | None = None, + ) -> None: + super().__init__( + read_stream, + write_stream, + types.ServerRequest, + types.ServerNotification, + read_timeout_seconds=read_timeout_seconds, + ) + + async def initialize(self) -> types.InitializeResult: + result = await self.send_request( + types.ClientRequest( + types.InitializeRequest( + method="initialize", + params=types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( + sampling=types.SamplingCapability(), + experimental=None, + roots=types.RootsCapability( + listChanged=True + ), + ), + clientInfo=types.Implementation(name="MCP-Bridge", version=version), + ), + ) + ), + types.InitializeResult, + ) + + if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: + raise RuntimeError( + "Unsupported protocol version from the server: " + f"{result.protocolVersion}" + ) + + await self.send_notification( + types.ClientNotification( + types.InitializedNotification(method="notifications/initialized") + ) + ) + + return result + + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + return await self.send_request( + types.ClientRequest( + types.PingRequest( + method="ping", + ) + ), + types.EmptyResult, + ) + + async def send_progress_notification( + self, progress_token: str | int, progress: float, total: float | None = None + ) -> None: + """Send a progress notification.""" + await self.send_notification( + types.ClientNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=progress_token, + progress=progress, + total=total, + ), + ), + ) + ) + + async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: + """Send a logging/setLevel request.""" + return await self.send_request( + types.ClientRequest( + types.SetLevelRequest( + method="logging/setLevel", + params=types.SetLevelRequestParams(level=level), + ) + ), + types.EmptyResult, + ) + + async def list_resources(self) -> types.ListResourcesResult: + """Send a resources/list request.""" + return await self.send_request( + types.ClientRequest( + types.ListResourcesRequest( + method="resources/list", + ) + ), + types.ListResourcesResult, + ) + + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + """Send a resources/read request.""" + return await self.send_request( + types.ClientRequest( + types.ReadResourceRequest( + method="resources/read", + params=types.ReadResourceRequestParams(uri=uri), + ) + ), + types.ReadResourceResult, + ) + + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/subscribe request.""" + return await self.send_request( + types.ClientRequest( + types.SubscribeRequest( + method="resources/subscribe", + params=types.SubscribeRequestParams(uri=uri), + ) + ), + types.EmptyResult, + ) + + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + return await self.send_request( + types.ClientRequest( + types.UnsubscribeRequest( + method="resources/unsubscribe", + params=types.UnsubscribeRequestParams(uri=uri), + ) + ), + types.EmptyResult, + ) + + async def call_tool( + self, name: str, arguments: dict | None = None + ) -> types.CallToolResult: + """Send a tools/call request.""" + return await self.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name=name, arguments=arguments), + ) + ), + types.CallToolResult, + ) + + async def list_prompts(self) -> types.ListPromptsResult: + """Send a prompts/list request.""" + return await self.send_request( + types.ClientRequest( + types.ListPromptsRequest( + method="prompts/list", + ) + ), + types.ListPromptsResult, + ) + + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> types.GetPromptResult: + """Send a prompts/get request.""" + return await self.send_request( + types.ClientRequest( + types.GetPromptRequest( + method="prompts/get", + params=types.GetPromptRequestParams(name=name, arguments=arguments), + ) + ), + types.GetPromptResult, + ) + + async def complete( + self, ref: types.ResourceReference | types.PromptReference, argument: dict + ) -> types.CompleteResult: + """Send a completion/complete request.""" + return await self.send_request( + types.ClientRequest( + types.CompleteRequest( + method="completion/complete", + params=types.CompleteRequestParams( + ref=ref, + argument=types.CompletionArgument(**argument), + ), + ) + ), + types.CompleteResult, + ) + + async def list_tools(self) -> types.ListToolsResult: + """Send a tools/list request.""" + return await self.send_request( + types.ClientRequest( + types.ListToolsRequest( + method="tools/list", + ) + ), + types.ListToolsResult, + ) + + async def send_roots_list_changed(self) -> None: + """Send a roots/list_changed notification.""" + await self.send_notification( + types.ClientNotification( + types.RootsListChangedNotification( + method="notifications/roots/list_changed", + ) + ) + ) + + async def _received_request( + self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] + ) -> None: + if isinstance(responder.request.root, types.CreateMessageRequest): + # handle create message request (sampling) + response = await self.sample(responder.request.root.params) + client_response = types.ClientResult(**response.model_dump()) + await responder.respond(client_response) + + async def sample(self, params: types.CreateMessageRequestParams) -> types.CreateMessageResult: + logger.info("got sampling request from mcp server") + resp = await handle_sampling_message(params) + logger.info("finished sampling request from mcp server") + return resp \ No newline at end of file diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..c91a9ce --- /dev/null +++ b/test.sh @@ -0,0 +1,61 @@ +# Test script + +# curl http://localhost:11410/v1/models | jq '.data[]' +# exit 0 + +# qwen3.6:27b-q8_0 minimax-m3:cloud minimax-m3:cloud llama3.2:latest llama3.2:1b +# Use MCP Google Search Weather to answer: What is the weather forecast for this week? Please search for current conditions. + +curl -X POST http://localhost:11410/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llama3.2:1b", + "messages": [ + { + "role": "user", + "content": "More MCPs Added, refresh MCP Tools. List all the MCP Tools in details. What are the commands you have access to in a list." + } + ] + }' > response.json + +cat response.json | jq --raw-output '.choices[0].message.content' + + + + +# docker compose build --no-cache +# docker compose up -d + +# docker compose logs mcp-bridge + + + + + # "filesystem": { + # "command": "npx", + # "args": [ + # "-y", + # "@modelcontextprotocol/server-filesystem", + # "/home/sameer/projects" + # ] + # }, + # "google-search": { + # "type": "http", + # "url": "http://latitude:11403/mcp", + # "auth": { + # "type": "none" + # }, + # "requestTimeout": 10000 + # }, + # "google-stitch-proxy": { + # "command": "npx", + # "args": [ + # "-y", + # "@_davideast/stitch-mcp", + # "serve", + # "--port", "11401" + # ], + # "env": { + # "STITCH_API_KEY": "API_KEY" + # } + # } \ No newline at end of file