From 47456d83d82f35b1a411a6d76a4db17556b0dcbc Mon Sep 17 00:00:00 2001 From: Adam Gohain <68021524+akgohain@users.noreply.github.com> Date: Mon, 13 Apr 2026 16:05:58 -0400 Subject: [PATCH] Add workflow export bundle endpoint and tests --- server_api/main.py | 2 + server_api/workflow/__init__.py | 3 + server_api/workflow/models.py | 27 ++++++++ server_api/workflow/router.py | 25 +++++++ server_api/workflow/service.py | 56 ++++++++++++++++ tests/test_workflow_export_bundle.py | 99 ++++++++++++++++++++++++++++ 6 files changed, 212 insertions(+) create mode 100644 server_api/workflow/__init__.py create mode 100644 server_api/workflow/models.py create mode 100644 server_api/workflow/router.py create mode 100644 server_api/workflow/service.py create mode 100644 tests/test_workflow_export_bundle.py diff --git a/server_api/main.py b/server_api/main.py index a97153f..3666ec5 100644 --- a/server_api/main.py +++ b/server_api/main.py @@ -21,6 +21,7 @@ from server_api.auth.database import get_db from server_api.auth.router import get_current_user from server_api.ehtool import router as ehtool_router +from server_api.workflow import router as workflow_router from fastapi.staticfiles import StaticFiles import os @@ -77,6 +78,7 @@ def _ensure_chatbot(): app.include_router(auth_router.router) app.include_router(ehtool_router.router, prefix="/eh", tags=["ehtool"]) +app.include_router(workflow_router.router) app.add_middleware( CORSMiddleware, diff --git a/server_api/workflow/__init__.py b/server_api/workflow/__init__.py new file mode 100644 index 0000000..5bc0c2e --- /dev/null +++ b/server_api/workflow/__init__.py @@ -0,0 +1,3 @@ +from .router import router + +__all__ = ["router"] diff --git a/server_api/workflow/models.py b/server_api/workflow/models.py new file mode 100644 index 0000000..b8b5611 --- /dev/null +++ b/server_api/workflow/models.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + + +class WorkflowArtifactEntry(BaseModel): + path: str + exists: bool + + +class WorkflowEvent(BaseModel): + id: str + type: str + timestamp: datetime + payload: Dict[str, Any] = Field(default_factory=dict) + + +class WorkflowExportBundle(BaseModel): + schema_version: str + exported_at: datetime + workflow_id: int + session_snapshot: Dict[str, Any] + events: List[WorkflowEvent] + artifact_paths: List[WorkflowArtifactEntry] diff --git a/server_api/workflow/router.py b/server_api/workflow/router.py new file mode 100644 index 0000000..8e4db10 --- /dev/null +++ b/server_api/workflow/router.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + +from fastapi import APIRouter, HTTPException + +from .models import WorkflowExportBundle +from .service import build_export_bundle + +router = APIRouter() + + +# Placeholder repository interface. Tests patch this for deterministic fixtures. +def get_workflow_export_record(workflow_id: int) -> Optional[Dict[str, Any]]: + return None + + +@router.post("/api/workflows/{workflow_id}/export-bundle", response_model=WorkflowExportBundle) +def export_workflow_bundle(workflow_id: int): + record = get_workflow_export_record(workflow_id) + if record is None: + raise HTTPException(status_code=404, detail=f"Workflow {workflow_id} not found") + + bundle = build_export_bundle(workflow_id, record) + return bundle diff --git a/server_api/workflow/service.py b/server_api/workflow/service.py new file mode 100644 index 0000000..e56a612 --- /dev/null +++ b/server_api/workflow/service.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import os +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, List, Tuple + + +def _parse_timestamp(value: str) -> datetime: + if value.endswith("Z"): + value = value.replace("Z", "+00:00") + return datetime.fromisoformat(value) + + +def _event_sort_key(event: Dict[str, Any]) -> Tuple[datetime, str]: + timestamp = str(event.get("timestamp") or "1970-01-01T00:00:00+00:00") + event_id = str(event.get("id") or "") + return (_parse_timestamp(timestamp), event_id) + + +def _collect_paths(value: Any) -> Iterable[str]: + if isinstance(value, dict): + for key, inner in value.items(): + if isinstance(inner, str) and key.endswith("_path"): + yield inner + if key == "path" and isinstance(inner, str): + yield inner + yield from _collect_paths(inner) + elif isinstance(value, list): + for item in value: + yield from _collect_paths(item) + + +def build_export_bundle(workflow_id: int, record: Dict[str, Any]) -> Dict[str, Any]: + session_snapshot = dict(record.get("session_snapshot") or {}) + events = list(record.get("events") or []) + ordered_events = sorted(events, key=_event_sort_key) + + explicit_paths = record.get("artifact_paths") + if explicit_paths is None: + discovered = set(_collect_paths(session_snapshot)) + for event in ordered_events: + discovered.update(_collect_paths(event)) + artifact_paths = sorted(path for path in discovered if path) + else: + artifact_paths = sorted({str(path) for path in explicit_paths if path}) + + artifacts = [{"path": path, "exists": os.path.exists(path)} for path in artifact_paths] + + return { + "schema_version": "workflow-export-bundle/v1", + "exported_at": datetime.now(timezone.utc).isoformat(), + "workflow_id": workflow_id, + "session_snapshot": session_snapshot, + "events": ordered_events, + "artifact_paths": artifacts, + } diff --git a/tests/test_workflow_export_bundle.py b/tests/test_workflow_export_bundle.py new file mode 100644 index 0000000..b092856 --- /dev/null +++ b/tests/test_workflow_export_bundle.py @@ -0,0 +1,99 @@ +import pathlib +import tempfile +import unittest +from unittest.mock import patch + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +import importlib + +workflow_router_module = importlib.import_module("server_api.workflow.router") + + +class WorkflowExportBundleTests(unittest.TestCase): + def setUp(self): + app = FastAPI() + app.include_router(workflow_router_module.router) + self.client = TestClient(app) + + def test_export_bundle_happy_path_sorts_events_and_sets_file_flags(self): + with tempfile.TemporaryDirectory() as tmpdir: + existing = pathlib.Path(tmpdir) / "existing.zarr" + existing.write_text("ok") + missing = pathlib.Path(tmpdir) / "missing.zarr" + + record = { + "session_snapshot": { + "id": 7, + "name": "proofreading", + "primary_artifact_path": str(existing), + }, + "events": [ + { + "id": "evt-2", + "type": "annotation", + "timestamp": "2026-04-10T10:00:00+00:00", + "payload": {"artifact": {"path": str(missing)}}, + }, + { + "id": "evt-1", + "type": "start", + "timestamp": "2026-04-10T09:00:00+00:00", + "payload": {}, + }, + ], + } + + with patch.object( + workflow_router_module, + "get_workflow_export_record", + return_value=record, + ): + response = self.client.post("/api/workflows/7/export-bundle") + + self.assertEqual(response.status_code, 200) + data = response.json() + + self.assertEqual(data["schema_version"], "workflow-export-bundle/v1") + self.assertEqual(data["workflow_id"], 7) + self.assertEqual([e["id"] for e in data["events"]], ["evt-1", "evt-2"]) + + artifacts = {item["path"]: item["exists"] for item in data["artifact_paths"]} + self.assertTrue(artifacts[str(existing)]) + self.assertFalse(artifacts[str(missing)]) + + def test_export_bundle_uses_explicit_artifact_paths_and_missing_is_safe(self): + with tempfile.TemporaryDirectory() as tmpdir: + existing = pathlib.Path(tmpdir) / "proofread.tif" + existing.write_text("ok") + missing = pathlib.Path(tmpdir) / "not_here.tif" + + record = { + "session_snapshot": {"id": 10, "name": "workflow"}, + "events": [], + "artifact_paths": [str(missing), str(existing)], + } + + with patch.object( + workflow_router_module, + "get_workflow_export_record", + return_value=record, + ): + response = self.client.post("/api/workflows/10/export-bundle") + + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual([e for e in data["events"]], []) + + self.assertEqual( + [entry["path"] for entry in data["artifact_paths"]], + sorted([str(existing), str(missing)]), + ) + exists_flags = {entry["path"]: entry["exists"] for entry in data["artifact_paths"]} + self.assertTrue(exists_flags[str(existing)]) + self.assertFalse(exists_flags[str(missing)]) + + +if __name__ == "__main__": + unittest.main()