From 3b0fe3dcb0ddc52335cf398aa56e4f6a6c999b3e Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Fri, 5 Jun 2026 11:02:20 +0200 Subject: [PATCH 1/2] Added tool processing module for pydantic-ai --- src/utils/agents/tool_processor.py | 528 ++++++++++++++++ .../unit/utils/agents/test_tool_processor.py | 597 ++++++++++++++++++ 2 files changed, 1125 insertions(+) create mode 100644 src/utils/agents/tool_processor.py create mode 100644 tests/unit/utils/agents/test_tool_processor.py diff --git a/src/utils/agents/tool_processor.py b/src/utils/agents/tool_processor.py new file mode 100644 index 000000000..59377b763 --- /dev/null +++ b/src/utils/agents/tool_processor.py @@ -0,0 +1,528 @@ +"""Process and record pydantic-ai tool parts during agent stream dispatch.""" + +from __future__ import annotations + +import json +from typing import Any, Optional, cast + +from openai.types.responses.response_file_search_tool_call import ( + Result as OpenAIFileSearchResult, +) +from pydantic import AnyUrl +from pydantic_ai.messages import ( + NativeToolCallPart, + NativeToolReturnPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.native_tools import FileSearchTool, MCPServerTool, WebSearchTool + +from constants import DEFAULT_RAG_TOOL +from log import get_logger +from models.common.agents import AgentTurnAccumulator +from models.common.turn_summary import ( + MCPListToolsSummary, + RAGChunk, + ReferencedDocument, + ToolCallSummary, + ToolInfoSummary, + ToolResultSummary, +) +from utils.responses import resolve_source_for_result + +logger = get_logger(__name__) + +_FILE_SEARCH_URL_KEYS = ("doc_url", "docs_url", "url", "link", "reference_url") + + +def summarize_function_tool_call(part: ToolCallPart) -> ToolCallSummary: + """Build a tool-call summary for a client function tool call. + + Args: + part: Function tool call part emitted by the agent. + + Returns: + Tool call summary in LCS turn-summary format. + """ + return ToolCallSummary( + id=part.tool_call_id, + name=part.tool_name, + args=part.args_as_dict(), + type="function_call", + ) + + +def summarize_native_tool_call( + part: NativeToolCallPart, +) -> Optional[ToolCallSummary]: + """Build a tool-call summary for a native agent tool call. + + Args: + part: Native tool call part emitted by the model. + + Returns: + Tool call summary in LCS turn-summary format. + """ + call_id = part.tool_call_id + args = part.args_as_dict() + match part.tool_name: + case WebSearchTool.kind: + return ToolCallSummary( + id=call_id, + name=part.tool_name, + args=args, + type="web_search_call", + ) + case FileSearchTool.kind: + return ToolCallSummary( + id=call_id, + name=DEFAULT_RAG_TOOL, + args=args, + type="file_search_call", + ) + case MCPServerTool.kind: + label = part.tool_name.removeprefix(f"{MCPServerTool.kind}:") + action = args.get("action") + # MCP list tools + if action == "list_tools": + return ToolCallSummary( + id=call_id, + name="mcp_list_tools", + args={"server_label": label}, + type="mcp_list_tools", + ) + + # MCP call + return ToolCallSummary( + id=call_id, + name=args.get("tool_name") or "", + args=args.get("tool_args", {}), + type="mcp_call", + ) + case _: + logger.warning(f"Unknown tool name: {part.tool_name}") + return None + + +def process_function_tool_call( + state: AgentTurnAccumulator, + part: ToolCallPart, +) -> Optional[ToolCallSummary]: + """Record a client function tool call on dispatch state. + + Args: + state: Mutable dispatch reducer state. + part: Function tool call part from the agent. + + Returns: + Tool call summary when recorded, otherwise None if already emitted. + """ + if part.tool_call_id in state.emitted_tool_call_ids: + return None + summary = summarize_function_tool_call(part) + state.increment_round_if_pending() + state.emitted_tool_call_ids.add(summary.id) + state.turn_summary.tool_calls.append(summary) + return summary + + +def process_native_tool_call( + state: AgentTurnAccumulator, + part: NativeToolCallPart, +) -> Optional[ToolCallSummary]: + """Record a native tool call on dispatch state. + + Args: + state: Mutable dispatch reducer state. + part: Native tool call part from the model. + + Returns: + Tool call summary when recorded, otherwise None if already emitted. + """ + if part.tool_call_id in state.emitted_tool_call_ids: + return None + if summary := summarize_native_tool_call(part): + state.increment_round_if_pending() + state.emitted_tool_call_ids.add(summary.id) + state.turn_summary.tool_calls.append(summary) + return summary + return None + + +def process_native_tool_result( + state: AgentTurnAccumulator, + part: NativeToolReturnPart, +) -> Optional[ToolResultSummary]: + """Record a native tool return on dispatch state. + + Args: + state: Mutable dispatch reducer state. + part: Native tool return part from the model. + + Returns: + Tool result summary when recorded, otherwise None if already emitted. + """ + if part.tool_call_id in state.emitted_tool_result_ids: + return None + + match part.tool_name: + case FileSearchTool.kind: + tool_result, rag_chunks, referenced_documents = ( + summarize_file_search_result( + part, + state.tool_round, + state.seen_docs, + state.vector_store_ids, + state.rag_id_mapping, + ) + ) + state.turn_summary.rag_chunks.extend(rag_chunks) + state.turn_summary.referenced_documents.extend(referenced_documents) + case WebSearchTool.kind: + tool_result = summarize_web_search_result(part, state.tool_round) + case MCPServerTool.kind: + tool_result = summarize_mcp_tool_result(part, state.tool_round) + case _: + logger.warning(f"Unknown tool name: {part.tool_name}") + return None + + state.emitted_tool_result_ids.add(tool_result.id) + state.turn_summary.tool_results.append(tool_result) + state.round_increment_pending = True + return tool_result + + +def process_function_tool_result( + state: AgentTurnAccumulator, + part: ToolReturnPart, +) -> Optional[ToolResultSummary]: + """Record a client function tool return on dispatch state. + + Args: + state: Mutable dispatch reducer state. + part: Function tool return part from the agent. + + Returns: + Tool result summary when recorded, otherwise None if already emitted. + """ + if part.tool_call_id in state.emitted_tool_result_ids: + return None + tool_result = summarize_function_tool_result(part, state.tool_round) + state.emitted_tool_result_ids.add(tool_result.id) + state.turn_summary.tool_results.append(tool_result) + state.round_increment_pending = True + return tool_result + + +def summarize_function_tool_result( + part: ToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary for a client function tool return. + + Args: + part: Function tool return part emitted by the agent. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + return ToolResultSummary( + id=part.tool_call_id, + status="success", + content=part.model_response_str(), + type="function_call_output", + round=tool_round, + ) + + +def referenced_documents_from_file_search_results( + results: list[OpenAIFileSearchResult], + seen_docs: set[tuple[str, str]], + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> list[ReferencedDocument]: + """Parse referenced documents from OpenAI file-search result rows. + + Args: + results: Validated file-search result rows. + seen_docs: Dedupe keys already emitted; updated in place. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Newly discovered referenced documents from these result rows. + """ + documents: list[ReferencedDocument] = [] + for result in results: + doc = build_referenced_document(result, vector_store_ids, rag_id_mapping) + if doc is None: + continue + + dedup_key = (str(doc.doc_url or ""), doc.doc_title or "") + if dedup_key in seen_docs: + continue + + seen_docs.add(dedup_key) + documents.append(doc) + + return documents + + +def build_referenced_document( + result: OpenAIFileSearchResult, + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> Optional[ReferencedDocument]: + """Build one referenced document from a single file-search result row. + + Args: + result: OpenAI file-search result row. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Referenced document when metadata is present, otherwise None. + """ + attributes = result.attributes or {} + + doc_url = _file_search_attribute_url(attributes) + doc_title = _file_search_attribute_str(attributes, "title") + if not (doc_title or doc_url): + return None + + doc_id = _file_search_attribute_str( + attributes, "document_id" + ) or _file_search_attribute_str(attributes, "doc_id") + return ReferencedDocument( + doc_url=AnyUrl(doc_url) if doc_url else None, + doc_title=doc_title, + source=resolve_source_for_result(attributes, vector_store_ids, rag_id_mapping), + document_id=doc_id, + ) + + +def _file_search_attribute_str( + attributes: dict[str, str | float | bool], + key: str, +) -> Optional[str]: + """Read a non-empty string metadata field from file-search attributes. + + Args: + attributes: File-search result metadata attributes. + key: Metadata key to read. + + Returns: + Non-empty string value for the key, or None. + """ + return str(value) if (value := attributes.get(key)) else None + + +def _file_search_attribute_url( + attributes: dict[str, str | float | bool], +) -> Optional[str]: + """Extract the first available document URL from file-search attributes. + + Args: + attributes: File-search result metadata attributes. + + Returns: + First matching URL value as a string, or None. + """ + for key in _FILE_SEARCH_URL_KEYS: + if url := _file_search_attribute_str(attributes, key): + return url + return None + + +def rag_chunks_from_file_search_results( + results: list[OpenAIFileSearchResult], + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> list[RAGChunk]: + """Extract RAG chunks from OpenAI file-search result rows. + + Args: + results: Validated file-search result rows. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + RAG chunks extracted from these result rows. + """ + return [ + RAGChunk( + content=result.text, + source=resolve_source_for_result( + result.attributes or {}, vector_store_ids, rag_id_mapping + ), + score=result.score, + attributes=result.attributes or None, + ) + for result in results + if result.text + ] + + +def summarize_web_search_result( + part: NativeToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary from a native web-search return. + + Args: + part: Native web-search tool return part from the model stream. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + content = cast(dict[str, Any], part.content) + status = str(content.pop("status")) + return ToolResultSummary( + id=part.tool_call_id, + status=status, + content=json.dumps(content) if content else "", + type="web_search_call", + round=tool_round, + ) + + +def summarize_mcp_list_tools_result( + part: NativeToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary from a native MCP list-tools return. + + Args: + part: Native MCP list-tools return part from the model stream. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + content = cast(dict[str, Any], part.content) + call_id = part.tool_call_id + label = part.tool_name.removeprefix(f"{MCPServerTool.kind}:") + + if error := content.get("error"): + return ToolResultSummary( + id=call_id, + status="failure", + content=str(error), + type="mcp_list_tools", + round=tool_round, + ) + + list_summary = MCPListToolsSummary( + server_label=label, + tools=[ToolInfoSummary.model_validate(tool) for tool in content["tools"]], + ) + return ToolResultSummary( + id=call_id, + status="success", + content=json.dumps(list_summary.model_dump()), + type="mcp_list_tools", + round=tool_round, + ) + + +def summarize_mcp_call_result( + part: NativeToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary from a native MCP tool call return. + + Args: + part: Native MCP call return part from the model stream. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + content = cast(dict[str, Any], part.content) + call_id = part.tool_call_id + + if error := content.get("error"): + return ToolResultSummary( + id=call_id, + status="failure", + content=str(error), + type="mcp_call", + round=tool_round, + ) + + output = content.get("output", "") + return ToolResultSummary( + id=call_id, + status="success", + content=str(output), + type="mcp_call", + round=tool_round, + ) + + +def summarize_mcp_tool_result( + part: NativeToolReturnPart, + tool_round: int, +) -> ToolResultSummary: + """Build a tool-result summary from a native MCP server tool return. + + Dispatches to list-tools or call processors based on return shape. + + Args: + part: Native MCP tool return part from the model stream. + tool_round: Tool execution round number for this result. + + Returns: + Tool result summary in LCS turn-summary format. + """ + content = cast(dict[str, Any], part.content) + if "tools" in content or "error" in content: + return summarize_mcp_list_tools_result(part, tool_round) + return summarize_mcp_call_result(part, tool_round) + + +def summarize_file_search_result( + part: NativeToolReturnPart, + tool_round: int, + seen_docs: set[tuple[str, str]], + vector_store_ids: list[str], + rag_id_mapping: dict[str, str], +) -> tuple[ToolResultSummary, list[RAGChunk], list[ReferencedDocument]]: + """Build tool result, RAG chunks, and referenced docs from a file-search return. + + Args: + part: Native file-search tool return part from the model stream. + tool_round: Tool execution round number for this result. + seen_docs: Dedupe keys for referenced documents; updated in place. + vector_store_ids: Vector store IDs used for source mapping. + rag_id_mapping: Mapping from vector store IDs to user-facing source labels. + + Returns: + Tool result summary, RAG chunks, and referenced documents for this return. + """ + content = cast(dict[str, Any], part.content) + tool_result = ToolResultSummary( + id=part.tool_call_id, + status=str(content.pop("status")), + content=json.dumps(content), + type="file_search_call", + round=tool_round, + ) + results = [ + OpenAIFileSearchResult.model_validate(result) + for result in content.get("results", []) + ] + rag_chunks = rag_chunks_from_file_search_results( + results, + vector_store_ids=vector_store_ids, + rag_id_mapping=rag_id_mapping, + ) + referenced_documents = referenced_documents_from_file_search_results( + results, + seen_docs, + vector_store_ids=vector_store_ids, + rag_id_mapping=rag_id_mapping, + ) + return tool_result, rag_chunks, referenced_documents diff --git a/tests/unit/utils/agents/test_tool_processor.py b/tests/unit/utils/agents/test_tool_processor.py new file mode 100644 index 000000000..58980e2b1 --- /dev/null +++ b/tests/unit/utils/agents/test_tool_processor.py @@ -0,0 +1,597 @@ +"""Unit tests for utils.agents.tool_processor module.""" + +import json + +import pytest +from openai.types.responses.response_file_search_tool_call import ( + Result as OpenAIFileSearchResult, +) +from pydantic import AnyUrl +from pydantic_ai.messages import ( + NativeToolCallPart, + NativeToolReturnPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.native_tools import FileSearchTool, MCPServerTool, WebSearchTool +from pytest_mock import MockerFixture + +from constants import DEFAULT_RAG_TOOL +from models.common.agents import AgentTurnAccumulator +from models.common.turn_summary import TurnSummary +from utils.agents.tool_processor import ( + build_referenced_document, + process_function_tool_call, + process_function_tool_result, + process_native_tool_call, + process_native_tool_result, + rag_chunks_from_file_search_results, + referenced_documents_from_file_search_results, + summarize_file_search_result, + summarize_function_tool_call, + summarize_function_tool_result, + summarize_mcp_call_result, + summarize_mcp_list_tools_result, + summarize_mcp_tool_result, + summarize_native_tool_call, + summarize_web_search_result, +) + + +@pytest.fixture(name="turn_state") +def turn_state_fixture() -> AgentTurnAccumulator: + """Create a fresh agent turn accumulator for dispatch tests.""" + return AgentTurnAccumulator( + vector_store_ids=["vs-001"], + rag_id_mapping={"vs-001": "ocp-docs"}, + turn_summary=TurnSummary(), + ) + + +def _file_search_result(**kwargs: object) -> OpenAIFileSearchResult: + """Build a validated OpenAI file-search result row.""" + return OpenAIFileSearchResult.model_validate(kwargs) + + +class TestSummarizeFunctionToolCall: + """Tests for summarize_function_tool_call.""" + + def test_builds_function_call_summary(self) -> None: + """Test function tool call is mapped to ToolCallSummary.""" + part = ToolCallPart( + tool_name="my_fn", + args={"key": "value"}, + tool_call_id="call-fn-1", + ) + + summary = summarize_function_tool_call(part) + + assert summary.id == "call-fn-1" + assert summary.name == "my_fn" + assert summary.args == {"key": "value"} + assert summary.type == "function_call" + + +class TestSummarizeNativeToolCall: + """Tests for summarize_native_tool_call.""" + + def test_web_search_call(self) -> None: + """Test web search native tool call summary.""" + part = NativeToolCallPart( + tool_name=WebSearchTool.kind, + args={"query": "OpenShift"}, + tool_call_id="ws-1", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.type == "web_search_call" + assert summary.name == WebSearchTool.kind + + def test_file_search_call(self) -> None: + """Test file search native tool call uses DEFAULT_RAG_TOOL name.""" + part = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": ["docs"]}, + tool_call_id="fs-1", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.name == DEFAULT_RAG_TOOL + assert summary.type == "file_search_call" + + def test_mcp_list_tools_call(self) -> None: + """Test MCP list-tools action summary.""" + part = NativeToolCallPart( + tool_name=MCPServerTool.kind, + args={"action": "list_tools"}, + tool_call_id="mcp-list-1", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.name == "mcp_list_tools" + assert summary.args == {"server_label": MCPServerTool.kind} + assert summary.type == "mcp_list_tools" + + def test_mcp_call(self) -> None: + """Test MCP tool call summary.""" + part = NativeToolCallPart( + tool_name=MCPServerTool.kind, + args={ + "action": "call", + "tool_name": "remote_tool", + "tool_args": {"arg": 1}, + }, + tool_call_id="mcp-call-1", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.name == "remote_tool" + assert summary.args == {"arg": 1} + assert summary.type == "mcp_call" + + def test_unknown_tool_returns_none(self, mocker: MockerFixture) -> None: + """Test unknown native tool logs warning and returns None.""" + mock_warning = mocker.patch("utils.agents.tool_processor.logger.warning") + part = NativeToolCallPart( + tool_name="unknown_tool", + args={}, + tool_call_id="unk-1", + ) + + assert summarize_native_tool_call(part) is None + mock_warning.assert_called_once() + + +class TestProcessFunctionToolCall: + """Tests for process_function_tool_call.""" + + def test_records_tool_call_on_state(self, turn_state: AgentTurnAccumulator) -> None: + """Test first function tool call is recorded on turn state.""" + part = ToolCallPart( + tool_name="fn", + args={"x": 1}, + tool_call_id="call-1", + ) + + summary = process_function_tool_call(turn_state, part) + + assert summary is not None + assert turn_state.turn_summary.tool_calls == [summary] + assert "call-1" in turn_state.emitted_tool_call_ids + + def test_skips_duplicate_tool_call(self, turn_state: AgentTurnAccumulator) -> None: + """Test duplicate function tool call id is not recorded twice.""" + part = ToolCallPart(tool_name="fn", args={}, tool_call_id="call-dup") + process_function_tool_call(turn_state, part) + + assert process_function_tool_call(turn_state, part) is None + assert len(turn_state.turn_summary.tool_calls) == 1 + + def test_increments_round_when_pending( + self, turn_state: AgentTurnAccumulator + ) -> None: + """Test pending round increment runs before recording tool call.""" + turn_state.round_increment_pending = True + turn_state.tool_round = 2 + part = ToolCallPart(tool_name="fn", args={}, tool_call_id="call-round") + + process_function_tool_call(turn_state, part) + + assert turn_state.tool_round == 3 + assert not turn_state.round_increment_pending + + +class TestProcessNativeToolCall: + """Tests for process_native_tool_call.""" + + def test_records_native_tool_call(self, turn_state: AgentTurnAccumulator) -> None: + """Test native tool call is recorded on turn state.""" + part = NativeToolCallPart( + tool_name=WebSearchTool.kind, + args={"query": "q"}, + tool_call_id="ws-record", + ) + + summary = process_native_tool_call(turn_state, part) + + assert summary is not None + assert turn_state.turn_summary.tool_calls == [summary] + + def test_skips_duplicate_and_unknown( + self, turn_state: AgentTurnAccumulator, mocker: MockerFixture + ) -> None: + """Test duplicate ids and unknown tools are not recorded.""" + mocker.patch("utils.agents.tool_processor.logger.warning") + part = NativeToolCallPart( + tool_name="unknown", + args={}, + tool_call_id="unk-record", + ) + + assert process_native_tool_call(turn_state, part) is None + assert not turn_state.turn_summary.tool_calls + + known = NativeToolCallPart( + tool_name=WebSearchTool.kind, + args={}, + tool_call_id="ws-dup", + ) + process_native_tool_call(turn_state, known) + assert process_native_tool_call(turn_state, known) is None + + +class TestSummarizeFunctionToolResult: + """Tests for summarize_function_tool_result.""" + + def test_builds_function_tool_result(self) -> None: + """Test function tool return maps to ToolResultSummary.""" + part = ToolReturnPart( + tool_name="fn", + content={"answer": 42}, + tool_call_id="result-1", + ) + + result = summarize_function_tool_result(part, tool_round=3) + + assert result.id == "result-1" + assert result.status == "success" + assert result.type == "function_call_output" + assert result.round == 3 + assert json.loads(result.content) == {"answer": 42} + + +class TestProcessFunctionToolResult: + """Tests for process_function_tool_result.""" + + def test_records_function_tool_result( + self, turn_state: AgentTurnAccumulator + ) -> None: + """Test function tool result is recorded and marks round pending.""" + part = ToolReturnPart( + tool_name="fn", + content="ok", + tool_call_id="result-record", + ) + + result = process_function_tool_result(turn_state, part) + + assert result is not None + assert turn_state.turn_summary.tool_results == [result] + assert turn_state.round_increment_pending + assert "result-record" in turn_state.emitted_tool_result_ids + + def test_skips_duplicate_result(self, turn_state: AgentTurnAccumulator) -> None: + """Test duplicate function tool result id is ignored.""" + part = ToolReturnPart(tool_name="fn", content="ok", tool_call_id="result-dup") + process_function_tool_result(turn_state, part) + + assert process_function_tool_result(turn_state, part) is None + assert len(turn_state.turn_summary.tool_results) == 1 + + +class TestBuildReferencedDocument: + """Tests for build_referenced_document.""" + + def test_returns_none_without_title_or_url(self) -> None: + """Test result without title or URL metadata is skipped.""" + result = _file_search_result(attributes={"document_id": "only-id"}) + + assert build_referenced_document(result, ["vs-001"], {}) is None + + def test_builds_from_url_and_title_with_source_mapping(self) -> None: + """Test referenced document resolves source from vector store mapping.""" + result = _file_search_result( + attributes={ + "link": "https://example.com/doc", + "title": "Example Doc", + "document_id": "doc-1", + } + ) + + doc = build_referenced_document(result, ["vs-001"], {"vs-001": "mapped-source"}) + + assert doc is not None + assert doc.doc_url == AnyUrl("https://example.com/doc") + assert doc.doc_title == "Example Doc" + assert doc.document_id == "doc-1" + assert doc.source == "mapped-source" + + def test_supports_alternate_url_and_id_keys(self) -> None: + """Test doc_url and doc_id attribute key fallbacks.""" + result = _file_search_result( + attributes={ + "docs_url": "https://example.com/alt", + "title": "Alt Doc", + "doc_id": "alt-id", + } + ) + + doc = build_referenced_document(result, [], {}) + + assert doc is not None + assert doc.doc_url == AnyUrl("https://example.com/alt") + assert doc.document_id == "alt-id" + + def test_title_only_document(self) -> None: + """Test referenced document can be built with title only.""" + result = _file_search_result(attributes={"title": "Title Only"}) + + doc = build_referenced_document(result, [], {}) + + assert doc is not None + assert doc.doc_url is None + assert doc.doc_title == "Title Only" + + +class TestReferencedDocumentsFromFileSearchResults: + """Tests for referenced_documents_from_file_search_results.""" + + def test_deduplicates_documents(self) -> None: + """Test seen_docs prevents duplicate referenced documents.""" + results = [ + _file_search_result(attributes={"url": "https://dup.com", "title": "Same"}), + _file_search_result( + attributes={"link": "https://dup.com", "title": "Same"} + ), + _file_search_result( + attributes={"url": "https://other.com", "title": "Other"} + ), + _file_search_result(attributes={"document_id": "no-metadata"}), + ] + seen_docs: set[tuple[str, str]] = set() + + documents = referenced_documents_from_file_search_results( + results, seen_docs, ["vs-001"], {"vs-001": "source"} + ) + + assert len(documents) == 2 + assert len(seen_docs) == 2 + + +class TestRagChunksFromFileSearchResults: + """Tests for rag_chunks_from_file_search_results.""" + + def test_skips_empty_text_and_maps_source(self) -> None: + """Test chunks without text are skipped and source is resolved.""" + results = [ + _file_search_result(text="chunk one", score=0.8, attributes={}), + _file_search_result(text="", score=0.5, attributes={}), + ] + + chunks = rag_chunks_from_file_search_results( + results, ["vs-001"], {"vs-001": "mapped"} + ) + + assert len(chunks) == 1 + assert chunks[0].content == "chunk one" + assert chunks[0].source == "mapped" + assert chunks[0].score == 0.8 + + +class TestSummarizeWebSearchResult: + """Tests for summarize_web_search_result.""" + + def test_serializes_remaining_content(self) -> None: + """Test web search result keeps non-status fields as JSON content.""" + part = NativeToolReturnPart( + tool_name=WebSearchTool.kind, + tool_call_id="ws-result", + content={"status": "success", "results": [{"title": "hit"}]}, + ) + + result = summarize_web_search_result(part, tool_round=1) + + assert result.status == "success" + assert result.type == "web_search_call" + assert json.loads(result.content) == {"results": [{"title": "hit"}]} + + def test_empty_content_when_only_status(self) -> None: + """Test web search result content is empty when only status remains.""" + part = NativeToolReturnPart( + tool_name=WebSearchTool.kind, + tool_call_id="ws-empty", + content={"status": "success"}, + ) + + result = summarize_web_search_result(part, tool_round=2) + + assert not result.content + + +class TestSummarizeMcpResults: + """Tests for MCP tool result summarizers.""" + + def test_list_tools_success(self) -> None: + """Test MCP list-tools success payload is serialized.""" + part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-list", + content={ + "tools": [ + {"name": "tool_a", "description": "does things"}, + ] + }, + ) + + result = summarize_mcp_list_tools_result(part, tool_round=1) + + assert result.status == "success" + assert result.type == "mcp_list_tools" + payload = json.loads(result.content) + assert payload["server_label"] == "srv" + assert payload["tools"][0]["name"] == "tool_a" + + def test_list_tools_error(self) -> None: + """Test MCP list-tools error returns failure summary.""" + part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-list-err", + content={"error": "unavailable"}, + ) + + result = summarize_mcp_list_tools_result(part, tool_round=1) + + assert result.status == "failure" + assert result.content == "unavailable" + + def test_mcp_call_success_and_error(self) -> None: + """Test MCP call success and error summaries.""" + success_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-call-ok", + content={"output": "done"}, + ) + error_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-call-err", + content={"error": "failed"}, + ) + + success = summarize_mcp_call_result(success_part, tool_round=2) + error = summarize_mcp_call_result(error_part, tool_round=2) + + assert success.status == "success" + assert success.content == "done" + assert error.status == "failure" + assert error.content == "failed" + + def test_mcp_tool_result_dispatches_by_shape(self) -> None: + """Test summarize_mcp_tool_result routes list-tools vs call payloads.""" + list_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="dispatch-list", + content={"tools": []}, + ) + call_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="dispatch-call", + content={"output": "ok"}, + ) + + list_result = summarize_mcp_tool_result(list_part, tool_round=1) + call_result = summarize_mcp_tool_result(call_part, tool_round=1) + + assert list_result.type == "mcp_list_tools" + assert call_result.type == "mcp_call" + + +class TestSummarizeFileSearchResult: + """Tests for summarize_file_search_result.""" + + def test_builds_tool_result_rag_chunks_and_referenced_docs(self) -> None: + """Test file-search return produces result, chunks, and referenced docs.""" + part = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="fs-result", + content={ + "status": "success", + "results": [ + { + "text": "chunk text", + "score": 0.95, + "attributes": { + "title": "Doc", + "url": "https://example.com", + }, + }, + {"text": "", "attributes": {}}, + ], + }, + ) + seen_docs: set[tuple[str, str]] = set() + + tool_result, rag_chunks, referenced_docs = summarize_file_search_result( + part, + tool_round=4, + seen_docs=seen_docs, + vector_store_ids=["vs-001"], + rag_id_mapping={"vs-001": "mapped"}, + ) + + assert tool_result.status == "success" + assert tool_result.type == "file_search_call" + assert tool_result.round == 4 + assert len(rag_chunks) == 1 + assert rag_chunks[0].content == "chunk text" + assert len(referenced_docs) == 1 + assert referenced_docs[0].doc_title == "Doc" + assert len(seen_docs) == 1 + + +class TestProcessNativeToolResult: + """Tests for process_native_tool_result.""" + + def test_records_file_search_result(self, turn_state: AgentTurnAccumulator) -> None: + """Test file-search result updates tool results, RAG chunks, and docs.""" + part = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="fs-process", + content={ + "status": "success", + "results": [ + { + "text": "rag", + "attributes": {"title": "RAG Doc", "url": "https://rag"}, + } + ], + }, + ) + + result = process_native_tool_result(turn_state, part) + + assert result is not None + assert turn_state.turn_summary.tool_results == [result] + assert len(turn_state.turn_summary.rag_chunks) == 1 + assert len(turn_state.turn_summary.referenced_documents) == 1 + assert turn_state.round_increment_pending + + def test_records_web_search_and_mcp_results( + self, turn_state: AgentTurnAccumulator + ) -> None: + """Test web search and MCP results are recorded on turn state.""" + web_part = NativeToolReturnPart( + tool_name=WebSearchTool.kind, + tool_call_id="ws-process", + content={"status": "success"}, + ) + mcp_part = NativeToolReturnPart( + tool_name=MCPServerTool.kind, + tool_call_id="mcp-process", + content={"output": "mcp-output"}, + ) + + web_result = process_native_tool_result(turn_state, web_part) + mcp_result = process_native_tool_result(turn_state, mcp_part) + + assert web_result is not None + assert mcp_result is not None + assert len(turn_state.turn_summary.tool_results) == 2 + + def test_skips_duplicate_and_unknown( + self, turn_state: AgentTurnAccumulator, mocker: MockerFixture + ) -> None: + """Test duplicate ids and unknown tool returns are ignored.""" + mocker.patch("utils.agents.tool_processor.logger.warning") + part = NativeToolReturnPart( + tool_name="unknown", + tool_call_id="unk-result", + content={"status": "success"}, + ) + + assert process_native_tool_result(turn_state, part) is None + + known = NativeToolReturnPart( + tool_name=WebSearchTool.kind, + tool_call_id="ws-dup-result", + content={"status": "success"}, + ) + process_native_tool_result(turn_state, known) + assert process_native_tool_result(turn_state, known) is None From be576760cab3170bbb857d3db3f7e15c0552f5a8 Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Fri, 5 Jun 2026 11:37:19 +0200 Subject: [PATCH 2/2] Fixed dynamic MCP naming by pydantic --- src/utils/agents/tool_processor.py | 7 ++-- .../unit/utils/agents/test_tool_processor.py | 34 ++++++++++++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/utils/agents/tool_processor.py b/src/utils/agents/tool_processor.py index 59377b763..6b0b910e2 100644 --- a/src/utils/agents/tool_processor.py +++ b/src/utils/agents/tool_processor.py @@ -33,6 +33,7 @@ logger = get_logger(__name__) _FILE_SEARCH_URL_KEYS = ("doc_url", "docs_url", "url", "link", "reference_url") +_MCP_SERVER_TOOL_PREFIX = f"{MCPServerTool.kind}:" def summarize_function_tool_call(part: ToolCallPart) -> ToolCallSummary: @@ -80,8 +81,8 @@ def summarize_native_tool_call( args=args, type="file_search_call", ) - case MCPServerTool.kind: - label = part.tool_name.removeprefix(f"{MCPServerTool.kind}:") + case tool_name if tool_name.startswith(_MCP_SERVER_TOOL_PREFIX): + label = tool_name.removeprefix(_MCP_SERVER_TOOL_PREFIX) action = args.get("action") # MCP list tools if action == "list_tools": @@ -180,7 +181,7 @@ def process_native_tool_result( state.turn_summary.referenced_documents.extend(referenced_documents) case WebSearchTool.kind: tool_result = summarize_web_search_result(part, state.tool_round) - case MCPServerTool.kind: + case tool_name if tool_name.startswith(_MCP_SERVER_TOOL_PREFIX): tool_result = summarize_mcp_tool_result(part, state.tool_round) case _: logger.warning(f"Unknown tool name: {part.tool_name}") diff --git a/tests/unit/utils/agents/test_tool_processor.py b/tests/unit/utils/agents/test_tool_processor.py index 58980e2b1..4bc33484a 100644 --- a/tests/unit/utils/agents/test_tool_processor.py +++ b/tests/unit/utils/agents/test_tool_processor.py @@ -106,7 +106,7 @@ def test_file_search_call(self) -> None: def test_mcp_list_tools_call(self) -> None: """Test MCP list-tools action summary.""" part = NativeToolCallPart( - tool_name=MCPServerTool.kind, + tool_name=f"{MCPServerTool.kind}:srv", args={"action": "list_tools"}, tool_call_id="mcp-list-1", ) @@ -115,13 +115,26 @@ def test_mcp_list_tools_call(self) -> None: assert summary is not None assert summary.name == "mcp_list_tools" - assert summary.args == {"server_label": MCPServerTool.kind} + assert summary.args == {"server_label": "srv"} assert summary.type == "mcp_list_tools" + def test_mcp_list_tools_call_with_label(self) -> None: + """Test labeled MCP list-tools action uses the server label suffix.""" + part = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:myserver", + args={"action": "list_tools"}, + tool_call_id="mcp-list-labeled", + ) + + summary = summarize_native_tool_call(part) + + assert summary is not None + assert summary.args == {"server_label": "myserver"} + def test_mcp_call(self) -> None: """Test MCP tool call summary.""" part = NativeToolCallPart( - tool_name=MCPServerTool.kind, + tool_name=f"{MCPServerTool.kind}:srv", args={ "action": "call", "tool_name": "remote_tool", @@ -553,6 +566,19 @@ def test_records_file_search_result(self, turn_state: AgentTurnAccumulator) -> N assert len(turn_state.turn_summary.referenced_documents) == 1 assert turn_state.round_increment_pending + def test_records_labeled_mcp_result(self, turn_state: AgentTurnAccumulator) -> None: + """Test labeled MCP tool return is processed like unlabeled MCP returns.""" + part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:srv", + tool_call_id="mcp-labeled", + content={"output": "labeled-output"}, + ) + + result = process_native_tool_result(turn_state, part) + + assert result is not None + assert result.content == "labeled-output" + def test_records_web_search_and_mcp_results( self, turn_state: AgentTurnAccumulator ) -> None: @@ -563,7 +589,7 @@ def test_records_web_search_and_mcp_results( content={"status": "success"}, ) mcp_part = NativeToolReturnPart( - tool_name=MCPServerTool.kind, + tool_name=f"{MCPServerTool.kind}:srv", tool_call_id="mcp-process", content={"output": "mcp-output"}, )