diff --git a/pyproject.toml b/pyproject.toml index a3b8cab..394004b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ license = { file = "LICENSE" } dependencies = [ "fastapi>=0.109.0", + "python-multipart>=0.0.9", "uvicorn>=0.27.0", "pydantic>=2.6.0", "pydantic-settings>=2.1.0", diff --git a/server.py b/server.py index 4bd1859..ff1025d 100644 --- a/server.py +++ b/server.py @@ -192,11 +192,13 @@ async def scrape_chat_link(req: ScrapeRequest): elapsed = round((time.perf_counter() - start) * 1000, 2) if not pairs: + from src.api.chat_share import scrape_failure_message + return JSONResponse( { "status": "error", "data": None, - "error": "Failed to extract messages from the provided link.", + "error": scrape_failure_message(result), "elapsed_ms": elapsed, }, status_code=400, @@ -554,7 +556,7 @@ def _render_chat_share_sync(url: str) -> tuple[str, str]: with sync_playwright() as p: browser = None launch_errors = [] - for channel in (None, "msedge", "chrome"): + for channel in ("chromium", None, "msedge", "chrome"): try: kwargs = {"headless": True} if channel: @@ -591,15 +593,15 @@ def _block_heavy_assets(route): page.route("**/*", _block_heavy_assets) try: - page.goto(url, wait_until="networkidle", timeout=20000) + page.goto(url, wait_until="domcontentloaded", timeout=30000) except Exception as exc: print(f"[scrape] navigation warning: {exc}", flush=True) provider = _detect_chat_provider(page.url or url) selector = { "chatgpt": "div[data-message-author-role]", - "claude": "script", - "gemini": "message-content, div.user-query, div.model-response", + "claude": "div[data-testid='user-message'], div.font-claude-response", + "gemini": "message-content, div.user-query, div.model-response, .query-text", }.get(provider) if selector: try: @@ -607,7 +609,7 @@ def _block_heavy_assets(route): except Exception as exc: print(f"[scrape] timed out waiting for {provider} content: {exc}", flush=True) - page.wait_for_timeout(2000) + page.wait_for_timeout(5000 if provider == "claude" else 2000) final_url = page.url html = page.content() finally: @@ -660,14 +662,37 @@ def _extract_chat_pairs(url: str, html: str) -> tuple[str, str, list[dict[str, s extraction_method = "structured" except Exception as exc: print(f"[scrape] Claude parse warning: {exc}", flush=True) + if not pairs: + user_msgs = soup.select("div[data-testid='user-message']") + asst_msgs = soup.select("div.font-claude-response") + for user_msg, assistant_msg in zip(user_msgs, asst_msgs): + pairs.append({ + "user_query": user_msg.get_text(separator="\n", strip=True), + "agent_response": assistant_msg.get_text(separator="\n", strip=True), + }) + if pairs: + extraction_method = "dom" elif provider == "gemini": - user_blocks = soup.select("message-content[role='user'], div.user-query") - model_blocks = soup.select("message-content[role='model'], div.model-response") + user_blocks = soup.select( + "message-content[role='user'], div.user-query, .query-text" + ) + model_blocks = soup.select( + "message-content[role='model'], div.model-response, " + "structured-content-container.message-content message-content, " + "message-content:not([role])" + ) for user_block, model_block in zip(user_blocks, model_blocks): + user_text = user_block.get_text(separator="\n", strip=True) + user_labels = {"you said", "your prompt", "あなたの入力", "あなたのプロンプト"} + user_lines = [ + line.strip() + for line in user_text.splitlines() + if line.strip() and line.strip().lower() not in user_labels + ] pairs.append({ - "user_query": user_block.get_text(separator="\n").strip(), - "agent_response": model_block.get_text(separator="\n").strip(), + "user_query": "\n".join(user_lines), + "agent_response": model_block.get_text(separator="\n", strip=True), }) if pairs: extraction_method = "dom" diff --git a/src/api/chat_share.py b/src/api/chat_share.py new file mode 100644 index 0000000..0c4cd17 --- /dev/null +++ b/src/api/chat_share.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Any + + +def scrape_failure_message(result: dict[str, Any]) -> str: + provider = result.get("provider") or "unknown" + + if provider in {"chatgpt", "claude", "gemini"}: + display_name = { + "chatgpt": "ChatGPT", + "claude": "Claude", + "gemini": "Gemini", + }[provider] + return ( + f"Could not extract messages from this {display_name} share link. " + "Make sure the link is public, still exists, and has not expired." + ) + + return ( + "Failed to extract messages from the provided link. " + "Supported public share links are ChatGPT, Claude, and Gemini." + ) diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index 3397a69..0579ba5 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -21,6 +21,7 @@ require_api_key, require_ready, ) +from src.api.chat_share import scrape_failure_message from src.api.schemas import ( APIResponse, BatchIngestRequest, @@ -154,7 +155,7 @@ def _get_or_create_browser(): _pw_instance = sync_playwright().start() launch_errors = [] - for channel in (None, "msedge", "chrome"): + for channel in ("chromium", None, "msedge", "chrome"): try: kwargs = {"headless": True} if channel: @@ -209,8 +210,8 @@ def _block_heavy_assets(route): provider = _detect_chat_provider(page.url or url) selector = { "chatgpt": "div[data-message-author-role]", - "claude": "script", - "gemini": "message-content, div.user-query, div.model-response", + "claude": "div[data-testid='user-message'], div.font-claude-response", + "gemini": "message-content, div.user-query, div.model-response, .query-text", }.get(provider) if selector: try: @@ -218,8 +219,8 @@ def _block_heavy_assets(route): except Exception as exc: logger.warning("Timed out waiting for %s content: %s", provider, exc) - # No hardcoded sleep — the selector wait above already guarantees - # the chat content DOM nodes are present. + if provider == "claude": + page.wait_for_timeout(5000) final_url = page.url html = page.content() @@ -273,14 +274,37 @@ def _extract_chat_pairs(url: str, html: str) -> tuple[str, str, List[MessagePair extraction_method = "structured" except Exception as exc: logger.warning("Failed to parse Claude preloaded state: %s", exc) + if not pairs: + user_msgs = soup.select("div[data-testid='user-message']") + asst_msgs = soup.select("div.font-claude-response") + for u, a in zip(user_msgs, asst_msgs): + pairs.append(MessagePair( + user_query=u.get_text(separator="\n", strip=True), + agent_response=a.get_text(separator="\n", strip=True), + )) + if pairs: + extraction_method = "dom" elif provider == "gemini": - user_blocks = soup.select("message-content[role='user'], div.user-query") - model_blocks = soup.select("message-content[role='model'], div.model-response") + user_blocks = soup.select( + "message-content[role='user'], div.user-query, .query-text" + ) + model_blocks = soup.select( + "message-content[role='model'], div.model-response, " + "structured-content-container.message-content message-content, " + "message-content:not([role])" + ) for u, m in zip(user_blocks, model_blocks): + user_text = u.get_text(separator="\n", strip=True) + user_labels = {"you said", "your prompt", "あなたの入力", "あなたのプロンプト"} + user_lines = [ + line.strip() + for line in user_text.splitlines() + if line.strip() and line.strip().lower() not in user_labels + ] pairs.append(MessagePair( - user_query=u.get_text(separator="\n").strip(), - agent_response=m.get_text(separator="\n").strip(), + user_query="\n".join(user_lines), + agent_response=m.get_text(separator="\n", strip=True), )) if pairs: extraction_method = "dom" @@ -757,7 +781,8 @@ async def scrape_chat_link(req: ScrapeRequest, request: Request): pairs = result["pairs"] if not pairs: - return _error(request, "Failed to extract messages from the provided link.", 400) + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _error(request, scrape_failure_message(result), 400, elapsed) data = ScrapeResponse(pairs=pairs) elapsed = round((time.perf_counter() - start) * 1000, 2) diff --git a/tests/test_chat_share_extraction.py b/tests/test_chat_share_extraction.py new file mode 100644 index 0000000..1a70b67 --- /dev/null +++ b/tests/test_chat_share_extraction.py @@ -0,0 +1,182 @@ +import json +import os +from types import SimpleNamespace + +os.environ.setdefault("PINECONE_API_KEY", "test-pinecone-key") +os.environ.setdefault("NEO4J_PASSWORD", "test-neo4j-password") +os.environ.setdefault("GEMINI_API_KEY", "test-gemini-key") + +from src.api.chat_share import scrape_failure_message +from src.api.routes.memory import ( + _detect_chat_provider, + _extract_chat_pairs, + scrape_chat_link, +) +from src.api.schemas import MessagePair, ScrapeRequest + + +def test_detects_supported_chat_share_providers(): + assert _detect_chat_provider("https://chatgpt.com/share/abc") == "chatgpt" + assert _detect_chat_provider("https://chat.openai.com/share/abc") == "chatgpt" + assert _detect_chat_provider("https://claude.ai/share/abc") == "claude" + assert _detect_chat_provider("https://gemini.google.com/share/abc") == "gemini" + assert _detect_chat_provider("https://g.co/gemini/share/abc") == "gemini" + + +def test_extracts_chatgpt_dom_pairs(): + html = """ +
What is XMem?
+
A long-term memory layer.
+ """ + + provider, method, pairs = _extract_chat_pairs("https://chatgpt.com/share/abc", html) + + assert provider == "chatgpt" + assert method == "dom" + assert pairs == [ + MessagePair( + user_query="What is XMem?", + agent_response="A long-term memory layer.", + ) + ] + + +def test_extracts_claude_preloaded_state_pairs(): + state = { + "chat": { + "messages": [ + {"sender": "human", "text": "Summarize this repo."}, + {"sender": "assistant", "text": "It stores memories for agents."}, + ] + } + } + html = ( + "" + ) + + provider, method, pairs = _extract_chat_pairs("https://claude.ai/share/abc", html) + + assert provider == "claude" + assert method == "structured" + assert pairs == [ + MessagePair( + user_query="Summarize this repo.", + agent_response="It stores memories for agents.", + ) + ] + + +def test_extracts_claude_current_public_share_dom_pairs(): + html = """ +
+

test test

+
+
+
+
+

Hey! I'm here and working.

+
+
+
+ """ + + provider, method, pairs = _extract_chat_pairs("https://claude.ai/share/abc", html) + + assert provider == "claude" + assert method == "dom" + assert pairs == [ + MessagePair( + user_query="test test", + agent_response="Hey! I'm here and working.", + ) + ] + + +def test_extracts_gemini_dom_pairs(): + html = """ + Compare memory tools. + XMem focuses on persistent agent memory. + """ + + provider, method, pairs = _extract_chat_pairs( + "https://gemini.google.com/share/abc", + html, + ) + + assert provider == "gemini" + assert method == "dom" + assert pairs == [ + MessagePair( + user_query="Compare memory tools.", + agent_response="XMem focuses on persistent agent memory.", + ) + ] + + +def test_extracts_gemini_current_public_share_dom_pairs(): + html = """ +
+ You said +

Test test

+
+ + +
+

Loud and clear! I'm here and ready to roll.

+
+
+
+ """ + + provider, method, pairs = _extract_chat_pairs( + "https://gemini.google.com/share/abc", + html, + ) + + assert provider == "gemini" + assert method == "dom" + assert pairs == [ + MessagePair( + user_query="Test test", + agent_response="Loud and clear! I'm here and ready to roll.", + ) + ] + + +def test_scrape_failure_message_names_private_or_missing_provider_links(): + message = scrape_failure_message({"provider": "claude"}) + + assert "Claude share link" in message + assert "public" in message + assert "expired" in message + + +def test_scrape_failure_message_lists_supported_unknown_links(): + message = scrape_failure_message({"provider": "unknown"}) + + assert "Supported public share links" in message + assert "ChatGPT" in message + assert "Claude" in message + assert "Gemini" in message + + +async def test_scrape_route_failure_uses_elapsed_ms(monkeypatch): + async def fake_scrape(url: str): + return {"provider": "gemini", "pairs": []} + + ticks = iter([10.0, 10.12345]) + + monkeypatch.setattr("src.api.routes.memory._scrape_chat_share", fake_scrape) + monkeypatch.setattr("src.api.routes.memory.time.perf_counter", lambda: next(ticks)) + + response = await scrape_chat_link( + ScrapeRequest(url="https://gemini.google.com/share/abc"), + SimpleNamespace(state=SimpleNamespace(request_id="req-test")), + ) + body = json.loads(response.body) + + assert response.status_code == 400 + assert body["elapsed_ms"] == 123.45 + assert "Gemini share link" in body["error"]