diff --git a/graphcore/graph.py b/graphcore/graph.py index 3b47afb..0b31d45 100644 --- a/graphcore/graph.py +++ b/graphcore/graph.py @@ -30,7 +30,7 @@ from langgraph.prebuilt.tool_node import ToolInvocationError from langchain_anthropic import ChatAnthropic from pydantic import BaseModel, ValidationError -from .utils import current_prompt_tokens, default_max_prompt_tokens, get_token_usage +from .utils import ainvoke, invoke, current_prompt_tokens, default_max_prompt_tokens, get_token_usage from .summary import SummaryConfig logger = logging.getLogger(__name__) @@ -180,7 +180,7 @@ def _async_llm( async def impl( s: list[AnyMessage] ) -> BaseMessage: - res = await llm.ainvoke(s) + res = await ainvoke(llm, s) _log_usage(res) return res return impl @@ -189,7 +189,7 @@ def _sync_llm( llm: LLM ) -> SyncLLM: def impl(m: list[AnyMessage]) -> BaseMessage: - res = llm.invoke(m) + res = invoke(llm, m) _log_usage(res) return res return impl diff --git a/graphcore/tools/memory.py b/graphcore/tools/memory.py index 51b1dc2..c467353 100644 --- a/graphcore/tools/memory.py +++ b/graphcore/tools/memory.py @@ -13,8 +13,92 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +# ============================================================================= +# Memory tool architecture +# ============================================================================= +# +# Three layers, each one a translation of the one above: +# +# PURE LOGIC ──> BACKEND ──> TOOL +# (effects as (drives the (langchain BaseTool, +# generators) generators provider-specific +# against I/O) schema) +# +# ── Pure layer (`PureMemoryBackend[P, Update, Row, RList]`) ──────────────── +# The semantic operations (view / create / str_replace / insert / rename / +# delete, plus their _pure helpers) are written as generators that *yield* +# I/O requests of type `P` and *receive* I/O results (`Row` for one-row +# reads, `RList` for multi-row reads, `Update` for write row-counts). +# These generators do no I/O themselves — they describe the sequence of +# primitive operations they need, and each primitive op is itself a +# generator method (`read_file_pure`, `write_file_pure`, `stat_pure`, +# `list_dir_pure`, `rm_pure`, `do_rename_pure`) that subclasses define in +# terms of `P`. +# +# Why this shape? **It solves the function-coloring problem.** Python's +# `async` is infectious — if the core "view a file with a line range" +# logic awaited an I/O call, every caller up the stack would have to be +# async too, and you couldn't share that body between a sync driver +# (e.g. a synchronous Postgres connection) and an async driver (e.g. an +# async pool). By representing I/O as yielded requests instead of awaited +# calls, the core generator is *colorless*: a sync driver advances it +# with `next()` / `.send(result)`, an async driver does the same with +# `await some_io(...)` between the steps. Same logic, two drivers. +# +# That's also how `SQLBackendPure` (`P = (sql, params)`, yields SQL, +# consumes rows) ends up shared by both `SyncSqlBackend` and +# `AsyncSQLBackend`. `PureFilesystemLogic` uses `P = Never` via +# `to_generator` since its primitives don't need a wedge — but the same +# `_view_file_pure` / `create_pure` / etc. logic is reused by both +# `FileSystemMemoryBackend` and `AsyncFileMemoryBackend`. +# +# ── Backend layer (`MemoryBackend` / `AsyncMemoryBackend`) ───────────────── +# These are the drivers. They hold a `PureMemoryBackend` (`self.logic`) +# and implement `_run_row` / `_run_multi` / `_run_update`: feed the +# generator, get a request `P`, perform the I/O, send the result back into +# the generator, repeat until `StopIteration`. +# +# Concrete impls: +# • `PostgresMemoryBackend`, `SqliteMemoryBackend` — sync SQL drivers +# holding a real connection/pool and a `SQLBackendPure` logic. +# • `AsyncPostgresBackend` — async variant of the same. +# • `FileSystemMemoryBackend` / `AsyncFileMemoryBackend` — drivers over +# `PureFilesystemLogic` whose `_run_*` methods just exhaust the +# generator (no I/O wedge needed; primitives already did the work). +# +# The public surface of a Backend is the six `MemoryToolImpl` methods +# (view, create, delete, rename, insert, str_replace). Sync impls return +# `str`; async impls return `Awaitable[str]`. This is what the tool layer +# consumes. +# +# ── Tool layer (`*_memory_tool` factories) ──────────────────────────────── +# `MemoryToolImpl[R]` is a structural Protocol — both `MemoryBackend` and +# `AsyncMemoryBackend` satisfy it, parameterized on `R = str` vs +# `R = Awaitable[str]`. The tool factories close over a `MemoryToolImpl` +# and wrap it in a `BaseTool` with a provider-specific args schema: +# +# • `memory_tool` — sync, Anthropic schema (UnifiedMemorySchema) +# • `async_memory_tool` — async, Anthropic schema +# • `openai_memory_tool` — sync, OpenAI schema (_OpenAIMemorySchema) +# • `openai_async_memory_tool` — async, OpenAI schema +# +# Anthropic's `UnifiedMemorySchema` is a flat bag of nullable fields, +# shape-matched to what Anthropic's trained-on memory tool emits. The +# Anthropic Files-API beta (`memory_20250818`) replaces this schema +# server-side, so the "sparse" tool and field documentations don't matter. +# +# OpenAI's `_OpenAIMemorySchema` wraps a single `memory_op` field whose +# type is a Pydantic discriminated union over six variant BaseModels +# (one per command). Top level is `type: "object"` so strict-mode JSON +# schema validation works; the `anyOf` lives inside `memory_op`. +# +# Dispatch: `_memory_tool_impl` handles the flat-bag → backend-method +# path; `_dispatch_openai_op` handles the sum-type → backend-method path. +# Both end up calling the same six MemoryToolImpl methods. +# ============================================================================= + from typing import ( - Literal, Optional, override, Protocol, Any, TypeVar, Iterator, + Annotated, Literal, Optional, override, Protocol, Any, TypeVar, Iterator, ContextManager, LiteralString, cast, Generator, AsyncContextManager, AsyncIterator, Callable, Awaitable, Sequence, Never, ParamSpec @@ -1193,7 +1277,7 @@ def memory_tool(backend: MemoryToolImpl[str]) -> BaseTool: """ def missing_required(s: str): return f"Error: missing required {s} argument" - + class MemoryTool(WithImplementation[str], UnifiedMemorySchema): """ Here to make the tool annotation happy @@ -1204,3 +1288,158 @@ def run(self) -> str: backend, self, missing_required ) return MemoryTool.as_tool("memory") + + +# --------------------------------------------------------------------------- +# OpenAI-targeted memory tool: discriminated-union schema. +# +# OpenAI (and OpenAI-compat backends) doesn't have a trained-on memory tool +# schema the way Anthropic does, so we can pick the shape. A sum type with a +# single ``memory_op`` field is much friendlier to the model — each variant +# only carries the fields its command actually needs — and stays +# strict-mode-compatible because the top level is still ``type: "object"``. +# --------------------------------------------------------------------------- + + +class _CreateOp(BaseModel): + """Create a file at ``path`` with the given contents.""" + op: Literal["create"] + path: str = Field(description="Absolute memory path (must start with /memories).") + file_text: str = Field(description="Full contents of the new file.") + + +class _ViewOp(BaseModel): + """Read a file or list a directory under ``path``.""" + op: Literal["view"] + path: str = Field(description="Absolute memory path (must start with /memories).") + view_range: list[int] | None = Field( + default=None, + description=( + "Optional [start, end] line range (1-indexed, inclusive). " + "Use -1 for end to read to EOF. Ignored for directories." + ), + ) + + +class _StrReplaceOp(BaseModel): + """Replace a unique substring in the file at ``path``.""" + op: Literal["str_replace"] + path: str = Field(description="Absolute memory path (must start with /memories).") + old_str: str = Field(description="Exact substring to replace. Must occur exactly once.") + new_str: str = Field(description="Replacement text.") + + +class _InsertOp(BaseModel): + """Insert a new line at ``insert_line`` in the file at ``path``.""" + op: Literal["insert"] + path: str = Field(description="Absolute memory path (must start with /memories).") + insert_line: int = Field(description="0-based line index at which to insert.") + insert_text: str = Field(description="Text to insert (a trailing newline is added).") + + +class _DeleteOp(BaseModel): + """Delete the file or directory at ``path``.""" + op: Literal["delete"] + path: str = Field(description="Absolute memory path (must start with /memories).") + + +class _RenameOp(BaseModel): + """Rename or move a file or directory.""" + op: Literal["rename"] + old_path: str = Field(description="Existing absolute memory path.") + new_path: str = Field(description="Destination absolute memory path.") + + +type _MemoryOpUnion = Annotated[ + _CreateOp | _ViewOp | _StrReplaceOp | _InsertOp | _DeleteOp | _RenameOp, + Field(discriminator="op"), +] + + +class _OpenAIMemorySchema(BaseModel): + """OpenAI-targeted memory tool schema. + + A single ``memory_op`` field carrying a discriminated union over + the six memory commands. Top level is still ``type: "object"`` so + OpenAI strict-mode validation works.""" + memory_op: _MemoryOpUnion = Field( + description=( + "The memory operation to perform. The 'op' tag selects the variant; " + "each variant carries only the fields relevant to that operation." + ), + ) + + +def _dispatch_openai_op[R]( + backend: MemoryToolImpl[R], + op: _CreateOp | _ViewOp | _StrReplaceOp | _InsertOp | _DeleteOp | _RenameOp, +) -> R: + """Route a parsed sum-type variant into the backend's typed methods.""" + match op: + case _CreateOp(path=p, file_text=ft): + return backend.create(p, ft) + case _ViewOp(path=p, view_range=vr): + rng: tuple[int, int] | None = None + if vr is not None and len(vr) >= 2: + rng = (vr[0], vr[1]) + return backend.view(p, rng) + case _StrReplaceOp(path=p, old_str=os_, new_str=ns): + return backend.str_replace(p, os_, ns) + case _InsertOp(path=p, insert_line=il, insert_text=it): + return backend.insert(p, il, it) + case _DeleteOp(path=p): + return backend.delete(p) + case _RenameOp(old_path=op_, new_path=np): + return backend.rename(op_, np) + + +_OPENAI_MEMORY_TOOL_DESCRIPTION = """\ +Persistent filesystem-style memory that survives across turns and +across conversations within this workflow. All paths live under +``/memories`` and are sandboxed there. Use this tool to record +intermediate observations, decisions, partial results, and any +context you want to recall later — anything not written here is +forgotten when the conversation ends. + +The ``memory_op`` field selects the operation. Each variant carries +only the fields that operation needs: + +- ``view``: read a file (optionally a line range) or list a directory. +- ``create``: write a brand-new file. +- ``str_replace``: replace a unique substring in an existing file. +- ``insert``: insert a line at a specific index in an existing file. +- ``delete``: remove a file or directory. +- ``rename``: move/rename a file or directory. + +Prefer ``view`` before ``create`` to avoid clobbering, and prefer +``str_replace`` over rewriting a whole file when only part changes. +""" + + +def openai_async_memory_tool( + backend: MemoryToolImpl[Awaitable[str]], +) -> BaseTool: + """Async OpenAI-flavored memory tool. Same backend contract as + :func:`async_memory_tool`; differs only in the tool-args schema + seen by the model.""" + + class OpenAIMemoryTool(WithAsyncImplementation[str], _OpenAIMemorySchema): + __doc__ = _OPENAI_MEMORY_TOOL_DESCRIPTION + + @override + async def run(self) -> str: + return await _dispatch_openai_op(backend, self.memory_op) + return OpenAIMemoryTool.as_tool("memory") + + +def openai_memory_tool(backend: MemoryToolImpl[str]) -> BaseTool: + """Sync OpenAI-flavored memory tool. Companion to + :func:`memory_tool`.""" + + class OpenAIMemoryTool(WithImplementation[str], _OpenAIMemorySchema): + __doc__ = _OPENAI_MEMORY_TOOL_DESCRIPTION + + @override + def run(self) -> str: + return _dispatch_openai_op(backend, self.memory_op) + return OpenAIMemoryTool.as_tool("memory") diff --git a/graphcore/utils.py b/graphcore/utils.py index ef1c93f..b9fa35e 100644 --- a/graphcore/utils.py +++ b/graphcore/utils.py @@ -13,8 +13,11 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from typing import TypedDict, Literal, List -from langchain_core.messages import AIMessage, AnyMessage +from typing import TypedDict, Literal, List, Sequence + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage, AnyMessage, BaseMessage +from langchain_core.runnables import Runnable type TokenUsageKeysT = Literal[ @@ -54,6 +57,47 @@ def get_token_usage(m: AIMessage) -> TokenUsageDict: to_ret[k] = to_ret[k] + tok return to_ret +class NormalizedTokenUsage(TypedDict): + total_input_tokens: int + total_output_tokens: int + + cache_read_tokens: int + cache_write_tokens: int + thinking_tokens: int + + model_name: str | None + +def get_normalized_token_usage(m: AIMessage) -> NormalizedTokenUsage: + to_ret : NormalizedTokenUsage = { + "total_input_tokens": 0, + "model_name": m.response_metadata.get("model_name"), + "cache_read_tokens": 0, + "cache_write_tokens": 0, + "thinking_tokens": 0, + "total_output_tokens": 0 + } + + if not (usage := m.usage_metadata): + return to_ret + + to_ret["total_input_tokens"] = usage["input_tokens"] + to_ret["total_output_tokens"] = usage["output_tokens"] + + if "output_token_details" in usage: + out_details = usage["output_token_details"] + to_ret["thinking_tokens"] = out_details.get("reasoning", 0) + if "input_token_details" in usage: + in_details = usage["input_token_details"] + to_ret["cache_read_tokens"] = in_details.get("cache_read", 0) + + cache_write = in_details.get("cache_creation", 0) + if not cache_write: + # thanks langchain + for t in ("ephemeral_5m_input_tokens", "ephemeral_1h_input_tokens"): + cache_write += in_details.get(t, 0) + to_ret["cache_write_tokens"] = cache_write + + return to_ret def current_prompt_tokens(messages: List[AnyMessage]) -> int: """ @@ -66,12 +110,8 @@ def current_prompt_tokens(messages: List[AnyMessage]) -> int: """ for m in reversed(messages): if isinstance(m, AIMessage): - usage = get_token_usage(m) - return ( - usage["input_tokens"] - + usage["cache_read_input_tokens"] - + usage["cache_creation_input_tokens"] - ) + usage = get_normalized_token_usage(m) + return usage["total_input_tokens"] return 0 @@ -90,3 +130,70 @@ def default_max_prompt_tokens(model_name: str | None) -> int: return 500_000 # 1M context window case _: return 100_000 # fallback for unknown models + + +# --------------------------------------------------------------------------- +# Content normalization for LLM invocation +# +# OpenAI's Chat Completions API rejects bare strings inside a list-shaped +# ``content`` — every list element must be a content-part dict with a +# ``type`` key. Anthropic's Messages API is more permissive and tolerates +# ``list[str | dict]``, but it also accepts the strict ``list[dict]`` +# form, so we normalize everything to ``list[dict]`` unconditionally +# before invoking. ``invoke`` / ``ainvoke`` are the wrappers every +# workflow LLM call should go through. +# --------------------------------------------------------------------------- + + +def _normalize_content(content: str | list[str | dict]) -> str | list[dict]: + """Promote bare strings inside a list-content to ``{"type": "text", + "text": s}`` dicts. ``str`` content (the single-text form) is + passed through unchanged.""" + if not isinstance(content, list): + return content + out: list[dict] = [] + for item in content: + if isinstance(item, str): + out.append({"type": "text", "text": item}) + else: + out.append(item) + return out + + +def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]: + """Return a list of messages whose ``content`` (where list-shaped) + has every bare string promoted to a text content-part dict. Each + affected message is copied; messages that already conform are + passed through unchanged so we don't churn references that the + caller may still hold.""" + out: list[BaseMessage] = [] + for m in messages: + content = m.content + if isinstance(content, list): + normalized = _normalize_content(content) + if normalized is not content: + m = m.model_copy(update={"content": normalized}) + out.append(m) + return out + + +def invoke( + llm: BaseChatModel | Runnable, + messages: Sequence[BaseMessage], + **kwargs, +) -> BaseMessage: + """Synchronous LLM invocation wrapper that normalizes message + content shapes before calling ``llm.invoke``. Use this in place of + ``llm.invoke(messages)`` everywhere a workflow talks to the model + — the normalization keeps OpenAI's Chat Completions happy and + leaves Anthropic's Messages API behavior unchanged.""" + return llm.invoke(_normalize_messages(messages), **kwargs) + + +async def ainvoke( + llm: BaseChatModel | Runnable, + messages: Sequence[BaseMessage], + **kwargs, +) -> BaseMessage: + """Async counterpart to :func:`invoke`.""" + return await llm.ainvoke(_normalize_messages(messages), **kwargs)