From 7955df947940a58b1c94affba6d0c60e989af450 Mon Sep 17 00:00:00 2001 From: Anmol Jaiswal <68013660+anmolg1997@users.noreply.github.com> Date: Wed, 15 Apr 2026 09:21:22 +0530 Subject: [PATCH] feat(memory): add DatabaseMemoryService with SQLAlchemy async backend Adds a durable, SQL-backed memory service that works with any SQLAlchemy-compatible async database (SQLite, PostgreSQL, MySQL/MariaDB). This fills the gap between the volatile InMemoryMemoryService and the cloud-only Firestore/Vertex AI options, giving self-hosted deployments a persistent memory backend with zero cloud dependencies. The implementation mirrors the keyword-extraction approach used by FirestoreMemoryService and reuses the existing SQLAlchemy patterns established by DatabaseSessionService. Closes #2524 Closes #2976 --- src/google/adk/memory/__init__.py | 16 + src/google/adk/memory/_memory_schemas.py | 65 +++ .../adk/memory/database_memory_service.py | 459 ++++++++++++++++++ .../memory/test_database_memory_service.py | 367 ++++++++++++++ 4 files changed, 907 insertions(+) create mode 100644 src/google/adk/memory/_memory_schemas.py create mode 100644 src/google/adk/memory/database_memory_service.py create mode 100644 tests/unittests/memory/test_database_memory_service.py diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index c47fb8ec40..220f434bf0 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -21,10 +21,26 @@ __all__ = [ 'BaseMemoryService', + 'DatabaseMemoryService', 'InMemoryMemoryService', 'VertexAiMemoryBankService', ] + +def __getattr__(name: str): + if name == 'DatabaseMemoryService': + try: + from .database_memory_service import DatabaseMemoryService + + return DatabaseMemoryService + except ImportError as e: + raise ImportError( + 'DatabaseMemoryService requires sqlalchemy>=2.0, please ensure it is' + ' installed correctly.' + ) from e + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') + + try: from .vertex_ai_rag_memory_service import VertexAiRagMemoryService diff --git a/src/google/adk/memory/_memory_schemas.py b/src/google/adk/memory/_memory_schemas.py new file mode 100644 index 0000000000..6c814e07e9 --- /dev/null +++ b/src/google/adk/memory/_memory_schemas.py @@ -0,0 +1,65 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQLAlchemy ORM models for DatabaseMemoryService.""" + +from __future__ import annotations + +from typing import Any +from typing import Optional + +from sqlalchemy import Float +from sqlalchemy import Index +from sqlalchemy import Text +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.types import String + +from ..sessions.schemas.shared import DynamicJSON + +_MAX_KEY_LENGTH = 128 + + +class Base(DeclarativeBase): + """Base class for memory database tables.""" + + pass + + +class StorageMemoryEntry(Base): + """Represents a single memory entry stored in the database.""" + + __tablename__ = "adk_memory_entries" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + + app_name: Mapped[str] = mapped_column(String(_MAX_KEY_LENGTH), index=True) + user_id: Mapped[str] = mapped_column(String(_MAX_KEY_LENGTH), index=True) + + keywords: Mapped[str] = mapped_column(Text) + + author: Mapped[Optional[str]] = mapped_column( + String(_MAX_KEY_LENGTH), nullable=True + ) + content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON) + timestamp: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + + __table_args__ = (Index("idx_memory_app_user", "app_name", "user_id"),) + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/src/google/adk/memory/database_memory_service.py b/src/google/adk/memory/database_memory_service.py new file mode 100644 index 0000000000..24e61b9205 --- /dev/null +++ b/src/google/adk/memory/database_memory_service.py @@ -0,0 +1,459 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A durable memory service backed by any SQLAlchemy-compatible async database. + +Supported dialects include SQLite (via ``aiosqlite``), PostgreSQL (via +``asyncpg``), MySQL / MariaDB, and any other database for which an async +SQLAlchemy driver is available. + +The implementation mirrors the keyword-extraction approach used by +:class:`~google.adk.integrations.firestore.FirestoreMemoryService` but +stores memories in a relational table managed by SQLAlchemy. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from collections.abc import Sequence +from contextlib import asynccontextmanager +import logging +import re +from typing import Any +from typing import AsyncIterator +from typing import Optional +from typing import TYPE_CHECKING + +from sqlalchemy import select +from sqlalchemy.engine import make_url +from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.pool import StaticPool +from typing_extensions import override + +from . import _utils +from ._memory_schemas import Base +from ._memory_schemas import StorageMemoryEntry +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..events.event import Event + from ..sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +_SQLITE_DIALECT = "sqlite" + +DEFAULT_STOP_WORDS = { + "a", + "about", + "above", + "after", + "again", + "against", + "all", + "am", + "an", + "and", + "any", + "are", + "as", + "at", + "be", + "because", + "been", + "before", + "being", + "below", + "between", + "both", + "but", + "by", + "can", + "could", + "did", + "do", + "does", + "doing", + "don", + "down", + "during", + "each", + "else", + "few", + "for", + "from", + "further", + "had", + "has", + "have", + "having", + "he", + "her", + "here", + "hers", + "herself", + "him", + "himself", + "his", + "how", + "i", + "if", + "in", + "into", + "is", + "it", + "its", + "itself", + "just", + "may", + "me", + "might", + "more", + "most", + "must", + "my", + "myself", + "no", + "nor", + "not", + "now", + "of", + "off", + "on", + "once", + "only", + "or", + "other", + "our", + "ours", + "ourselves", + "out", + "over", + "own", + "s", + "same", + "shall", + "she", + "should", + "so", + "some", + "such", + "t", + "than", + "that", + "the", + "their", + "theirs", + "them", + "themselves", + "then", + "there", + "these", + "they", + "this", + "those", + "through", + "to", + "too", + "under", + "until", + "up", + "very", + "was", + "we", + "were", + "what", + "when", + "where", + "which", + "who", + "whom", + "why", + "will", + "with", + "would", + "you", + "your", + "yours", + "yourself", + "yourselves", +} + + +class DatabaseMemoryService(BaseMemoryService): # type: ignore[misc] + """Memory service backed by any SQLAlchemy-compatible async database. + + Uses keyword extraction (identical to the Firestore memory service) to + index session events, and keyword matching to search for relevant memories. + + Example usage:: + + # SQLite (zero-config) + memory = DatabaseMemoryService("sqlite+aiosqlite:///memory.db") + + # PostgreSQL + memory = DatabaseMemoryService( + "postgresql+asyncpg://user:pass@host/dbname" + ) + + async with memory: + await memory.add_session_to_memory(session) + result = await memory.search_memory( + app_name="my_app", user_id="u1", query="agent tool usage" + ) + """ + + def __init__( + self, + db_url: str, + *, + stop_words: Optional[set[str]] = None, + **engine_kwargs: Any, + ): + """Initializes the database memory service. + + Args: + db_url: A SQLAlchemy-compatible async database URL. + stop_words: A set of words to ignore when extracting keywords. Defaults + to a standard English stop-words list. + **engine_kwargs: Additional keyword arguments forwarded to + ``create_async_engine``. + """ + try: + url = make_url(db_url) + if ( + url.get_backend_name() == _SQLITE_DIALECT + and url.database == ":memory:" + ): + engine_kwargs.setdefault("poolclass", StaticPool) + connect_args = dict(engine_kwargs.get("connect_args", {})) + connect_args.setdefault("check_same_thread", False) + engine_kwargs["connect_args"] = connect_args + elif url.get_backend_name() != _SQLITE_DIALECT: + engine_kwargs.setdefault("pool_pre_ping", True) + + self._engine: AsyncEngine = create_async_engine(db_url, **engine_kwargs) + except Exception as e: + if isinstance(e, ArgumentError): + raise ValueError( + f"Invalid database URL format or argument '{db_url}'." + ) from e + if isinstance(e, ImportError): + raise ValueError( + f"Database related module not found for URL '{db_url}'." + ) from e + raise ValueError( + f"Failed to create database engine for URL '{db_url}'" + ) from e + + self._session_factory: async_sessionmaker[DatabaseSessionFactory] = ( + async_sessionmaker(bind=self._engine, expire_on_commit=False) + ) + + self._tables_created = False + self._table_creation_lock = asyncio.Lock() + + self.stop_words = ( + stop_words if stop_words is not None else DEFAULT_STOP_WORDS + ) + + # -- lifecycle helpers ---------------------------------------------------- + + @asynccontextmanager + async def _db_session(self) -> AsyncIterator[DatabaseSessionFactory]: + """Yields a DB session with guaranteed rollback on errors.""" + async with self._session_factory() as sql_session: + try: + yield sql_session + except BaseException: + await sql_session.rollback() + raise + + async def _prepare_tables(self) -> None: + """Lazily create the memory table (double-checked locking).""" + if self._tables_created: + return + async with self._table_creation_lock: + if self._tables_created: + return + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + self._tables_created = True + + async def close(self) -> None: + """Disposes the SQLAlchemy engine and closes pooled connections.""" + await self._engine.dispose() + + async def __aenter__(self) -> DatabaseMemoryService: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() + + # -- keyword helpers ------------------------------------------------------ + + def _extract_keywords(self, text: str) -> set[str]: + """Extracts keywords from *text*, ignoring stop words.""" + words = re.findall(r"[A-Za-z]+", text.lower()) + return {word for word in words if word not in self.stop_words} + + # -- BaseMemoryService implementation ------------------------------------- + + @override + async def add_session_to_memory(self, session: Session) -> None: + """Extracts keywords from session events and persists them.""" + await self._prepare_tables() + entries: list[StorageMemoryEntry] = [] + + for event in session.events: + if not event.content or not event.content.parts: + continue + + text = " ".join([part.text for part in event.content.parts if part.text]) + if not text: + continue + + keywords = self._extract_keywords(text) + if not keywords: + continue + + entries.append( + StorageMemoryEntry( + app_name=session.app_name, + user_id=session.user_id, + keywords=" ".join(sorted(keywords)), + author=event.author, + content=event.content.model_dump(exclude_none=True, mode="json"), + timestamp=event.timestamp, + ) + ) + + if not entries: + return + + async with self._db_session() as sql_session: + sql_session.add_all(entries) + await sql_session.commit() + + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: str | None = None, + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + """Adds an explicit list of events as memory entries (delta ingestion).""" + await self._prepare_tables() + entries: list[StorageMemoryEntry] = [] + + for event in events: + if not event.content or not event.content.parts: + continue + + text = " ".join([part.text for part in event.content.parts if part.text]) + if not text: + continue + + keywords = self._extract_keywords(text) + if not keywords: + continue + + entries.append( + StorageMemoryEntry( + app_name=app_name, + user_id=user_id, + keywords=" ".join(sorted(keywords)), + author=event.author, + content=event.content.model_dump(exclude_none=True, mode="json"), + timestamp=event.timestamp, + ) + ) + + if not entries: + return + + async with self._db_session() as sql_session: + sql_session.add_all(entries) + await sql_session.commit() + + @override + async def search_memory( + self, + *, + app_name: str, + user_id: str, + query: str, + ) -> SearchMemoryResponse: + """Searches memory for entries whose keywords overlap with *query*.""" + keywords = self._extract_keywords(query) + if not keywords: + return SearchMemoryResponse() + + await self._prepare_tables() + async with self._db_session() as sql_session: + stmt = ( + select(StorageMemoryEntry) + .filter(StorageMemoryEntry.app_name == app_name) + .filter(StorageMemoryEntry.user_id == user_id) + ) + result = await sql_session.execute(stmt) + all_rows = result.scalars().all() + + seen: set[tuple[str | None, str, str | None]] = set() + memories: list[MemoryEntry] = [] + + for row in all_rows: + stored_keywords = set(row.keywords.split()) + if not stored_keywords.intersection(keywords): + continue + + try: + from google.genai import types + + content = types.Content.model_validate(row.content) + except Exception as e: + logger.warning(f"Failed to parse memory entry: {e}") + continue + + content_text = "" + if content.parts: + content_text = " ".join( + [part.text for part in content.parts if part.text] + ) + + timestamp_str = _utils.format_timestamp(row.timestamp or 0.0) + dedup_key = (row.author, content_text, timestamp_str) + if dedup_key in seen: + continue + seen.add(dedup_key) + + memories.append( + MemoryEntry( + content=content, + author=row.author or "", + timestamp=timestamp_str, + ) + ) + + return SearchMemoryResponse(memories=memories) diff --git a/tests/unittests/memory/test_database_memory_service.py b/tests/unittests/memory/test_database_memory_service.py new file mode 100644 index 0000000000..5a92c609fd --- /dev/null +++ b/tests/unittests/memory/test_database_memory_service.py @@ -0,0 +1,367 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.events.event import Event +from google.adk.memory.database_memory_service import DatabaseMemoryService +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +MOCK_APP_NAME = "test-app" +MOCK_USER_ID = "test-user" +MOCK_OTHER_USER_ID = "another-user" + +MOCK_SESSION_1 = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id="session-1", + last_update_time=1000, + events=[ + Event( + id="event-1a", + invocation_id="inv-1", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="The ADK is a great toolkit.")] + ), + ), + Event( + id="event-1b", + invocation_id="inv-2", + author="user", + timestamp=12346, + ), + Event( + id="event-1c", + invocation_id="inv-3", + author="model", + timestamp=12347, + content=types.Content( + parts=[ + types.Part( + text="I agree. The Agent Development Kit (ADK) rocks!" + ) + ] + ), + ), + ], +) + +MOCK_SESSION_2 = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id="session-2", + last_update_time=2000, + events=[ + Event( + id="event-2a", + invocation_id="inv-4", + author="user", + timestamp=54321, + content=types.Content( + parts=[types.Part(text="I like to code in Python.")] + ), + ), + ], +) + +MOCK_SESSION_DIFFERENT_USER = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_OTHER_USER_ID, + id="session-3", + last_update_time=3000, + events=[ + Event( + id="event-3a", + invocation_id="inv-5", + author="user", + timestamp=60000, + content=types.Content(parts=[types.Part(text="This is a secret.")]), + ), + ], +) + +MOCK_SESSION_WITH_NO_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id="session-4", + last_update_time=4000, +) + + +@pytest.fixture +async def memory_service(): + service = DatabaseMemoryService("sqlite+aiosqlite:///:memory:") + async with service: + yield service + + +def test_extract_keywords(): + service = DatabaseMemoryService("sqlite+aiosqlite:///:memory:") + text = "The quick brown fox jumps over the lazy dog." + keywords = service._extract_keywords(text) + + assert "the" not in keywords + assert "over" not in keywords + assert "quick" in keywords + assert "brown" in keywords + assert "fox" in keywords + assert "jumps" in keywords + assert "lazy" in keywords + assert "dog" in keywords + + +@pytest.mark.asyncio +async def test_add_session_to_memory(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK toolkit" + ) + assert result.memories + assert len(result.memories) >= 1 + + +@pytest.mark.asyncio +async def test_add_session_with_no_events(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_NO_EVENTS) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="anything" + ) + assert not result.memories + + +@pytest.mark.asyncio +async def test_add_session_to_memory_filters_no_content_events(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK" + ) + assert len(result.memories) == 2 + authors = {m.author for m in result.memories} + assert "user" in authors + assert "model" in authors + + +@pytest.mark.asyncio +async def test_add_session_to_memory_skips_stop_words_only(memory_service): + session = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id="session-stop", + events=[ + Event( + id="e-stop", + invocation_id="inv-stop", + author="user", + content=types.Content(parts=[types.Part(text="the and or")]), + ), + ], + ) + await memory_service.add_session_to_memory(session) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="random" + ) + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_empty_query(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="" + ) + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_only_stop_words(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="the and or" + ) + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_simple_match(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_2) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="Python" + ) + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == "I like to code in Python." + assert result.memories[0].author == "user" + + +@pytest.mark.asyncio +async def test_search_memory_case_insensitive(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="development" + ) + assert len(result.memories) == 1 + assert ( + result.memories[0].content.parts[0].text + == "I agree. The Agent Development Kit (ADK) rocks!" + ) + + +@pytest.mark.asyncio +async def test_search_memory_multiple_matches(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK" + ) + assert len(result.memories) == 2 + texts = {m.content.parts[0].text for m in result.memories} + assert "The ADK is a great toolkit." in texts + assert "I agree. The Agent Development Kit (ADK) rocks!" in texts + + +@pytest.mark.asyncio +async def test_search_memory_no_match(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="nonexistent" + ) + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_scoped_by_user(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_DIFFERENT_USER) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="secret" + ) + assert not result.memories + + result_other = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_OTHER_USER_ID, query="secret" + ) + assert len(result_other.memories) == 1 + assert result_other.memories[0].content.parts[0].text == "This is a secret." + + +@pytest.mark.asyncio +async def test_search_memory_deduplication(memory_service): + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK" + ) + assert len(result.memories) == 2 + + +@pytest.mark.asyncio +async def test_add_events_to_memory(memory_service): + events = [MOCK_SESSION_1.events[0]] + await memory_service.add_events_to_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + session_id="session-1", + events=events, + ) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK toolkit" + ) + assert len(result.memories) == 1 + assert result.memories[0].author == "user" + + +@pytest.mark.asyncio +async def test_add_events_to_memory_without_session_id(memory_service): + events = [MOCK_SESSION_2.events[0]] + await memory_service.add_events_to_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + events=events, + ) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="Python" + ) + assert len(result.memories) == 1 + + +@pytest.mark.asyncio +async def test_add_events_to_memory_skips_empty(memory_service): + events = [ + Event( + id="e-empty", + invocation_id="inv-empty", + author="user", + timestamp=12345, + ), + ] + await memory_service.add_events_to_memory( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + events=events, + ) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="anything" + ) + assert not result.memories + + +@pytest.mark.asyncio +async def test_context_manager(): + service = DatabaseMemoryService("sqlite+aiosqlite:///:memory:") + async with service: + await service.add_session_to_memory(MOCK_SESSION_1) + result = await service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK" + ) + assert result.memories + + +@pytest.mark.asyncio +async def test_file_based_sqlite(tmp_path): + db_path = tmp_path / "test_memory.db" + async with DatabaseMemoryService(f"sqlite+aiosqlite:///{db_path}") as service: + await service.add_session_to_memory(MOCK_SESSION_1) + result = await service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK" + ) + assert len(result.memories) == 2 + + async with DatabaseMemoryService( + f"sqlite+aiosqlite:///{db_path}" + ) as service2: + result = await service2.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK" + ) + assert len(result.memories) == 2 + + +def test_invalid_db_url(): + with pytest.raises(ValueError, match="Invalid database URL"): + DatabaseMemoryService("not-a-valid-url://")