From 0ba1f29ea00bce578b869db0dc54dd84597c8272 Mon Sep 17 00:00:00 2001 From: Adam Gohain Date: Sun, 12 Apr 2026 20:57:52 -0400 Subject: [PATCH 1/6] Add backend workflow spine (cherry picked from commit 2b48cf5d5cc90aa95762ffd5f9c7de020c60065e) --- pytest.ini | 8 + server_api/ehtool/db_models.py | 1 + server_api/ehtool/models.py | 1 + server_api/ehtool/router.py | 97 +++++++ server_api/main.py | 131 ++++++++- server_api/workflows/__init__.py | 2 + server_api/workflows/db_models.py | 55 ++++ server_api/workflows/router.py | 439 ++++++++++++++++++++++++++++++ server_api/workflows/service.py | 246 +++++++++++++++++ tests/test_workflow_routes.py | 227 +++++++++++++++ 10 files changed, 1204 insertions(+), 3 deletions(-) create mode 100644 pytest.ini create mode 100644 server_api/workflows/__init__.py create mode 100644 server_api/workflows/db_models.py create mode 100644 server_api/workflows/router.py create mode 100644 server_api/workflows/service.py create mode 100644 tests/test_workflow_routes.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..081f5cd --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +testpaths = tests +pythonpath = . +norecursedirs = + .git + .venv + client + pytorch_connectomics diff --git a/server_api/ehtool/db_models.py b/server_api/ehtool/db_models.py index 475d465..38daab4 100644 --- a/server_api/ehtool/db_models.py +++ b/server_api/ehtool/db_models.py @@ -19,6 +19,7 @@ class EHToolSession(Base): dataset_path = Column(String) # Path to uploaded dataset mask_path = Column(String, nullable=True) # Path to mask dataset (optional) total_layers = Column(Integer, default=0) + workflow_id = Column(Integer, nullable=True, index=True) created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) diff --git a/server_api/ehtool/models.py b/server_api/ehtool/models.py index d1b88c4..d1ec7ac 100644 --- a/server_api/ehtool/models.py +++ b/server_api/ehtool/models.py @@ -15,6 +15,7 @@ class DetectionLoadRequest(BaseModel): dataset_path: str mask_path: Optional[str] = None project_name: str = "Untitled Project" + workflow_id: Optional[int] = None @field_validator("dataset_path", "mask_path") @classmethod diff --git a/server_api/ehtool/router.py b/server_api/ehtool/router.py index 2a93923..567ce39 100644 --- a/server_api/ehtool/router.py +++ b/server_api/ehtool/router.py @@ -36,6 +36,12 @@ from .db_models import EHToolSession, EHToolLayer from .data_manager import DataManager from .utils import array_to_base64, glasbey_color +from server_api.workflows.service import ( + append_event_for_workflow_if_present, + append_workflow_event, + get_user_workflow_or_404, + update_workflow_fields, +) router = APIRouter() @@ -86,6 +92,12 @@ async def load_detection_dataset( db: Session = Depends(get_db), ): try: + workflow = None + if request.workflow_id: + workflow = get_user_workflow_or_404( + db, workflow_id=request.workflow_id, user_id=current_user.id + ) + # Create DataManager and load dataset data_manager = DataManager() dataset_info = data_manager.load_dataset( @@ -101,6 +113,7 @@ async def load_detection_dataset( dataset_path=request.dataset_path, mask_path=request.mask_path, total_layers=dataset_info["total_layers"], + workflow_id=request.workflow_id, ) db.add(db_session) db.commit() @@ -122,6 +135,47 @@ async def load_detection_dataset( # Cache DataManager _data_managers[db_session.id] = data_manager + if workflow: + update_workflow_fields( + db, + workflow, + { + "stage": "proofreading", + "title": request.project_name or workflow.title, + "dataset_path": request.dataset_path, + "image_path": request.dataset_path, + "mask_path": request.mask_path, + "proofreading_session_id": db_session.id, + }, + commit=True, + ) + append_workflow_event( + db, + workflow_id=workflow.id, + actor="user", + event_type="dataset.loaded", + stage="proofreading", + summary=f"Loaded dataset for proofreading: {request.project_name}", + payload={ + "dataset_path": request.dataset_path, + "mask_path": request.mask_path, + "total_layers": dataset_info["total_layers"], + "ehtool_session_id": db_session.id, + }, + ) + append_workflow_event( + db, + workflow_id=workflow.id, + actor="system", + event_type="proofreading.session_loaded", + stage="proofreading", + summary="Mask proofreading session linked to workflow.", + payload={ + "ehtool_session_id": db_session.id, + "project_name": request.project_name, + }, + ) + return DetectionLoadResponse( session_id=db_session.id, total_layers=dataset_info["total_layers"], @@ -613,6 +667,20 @@ async def classify_instances( ui_state = request.ui_state.dict() if request.ui_state else None data_manager.save_progress(ui_state=ui_state) + append_event_for_workflow_if_present( + db, + workflow_id=db_session.workflow_id, + actor="user", + event_type="proofreading.instance_classified", + stage="proofreading", + summary=f"Classified {updated} instance(s) as {request.classification}.", + payload={ + "ehtool_session_id": request.session_id, + "instance_ids": request.instance_ids, + "classification": request.classification, + "updated_count": updated, + }, + ) return ClassifyResponse( updated_count=updated, @@ -754,6 +822,20 @@ async def save_instance_mask( ) if request.ui_state: data_manager.save_progress(ui_state=request.ui_state.dict()) + append_event_for_workflow_if_present( + db, + workflow_id=db_session.workflow_id, + actor="user", + event_type="proofreading.mask_saved", + stage="proofreading", + summary=f"Saved mask edit for instance {request.instance_id}.", + payload={ + "ehtool_session_id": request.session_id, + "instance_id": request.instance_id, + "axis": request.axis, + "z_index": request.z_index, + }, + ) return {"message": "Instance mask saved successfully"} except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @@ -794,6 +876,21 @@ async def export_masks( output_path=request.output_path, create_backup=request.create_backup, ) + append_event_for_workflow_if_present( + db, + workflow_id=db_session.workflow_id, + actor="user", + event_type="proofreading.masks_exported", + stage="proofreading", + summary=f"Exported edited masks to {result['written_path']}.", + payload={ + "ehtool_session_id": request.session_id, + "mode": request.mode, + "written_path": result["written_path"], + "backup_path": result.get("backup_path"), + "timestamp": result["timestamp"], + }, + ) return ExportMasksResponse( message=result["message"], written_path=result["written_path"], diff --git a/server_api/main.py b/server_api/main.py index a97153f..8c33c79 100644 --- a/server_api/main.py +++ b/server_api/main.py @@ -10,6 +10,7 @@ import uvicorn from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy import inspect, text from sqlalchemy.orm import Session from runtime_settings import ( get_allowed_origins, @@ -21,6 +22,12 @@ from server_api.auth.database import get_db from server_api.auth.router import get_current_user from server_api.ehtool import router as ehtool_router +from server_api.workflows import router as workflow_router +from server_api.workflows.service import ( + append_event_for_workflow_if_present, + get_user_workflow_or_404, + update_workflow_fields, +) from fastapi.staticfiles import StaticFiles import os @@ -69,6 +76,22 @@ def _ensure_chatbot(): models.Base.metadata.create_all(bind=database.engine) + +def _ensure_sqlite_column(table_name: str, column_name: str, ddl: str) -> None: + if database.engine.dialect.name != "sqlite": + return + inspector = inspect(database.engine) + if table_name not in inspector.get_table_names(): + return + existing = {column["name"] for column in inspector.get_columns(table_name)} + if column_name in existing: + return + with database.engine.begin() as connection: + connection.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {ddl}")) + + +_ensure_sqlite_column("ehtool_sessions", "workflow_id", "workflow_id INTEGER") + app = FastAPI() # Ensure uploads directory exists @@ -77,6 +100,7 @@ def _ensure_chatbot(): app.include_router(auth_router.router) app.include_router(ehtool_router.router, prefix="/eh", tags=["ehtool"]) +app.include_router(workflow_router.router, prefix="/api/workflows", tags=["workflows"]) app.add_middleware( CORSMiddleware, @@ -363,7 +387,11 @@ def _is_probable_label_volume(image_array) -> bool: @app.post("/neuroglancer") -async def neuroglancer(req: Request): +async def neuroglancer( + req: Request, + current_user: models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): import neuroglancer cleanup_paths: List[pathlib.Path] = [] @@ -383,6 +411,8 @@ async def neuroglancer(req: Request): raise HTTPException( status_code=400, detail="Scales payload is invalid." ) + workflow_id_raw = form.get("workflow_id") + workflow_id = int(workflow_id_raw) if workflow_id_raw else None image = save_upload_to_tempfile(image_upload) cleanup_paths.append(image) @@ -397,6 +427,7 @@ async def neuroglancer(req: Request): image = process_path(payload["image"]) label = process_path(payload.get("label")) scales = payload["scales"] + workflow_id = payload.get("workflow_id") or payload.get("workflowId") print(image, label, scales) @@ -433,6 +464,35 @@ def ngLayer(data, res, oo=[0, 0, 0], tt="segmentation"): s.layers.append(name="gt", layer=ngLayer(gt, res, tt="segmentation")) public_url = _build_neuroglancer_public_url(str(viewer), req) + if workflow_id: + workflow = get_user_workflow_or_404( + db, workflow_id=int(workflow_id), user_id=current_user.id + ) + update_workflow_fields( + db, + workflow, + { + "stage": "visualization", + "image_path": str(image), + "label_path": str(label) if label else None, + "neuroglancer_url": public_url, + }, + commit=True, + ) + append_event_for_workflow_if_present( + db, + workflow_id=workflow.id, + actor="user", + event_type="viewer.created", + stage="visualization", + summary="Created Neuroglancer viewer.", + payload={ + "image_path": str(image), + "label_path": str(label) if label else None, + "scales": scales, + "neuroglancer_url": public_url, + }, + ) print(public_url) return public_url finally: @@ -446,8 +506,39 @@ def ngLayer(data, res, oo=[0, 0, 0], tt="segmentation"): @app.post("/start_model_training") -async def start_model_training(req: Request): +async def start_model_training( + req: Request, + current_user: models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): body = await req.json() + workflow_id = body.get("workflow_id") or body.get("workflowId") + if workflow_id: + workflow = get_user_workflow_or_404( + db, workflow_id=int(workflow_id), user_id=current_user.id + ) + update_workflow_fields( + db, + workflow, + { + "stage": "retraining_staged", + "training_output_path": body.get("outputPath"), + }, + commit=True, + ) + append_event_for_workflow_if_present( + db, + workflow_id=workflow.id, + actor="user", + event_type="training.started", + stage=workflow.stage, + summary="Started model training from the workflow.", + payload={ + "outputPath": body.get("outputPath"), + "logPath": body.get("logPath"), + "configOriginPath": body.get("configOriginPath"), + }, + ) worker_data = _proxy_to_worker( "post", "/start_model_training", @@ -481,8 +572,42 @@ async def get_training_logs(): @app.post("/start_model_inference") -async def start_model_inference(req: Request): +async def start_model_inference( + req: Request, + current_user: models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): body = await req.json() + workflow_id = body.get("workflow_id") or body.get("workflowId") + if workflow_id: + workflow = get_user_workflow_or_404( + db, workflow_id=int(workflow_id), user_id=current_user.id + ) + update_workflow_fields( + db, + workflow, + { + "stage": "inference", + "inference_output_path": body.get("outputPath"), + "checkpoint_path": (body.get("arguments") or {}).get("checkpoint") + or body.get("checkpointPath"), + }, + commit=True, + ) + append_event_for_workflow_if_present( + db, + workflow_id=workflow.id, + actor="user", + event_type="inference.started", + stage="inference", + summary="Started model inference from the workflow.", + payload={ + "outputPath": body.get("outputPath"), + "checkpointPath": (body.get("arguments") or {}).get("checkpoint") + or body.get("checkpointPath"), + "configOriginPath": body.get("configOriginPath"), + }, + ) worker_data = _proxy_to_worker( "post", "/start_model_inference", diff --git a/server_api/workflows/__init__.py b/server_api/workflows/__init__.py new file mode 100644 index 0000000..d03ab02 --- /dev/null +++ b/server_api/workflows/__init__.py @@ -0,0 +1,2 @@ +"""Workflow spine package for iterative segmentation sessions.""" + diff --git a/server_api/workflows/db_models.py b/server_api/workflows/db_models.py new file mode 100644 index 0000000..f7d2b56 --- /dev/null +++ b/server_api/workflows/db_models.py @@ -0,0 +1,55 @@ +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from server_api.auth.database import Base + + +class WorkflowSession(Base): + __tablename__ = "workflow_sessions" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) + title = Column(String, default="Segmentation Workflow") + stage = Column(String, default="setup", index=True) + dataset_path = Column(String, nullable=True) + image_path = Column(String, nullable=True) + label_path = Column(String, nullable=True) + mask_path = Column(String, nullable=True) + neuroglancer_url = Column(Text, nullable=True) + inference_output_path = Column(String, nullable=True) + checkpoint_path = Column(String, nullable=True) + proofreading_session_id = Column(Integer, nullable=True, index=True) + corrected_mask_path = Column(String, nullable=True) + training_output_path = Column(String, nullable=True) + metadata_json = Column(Text, nullable=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + events = relationship( + "WorkflowEvent", + back_populates="workflow", + cascade="all, delete-orphan", + order_by="WorkflowEvent.created_at", + ) + + +class WorkflowEvent(Base): + __tablename__ = "workflow_events" + + id = Column(Integer, primary_key=True, index=True) + workflow_id = Column( + Integer, ForeignKey("workflow_sessions.id"), nullable=False, index=True + ) + actor = Column(String, default="system", index=True) + event_type = Column(String, nullable=False, index=True) + stage = Column(String, nullable=True, index=True) + summary = Column(Text, nullable=False) + payload_json = Column(Text, nullable=True) + approval_status = Column(String, default="not_required", index=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + workflow = relationship("WorkflowSession", back_populates="events") + diff --git a/server_api/workflows/router.py b/server_api/workflows/router.py new file mode 100644 index 0000000..6422502 --- /dev/null +++ b/server_api/workflows/router.py @@ -0,0 +1,439 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from server_api.auth import models as auth_models +from server_api.auth.database import get_db +from server_api.auth.router import get_current_user + +from .db_models import WorkflowEvent, WorkflowSession +from .service import ( + append_workflow_event, + decode_json, + event_to_dict, + get_current_or_create_workflow, + get_user_workflow_or_404, + update_workflow_fields, + validate_stage, + workflow_to_dict, +) + +router = APIRouter() + + +class WorkflowResponse(BaseModel): + id: int + user_id: int + title: Optional[str] = None + stage: str + dataset_path: Optional[str] = None + image_path: Optional[str] = None + label_path: Optional[str] = None + mask_path: Optional[str] = None + neuroglancer_url: Optional[str] = None + inference_output_path: Optional[str] = None + checkpoint_path: Optional[str] = None + proofreading_session_id: Optional[int] = None + corrected_mask_path: Optional[str] = None + training_output_path: Optional[str] = None + metadata_json: Optional[str] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + created_at: Any + updated_at: Any + + +class WorkflowEventResponse(BaseModel): + id: int + workflow_id: int + actor: str + event_type: str + stage: Optional[str] = None + summary: str + payload_json: Optional[str] = None + payload: Dict[str, Any] = Field(default_factory=dict) + approval_status: str + created_at: Any + + +class WorkflowDetailResponse(BaseModel): + workflow: WorkflowResponse + events: List[WorkflowEventResponse] + + +class WorkflowUpdateRequest(BaseModel): + title: Optional[str] = None + stage: Optional[str] = None + dataset_path: Optional[str] = None + image_path: Optional[str] = None + label_path: Optional[str] = None + mask_path: Optional[str] = None + neuroglancer_url: Optional[str] = None + inference_output_path: Optional[str] = None + checkpoint_path: Optional[str] = None + proofreading_session_id: Optional[int] = None + corrected_mask_path: Optional[str] = None + training_output_path: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +class WorkflowEventCreateRequest(BaseModel): + actor: str = "system" + event_type: str + stage: Optional[str] = None + summary: str + payload: Optional[Dict[str, Any]] = None + approval_status: str = "not_required" + + +class AgentActionCreateRequest(BaseModel): + action: str + summary: Optional[str] = None + payload: Dict[str, Any] = Field(default_factory=dict) + + +class AgentQueryRequest(BaseModel): + query: str + + +class AgentQueryResponse(BaseModel): + response: str + proposals: List[WorkflowEventResponse] = Field(default_factory=list) + + +class AgentActionResult(BaseModel): + workflow: WorkflowResponse + proposal: WorkflowEventResponse + events: List[WorkflowEventResponse] + client_effects: Dict[str, Any] = Field(default_factory=dict) + + +def _workflow_response(workflow: WorkflowSession) -> WorkflowResponse: + return WorkflowResponse(**workflow_to_dict(workflow)) + + +def _event_response(event: WorkflowEvent) -> WorkflowEventResponse: + return WorkflowEventResponse(**event_to_dict(event)) + + +def _event_list(db: Session, workflow_id: int) -> List[WorkflowEventResponse]: + events = ( + db.query(WorkflowEvent) + .filter(WorkflowEvent.workflow_id == workflow_id) + .order_by(WorkflowEvent.created_at.asc(), WorkflowEvent.id.asc()) + .all() + ) + return [_event_response(event) for event in events] + + +def _get_pending_proposal_or_404( + db: Session, *, workflow_id: int, event_id: int +) -> WorkflowEvent: + event = ( + db.query(WorkflowEvent) + .filter( + WorkflowEvent.id == event_id, + WorkflowEvent.workflow_id == workflow_id, + WorkflowEvent.event_type == "agent.proposal_created", + ) + .first() + ) + if not event: + raise HTTPException(status_code=404, detail="Agent proposal not found") + if event.approval_status != "pending": + raise HTTPException(status_code=400, detail="Agent proposal is not pending") + return event + + +def _latest_exported_mask_path(db: Session, workflow_id: int) -> Optional[str]: + event = ( + db.query(WorkflowEvent) + .filter( + WorkflowEvent.workflow_id == workflow_id, + WorkflowEvent.event_type == "proofreading.masks_exported", + ) + .order_by(WorkflowEvent.created_at.desc(), WorkflowEvent.id.desc()) + .first() + ) + if not event: + return None + payload = decode_json(event.payload_json) + return payload.get("written_path") or payload.get("output_path") + + +def _proposal_action_payload(proposal: WorkflowEvent) -> Dict[str, Any]: + payload = decode_json(proposal.payload_json) + params = payload.get("params") + if not isinstance(params, dict): + params = {} + return {"action": payload.get("action"), "params": params} + + +def _recommendation_for_workflow( + workflow: WorkflowSession, + events: List[WorkflowEventResponse], +) -> str: + if workflow.stage == "setup": + return "Start by loading an image volume and, if available, the current mask or label volume." + if workflow.stage == "visualization": + return "The next useful step is to run inference or open proofreading on the current result." + if workflow.stage == "inference": + return "Review the inference output and send likely failure regions into proofreading." + if workflow.stage == "proofreading": + has_export = any( + event.event_type == "proofreading.masks_exported" for event in events + ) + if has_export or workflow.corrected_mask_path: + return "Corrected masks are available. Stage them for retraining so the next model iteration is linked to the edits." + return "Continue classifying instances and save or export corrected masks before retraining." + if workflow.stage == "retraining_staged": + return "The corrected masks are staged. Review the training configuration before launching retraining." + return "Review the workflow timeline and compare results before starting another iteration." + + +@router.get("/current", response_model=WorkflowDetailResponse) +def get_current_workflow( + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_current_or_create_workflow(db, user_id=user.id) + return { + "workflow": _workflow_response(workflow), + "events": _event_list(db, workflow.id), + } + + +@router.patch("/{workflow_id}", response_model=WorkflowResponse) +async def update_workflow( + workflow_id: int, + body: WorkflowUpdateRequest, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + updates = body.model_dump(exclude_unset=True) + workflow = update_workflow_fields(db, workflow, updates, commit=True) + return _workflow_response(workflow) + + +@router.get("/{workflow_id}/events", response_model=List[WorkflowEventResponse]) +def list_workflow_events( + workflow_id: int, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + return _event_list(db, workflow_id) + + +@router.post("/{workflow_id}/events", response_model=WorkflowEventResponse) +async def create_workflow_event( + workflow_id: int, + body: WorkflowEventCreateRequest, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + stage = validate_stage(body.stage or workflow.stage) + event = append_workflow_event( + db, + workflow_id=workflow.id, + actor=body.actor, + event_type=body.event_type, + stage=stage, + summary=body.summary, + payload=body.payload, + approval_status=body.approval_status, + commit=True, + ) + return _event_response(event) + + +@router.post("/{workflow_id}/agent-actions", response_model=WorkflowEventResponse) +async def create_agent_action( + workflow_id: int, + body: AgentActionCreateRequest, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + summary = body.summary or f"Agent proposed: {body.action}" + event = append_workflow_event( + db, + workflow_id=workflow.id, + actor="agent", + event_type="agent.proposal_created", + stage=workflow.stage, + summary=summary, + payload={"action": body.action, "params": body.payload}, + approval_status="pending", + commit=True, + ) + return _event_response(event) + + +@router.post( + "/{workflow_id}/agent-actions/{event_id}/approve", + response_model=AgentActionResult, +) +async def approve_agent_action( + workflow_id: int, + event_id: int, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + proposal = _get_pending_proposal_or_404( + db, workflow_id=workflow.id, event_id=event_id + ) + action_payload = _proposal_action_payload(proposal) + action = action_payload.get("action") + params = action_payload.get("params", {}) + + if action != "stage_retraining_from_corrections": + raise HTTPException(status_code=400, detail=f"Unsupported action: {action}") + + corrected_mask_path = ( + params.get("corrected_mask_path") + or params.get("written_path") + or workflow.corrected_mask_path + or _latest_exported_mask_path(db, workflow.id) + ) + if not corrected_mask_path: + raise HTTPException( + status_code=400, + detail="No corrected mask artifact is available to stage for retraining.", + ) + + proposal.approval_status = "approved" + update_workflow_fields( + db, + workflow, + { + "stage": "retraining_staged", + "corrected_mask_path": corrected_mask_path, + "training_output_path": params.get("training_output_path") + or workflow.training_output_path, + }, + commit=False, + ) + db.commit() + db.refresh(workflow) + db.refresh(proposal) + + approved = append_workflow_event( + db, + workflow_id=workflow.id, + actor="user", + event_type="agent.proposal_approved", + stage=workflow.stage, + summary=f"Approved agent proposal: {proposal.summary}", + payload={"proposal_event_id": proposal.id, "action": action}, + commit=True, + ) + staged = append_workflow_event( + db, + workflow_id=workflow.id, + actor="system", + event_type="retraining.staged", + stage=workflow.stage, + summary="Corrected masks staged for retraining.", + payload={ + "corrected_mask_path": corrected_mask_path, + "source": "agent_action", + "proposal_event_id": proposal.id, + }, + commit=True, + ) + return AgentActionResult( + workflow=_workflow_response(workflow), + proposal=_event_response(proposal), + events=[_event_response(approved), _event_response(staged)], + client_effects={ + "navigate_to": "training", + "set_training_label_path": corrected_mask_path, + "workflow_stage": workflow.stage, + }, + ) + + +@router.post( + "/{workflow_id}/agent-actions/{event_id}/reject", + response_model=WorkflowEventResponse, +) +async def reject_agent_action( + workflow_id: int, + event_id: int, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + proposal = _get_pending_proposal_or_404( + db, workflow_id=workflow.id, event_id=event_id + ) + proposal.approval_status = "rejected" + db.commit() + db.refresh(proposal) + event = append_workflow_event( + db, + workflow_id=workflow.id, + actor="user", + event_type="agent.proposal_rejected", + stage=workflow.stage, + summary=f"Rejected agent proposal: {proposal.summary}", + payload={"proposal_event_id": proposal.id}, + commit=True, + ) + return _event_response(event) + + +@router.post("/{workflow_id}/agent/query", response_model=AgentQueryResponse) +async def query_workflow_agent( + workflow_id: int, + body: AgentQueryRequest, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + if not body.query.strip(): + raise HTTPException(status_code=400, detail="query must be non-empty") + + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + events = _event_list(db, workflow.id) + recommendation = _recommendation_for_workflow(workflow, events) + proposals: List[WorkflowEventResponse] = [] + lower_query = body.query.lower() + wants_retraining = any( + term in lower_query for term in ["retrain", "training", "stage", "corrected"] + ) + corrected_mask_path = workflow.corrected_mask_path or _latest_exported_mask_path( + db, workflow.id + ) + + if wants_retraining and corrected_mask_path: + proposal = append_workflow_event( + db, + workflow_id=workflow.id, + actor="agent", + event_type="agent.proposal_created", + stage=workflow.stage, + summary="Stage corrected masks for retraining.", + payload={ + "action": "stage_retraining_from_corrections", + "params": {"corrected_mask_path": corrected_mask_path}, + }, + approval_status="pending", + commit=True, + ) + proposals.append(_event_response(proposal)) + response = ( + "I found a corrected mask artifact and prepared a retraining-stage " + "proposal. Approve it when you want the app to link those corrections " + "to the next training configuration." + ) + else: + response = recommendation + + return AgentQueryResponse(response=response, proposals=proposals) diff --git a/server_api/workflows/service.py b/server_api/workflows/service.py new file mode 100644 index 0000000..850a585 --- /dev/null +++ b/server_api/workflows/service.py @@ -0,0 +1,246 @@ +import json +from typing import Any, Dict, Optional + +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from .db_models import WorkflowEvent, WorkflowSession + +ALLOWED_STAGES = { + "setup", + "visualization", + "inference", + "proofreading", + "retraining_staged", + "evaluation", +} + +ALLOWED_ACTORS = {"user", "agent", "system"} +ALLOWED_APPROVAL_STATUSES = {"not_required", "pending", "approved", "rejected"} + + +def encode_json(value: Optional[Dict[str, Any]]) -> Optional[str]: + if value is None: + return None + return json.dumps(value, ensure_ascii=False, sort_keys=True) + + +def decode_json(value: Optional[str]) -> Dict[str, Any]: + if not value: + return {} + try: + parsed = json.loads(value) + except (TypeError, json.JSONDecodeError): + return {} + return parsed if isinstance(parsed, dict) else {} + + +def validate_stage(stage: Optional[str]) -> Optional[str]: + if stage is None: + return None + if stage not in ALLOWED_STAGES: + raise HTTPException( + status_code=400, + detail=f"stage must be one of: {', '.join(sorted(ALLOWED_STAGES))}", + ) + return stage + + +def validate_actor(actor: str) -> str: + if actor not in ALLOWED_ACTORS: + raise HTTPException( + status_code=400, + detail=f"actor must be one of: {', '.join(sorted(ALLOWED_ACTORS))}", + ) + return actor + + +def validate_approval_status(status: str) -> str: + if status not in ALLOWED_APPROVAL_STATUSES: + raise HTTPException( + status_code=400, + detail=( + "approval_status must be one of: " + f"{', '.join(sorted(ALLOWED_APPROVAL_STATUSES))}" + ), + ) + return status + + +def event_to_dict(event: WorkflowEvent) -> Dict[str, Any]: + return { + "id": event.id, + "workflow_id": event.workflow_id, + "actor": event.actor, + "event_type": event.event_type, + "stage": event.stage, + "summary": event.summary, + "payload_json": event.payload_json, + "payload": decode_json(event.payload_json), + "approval_status": event.approval_status, + "created_at": event.created_at, + } + + +def workflow_to_dict(workflow: WorkflowSession) -> Dict[str, Any]: + return { + "id": workflow.id, + "user_id": workflow.user_id, + "title": workflow.title, + "stage": workflow.stage, + "dataset_path": workflow.dataset_path, + "image_path": workflow.image_path, + "label_path": workflow.label_path, + "mask_path": workflow.mask_path, + "neuroglancer_url": workflow.neuroglancer_url, + "inference_output_path": workflow.inference_output_path, + "checkpoint_path": workflow.checkpoint_path, + "proofreading_session_id": workflow.proofreading_session_id, + "corrected_mask_path": workflow.corrected_mask_path, + "training_output_path": workflow.training_output_path, + "metadata_json": workflow.metadata_json, + "metadata": decode_json(workflow.metadata_json), + "created_at": workflow.created_at, + "updated_at": workflow.updated_at, + } + + +def get_user_workflow_or_404( + db: Session, *, workflow_id: int, user_id: int +) -> WorkflowSession: + workflow = ( + db.query(WorkflowSession) + .filter(WorkflowSession.id == workflow_id, WorkflowSession.user_id == user_id) + .first() + ) + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + return workflow + + +def get_current_or_create_workflow(db: Session, *, user_id: int) -> WorkflowSession: + workflow = ( + db.query(WorkflowSession) + .filter(WorkflowSession.user_id == user_id) + .order_by(WorkflowSession.updated_at.desc(), WorkflowSession.id.desc()) + .first() + ) + if workflow: + return workflow + + workflow = WorkflowSession(user_id=user_id, title="Segmentation Workflow") + db.add(workflow) + db.commit() + db.refresh(workflow) + append_workflow_event( + db, + workflow_id=workflow.id, + actor="system", + event_type="workflow.created", + stage=workflow.stage, + summary="Workflow session created.", + commit=True, + ) + db.refresh(workflow) + return workflow + + +WORKFLOW_PATCH_FIELDS = { + "title", + "stage", + "dataset_path", + "image_path", + "label_path", + "mask_path", + "neuroglancer_url", + "inference_output_path", + "checkpoint_path", + "proofreading_session_id", + "corrected_mask_path", + "training_output_path", +} + + +def update_workflow_fields( + db: Session, + workflow: WorkflowSession, + updates: Dict[str, Any], + *, + commit: bool = True, +) -> WorkflowSession: + for key, value in updates.items(): + if key == "metadata": + workflow.metadata_json = encode_json(value if isinstance(value, dict) else {}) + continue + if key == "metadata_json": + workflow.metadata_json = value + continue + if key not in WORKFLOW_PATCH_FIELDS: + continue + if key == "stage": + if value is None: + continue + value = validate_stage(value) + setattr(workflow, key, value) + + if commit: + db.commit() + db.refresh(workflow) + return workflow + + +def append_workflow_event( + db: Session, + *, + workflow_id: Optional[int], + actor: str, + event_type: str, + summary: str, + stage: Optional[str] = None, + payload: Optional[Dict[str, Any]] = None, + approval_status: str = "not_required", + commit: bool = True, +) -> Optional[WorkflowEvent]: + if not workflow_id: + return None + actor = validate_actor(actor) + approval_status = validate_approval_status(approval_status) + stage = validate_stage(stage) if stage else stage + event = WorkflowEvent( + workflow_id=workflow_id, + actor=actor, + event_type=event_type, + stage=stage, + summary=summary, + payload_json=encode_json(payload), + approval_status=approval_status, + ) + db.add(event) + if commit: + db.commit() + db.refresh(event) + return event + + +def append_event_for_workflow_if_present( + db: Session, + *, + workflow_id: Optional[int], + actor: str, + event_type: str, + summary: str, + stage: Optional[str] = None, + payload: Optional[Dict[str, Any]] = None, +) -> Optional[WorkflowEvent]: + if not workflow_id: + return None + return append_workflow_event( + db, + workflow_id=workflow_id, + actor=actor, + event_type=event_type, + summary=summary, + stage=stage, + payload=payload, + commit=True, + ) diff --git a/tests/test_workflow_routes.py b/tests/test_workflow_routes.py new file mode 100644 index 0000000..6c77f47 --- /dev/null +++ b/tests/test_workflow_routes.py @@ -0,0 +1,227 @@ +import pathlib +import tempfile +import unittest + +import numpy as np +import tifffile +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from server_api.auth import database as auth_database +from server_api.auth import models +import server_api.ehtool.router as ehtool_router_module +from server_api.ehtool.utils import array_to_base64 +from server_api.main import app as server_api_app + + +class WorkflowRouteTests(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.db_path = pathlib.Path(self.temp_dir.name) / "workflow-test.db" + self.engine = create_engine( + f"sqlite:///{self.db_path}", connect_args={"check_same_thread": False} + ) + self.SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=self.engine + ) + models.Base.metadata.create_all(bind=self.engine) + + def override_get_db(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + + server_api_app.dependency_overrides[auth_database.get_db] = override_get_db + self.client = TestClient(server_api_app) + + def tearDown(self): + server_api_app.dependency_overrides.clear() + ehtool_router_module._data_managers.clear() + self.engine.dispose() + self.temp_dir.cleanup() + + def _current_workflow(self): + response = self.client.get("/api/workflows/current") + self.assertEqual(response.status_code, 200) + payload = response.json() + return payload["workflow"], payload["events"] + + def test_current_workflow_create_update_and_events(self): + workflow, events = self._current_workflow() + + self.assertEqual(workflow["stage"], "setup") + self.assertEqual(len(events), 1) + self.assertEqual(events[0]["event_type"], "workflow.created") + + second_workflow, second_events = self._current_workflow() + self.assertEqual(second_workflow["id"], workflow["id"]) + self.assertEqual(len(second_events), 1) + + patch_response = self.client.patch( + f"/api/workflows/{workflow['id']}", + json={"stage": "visualization", "image_path": "/tmp/image.tif"}, + ) + self.assertEqual(patch_response.status_code, 200) + self.assertEqual(patch_response.json()["stage"], "visualization") + + event_response = self.client.post( + f"/api/workflows/{workflow['id']}/events", + json={ + "actor": "user", + "event_type": "dataset.loaded", + "stage": "visualization", + "summary": "Loaded a test dataset.", + "payload": {"image_path": "/tmp/image.tif"}, + }, + ) + self.assertEqual(event_response.status_code, 200) + + events_response = self.client.get(f"/api/workflows/{workflow['id']}/events") + self.assertEqual(events_response.status_code, 200) + event_types = [event["event_type"] for event in events_response.json()] + self.assertEqual(event_types, ["workflow.created", "dataset.loaded"]) + + def test_agent_action_approve_and_reject_flow(self): + workflow, _ = self._current_workflow() + workflow_id = workflow["id"] + + reject_proposal = self.client.post( + f"/api/workflows/{workflow_id}/agent-actions", + json={ + "action": "stage_retraining_from_corrections", + "summary": "Rejectable staging proposal.", + "payload": {"corrected_mask_path": "/tmp/rejected.tif"}, + }, + ) + self.assertEqual(reject_proposal.status_code, 200) + self.assertEqual(reject_proposal.json()["approval_status"], "pending") + + reject_response = self.client.post( + f"/api/workflows/{workflow_id}/agent-actions/" + f"{reject_proposal.json()['id']}/reject" + ) + self.assertEqual(reject_response.status_code, 200) + self.assertEqual(reject_response.json()["event_type"], "agent.proposal_rejected") + + approve_proposal = self.client.post( + f"/api/workflows/{workflow_id}/agent-actions", + json={ + "action": "stage_retraining_from_corrections", + "summary": "Stage corrected masks.", + "payload": {"corrected_mask_path": "/tmp/corrected.tif"}, + }, + ) + self.assertEqual(approve_proposal.status_code, 200) + + approve_response = self.client.post( + f"/api/workflows/{workflow_id}/agent-actions/" + f"{approve_proposal.json()['id']}/approve" + ) + self.assertEqual(approve_response.status_code, 200) + approve_payload = approve_response.json() + self.assertEqual(approve_payload["workflow"]["stage"], "retraining_staged") + self.assertEqual( + approve_payload["workflow"]["corrected_mask_path"], "/tmp/corrected.tif" + ) + self.assertEqual( + approve_payload["client_effects"]["set_training_label_path"], + "/tmp/corrected.tif", + ) + + events_response = self.client.get(f"/api/workflows/{workflow_id}/events") + event_types = [event["event_type"] for event in events_response.json()] + self.assertIn("agent.proposal_approved", event_types) + self.assertIn("agent.proposal_rejected", event_types) + self.assertIn("retraining.staged", event_types) + + def test_ehtool_load_classify_save_and_export_append_workflow_events(self): + workflow, _ = self._current_workflow() + workflow_id = workflow["id"] + data_root = pathlib.Path(self.temp_dir.name) / "volumes" + data_root.mkdir() + image_path = data_root / "image.tif" + mask_path = data_root / "mask.tif" + export_path = data_root / "corrected-mask.tif" + + image = np.arange(2 * 6 * 6, dtype=np.uint8).reshape(2, 6, 6) + mask = np.zeros((2, 6, 6), dtype=np.uint16) + mask[0, 0:2, 0:2] = 1 + mask[0, 2:4, 2:4] = 2 + mask[1, 4:6, 4:6] = 3 + tifffile.imwrite(str(image_path), image) + tifffile.imwrite(str(mask_path), mask) + + load_response = self.client.post( + "/eh/detection/load", + json={ + "dataset_path": str(image_path), + "mask_path": str(mask_path), + "project_name": "Workflow EHTool", + "workflow_id": workflow_id, + }, + ) + self.assertEqual(load_response.status_code, 200) + session_id = load_response.json()["session_id"] + + workflow_response = self.client.get("/api/workflows/current") + updated_workflow = workflow_response.json()["workflow"] + self.assertEqual(updated_workflow["stage"], "proofreading") + self.assertEqual(updated_workflow["proofreading_session_id"], session_id) + + instances_response = self.client.get( + "/eh/detection/instances", params={"session_id": session_id} + ) + self.assertEqual(instances_response.status_code, 200) + instance_id = instances_response.json()["instances"][0]["id"] + + classify_response = self.client.post( + "/eh/detection/instance-classify", + json={ + "session_id": session_id, + "instance_ids": [instance_id], + "classification": "correct", + }, + ) + self.assertEqual(classify_response.status_code, 200) + self.assertEqual(classify_response.json()["updated_count"], 1) + + edited_mask = np.ones((6, 6), dtype=np.uint8) * 255 + save_response = self.client.post( + "/eh/detection/instance-mask", + json={ + "session_id": session_id, + "instance_id": instance_id, + "axis": "xy", + "z_index": 0, + "mask_base64": array_to_base64(edited_mask, format="PNG"), + }, + ) + self.assertEqual(save_response.status_code, 200) + + export_response = self.client.post( + "/eh/detection/export-masks", + json={ + "session_id": session_id, + "mode": "new_file", + "output_path": str(export_path), + "create_backup": True, + }, + ) + self.assertEqual(export_response.status_code, 200) + self.assertEqual(export_response.json()["written_path"], str(export_path)) + + events_response = self.client.get(f"/api/workflows/{workflow_id}/events") + self.assertEqual(events_response.status_code, 200) + event_types = [event["event_type"] for event in events_response.json()] + self.assertIn("dataset.loaded", event_types) + self.assertIn("proofreading.session_loaded", event_types) + self.assertIn("proofreading.instance_classified", event_types) + self.assertIn("proofreading.mask_saved", event_types) + self.assertIn("proofreading.masks_exported", event_types) + + +if __name__ == "__main__": + unittest.main() From a73c927ed50be00e947a1a4c08657c3ec81e2bb2 Mon Sep 17 00:00:00 2001 From: Adam Gohain Date: Mon, 13 Apr 2026 16:54:48 -0400 Subject: [PATCH 2/6] feat(workflows): add hotspots insights, metrics, and export bundle (cherry picked from commit 0bfcee99199c5dfb77e7a76abdb9873af86f6318) --- docs/research/workflow-evidence-export.md | 28 ++ server_api/workflows/bundle_export.py | 72 ++++ server_api/workflows/evidence_export.py | 102 ++++++ server_api/workflows/metrics.py | 68 ++++ server_api/workflows/router.py | 409 ++++++++++++++++++++++ tests/test_workflow_evidence_export.py | 98 ++++++ tests/test_workflow_export_bundle.py | 93 +++++ tests/test_workflow_metrics.py | 120 +++++++ tests/test_workflow_routes.py | 85 ++++- tests/test_workflow_spine_smoke.py | 103 ++++++ 10 files changed, 1176 insertions(+), 2 deletions(-) create mode 100644 docs/research/workflow-evidence-export.md create mode 100644 server_api/workflows/bundle_export.py create mode 100644 server_api/workflows/evidence_export.py create mode 100644 server_api/workflows/metrics.py create mode 100644 tests/test_workflow_evidence_export.py create mode 100644 tests/test_workflow_export_bundle.py create mode 100644 tests/test_workflow_metrics.py create mode 100644 tests/test_workflow_spine_smoke.py diff --git a/docs/research/workflow-evidence-export.md b/docs/research/workflow-evidence-export.md new file mode 100644 index 0000000..6ca60b8 --- /dev/null +++ b/docs/research/workflow-evidence-export.md @@ -0,0 +1,28 @@ +# Workflow Evidence Export + +Use the workflow utilities to generate stable evidence summaries for paper notes and analysis. + +## Primary endpoints + +- `GET /api/workflows/{workflow_id}/metrics` +- `POST /api/workflows/{workflow_id}/export-bundle` + +## Bundle schema highlights + +`export-bundle` returns: + +- `schema_version` (`workflow-export-bundle/v1`) +- `workflow_id` +- `session_snapshot` (current workflow state) +- `events` (chronologically ordered) +- `artifact_paths` (`[{ path, exists }]`) + +## Paper-note utility + +`server_api.workflows.evidence_export` provides a deterministic summary shape: + +- `stage_progression_summary` +- `agent_proposal_approval_summary` +- `key_event_timeline_snippet` + +Use this when exporting workflow evidence into local research logs for manuscript drafting. diff --git a/server_api/workflows/bundle_export.py b/server_api/workflows/bundle_export.py new file mode 100644 index 0000000..1abe8be --- /dev/null +++ b/server_api/workflows/bundle_export.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import os +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, List, Tuple + +from .db_models import WorkflowEvent, WorkflowSession +from .service import event_to_dict, workflow_to_dict + + +def _parse_timestamp(value: Any) -> datetime: + if isinstance(value, datetime): + return value + text = str(value or "1970-01-01T00:00:00+00:00") + if text.endswith("Z"): + text = text.replace("Z", "+00:00") + return datetime.fromisoformat(text) + + +def _event_sort_key(event: Dict[str, Any]) -> Tuple[datetime, str]: + timestamp = event.get("created_at") + event_id = str(event.get("id") or "") + return (_parse_timestamp(timestamp), event_id) + + +def _normalize_value(value: Any) -> Any: + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, dict): + return {k: _normalize_value(v) for k, v in value.items()} + if isinstance(value, list): + return [_normalize_value(item) for item in value] + return value + + +def _collect_paths(value: Any) -> Iterable[str]: + if isinstance(value, dict): + for key, inner in value.items(): + if isinstance(inner, str) and key.endswith("_path"): + yield inner + if key == "path" and isinstance(inner, str): + yield inner + yield from _collect_paths(inner) + elif isinstance(value, list): + for item in value: + yield from _collect_paths(item) + + +def build_export_bundle( + workflow: WorkflowSession, + events: List[WorkflowEvent], +) -> Dict[str, Any]: + session_snapshot = _normalize_value(workflow_to_dict(workflow)) + ordered_events = sorted( + (_normalize_value(event_to_dict(event)) for event in events), + key=_event_sort_key, + ) + + discovered = set(_collect_paths(session_snapshot)) + for event in ordered_events: + discovered.update(_collect_paths(event)) + artifact_paths = sorted(path for path in discovered if path) + artifacts = [{"path": path, "exists": os.path.exists(path)} for path in artifact_paths] + + return { + "schema_version": "workflow-export-bundle/v1", + "exported_at": datetime.now(timezone.utc).isoformat(), + "workflow_id": workflow.id, + "session_snapshot": session_snapshot, + "events": ordered_events, + "artifact_paths": artifacts, + } diff --git a/server_api/workflows/evidence_export.py b/server_api/workflows/evidence_export.py new file mode 100644 index 0000000..ca51394 --- /dev/null +++ b/server_api/workflows/evidence_export.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from collections import Counter +from typing import Any, Dict, List + +from .db_models import WorkflowEvent, WorkflowSession +from .service import decode_json + +EVIDENCE_EXPORT_VERSION = "1.0" + + +def _iso(value: Any) -> str | None: + if value is None: + return None + if hasattr(value, "isoformat"): + return value.isoformat() + return str(value) + + +def _stage_progression(events: List[WorkflowEvent]) -> Dict[str, Any]: + observed: List[str] = [] + entered_at: Dict[str, str] = {} + transition_count = 0 + previous_stage: str | None = None + + for event in events: + stage = event.stage + if not stage: + continue + if stage not in observed: + observed.append(stage) + if stage not in entered_at: + entered_at[stage] = _iso(event.created_at) or "" + if previous_stage and previous_stage != stage: + transition_count += 1 + previous_stage = stage + + return { + "observed_stages": observed, + "entered_at": entered_at, + "transition_count": transition_count, + } + + +def _proposal_approval_summary(events: List[WorkflowEvent]) -> Dict[str, Any]: + proposals_by_actor: Counter[str] = Counter() + decisions_by_actor: Counter[str] = Counter() + approved = 0 + rejected = 0 + + for event in events: + if event.event_type == "agent.proposal_created": + proposals_by_actor[event.actor or "unknown"] += 1 + elif event.event_type == "agent.proposal_approved": + decisions_by_actor[event.actor or "unknown"] += 1 + approved += 1 + elif event.event_type == "agent.proposal_rejected": + decisions_by_actor[event.actor or "unknown"] += 1 + rejected += 1 + + return { + "proposal_count": int(sum(proposals_by_actor.values())), + "approved_count": approved, + "rejected_count": rejected, + "proposals_by_actor": dict(proposals_by_actor), + "decisions_by_actor": dict(decisions_by_actor), + } + + +def _timeline_snippet(events: List[WorkflowEvent], max_events: int = 20) -> List[Dict[str, Any]]: + snippet: List[Dict[str, Any]] = [] + for event in events[:max_events]: + payload = decode_json(event.payload_json) + snippet.append( + { + "timestamp": _iso(event.created_at), + "event_type": event.event_type, + "stage": event.stage, + "actor": event.actor, + "summary": event.summary, + "payload_keys": sorted(payload.keys()), + } + ) + return snippet + + +def build_workflow_evidence_export( + workflow: WorkflowSession, + events: List[WorkflowEvent], +) -> Dict[str, Any]: + stage_progression = _stage_progression(events) + proposal_summary = _proposal_approval_summary(events) + timeline = _timeline_snippet(events) + + return { + "version": EVIDENCE_EXPORT_VERSION, + "workflow_id": workflow.id, + "workflow_stage": workflow.stage, + "stage_progression_summary": stage_progression, + "agent_proposal_approval_summary": proposal_summary, + "key_event_timeline_snippet": timeline, + } diff --git a/server_api/workflows/metrics.py b/server_api/workflows/metrics.py new file mode 100644 index 0000000..8012d1a --- /dev/null +++ b/server_api/workflows/metrics.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from collections import Counter +from datetime import datetime +from typing import Any, Dict, List + +from .db_models import WorkflowEvent + + +def _to_iso8601(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, datetime): + return value.isoformat() + return str(value) + + +def compute_workflow_metrics(events: List[WorkflowEvent]) -> Dict[str, Any]: + event_type_counts: Counter[str] = Counter() + stage_transition_counts: Counter[str] = Counter() + + approvals = 0 + rejections = 0 + first_timestamp: Any = None + last_timestamp: Any = None + last_stage: str | None = None + + for event in events: + event_type = event.event_type or "unknown" + event_type_counts[event_type] += 1 + + if event_type == "agent.proposal_approved": + approvals += 1 + if event_type == "agent.proposal_rejected": + rejections += 1 + + stage = event.stage + if stage: + if last_stage and last_stage != stage: + stage_transition_counts[f"{last_stage}->{stage}"] += 1 + last_stage = stage + + timestamp = event.created_at + if first_timestamp is None or (timestamp and timestamp < first_timestamp): + first_timestamp = timestamp + if last_timestamp is None or (timestamp and timestamp > last_timestamp): + last_timestamp = timestamp + + total_decisions = approvals + rejections + approval_rate = (approvals / total_decisions) if total_decisions else 0.0 + rejection_rate = (rejections / total_decisions) if total_decisions else 0.0 + + return { + "event_counts": dict(event_type_counts), + "decision_metrics": { + "approvals": approvals, + "rejections": rejections, + "approval_rate": approval_rate, + "rejection_rate": rejection_rate, + "total_decisions": total_decisions, + }, + "stage_transition_counts": dict(stage_transition_counts), + "timestamps": { + "first_event_at": _to_iso8601(first_timestamp), + "last_event_at": _to_iso8601(last_timestamp), + }, + "total_events": len(events), + } diff --git a/server_api/workflows/router.py b/server_api/workflows/router.py index 6422502..5bcfe80 100644 --- a/server_api/workflows/router.py +++ b/server_api/workflows/router.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime, timezone from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException @@ -21,6 +22,8 @@ validate_stage, workflow_to_dict, ) +from .bundle_export import build_export_bundle +from .metrics import compute_workflow_metrics router = APIRouter() @@ -111,6 +114,49 @@ class AgentActionResult(BaseModel): client_effects: Dict[str, Any] = Field(default_factory=dict) +class WorkflowHotspotItem(BaseModel): + rank: int + region_key: str + score: float + severity: str + summary: str + recommended_action: str + evidence: Dict[str, Any] = Field(default_factory=dict) + + +class WorkflowHotspotsResponse(BaseModel): + workflow_id: int + generated_at: str + hotspots: List[WorkflowHotspotItem] = Field(default_factory=list) + + +class WorkflowImpactPreviewResponse(BaseModel): + workflow_id: int + generated_at: str + can_stage_retraining: bool + recommended_stage: str + corrected_mask_path: Optional[str] = None + confidence: str + projected_improvement: float + summary: str + signals: Dict[str, int] = Field(default_factory=dict) + next_actions: List[str] = Field(default_factory=list) + + +class WorkflowMetricsResponse(BaseModel): + workflow_id: int + metrics: Dict[str, Any] = Field(default_factory=dict) + + +class WorkflowExportBundleResponse(BaseModel): + schema_version: str + exported_at: str + workflow_id: int + session_snapshot: Dict[str, Any] = Field(default_factory=dict) + events: List[Dict[str, Any]] = Field(default_factory=list) + artifact_paths: List[Dict[str, Any]] = Field(default_factory=list) + + def _workflow_response(workflow: WorkflowSession) -> WorkflowResponse: return WorkflowResponse(**workflow_to_dict(workflow)) @@ -129,6 +175,15 @@ def _event_list(db: Session, workflow_id: int) -> List[WorkflowEventResponse]: return [_event_response(event) for event in events] +def _event_rows(db: Session, workflow_id: int) -> List[WorkflowEvent]: + return ( + db.query(WorkflowEvent) + .filter(WorkflowEvent.workflow_id == workflow_id) + .order_by(WorkflowEvent.created_at.asc(), WorkflowEvent.id.asc()) + .all() + ) + + def _get_pending_proposal_or_404( db: Session, *, workflow_id: int, event_id: int ) -> WorkflowEvent: @@ -172,6 +227,280 @@ def _proposal_action_payload(proposal: WorkflowEvent) -> Dict[str, Any]: return {"action": payload.get("action"), "params": params} +NEGATIVE_CLASSIFICATIONS = { + "incorrect", + "uncertain", + "error", + "false_positive", + "false_negative", + "needs_review", +} + + +def _region_key_from_payload(payload: Dict[str, Any]) -> Optional[str]: + for key in ("region_key", "region_id", "region"): + value = payload.get(key) + if isinstance(value, (str, int, float)): + return str(value) + + instance_id = payload.get("instance_id") + if isinstance(instance_id, (str, int)): + return f"instance:{instance_id}" + + instance_ids = payload.get("instance_ids") + if isinstance(instance_ids, list) and instance_ids: + first = instance_ids[0] + if isinstance(first, (str, int)): + return f"instance:{first}" + + axis = payload.get("axis") + z_index = payload.get("z_index") + if axis is not None and z_index is not None: + return f"{axis}:{z_index}" + if z_index is not None: + return f"z:{z_index}" + return None + + +def _hotspot_severity(score: float) -> str: + if score >= 8: + return "high" + if score >= 4: + return "medium" + return "low" + + +def _default_region_action(workflow: WorkflowSession, severity: str) -> str: + if workflow.stage == "proofreading": + return "Open this region in proofreading and apply mask corrections." + if severity == "high": + return "Prioritize this region for proofreading before the next training iteration." + if workflow.stage in {"visualization", "inference"}: + return "Inspect this region in visualization and route it into proofreading." + return "Inspect this region and log whether model output is acceptable." + + +def _compute_hotspots( + workflow: WorkflowSession, + events: List[WorkflowEvent], +) -> List[WorkflowHotspotItem]: + region_stats: Dict[str, Dict[str, Any]] = {} + + def ensure_region(region_key: str) -> Dict[str, Any]: + stat = region_stats.get(region_key) + if stat: + return stat + stat = { + "raw_score": 0.0, + "inference_failures": 0, + "classifications": 0, + "negative_classifications": 0, + "mask_saves": 0, + "exports": 0, + "last_event_index": -1, + } + region_stats[region_key] = stat + return stat + + for index, event in enumerate(events): + payload = decode_json(event.payload_json) + region_key = _region_key_from_payload(payload) + stat = ensure_region(region_key) if region_key else None + + if event.event_type == "inference.failed" and stat: + stat["raw_score"] += 4.0 + stat["inference_failures"] += 1 + stat["last_event_index"] = index + + if event.event_type == "proofreading.instance_classified" and stat: + classification = str(payload.get("classification") or "").lower() + if classification in NEGATIVE_CLASSIFICATIONS: + stat["raw_score"] += 2.5 + stat["negative_classifications"] += 1 + else: + stat["raw_score"] += 1.0 + stat["classifications"] += 1 + stat["last_event_index"] = index + + if event.event_type == "proofreading.mask_saved" and stat: + stat["raw_score"] += 3.0 + stat["mask_saves"] += 1 + stat["last_event_index"] = index + + if event.event_type == "proofreading.masks_exported": + if stat: + stat["raw_score"] += 1.0 + stat["exports"] += 1 + stat["last_event_index"] = index + else: + # Export happened without region metadata; keep evidence visible. + global_region = ensure_region("current_view") + global_region["raw_score"] += 1.0 + global_region["exports"] += 1 + global_region["last_event_index"] = index + + if not region_stats: + region_stats["current_view"] = { + "raw_score": 1.0, + "inference_failures": 0, + "classifications": 0, + "negative_classifications": 0, + "mask_saves": 0, + "exports": 0, + "last_event_index": len(events) - 1, + } + + total_events = max(len(events), 1) + ranked: List[WorkflowHotspotItem] = [] + for region_key, stat in region_stats.items(): + distance = max(0, (total_events - 1) - stat["last_event_index"]) + recency_bonus = max(0.0, 1.5 - (distance * 0.25)) + score = round(float(stat["raw_score"] + recency_bonus), 2) + severity = _hotspot_severity(score) + + summary = ( + f"{region_key}: {stat['inference_failures']} inference failures, " + f"{stat['mask_saves']} mask edits, " + f"{stat['negative_classifications']} uncertain/incorrect classifications." + ) + recommended_action = _default_region_action(workflow, severity) + if stat["exports"] > 0 and workflow.stage != "retraining_staged": + recommended_action = ( + "Corrections already exist; stage this region's masks for retraining." + ) + + ranked.append( + WorkflowHotspotItem( + rank=0, + region_key=region_key, + score=score, + severity=severity, + summary=summary, + recommended_action=recommended_action, + evidence={ + "inference_failures": stat["inference_failures"], + "classifications": stat["classifications"], + "negative_classifications": stat["negative_classifications"], + "mask_saves": stat["mask_saves"], + "exports": stat["exports"], + }, + ) + ) + + ranked.sort(key=lambda item: item.score, reverse=True) + return [ + item.model_copy(update={"rank": index + 1}) for index, item in enumerate(ranked) + ] + + +def _compute_impact_preview( + workflow: WorkflowSession, + events: List[WorkflowEvent], + hotspots: List[WorkflowHotspotItem], + corrected_mask_path: Optional[str], +) -> WorkflowImpactPreviewResponse: + signals = { + "inference_started": 0, + "inference_completed": 0, + "inference_failed": 0, + "proofreading_classified": 0, + "proofreading_mask_saved": 0, + "proofreading_masks_exported": 0, + "pending_agent_proposals": 0, + } + for event in events: + if event.event_type == "inference.started": + signals["inference_started"] += 1 + if event.event_type == "inference.completed": + signals["inference_completed"] += 1 + if event.event_type == "inference.failed": + signals["inference_failed"] += 1 + if event.event_type == "proofreading.instance_classified": + signals["proofreading_classified"] += 1 + if event.event_type == "proofreading.mask_saved": + signals["proofreading_mask_saved"] += 1 + if event.event_type == "proofreading.masks_exported": + signals["proofreading_masks_exported"] += 1 + if ( + event.event_type == "agent.proposal_created" + and event.approval_status == "pending" + ): + signals["pending_agent_proposals"] += 1 + + correction_signal = ( + (signals["proofreading_classified"] * 2) + + (signals["proofreading_mask_saved"] * 3) + + (signals["proofreading_masks_exported"] * 5) + ) + failure_signal = signals["inference_failed"] * 3 + hotspot_signal = len([item for item in hotspots if item.severity == "high"]) * 4 + projected_improvement = min( + 0.95, + round( + 0.05 + + (correction_signal * 0.03) + + (failure_signal * 0.02) + + (hotspot_signal * 0.01), + 3, + ), + ) + + confidence = "low" + if correction_signal >= 8 and signals["proofreading_masks_exported"] > 0: + confidence = "high" + elif correction_signal >= 3: + confidence = "medium" + + can_stage_retraining = ( + bool(corrected_mask_path) and workflow.stage != "retraining_staged" + ) + recommended_stage = workflow.stage + if can_stage_retraining: + recommended_stage = "retraining_staged" + elif workflow.stage == "retraining_staged": + recommended_stage = "evaluation" + + next_actions: List[str] = [] + if hotspots: + next_actions.append(hotspots[0].recommended_action) + if not corrected_mask_path: + next_actions.append( + "Export corrected masks from proofreading to create a retraining artifact." + ) + elif can_stage_retraining: + next_actions.append("Approve or trigger retraining staging from corrected masks.") + if workflow.stage == "retraining_staged": + next_actions.append( + "Open Model Training and launch the next experiment using staged labels." + ) + + if confidence == "low": + summary = ( + "Correction evidence is still sparse; prioritize proofreading edits before retraining." + ) + elif can_stage_retraining: + summary = ( + "Corrections are substantial enough to justify retraining staging for the next model iteration." + ) + else: + summary = ( + "Correction evidence is accumulating; compare outcomes after the next staged loop." + ) + + return WorkflowImpactPreviewResponse( + workflow_id=workflow.id, + generated_at=datetime.now(timezone.utc).isoformat(), + can_stage_retraining=can_stage_retraining, + recommended_stage=recommended_stage, + corrected_mask_path=corrected_mask_path, + confidence=confidence, + projected_improvement=projected_improvement, + summary=summary, + signals=signals, + next_actions=next_actions, + ) + + def _recommendation_for_workflow( workflow: WorkflowSession, events: List[WorkflowEventResponse], @@ -229,6 +558,68 @@ def list_workflow_events( return _event_list(db, workflow_id) +@router.get("/{workflow_id}/hotspots", response_model=WorkflowHotspotsResponse) +def get_workflow_hotspots( + workflow_id: int, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + events = _event_rows(db, workflow.id) + hotspots = _compute_hotspots(workflow, events) + return WorkflowHotspotsResponse( + workflow_id=workflow.id, + generated_at=datetime.now(timezone.utc).isoformat(), + hotspots=hotspots, + ) + + +@router.get( + "/{workflow_id}/impact-preview", + response_model=WorkflowImpactPreviewResponse, +) +def get_workflow_impact_preview( + workflow_id: int, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + events = _event_rows(db, workflow.id) + hotspots = _compute_hotspots(workflow, events) + corrected_mask_path = workflow.corrected_mask_path or _latest_exported_mask_path( + db, workflow.id + ) + return _compute_impact_preview(workflow, events, hotspots, corrected_mask_path) + + +@router.get("/{workflow_id}/metrics", response_model=WorkflowMetricsResponse) +def get_workflow_metrics( + workflow_id: int, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + events = _event_rows(db, workflow.id) + return WorkflowMetricsResponse( + workflow_id=workflow.id, + metrics=compute_workflow_metrics(events), + ) + + +@router.post( + "/{workflow_id}/export-bundle", + response_model=WorkflowExportBundleResponse, +) +def export_workflow_bundle( + workflow_id: int, + user: auth_models.User = Depends(get_current_user), + db: Session = Depends(get_db), +): + workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + events = _event_rows(db, workflow.id) + return WorkflowExportBundleResponse(**build_export_bundle(workflow, events)) + + @router.post("/{workflow_id}/events", response_model=WorkflowEventResponse) async def create_workflow_event( workflow_id: int, @@ -401,6 +792,7 @@ async def query_workflow_agent( raise HTTPException(status_code=400, detail="query must be non-empty") workflow = get_user_workflow_or_404(db, workflow_id=workflow_id, user_id=user.id) + event_rows = _event_rows(db, workflow.id) events = _event_list(db, workflow.id) recommendation = _recommendation_for_workflow(workflow, events) proposals: List[WorkflowEventResponse] = [] @@ -408,9 +800,19 @@ async def query_workflow_agent( wants_retraining = any( term in lower_query for term in ["retrain", "training", "stage", "corrected"] ) + wants_failure_analysis = any( + term in lower_query for term in ["fail", "failure", "error", "hotspot", "where"] + ) corrected_mask_path = workflow.corrected_mask_path or _latest_exported_mask_path( db, workflow.id ) + hotspots = _compute_hotspots(workflow, event_rows) + impact = _compute_impact_preview( + workflow, + event_rows, + hotspots, + corrected_mask_path, + ) if wants_retraining and corrected_mask_path: proposal = append_workflow_event( @@ -433,6 +835,13 @@ async def query_workflow_agent( "proposal. Approve it when you want the app to link those corrections " "to the next training configuration." ) + elif wants_failure_analysis and hotspots: + top_hotspot = hotspots[0] + response = ( + f"Top hotspot: {top_hotspot.summary} " + f"Recommended action: {top_hotspot.recommended_action} " + f"Impact preview: {impact.summary}" + ) else: response = recommendation diff --git a/tests/test_workflow_evidence_export.py b/tests/test_workflow_evidence_export.py new file mode 100644 index 0000000..9792ff1 --- /dev/null +++ b/tests/test_workflow_evidence_export.py @@ -0,0 +1,98 @@ +import pathlib +import tempfile +import unittest + +import pytest +pytest.importorskip("sqlalchemy") +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +pytest.importorskip("fastapi") +from fastapi.testclient import TestClient + +from server_api.auth import database as auth_database +from server_api.auth import models +from server_api.main import app as server_api_app +from server_api.workflows.db_models import WorkflowEvent, WorkflowSession +from server_api.workflows.evidence_export import ( + EVIDENCE_EXPORT_VERSION, + build_workflow_evidence_export, +) + + +class WorkflowEvidenceExportTests(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.db_path = pathlib.Path(self.temp_dir.name) / "workflow-evidence-test.db" + self.engine = create_engine( + f"sqlite:///{self.db_path}", connect_args={"check_same_thread": False} + ) + self.SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=self.engine + ) + models.Base.metadata.create_all(bind=self.engine) + + def override_get_db(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + + server_api_app.dependency_overrides[auth_database.get_db] = override_get_db + self.client = TestClient(server_api_app) + + def tearDown(self): + server_api_app.dependency_overrides.clear() + self.engine.dispose() + self.temp_dir.cleanup() + + def test_evidence_export_contains_required_sections(self): + response = self.client.get("/api/workflows/current") + self.assertEqual(response.status_code, 200) + workflow_id = response.json()["workflow"]["id"] + + self.client.post( + f"/api/workflows/{workflow_id}/events", + json={ + "actor": "agent", + "event_type": "agent.proposal_created", + "stage": "proofreading", + "summary": "Proposal created", + "approval_status": "pending", + }, + ) + self.client.post( + f"/api/workflows/{workflow_id}/events", + json={ + "actor": "user", + "event_type": "agent.proposal_approved", + "stage": "retraining_staged", + "summary": "Proposal approved", + }, + ) + + with self.SessionLocal() as db: + workflow = ( + db.query(WorkflowSession).filter(WorkflowSession.id == workflow_id).first() + ) + events = ( + db.query(WorkflowEvent) + .filter(WorkflowEvent.workflow_id == workflow_id) + .order_by(WorkflowEvent.created_at.asc(), WorkflowEvent.id.asc()) + .all() + ) + + export_payload = build_workflow_evidence_export(workflow, events) + self.assertEqual(export_payload["version"], EVIDENCE_EXPORT_VERSION) + self.assertEqual(export_payload["workflow_id"], workflow_id) + self.assertIn("stage_progression_summary", export_payload) + self.assertIn("agent_proposal_approval_summary", export_payload) + self.assertIn("key_event_timeline_snippet", export_payload) + self.assertGreaterEqual( + export_payload["agent_proposal_approval_summary"]["approved_count"], 1 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_workflow_export_bundle.py b/tests/test_workflow_export_bundle.py new file mode 100644 index 0000000..d166e53 --- /dev/null +++ b/tests/test_workflow_export_bundle.py @@ -0,0 +1,93 @@ +import pathlib +import tempfile +import unittest + +import pytest +pytest.importorskip("sqlalchemy") +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +pytest.importorskip("fastapi") +from fastapi.testclient import TestClient + +from server_api.auth import database as auth_database +from server_api.auth import models +from server_api.main import app as server_api_app + + +class WorkflowExportBundleTests(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.db_path = pathlib.Path(self.temp_dir.name) / "workflow-export-test.db" + self.engine = create_engine( + f"sqlite:///{self.db_path}", connect_args={"check_same_thread": False} + ) + self.SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=self.engine + ) + models.Base.metadata.create_all(bind=self.engine) + + def override_get_db(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + + server_api_app.dependency_overrides[auth_database.get_db] = override_get_db + self.client = TestClient(server_api_app) + + def tearDown(self): + server_api_app.dependency_overrides.clear() + self.engine.dispose() + self.temp_dir.cleanup() + + def _current_workflow_id(self) -> int: + response = self.client.get("/api/workflows/current") + self.assertEqual(response.status_code, 200) + return response.json()["workflow"]["id"] + + def test_export_bundle_sorts_events_and_reports_artifact_existence(self): + workflow_id = self._current_workflow_id() + existing_path = pathlib.Path(self.temp_dir.name) / "existing-mask.tif" + existing_path.write_text("ok", encoding="utf-8") + missing_path = pathlib.Path(self.temp_dir.name) / "missing-mask.tif" + + self.client.patch( + f"/api/workflows/{workflow_id}", + json={ + "image_path": str(existing_path), + "corrected_mask_path": str(missing_path), + "stage": "proofreading", + }, + ) + + self.client.post( + f"/api/workflows/{workflow_id}/events", + json={ + "actor": "system", + "event_type": "proofreading.masks_exported", + "stage": "proofreading", + "summary": "Masks exported", + "payload": {"output_path": str(existing_path)}, + }, + ) + + response = self.client.post(f"/api/workflows/{workflow_id}/export-bundle") + self.assertEqual(response.status_code, 200) + payload = response.json() + + self.assertEqual(payload["schema_version"], "workflow-export-bundle/v1") + self.assertEqual(payload["workflow_id"], workflow_id) + self.assertIsInstance(payload["events"], list) + self.assertGreaterEqual(len(payload["events"]), 2) # includes workflow.created + + artifacts = {entry["path"]: entry["exists"] for entry in payload["artifact_paths"]} + self.assertIn(str(existing_path), artifacts) + self.assertIn(str(missing_path), artifacts) + self.assertTrue(artifacts[str(existing_path)]) + self.assertFalse(artifacts[str(missing_path)]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_workflow_metrics.py b/tests/test_workflow_metrics.py new file mode 100644 index 0000000..5f72479 --- /dev/null +++ b/tests/test_workflow_metrics.py @@ -0,0 +1,120 @@ +import pathlib +import tempfile +import unittest + +import pytest +pytest.importorskip("sqlalchemy") +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +pytest.importorskip("fastapi") +from fastapi.testclient import TestClient + +from server_api.auth import database as auth_database +from server_api.auth import models +from server_api.main import app as server_api_app + + +class WorkflowMetricsTests(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.db_path = pathlib.Path(self.temp_dir.name) / "workflow-metrics-test.db" + self.engine = create_engine( + f"sqlite:///{self.db_path}", connect_args={"check_same_thread": False} + ) + self.SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=self.engine + ) + models.Base.metadata.create_all(bind=self.engine) + + def override_get_db(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + + server_api_app.dependency_overrides[auth_database.get_db] = override_get_db + self.client = TestClient(server_api_app) + + def tearDown(self): + server_api_app.dependency_overrides.clear() + self.engine.dispose() + self.temp_dir.cleanup() + + def _current_workflow_id(self) -> int: + response = self.client.get("/api/workflows/current") + self.assertEqual(response.status_code, 200) + return response.json()["workflow"]["id"] + + def test_metrics_empty_workflow_returns_stable_shape(self): + workflow_id = self._current_workflow_id() + response = self.client.get(f"/api/workflows/{workflow_id}/metrics") + self.assertEqual(response.status_code, 200) + payload = response.json() + metrics = payload["metrics"] + self.assertEqual(payload["workflow_id"], workflow_id) + self.assertIn("event_counts", metrics) + self.assertIn("decision_metrics", metrics) + self.assertIn("stage_transition_counts", metrics) + self.assertIn("timestamps", metrics) + self.assertIn("total_events", metrics) + + def test_metrics_aggregates_events_and_decisions(self): + workflow_id = self._current_workflow_id() + self.client.patch( + f"/api/workflows/{workflow_id}", + json={"stage": "visualization"}, + ) + self.client.patch( + f"/api/workflows/{workflow_id}", + json={"stage": "inference"}, + ) + + events = [ + { + "actor": "system", + "event_type": "inference.started", + "stage": "inference", + "summary": "Inference started", + }, + { + "actor": "agent", + "event_type": "agent.proposal_created", + "stage": "proofreading", + "summary": "Proposal created", + "approval_status": "pending", + }, + { + "actor": "user", + "event_type": "agent.proposal_approved", + "stage": "retraining_staged", + "summary": "Proposal approved", + }, + { + "actor": "user", + "event_type": "agent.proposal_rejected", + "stage": "proofreading", + "summary": "Proposal rejected", + }, + ] + for event in events: + response = self.client.post(f"/api/workflows/{workflow_id}/events", json=event) + self.assertEqual(response.status_code, 200) + + response = self.client.get(f"/api/workflows/{workflow_id}/metrics") + self.assertEqual(response.status_code, 200) + metrics = response.json()["metrics"] + + self.assertEqual(metrics["event_counts"]["inference.started"], 1) + self.assertEqual(metrics["event_counts"]["agent.proposal_created"], 1) + self.assertEqual(metrics["decision_metrics"]["approvals"], 1) + self.assertEqual(metrics["decision_metrics"]["rejections"], 1) + self.assertEqual(metrics["decision_metrics"]["total_decisions"], 2) + self.assertGreaterEqual(metrics["decision_metrics"]["approval_rate"], 0.0) + self.assertGreaterEqual(metrics["decision_metrics"]["rejection_rate"], 0.0) + self.assertGreaterEqual(metrics["total_events"], 5) # includes workflow.created + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_workflow_routes.py b/tests/test_workflow_routes.py index 6c77f47..da48d11 100644 --- a/tests/test_workflow_routes.py +++ b/tests/test_workflow_routes.py @@ -3,11 +3,15 @@ import unittest import numpy as np -import tifffile -from fastapi.testclient import TestClient +import pytest +pytest.importorskip("sqlalchemy") from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +pytest.importorskip("fastapi") +tifffile = pytest.importorskip("tifffile") +from fastapi.testclient import TestClient + from server_api.auth import database as auth_database from server_api.auth import models import server_api.ehtool.router as ehtool_router_module @@ -137,6 +141,83 @@ def test_agent_action_approve_and_reject_flow(self): self.assertIn("agent.proposal_rejected", event_types) self.assertIn("retraining.staged", event_types) + def test_hotspots_and_impact_preview(self): + workflow, _ = self._current_workflow() + workflow_id = workflow["id"] + patch_response = self.client.patch( + f"/api/workflows/{workflow_id}", + json={"stage": "proofreading"}, + ) + self.assertEqual(patch_response.status_code, 200) + + self.client.post( + f"/api/workflows/{workflow_id}/events", + json={ + "actor": "system", + "event_type": "inference.failed", + "stage": "inference", + "summary": "Inference failed on z:12", + "payload": {"region_id": "z:12"}, + }, + ) + self.client.post( + f"/api/workflows/{workflow_id}/events", + json={ + "actor": "user", + "event_type": "proofreading.instance_classified", + "stage": "proofreading", + "summary": "Classified uncertain instances.", + "payload": { + "region_id": "z:12", + "classification": "incorrect", + "instance_ids": [101, 202], + }, + }, + ) + self.client.post( + f"/api/workflows/{workflow_id}/events", + json={ + "actor": "user", + "event_type": "proofreading.mask_saved", + "stage": "proofreading", + "summary": "Saved corrected mask for z:12", + "payload": {"region_id": "z:12", "instance_id": 101}, + }, + ) + self.client.post( + f"/api/workflows/{workflow_id}/events", + json={ + "actor": "system", + "event_type": "proofreading.masks_exported", + "stage": "proofreading", + "summary": "Exported corrected masks.", + "payload": {"written_path": "/tmp/corrected-z12.tif"}, + }, + ) + + hotspot_response = self.client.get(f"/api/workflows/{workflow_id}/hotspots") + self.assertEqual(hotspot_response.status_code, 200) + hotspot_payload = hotspot_response.json() + self.assertEqual(hotspot_payload["workflow_id"], workflow_id) + self.assertGreaterEqual(len(hotspot_payload["hotspots"]), 1) + self.assertEqual(hotspot_payload["hotspots"][0]["region_key"], "z:12") + self.assertIn( + hotspot_payload["hotspots"][0]["severity"], + {"low", "medium", "high"}, + ) + + impact_response = self.client.get( + f"/api/workflows/{workflow_id}/impact-preview" + ) + self.assertEqual(impact_response.status_code, 200) + impact_payload = impact_response.json() + self.assertTrue(impact_payload["can_stage_retraining"]) + self.assertEqual( + impact_payload["corrected_mask_path"], "/tmp/corrected-z12.tif" + ) + self.assertIn(impact_payload["confidence"], {"low", "medium", "high"}) + self.assertIn("proofreading_mask_saved", impact_payload["signals"]) + def test_ehtool_load_classify_save_and_export_append_workflow_events(self): workflow, _ = self._current_workflow() workflow_id = workflow["id"] diff --git a/tests/test_workflow_spine_smoke.py b/tests/test_workflow_spine_smoke.py new file mode 100644 index 0000000..d46ace5 --- /dev/null +++ b/tests/test_workflow_spine_smoke.py @@ -0,0 +1,103 @@ +import pathlib +import tempfile +import unittest + +import pytest +pytest.importorskip("sqlalchemy") +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +pytest.importorskip("fastapi") +from fastapi.testclient import TestClient + +from server_api.auth import database as auth_database +from server_api.auth import models +from server_api.main import app as server_api_app + + +class WorkflowSpineSmokeTests(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.db_path = pathlib.Path(self.temp_dir.name) / "workflow-spine-smoke.db" + self.engine = create_engine( + f"sqlite:///{self.db_path}", connect_args={"check_same_thread": False} + ) + self.SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=self.engine + ) + models.Base.metadata.create_all(bind=self.engine) + + def override_get_db(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + + server_api_app.dependency_overrides[auth_database.get_db] = override_get_db + self.client = TestClient(server_api_app) + + def tearDown(self): + server_api_app.dependency_overrides.clear() + self.engine.dispose() + self.temp_dir.cleanup() + + def test_spine_loop_load_proofread_stage_and_export_evidence(self): + workflow_response = self.client.get("/api/workflows/current") + self.assertEqual(workflow_response.status_code, 200) + workflow = workflow_response.json()["workflow"] + workflow_id = workflow["id"] + + export_path = pathlib.Path(self.temp_dir.name) / "corrected-mask.tif" + export_path.write_text("mask", encoding="utf-8") + + self.client.patch( + f"/api/workflows/{workflow_id}", + json={"stage": "proofreading", "corrected_mask_path": str(export_path)}, + ) + self.client.post( + f"/api/workflows/{workflow_id}/events", + json={ + "actor": "system", + "event_type": "proofreading.masks_exported", + "stage": "proofreading", + "summary": "Corrected masks exported.", + "payload": {"written_path": str(export_path), "region_id": "z:12"}, + }, + ) + + query_response = self.client.post( + f"/api/workflows/{workflow_id}/agent/query", + json={"query": "stage corrected masks for retraining"}, + ) + self.assertEqual(query_response.status_code, 200) + proposals = query_response.json()["proposals"] + self.assertEqual(len(proposals), 1) + proposal_id = proposals[0]["id"] + + approve_response = self.client.post( + f"/api/workflows/{workflow_id}/agent-actions/{proposal_id}/approve" + ) + self.assertEqual(approve_response.status_code, 200) + self.assertEqual( + approve_response.json()["workflow"]["stage"], "retraining_staged" + ) + + hotspots_response = self.client.get(f"/api/workflows/{workflow_id}/hotspots") + impact_response = self.client.get(f"/api/workflows/{workflow_id}/impact-preview") + metrics_response = self.client.get(f"/api/workflows/{workflow_id}/metrics") + bundle_response = self.client.post( + f"/api/workflows/{workflow_id}/export-bundle" + ) + + self.assertEqual(hotspots_response.status_code, 200) + self.assertEqual(impact_response.status_code, 200) + self.assertEqual(metrics_response.status_code, 200) + self.assertEqual(bundle_response.status_code, 200) + self.assertEqual( + bundle_response.json()["schema_version"], "workflow-export-bundle/v1" + ) + + +if __name__ == "__main__": + unittest.main() From b2b6ca20e0043e5321f2301859559b644bfee11a Mon Sep 17 00:00:00 2001 From: Adam Gohain Date: Mon, 13 Apr 2026 16:55:13 -0400 Subject: [PATCH 3/6] feat(chat): add request lifecycle observability logs (cherry picked from commit 74c0d3ec82100d05bb60329feddeb1d56a4d0993) --- server_api/chatbot/logging_utils.py | 29 ++++ server_api/main.py | 202 ++++++++++++++++++++++++++-- tests/test_chat_logging_fields.py | 43 ++++++ 3 files changed, 263 insertions(+), 11 deletions(-) create mode 100644 server_api/chatbot/logging_utils.py create mode 100644 tests/test_chat_logging_fields.py diff --git a/server_api/chatbot/logging_utils.py b/server_api/chatbot/logging_utils.py new file mode 100644 index 0000000..41dcdfe --- /dev/null +++ b/server_api/chatbot/logging_utils.py @@ -0,0 +1,29 @@ +import logging +import time +import uuid +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +def request_id_from_request(request: Any) -> str: + return request.headers.get("x-request-id") or str(uuid.uuid4()) + + +def log_request_summary( + *, + request_id: str, + endpoint: str, + start_time: float, + status: str, + error_type: Optional[str] = None, +) -> None: + latency_ms = round((time.perf_counter() - start_time) * 1000, 2) + logger.info( + "request_summary request_id=%s endpoint=%s latency_ms=%s status=%s error_type=%s", + request_id, + endpoint, + latency_ms, + status, + error_type or "none", + ) diff --git a/server_api/main.py b/server_api/main.py index 8c33c79..736ba3a 100644 --- a/server_api/main.py +++ b/server_api/main.py @@ -3,6 +3,9 @@ import re import shutil import tempfile +import traceback +import time +from concurrent.futures import ThreadPoolExecutor, TimeoutError from typing import List, Optional from urllib.parse import urlsplit, urlunsplit @@ -22,6 +25,10 @@ from server_api.auth.database import get_db from server_api.auth.router import get_current_user from server_api.ehtool import router as ehtool_router +from server_api.chatbot.logging_utils import ( + log_request_summary, + request_id_from_request, +) from server_api.workflows import router as workflow_router from server_api.workflows.service import ( append_event_for_workflow_if_present, @@ -35,12 +42,25 @@ # Chatbot is optional; keep the server running if dependencies or model endpoints # are unavailable. We initialize lazily on demand. try: - from server_api.chatbot.chatbot import build_chain, build_helper_chain + from server_api.chatbot.chatbot import ( + build_chain, + build_helper_chain, + _format_admin_llm_error, + ) except Exception as exc: # pragma: no cover - exercised indirectly via endpoints build_chain = None build_helper_chain = None _chatbot_error = exc + def _format_admin_llm_error(error): + return ( + "The AI assistant could not connect to its configured language model. " + "Please contact your system administrator with this error: " + f"{str(error).strip() or error.__class__.__name__}" + ) +else: + _chatbot_error = None + chain = None _reset_search = None @@ -57,20 +77,57 @@ def _ensure_chatbot(): global chain, _reset_search, _chatbot_error if chain is not None and _reset_search is not None: + print("[CHATBOT] Reusing initialized main chat chain") return True if build_chain is None: + print("[CHATBOT] build_chain is unavailable; chatbot backend not configured") return False + start_time = time.perf_counter() + print("[CHATBOT] Initializing main chat chain...") try: chain, _reset_search = build_chain() _chatbot_error = None + elapsed = time.perf_counter() - start_time + print(f"[CHATBOT] Main chat chain ready in {elapsed:.2f}s") return True except Exception as exc: # pragma: no cover - runtime config issue chain = None _reset_search = None _chatbot_error = exc + print( + "[CHATBOT] Failed to initialize LLM backend: " + f"{exc.__class__.__name__}: {exc!r}" + ) + traceback.print_exc() return False +def _llm_unavailable_detail(error): + return { + "user_message": _format_admin_llm_error(error), + "error": str(error), + "reason": "llm_unavailable", + } + + +def _invoke_with_progress(invoke_fn, *, label: str, request_id: str, poll_seconds=5.0): + start_time = time.perf_counter() + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(invoke_fn) + while True: + try: + result = future.result(timeout=poll_seconds) + elapsed = time.perf_counter() - start_time + print( + f"[CHATBOT][{request_id}] {label} completed in {elapsed:.2f}s" + ) + return result + except TimeoutError: + elapsed = time.perf_counter() - start_time + print( + f"[CHATBOT][{request_id}] {label} still running " + f"after {elapsed:.2f}s..." + ) REACT_APP_SERVER_PROTOCOL = "http" REACT_APP_SERVER_URL = "localhost:4243" @@ -834,16 +891,31 @@ async def chat_query( user: models.User = Depends(get_current_user), db: Session = Depends(get_db), ): + global _chatbot_error + request_id = request_id_from_request(req).replace("-", "")[:8] + request_start = time.perf_counter() + print(f"[CHATBOT][{request_id}] Incoming /chat/query request from user={user.id}") if not _ensure_chatbot(): - detail = "Chatbot is not configured" - if "_chatbot_error" in globals(): - detail = f"{detail}: {_chatbot_error}" - raise HTTPException(status_code=503, detail=detail) + print(f"[CHATBOT][{request_id}] Main chain unavailable before request") + log_request_summary( + request_id=request_id, + endpoint="/chat/query", + start_time=request_start, + status="error", + error_type="llm_unavailable", + ) + raise HTTPException( + status_code=503, detail=_llm_unavailable_detail(_chatbot_error) + ) body = await req.json() query = body.get("query") convo_id = body.get("conversationId") if not isinstance(query, str) or not query.strip(): raise HTTPException(status_code=400, detail="Query must be a non-empty string.") + print( + f"[CHATBOT][{request_id}] Parsed request: convo_id={convo_id}, " + f"query_len={len(query.strip())}" + ) # Auto-create a conversation if none supplied if not convo_id: @@ -866,13 +938,48 @@ async def chat_query( # Rebuild in-memory history from DB when switching conversations _load_history_for_convo(convo_id, db) + print( + f"[CHATBOT][{request_id}] Loaded history messages={len(_chat_history)} " + f"for convo_id={convo_id}" + ) if _reset_search is not None: _reset_search() + print(f"[CHATBOT][{request_id}] Reset documentation search call counter") all_messages = _chat_history + [{"role": "user", "content": query}] - result = chain.invoke({"messages": all_messages}) + print( + f"[CHATBOT][{request_id}] Invoking main chain with " + f"{len(all_messages)} message(s)" + ) + try: + result = _invoke_with_progress( + lambda: chain.invoke({"messages": all_messages}), + label="main chain invoke", + request_id=request_id, + ) + except Exception as exc: + _chatbot_error = exc + print( + "[CHATBOT] LLM request failed: " + f"{exc.__class__.__name__}: {exc!r}" + ) + traceback.print_exc() + log_request_summary( + request_id=request_id, + endpoint="/chat/query", + start_time=request_start, + status="error", + error_type=exc.__class__.__name__, + ) + raise HTTPException( + status_code=503, detail=_llm_unavailable_detail(exc) + ) from exc messages = result.get("messages", []) response = messages[-1].content if messages else "No response generated" + print( + f"[CHATBOT][{request_id}] Chain returned messages={len(messages)}, " + f"response_len={len(response) if isinstance(response, str) else 0}" + ) # Persist to DB db.add(models.ChatMessage(conversation_id=convo_id, role="user", content=query)) @@ -889,6 +996,17 @@ async def chat_query( # Update in-memory history _chat_history.append({"role": "user", "content": query}) _chat_history.append({"role": "assistant", "content": response}) + total_elapsed = time.perf_counter() - request_start + print( + f"[CHATBOT][{request_id}] /chat/query completed in {total_elapsed:.2f}s " + f"(convo_id={convo_id})" + ) + log_request_summary( + request_id=request_id, + endpoint="/chat/query", + start_time=request_start, + status="ok", + ) return {"response": response, "conversationId": convo_id} @@ -929,21 +1047,36 @@ def _ensure_helper_chat(task_key: str): """Lazily build a helper agent for *task_key*, reusing it on subsequent calls.""" global _chatbot_error if task_key in _helper_chains: + print(f"[CHATBOT] Reusing helper chain for task_key={task_key}") return True if build_helper_chain is None: + print("[CHATBOT] build_helper_chain is unavailable") return False + start_time = time.perf_counter() + print(f"[CHATBOT] Initializing helper chain for task_key={task_key}...") try: agent, reset_fn = build_helper_chain() _helper_chains[task_key] = (agent, reset_fn) _helper_histories[task_key] = [] + elapsed = time.perf_counter() - start_time + print( + f"[CHATBOT] Helper chain ready for task_key={task_key} in {elapsed:.2f}s" + ) return True except Exception as exc: _chatbot_error = exc + print( + "[CHATBOT] Failed to initialize helper LLM backend: " + f"{exc.__class__.__name__}: {exc!r}" + ) + traceback.print_exc() return False @app.post("/chat/helper/query") async def chat_helper_query(req: Request): + request_id = request_id_from_request(req).replace("-", "")[:8] + request_start = time.perf_counter() body = await req.json() task_key = body.get("taskKey") query = body.get("query") @@ -953,12 +1086,23 @@ async def chat_helper_query(req: Request): raise HTTPException(status_code=400, detail="taskKey is required") if not isinstance(query, str) or not query.strip(): raise HTTPException(status_code=400, detail="query must be a non-empty string.") + print( + f"[CHATBOT][{request_id}] Incoming /chat/helper/query " + f"task_key={task_key} query_len={len(query.strip())}" + ) if not _ensure_helper_chat(task_key): - detail = "Helper chatbot is not configured" - if "_chatbot_error" in globals(): - detail = f"{detail}: {_chatbot_error}" - raise HTTPException(status_code=503, detail=detail) + print(f"[CHATBOT][{request_id}] Helper chain unavailable for task_key={task_key}") + log_request_summary( + request_id=request_id, + endpoint="/chat/helper/query", + start_time=request_start, + status="error", + error_type="llm_unavailable", + ) + raise HTTPException( + status_code=503, detail=_llm_unavailable_detail(_chatbot_error) + ) agent, reset_fn = _helper_chains[task_key] history = _helper_histories[task_key] @@ -970,12 +1114,48 @@ async def chat_helper_query(req: Request): ) reset_fn() + print( + f"[CHATBOT][{request_id}] Helper history_len={len(history)}; invoking helper " + f"with context_len={len(user_content)}" + ) all_messages = history + [{"role": "user", "content": user_content}] - result = agent.invoke({"messages": all_messages}) + try: + result = _invoke_with_progress( + lambda: agent.invoke({"messages": all_messages}), + label="helper chain invoke", + request_id=request_id, + ) + except Exception as exc: + print( + "[CHATBOT] Helper LLM request failed: " + f"{exc.__class__.__name__}: {exc!r}" + ) + traceback.print_exc() + log_request_summary( + request_id=request_id, + endpoint="/chat/helper/query", + start_time=request_start, + status="error", + error_type=exc.__class__.__name__, + ) + raise HTTPException( + status_code=503, detail=_llm_unavailable_detail(exc) + ) from exc messages = result.get("messages", []) response = messages[-1].content if messages else "No response generated" history.append({"role": "user", "content": user_content}) history.append({"role": "assistant", "content": response}) + total_elapsed = time.perf_counter() - request_start + print( + f"[CHATBOT][{request_id}] /chat/helper/query completed in {total_elapsed:.2f}s " + f"response_len={len(response) if isinstance(response, str) else 0}" + ) + log_request_summary( + request_id=request_id, + endpoint="/chat/helper/query", + start_time=request_start, + status="ok", + ) return {"response": response} diff --git a/tests/test_chat_logging_fields.py b/tests/test_chat_logging_fields.py new file mode 100644 index 0000000..71c02ee --- /dev/null +++ b/tests/test_chat_logging_fields.py @@ -0,0 +1,43 @@ +import logging +import time + +from server_api.chatbot import logging_utils + + +def test_standardized_summary_log_success_fields(caplog): + caplog.set_level(logging.INFO) + + logging_utils.log_request_summary( + request_id="req-123", + endpoint="/chat/query", + start_time=time.perf_counter() - 0.01, + status="ok", + ) + + message = caplog.records[-1].getMessage() + assert "request_id=req-123" in message + assert "endpoint=/chat/query" in message + assert "latency_ms=" in message + assert "status=ok" in message + assert "error_type=none" in message + + +def test_standardized_summary_log_error_fields_and_no_payload_leak(caplog): + caplog.set_level(logging.INFO) + sensitive_query = "my secret token is abc123" + + logging_utils.log_request_summary( + request_id="req-456", + endpoint="/chat/helper/query", + start_time=time.perf_counter() - 0.02, + status="error", + error_type="HTTPException", + ) + + message = caplog.records[-1].getMessage() + assert "request_id=req-456" in message + assert "endpoint=/chat/helper/query" in message + assert "latency_ms=" in message + assert "status=error" in message + assert "error_type=HTTPException" in message + assert sensitive_query not in message From 2b3f140a2f1058cdae48246fa49630d27364fa0b Mon Sep 17 00:00:00 2001 From: Adam Gohain Date: Sun, 12 Apr 2026 20:58:12 -0400 Subject: [PATCH 4/6] Wire workflow spine into frontend (cherry picked from commit fa97e5e0b1c8bf498be6618ab697af9d4fe3cf76) --- client/package.json | 1 + client/src/App.js | 13 +- client/src/api.js | 93 +++++++++- client/src/components/Chatbot.js | 39 ++++ client/src/components/WorkflowTimeline.js | 142 +++++++++++++++ .../src/components/WorkflowTimeline.test.js | 138 +++++++++++++++ client/src/contexts/WorkflowContext.js | 166 ++++++++++++++++++ client/src/contexts/WorkflowContext.test.js | 95 ++++++++++ client/src/views/EHTool.js | 2 + client/src/views/MaskProofreading.js | 4 + client/src/views/MaskProofreading.test.js | 26 +++ client/src/views/ModelInference.js | 59 ++++++- client/src/views/ModelTraining.js | 3 + client/src/views/Views.js | 13 ++ client/src/views/Views.test.js | 12 +- client/src/views/Visualization.js | 3 + client/src/views/ehtool/DetectionWorkflow.js | 65 ++++++- 17 files changed, 863 insertions(+), 11 deletions(-) create mode 100644 client/src/components/WorkflowTimeline.js create mode 100644 client/src/components/WorkflowTimeline.test.js create mode 100644 client/src/contexts/WorkflowContext.js create mode 100644 client/src/contexts/WorkflowContext.test.js create mode 100644 client/src/views/MaskProofreading.test.js diff --git a/client/package.json b/client/package.json index 7a08018..0107480 100644 --- a/client/package.json +++ b/client/package.json @@ -28,6 +28,7 @@ }, "scripts": { "start": "react-scripts start", + "test": "react-scripts test", "build": "cross-env CI=false react-scripts build", "electron": "electron ." }, diff --git a/client/src/App.js b/client/src/App.js index 6d14652..9c26a2c 100644 --- a/client/src/App.js +++ b/client/src/App.js @@ -3,6 +3,7 @@ import "./App.css"; import Views from "./views/Views"; import { AppContext, ContextWrapper } from "./contexts/GlobalContext"; import { YamlContextWrapper } from "./contexts/YamlContext"; +import { WorkflowProvider } from "./contexts/WorkflowContext"; function CacheBootstrapper({ children }) { const { resetFileState } = useContext(AppContext); @@ -38,11 +39,13 @@ function App() { return ( - -
- -
-
+ + +
+ +
+
+
); diff --git a/client/src/api.js b/client/src/api.js index fffe5f2..18df605 100644 --- a/client/src/api.js +++ b/client/src/api.js @@ -52,7 +52,7 @@ const getErrorDetailMessage = (detail) => { return String(detail); }; -export async function getNeuroglancerViewer(image, label, scales) { +export async function getNeuroglancerViewer(image, label, scales, workflowId = null) { try { const url = `${BASE_URL}/neuroglancer`; if (hasBrowserFile(image)) { @@ -70,6 +70,9 @@ export async function getNeuroglancerViewer(image, label, scales) { ); } formData.append("scales", JSON.stringify(scales)); + if (workflowId) { + formData.append("workflow_id", String(workflowId)); + } const res = await axios.post(url, formData); return res.data; } @@ -78,6 +81,7 @@ export async function getNeuroglancerViewer(image, label, scales) { image: buildFilePath(image), label: buildFilePath(label), scales, + workflow_id: workflowId, }); const res = await axios.post(url, data); return res.data; @@ -141,6 +145,7 @@ export async function startModelTraining( logPath, outputPath, configOriginPath = "", + workflowId = null, ) { try { console.log("[API] ===== Starting Training Configuration ====="); @@ -178,6 +183,7 @@ export async function startModelTraining( outputPath, // TensorBoard will use this instead trainingConfig: configToSend, configOriginPath, + workflow_id: workflowId, }); console.log("[API] Request payload size:", data.length, "bytes"); @@ -232,6 +238,7 @@ export async function startModelInference( outputPath, checkpointPath, configOriginPath = "", + workflowId = null, ) { console.log("\n========== API.JS: START_MODEL_INFERENCE CALLED =========="); console.log("[API] Function arguments:"); @@ -293,6 +300,7 @@ export async function startModelInference( outputPath, inferenceConfig: configToSend, configOriginPath, + workflow_id: workflowId, }; console.log("[API] Payload structure:"); @@ -462,3 +470,86 @@ export async function getConfigPresetContent(path) { export async function getModelArchitectures() { return makeApiRequest("pytc/architectures", "get"); } + +// ── Workflow spine ─────────────────────────────────────────────────────────── + +export async function getCurrentWorkflow() { + try { + const res = await apiClient.get("/api/workflows/current"); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function updateWorkflow(workflowId, patch) { + try { + const res = await apiClient.patch(`/api/workflows/${workflowId}`, patch); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function listWorkflowEvents(workflowId) { + try { + const res = await apiClient.get(`/api/workflows/${workflowId}/events`); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function appendWorkflowEvent(workflowId, event) { + try { + const res = await apiClient.post(`/api/workflows/${workflowId}/events`, event); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function createAgentAction(workflowId, action) { + try { + const res = await apiClient.post( + `/api/workflows/${workflowId}/agent-actions`, + action, + ); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function approveAgentAction(workflowId, eventId) { + try { + const res = await apiClient.post( + `/api/workflows/${workflowId}/agent-actions/${eventId}/approve`, + ); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function rejectAgentAction(workflowId, eventId) { + try { + const res = await apiClient.post( + `/api/workflows/${workflowId}/agent-actions/${eventId}/reject`, + ); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function queryWorkflowAgent(workflowId, query) { + try { + const res = await apiClient.post(`/api/workflows/${workflowId}/agent/query`, { + query, + }); + return res.data; + } catch (error) { + handleError(error); + } +} diff --git a/client/src/components/Chatbot.js b/client/src/components/Chatbot.js index 1787503..38de04e 100644 --- a/client/src/components/Chatbot.js +++ b/client/src/components/Chatbot.js @@ -27,6 +27,8 @@ import { } from "../api"; import ReactMarkdown from "react-markdown"; import remarkGfm from "remark-gfm"; +import WorkflowTimeline from "./WorkflowTimeline"; +import { useWorkflow } from "../contexts/WorkflowContext"; const { TextArea } = Input; const { Text } = Typography; @@ -41,6 +43,20 @@ const GREETING = { const truncate = (str, n = 50) => str.length > n ? str.slice(0, n).trimEnd() + "…" : str; +const WORKFLOW_QUERY_TERMS = [ + "workflow", + "next", + "stage", + "retrain", + "training", + "corrected", + "proofread", + "mask", + "inference", + "visualize", + "evaluate", +]; + /* ═══════════════════════════════════════════════════════════════════════════ */ function Chatbot({ onClose }) { @@ -54,6 +70,15 @@ function Chatbot({ onClose }) { const [isLoadingConvo, setIsLoadingConvo] = useState(false); const lastMessageRef = useRef(null); + const workflowContext = useWorkflow(); + + const shouldUseWorkflowAgent = (query) => { + if (!workflowContext?.workflow?.id || !workflowContext?.queryAgent) { + return false; + } + const lower = query.toLowerCase(); + return WORKFLOW_QUERY_TERMS.some((term) => lower.includes(term)); + }; /* ── scroll ────────────────────────────────────────────────────────────── */ const scrollToBottom = useCallback(() => { @@ -119,6 +144,18 @@ function Chatbot({ onClose }) { setMessages((prev) => [...prev, { role: "user", content: query }]); setIsSending(true); try { + if (shouldUseWorkflowAgent(query)) { + const data = await workflowContext.queryAgent(query); + const response = + data?.response || + "I could not inspect the workflow state for that request."; + setMessages((prev) => [ + ...prev, + { role: "assistant", content: response }, + ]); + return; + } + const data = await queryChatBot(query, activeConvoId); const response = data?.response || "Sorry, I could not generate a response."; @@ -418,6 +455,8 @@ function Chatbot({ onClose }) { + + {/* messages */}
{isLoadingConvo ? ( diff --git a/client/src/components/WorkflowTimeline.js b/client/src/components/WorkflowTimeline.js new file mode 100644 index 0000000..b667687 --- /dev/null +++ b/client/src/components/WorkflowTimeline.js @@ -0,0 +1,142 @@ +import React from "react"; +import { Button, Empty, List, Space, Tag, Typography } from "antd"; +import { CheckOutlined, CloseOutlined } from "@ant-design/icons"; +import { useWorkflow } from "../contexts/WorkflowContext"; + +const { Text } = Typography; + +const STAGE_LABELS = { + setup: "Setup", + visualization: "Visualization", + inference: "Inference", + proofreading: "Proofreading", + retraining_staged: "Retraining staged", + evaluation: "Evaluation", +}; + +function formatEventTime(value) { + if (!value) return ""; + const date = new Date(value); + if (Number.isNaN(date.getTime())) return ""; + return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit" }); +} + +function WorkflowTimeline({ limit = 8 }) { + const workflowContext = useWorkflow(); + const workflow = workflowContext?.workflow; + const events = workflowContext?.events || []; + const approveAgentAction = workflowContext?.approveAgentAction; + const rejectAgentAction = workflowContext?.rejectAgentAction; + + if (!workflowContext) return null; + + const visibleEvents = events.slice(-limit).reverse(); + const stageLabel = STAGE_LABELS[workflow?.stage] || workflow?.stage || "Loading"; + + return ( +
+ + + + Workflow + + + {stageLabel} + + + + {workflow?.title || "Segmentation Workflow"} + + + + {visibleEvents.length === 0 ? ( + + ) : ( + { + const isPendingProposal = + event.event_type === "agent.proposal_created" && + event.approval_status === "pending"; + return ( + } + onClick={() => approveAgentAction?.(event.id)} + > + Approve + , + , + ] + : [] + } + > + + {event.summary} + {event.approval_status !== "not_required" && ( + + {event.approval_status} + + )} + + } + description={ + + + {event.actor} + + + {event.event_type} + + + {formatEventTime(event.created_at)} + + + } + /> + + ); + }} + /> + )} +
+ ); +} + +export default WorkflowTimeline; diff --git a/client/src/components/WorkflowTimeline.test.js b/client/src/components/WorkflowTimeline.test.js new file mode 100644 index 0000000..f3a40c3 --- /dev/null +++ b/client/src/components/WorkflowTimeline.test.js @@ -0,0 +1,138 @@ +import React from "react"; +import { fireEvent, render, screen } from "@testing-library/react"; + +import WorkflowTimeline from "./WorkflowTimeline"; +import { WorkflowContext } from "../contexts/WorkflowContext"; + +jest.mock("antd", () => { + const React = require("react"); + const List = ({ dataSource = [], renderItem }) => ( +
+ {dataSource.map((item, index) => ( + + {renderItem(item, index)} + + ))} +
+ ); + List.Item = ({ children, actions = [] }) => ( +
+ {children} + {actions.map((action, index) => ( + {action} + ))} +
+ ); + List.Item.Meta = ({ title, description }) => ( +
+
{title}
+
{description}
+
+ ); + + const Empty = ({ description }) =>
{description}
; + Empty.PRESENTED_IMAGE_SIMPLE = "simple"; + + return { + Button: ({ children, icon, ...props }) => ( + + ), + Empty, + List, + Space: ({ children }) => {children}, + Tag: ({ children }) => {children}, + Typography: { + Text: ({ children }) => {children}, + }, + }; +}); + +jest.mock("@ant-design/icons", () => { + const Icon = () => ; + return { + CheckOutlined: Icon, + CloseOutlined: Icon, + }; +}); + +jest.mock("../contexts/WorkflowContext", () => { + const React = require("react"); + const WorkflowContext = React.createContext(null); + return { + WorkflowContext, + useWorkflow: () => React.useContext(WorkflowContext), + }; +}); + +const workflow = { + id: 1, + title: "Test Workflow", + stage: "proofreading", +}; + +function renderTimeline(overrides = {}) { + const value = { + workflow, + events: [], + approveAgentAction: jest.fn(), + rejectAgentAction: jest.fn(), + ...overrides, + }; + render( + + + , + ); + return value; +} + +describe("WorkflowTimeline", () => { + it("renders workflow stage and chronological evidence", () => { + renderTimeline({ + events: [ + { + id: 1, + actor: "user", + event_type: "dataset.loaded", + stage: "proofreading", + summary: "Loaded dataset.", + approval_status: "not_required", + created_at: "2026-04-12T12:00:00Z", + }, + ], + }); + + expect(screen.getByText("Proofreading")).toBeTruthy(); + expect(screen.getByText("Loaded dataset.")).toBeTruthy(); + expect(screen.getByText("dataset.loaded")).toBeTruthy(); + }); + + it("exposes approve and reject controls for pending agent proposals", () => { + const approveAgentAction = jest.fn(); + const rejectAgentAction = jest.fn(); + renderTimeline({ + approveAgentAction, + rejectAgentAction, + events: [ + { + id: 7, + actor: "agent", + event_type: "agent.proposal_created", + stage: "proofreading", + summary: "Stage corrected masks.", + approval_status: "pending", + created_at: "2026-04-12T12:00:00Z", + }, + ], + }); + + fireEvent.click(screen.getByText("Approve")); + fireEvent.click(screen.getByText("Reject")); + + expect(approveAgentAction).toHaveBeenCalledWith(7); + expect(rejectAgentAction).toHaveBeenCalledWith(7); + }); +}); diff --git a/client/src/contexts/WorkflowContext.js b/client/src/contexts/WorkflowContext.js new file mode 100644 index 0000000..bb267f5 --- /dev/null +++ b/client/src/contexts/WorkflowContext.js @@ -0,0 +1,166 @@ +import React, { + createContext, + useCallback, + useContext, + useEffect, + useState, +} from "react"; +import { message } from "antd"; +import { + appendWorkflowEvent, + approveAgentAction as approveAgentActionApi, + createAgentAction, + getCurrentWorkflow, + listWorkflowEvents, + queryWorkflowAgent, + rejectAgentAction as rejectAgentActionApi, + updateWorkflow as updateWorkflowApi, +} from "../api"; +import { AppContext } from "./GlobalContext"; + +export const WorkflowContext = createContext(null); + +export function useWorkflow() { + return useContext(WorkflowContext); +} + +export function WorkflowProvider({ children }) { + const appContext = useContext(AppContext); + const [workflow, setWorkflow] = useState(null); + const [events, setEvents] = useState([]); + const [loading, setLoading] = useState(true); + const [lastClientEffects, setLastClientEffects] = useState(null); + + const refreshWorkflow = useCallback(async () => { + setLoading(true); + try { + const data = await getCurrentWorkflow(); + setWorkflow(data?.workflow || null); + setEvents(data?.events || []); + return data; + } catch (error) { + message.error("Failed to load workflow state."); + return null; + } finally { + setLoading(false); + } + }, []); + + const refreshEvents = useCallback(async () => { + if (!workflow?.id) return []; + const nextEvents = await listWorkflowEvents(workflow.id); + setEvents(nextEvents || []); + return nextEvents || []; + }, [workflow?.id]); + + useEffect(() => { + refreshWorkflow(); + }, [refreshWorkflow]); + + const updateWorkflow = useCallback( + async (patch) => { + if (!workflow?.id) return null; + const nextWorkflow = await updateWorkflowApi(workflow.id, patch); + setWorkflow(nextWorkflow); + return nextWorkflow; + }, + [workflow?.id], + ); + + const appendEvent = useCallback( + async (event) => { + if (!workflow?.id) return null; + const nextEvent = await appendWorkflowEvent(workflow.id, event); + setEvents((prev) => [...prev, nextEvent]); + return nextEvent; + }, + [workflow?.id], + ); + + const proposeAgentAction = useCallback( + async (action) => { + if (!workflow?.id) return null; + const proposal = await createAgentAction(workflow.id, action); + await refreshEvents(); + return proposal; + }, + [workflow?.id, refreshEvents], + ); + + const applyClientEffects = useCallback( + (effects) => { + if (!effects) return; + if (effects.set_training_label_path && appContext?.trainingState) { + appContext.trainingState.setInputLabel(effects.set_training_label_path); + } + setLastClientEffects(effects); + }, + [appContext], + ); + + const approveAgentAction = useCallback( + async (eventId) => { + if (!workflow?.id) return null; + const result = await approveAgentActionApi(workflow.id, eventId); + if (result?.workflow) { + setWorkflow(result.workflow); + } + applyClientEffects(result?.client_effects); + await refreshEvents(); + message.success("Agent proposal approved."); + return result; + }, + [workflow?.id, refreshEvents, applyClientEffects], + ); + + const rejectAgentAction = useCallback( + async (eventId) => { + if (!workflow?.id) return null; + const result = await rejectAgentActionApi(workflow.id, eventId); + await refreshEvents(); + message.info("Agent proposal rejected."); + return result; + }, + [workflow?.id, refreshEvents], + ); + + const queryAgent = useCallback( + async (query) => { + if (!workflow?.id) return null; + const result = await queryWorkflowAgent(workflow.id, query); + if (result?.proposals?.length) { + await refreshEvents(); + } + return result; + }, + [workflow?.id, refreshEvents], + ); + + const consumeClientEffects = useCallback(() => { + const effects = lastClientEffects; + setLastClientEffects(null); + return effects; + }, [lastClientEffects]); + + return ( + + {children} + + ); +} diff --git a/client/src/contexts/WorkflowContext.test.js b/client/src/contexts/WorkflowContext.test.js new file mode 100644 index 0000000..3bac823 --- /dev/null +++ b/client/src/contexts/WorkflowContext.test.js @@ -0,0 +1,95 @@ +import React from "react"; +import { fireEvent, render, screen, waitFor } from "@testing-library/react"; + +import { AppContext } from "./GlobalContext"; +import { WorkflowProvider, useWorkflow } from "./WorkflowContext"; +import { + approveAgentAction, + getCurrentWorkflow, + listWorkflowEvents, +} from "../api"; + +jest.mock("../api", () => ({ + approveAgentAction: jest.fn(), + appendWorkflowEvent: jest.fn(), + createAgentAction: jest.fn(), + getCurrentWorkflow: jest.fn(), + listWorkflowEvents: jest.fn(), + queryWorkflowAgent: jest.fn(), + rejectAgentAction: jest.fn(), + updateWorkflow: jest.fn(), +})); + +const baseWorkflow = { + id: 1, + title: "Segmentation Workflow", + stage: "setup", +}; + +function Probe() { + const workflowContext = useWorkflow(); + return ( +
+
{workflowContext.workflow?.stage || "loading"}
+
{workflowContext.events.map((event) => event.event_type).join(",")}
+ +
+ ); +} + +function renderProvider(appContextValue) { + render( + + + + + , + ); +} + +describe("WorkflowProvider", () => { + beforeEach(() => { + jest.clearAllMocks(); + getCurrentWorkflow.mockResolvedValue({ + workflow: baseWorkflow, + events: [{ id: 1, event_type: "workflow.created" }], + }); + listWorkflowEvents.mockResolvedValue([]); + }); + + it("loads the current workflow and events on startup", async () => { + renderProvider({ trainingState: { setInputLabel: jest.fn() } }); + + expect(await screen.findByText("setup")).toBeTruthy(); + expect(screen.getByText("workflow.created")).toBeTruthy(); + expect(getCurrentWorkflow).toHaveBeenCalledTimes(1); + }); + + it("applies client effects when an agent proposal is approved", async () => { + const setInputLabel = jest.fn(); + approveAgentAction.mockResolvedValue({ + workflow: { ...baseWorkflow, stage: "retraining_staged" }, + client_effects: { + navigate_to: "training", + set_training_label_path: "/tmp/corrected.tif", + }, + }); + + renderProvider({ trainingState: { setInputLabel } }); + await screen.findByText("setup"); + + fireEvent.click(screen.getByText("Approve proposal")); + + await waitFor(() => { + expect(setInputLabel).toHaveBeenCalledWith("/tmp/corrected.tif"); + }); + await waitFor(() => { + expect(screen.getByText("retraining_staged")).toBeTruthy(); + }); + }); +}); diff --git a/client/src/views/EHTool.js b/client/src/views/EHTool.js index 5d546fd..3cb38e3 100644 --- a/client/src/views/EHTool.js +++ b/client/src/views/EHTool.js @@ -13,6 +13,7 @@ function EHTool({ onSessionChange, refreshTrigger, savedSessionId, + workflowId, }) { // Initialize with saved session if available const [sessionId, setSessionId] = useState(savedSessionId || null); @@ -37,6 +38,7 @@ function EHTool({ diff --git a/client/src/views/MaskProofreading.js b/client/src/views/MaskProofreading.js index 1050a1b..b86bb28 100644 --- a/client/src/views/MaskProofreading.js +++ b/client/src/views/MaskProofreading.js @@ -1,15 +1,19 @@ import React, { useState } from "react"; import EHTool from "./EHTool"; +import { useWorkflow } from "../contexts/WorkflowContext"; function MaskProofreading() { const [ehToolSession, setEhToolSession] = useState(null); const [refreshTrigger, setRefreshTrigger] = useState(0); + const workflowContext = useWorkflow(); + const workflowId = workflowContext?.workflow?.id ?? null; return (
{ // This prop is now nominally used to trigger internal modal diff --git a/client/src/views/MaskProofreading.test.js b/client/src/views/MaskProofreading.test.js new file mode 100644 index 0000000..4fb7fa9 --- /dev/null +++ b/client/src/views/MaskProofreading.test.js @@ -0,0 +1,26 @@ +import React from "react"; +import { render, screen } from "@testing-library/react"; + +import MaskProofreading from "./MaskProofreading"; + +jest.mock("../EHTool", () => (props) => ( +
+)); + +jest.mock("../../contexts/WorkflowContext", () => ({ + useWorkflow: () => ({ workflow: { id: 42 } }), +})); + +describe("MaskProofreading", () => { + it("renders EHTool and passes the active workflow id", () => { + render(); + + const ehTool = screen.getByTestId("eh-tool"); + expect(ehTool).toBeTruthy(); + expect(ehTool.getAttribute("data-workflow-id")).toBe("42"); + }); +}); diff --git a/client/src/views/ModelInference.js b/client/src/views/ModelInference.js index 8ea3bf4..bd3384d 100644 --- a/client/src/views/ModelInference.js +++ b/client/src/views/ModelInference.js @@ -11,13 +11,18 @@ import Configurator from "../components/Configurator"; import { applyInputPaths } from "../configSchema"; import RuntimeLogPanel from "../components/RuntimeLogPanel"; import { AppContext } from "../contexts/GlobalContext"; +import { useWorkflow } from "../contexts/WorkflowContext"; function ModelInference({ isInferring, setIsInferring }) { const context = useContext(AppContext); + const workflowContext = useWorkflow(); + const appendWorkflowEvent = workflowContext?.appendEvent; + const workflowId = workflowContext?.workflow?.id; const inference = context.inferenceState; const [inferenceStatus, setInferenceStatus] = useState(""); const [inferenceRuntime, setInferenceRuntime] = useState(null); const pollingIntervalRef = useRef(null); + const terminalLoggedRef = useRef(false); const getPath = (val) => { if (!val) return ""; @@ -81,6 +86,26 @@ function ModelInference({ isInferring, setIsInferring }) { if (!status.isRunning) { setIsInferring(false); + if (!terminalLoggedRef.current && appendWorkflowEvent) { + terminalLoggedRef.current = true; + const succeeded = status.exitCode === 0; + await appendWorkflowEvent({ + actor: "system", + event_type: succeeded + ? "inference.completed" + : "inference.failed", + stage: "inference", + summary: succeeded + ? "Inference completed successfully." + : "Inference finished without a successful exit.", + payload: { + exitCode: status.exitCode, + phase: status.phase, + lastError: status.lastError, + outputPath: getPath(inference.outputPath), + }, + }); + } if (status.exitCode === 0) { setInferenceStatus("Inference completed successfully! ✓"); } else if (status.exitCode !== null && status.exitCode !== undefined) { @@ -96,6 +121,19 @@ function ModelInference({ isInferring, setIsInferring }) { } catch (error) { console.error("Error polling inference status:", error); setIsInferring(false); + if (!terminalLoggedRef.current && appendWorkflowEvent) { + terminalLoggedRef.current = true; + await appendWorkflowEvent({ + actor: "system", + event_type: "inference.failed", + stage: "inference", + summary: "Inference status polling failed.", + payload: { + error: error.message || "unknown error", + outputPath: getPath(inference.outputPath), + }, + }); + } setInferenceStatus( `Inference status polling failed: ${error.message || "unknown error"}`, ); @@ -109,9 +147,10 @@ function ModelInference({ isInferring, setIsInferring }) { pollingIntervalRef.current = null; } }; - }, [isInferring, setIsInferring]); + }, [isInferring, setIsInferring, appendWorkflowEvent, inference.outputPath]); const handleStartButton = async () => { + let checkpointPath = ""; try { const inferenceConfig = context.inferenceConfig; if (!inferenceConfig) { @@ -121,7 +160,7 @@ function ModelInference({ isInferring, setIsInferring }) { return; } - const checkpointPath = getPath(inference.checkpointPath); + checkpointPath = getPath(inference.checkpointPath); if (!checkpointPath) { setInferenceStatus("Error: Please set checkpoint path first."); return; @@ -129,6 +168,7 @@ function ModelInference({ isInferring, setIsInferring }) { setIsInferring(true); setInferenceStatus("Starting inference..."); + terminalLoggedRef.current = false; const preparedInferenceConfig = getPreparedInferenceConfig(inferenceConfig); @@ -137,6 +177,7 @@ function ModelInference({ isInferring, setIsInferring }) { getPath(inference.outputPath), checkpointPath, getConfigOriginPath(), + workflowId, ); console.log(res); await refreshInferenceLogs(); @@ -144,6 +185,20 @@ function ModelInference({ isInferring, setIsInferring }) { } catch (e) { console.log(e); setIsInferring(false); + if (!terminalLoggedRef.current && appendWorkflowEvent) { + terminalLoggedRef.current = true; + await appendWorkflowEvent({ + actor: "system", + event_type: "inference.failed", + stage: "inference", + summary: "Inference failed to start.", + payload: { + error: e.message || "unknown error", + outputPath: getPath(inference.outputPath), + checkpointPath, + }, + }); + } await refreshInferenceLogs(); setInferenceStatus( `Inference error: ${e.message || "Please check console for details."}`, diff --git a/client/src/views/ModelTraining.js b/client/src/views/ModelTraining.js index 312724d..220062d 100644 --- a/client/src/views/ModelTraining.js +++ b/client/src/views/ModelTraining.js @@ -12,9 +12,11 @@ import Configurator from "../components/Configurator"; import { applyInputPaths } from "../configSchema"; import RuntimeLogPanel from "../components/RuntimeLogPanel"; import { AppContext } from "../contexts/GlobalContext"; +import { useWorkflow } from "../contexts/WorkflowContext"; function ModelTraining() { const context = useContext(AppContext); + const workflowContext = useWorkflow(); const training = context.trainingState; const [isTraining, setIsTraining] = useState(false); const [trainingStatus, setTrainingStatus] = useState(""); @@ -161,6 +163,7 @@ function ModelTraining() { getPath(training.logPath) || getPath(training.outputPath), getPath(training.outputPath), getConfigOriginPath(), + workflowContext?.workflow?.id, ); console.log(res); await refreshTrainingLogs(); diff --git a/client/src/views/Views.js b/client/src/views/Views.js index e442adf..c6e5e06 100644 --- a/client/src/views/Views.js +++ b/client/src/views/Views.js @@ -16,6 +16,7 @@ import ModelInference from "./ModelInference"; import Monitoring from "./Monitoring"; import MaskProofreading from "./MaskProofreading"; import Chatbot from "../components/Chatbot"; +import { useWorkflow } from "../contexts/WorkflowContext"; const { Content } = Layout; @@ -39,6 +40,9 @@ const MODULE_ITEMS = [ function Views() { const [current, setCurrent] = useState("files"); const [visitedTabs, setVisitedTabs] = useState(new Set(["files"])); + const workflowContext = useWorkflow(); + const lastClientEffects = workflowContext?.lastClientEffects; + const consumeClientEffects = workflowContext?.consumeClientEffects; const [isChatOpen, setIsChatOpen] = useState(false); const [chatWidth, setChatWidth] = useState(560); const isResizing = useRef(false); @@ -78,6 +82,15 @@ function Views() { }; }, [resize, stopResizing]); + useEffect(() => { + const target = lastClientEffects?.navigate_to; + if (!target) return; + const targetKey = target === "model-training" ? "training" : target; + setCurrent(targetKey); + setVisitedTabs((prev) => new Set(prev).add(targetKey)); + consumeClientEffects?.(); + }, [lastClientEffects, consumeClientEffects]); + const renderTabContent = (key, component) => { if (!visitedTabs.has(key)) return null; return ( diff --git a/client/src/views/Views.test.js b/client/src/views/Views.test.js index 1f08242..979bbde 100644 --- a/client/src/views/Views.test.js +++ b/client/src/views/Views.test.js @@ -3,6 +3,10 @@ import { fireEvent, render, screen } from "@testing-library/react"; import Views from "./Views"; +jest.mock("../contexts/WorkflowContext", () => ({ + useWorkflow: () => null, +})); + jest.mock("antd", () => { const React = require("react"); @@ -41,6 +45,7 @@ jest.mock("@ant-design/icons", () => { DashboardOutlined: Icon, BugOutlined: Icon, MessageOutlined: Icon, + ProjectOutlined: Icon, }; }); @@ -49,7 +54,12 @@ jest.mock("./Visualization", () => () =>
Visualization Content
); jest.mock("./ModelTraining", () => () =>
Training Content
); jest.mock("./ModelInference", () => () =>
Inference Content
); jest.mock("./Monitoring", () => () =>
Monitoring Content
); -jest.mock("./MaskProofreading", () => () =>
Mask Proofreading Content
); +jest.mock("./mask-proofreading/MaskProofreading", () => () => ( +
Mask Proofreading Content
+)); +jest.mock("./project-manager/ProjectManager", () => () => ( +
Project Manager Content
+)); jest.mock("../components/Chatbot", () => () =>
Chatbot
); describe("Views", () => { diff --git a/client/src/views/Visualization.js b/client/src/views/Visualization.js index ab2c58e..4c2de86 100644 --- a/client/src/views/Visualization.js +++ b/client/src/views/Visualization.js @@ -7,11 +7,13 @@ import { } from "@ant-design/icons"; import { getNeuroglancerViewer } from "../api"; import UnifiedFileInput from "../components/UnifiedFileInput"; +import { useWorkflow } from "../contexts/WorkflowContext"; const { Title } = Typography; function Visualization(props) { const { viewers, setViewers } = props; + const workflowContext = useWorkflow(); const [activeKey, setActiveKey] = useState( viewers.length > 0 ? viewers[0].key : null, ); @@ -77,6 +79,7 @@ function Visualization(props) { imagePath, labelPath, scalesArray, + workflowContext?.workflow?.id, ); console.log("Current Viewer at ", res); diff --git a/client/src/views/ehtool/DetectionWorkflow.js b/client/src/views/ehtool/DetectionWorkflow.js index 2207f60..48ee921 100644 --- a/client/src/views/ehtool/DetectionWorkflow.js +++ b/client/src/views/ehtool/DetectionWorkflow.js @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useMemo, useRef } from "react"; +import React, { useContext, useState, useEffect, useMemo, useRef } from "react"; import { Layout, message, @@ -20,6 +20,8 @@ import InstanceNavigator from "./InstanceNavigator"; import ProofreadingEditor from "./ProofreadingEditor"; import SliceScheduler from "./SliceScheduler"; import { apiClient } from "../../api"; +import { AppContext } from "../../contexts/GlobalContext"; +import { useWorkflow } from "../../contexts/WorkflowContext"; const { Sider, Content } = Layout; const { Title, Text } = Typography; @@ -55,7 +57,14 @@ const SCRUB_IDLE_MS = parsePositiveInt( 120, ); -function DetectionWorkflow({ sessionId, setSessionId, refreshTrigger }) { +function DetectionWorkflow({ + sessionId, + setSessionId, + refreshTrigger, + workflowId, +}) { + const appContext = useContext(AppContext); + const workflowContext = useWorkflow(); const [projectName, setProjectName] = useState(""); const [totalLayers, setTotalLayers] = useState(0); const [instances, setInstances] = useState([]); @@ -108,6 +117,7 @@ function DetectionWorkflow({ sessionId, setSessionId, refreshTrigger }) { const [persistence, setPersistence] = useState(null); const [showExportModal, setShowExportModal] = useState(false); const [exportPath, setExportPath] = useState(""); + const [lastExportPath, setLastExportPath] = useState(""); const [exportingMasks, setExportingMasks] = useState(false); const [overwritingSource, setOverwritingSource] = useState(false); const lastPersistenceErrorRef = useRef(null); @@ -448,6 +458,7 @@ function DetectionWorkflow({ sessionId, setSessionId, refreshTrigger }) { dataset_path: datasetPath, mask_path: maskPath || null, project_name: projectName, + workflow_id: workflowId || null, }); setSessionId(response.data.session_id); @@ -1767,6 +1778,7 @@ function DetectionWorkflow({ sessionId, setSessionId, refreshTrigger }) { create_backup: true, }); setShowExportModal(false); + setLastExportPath(response.data.written_path); message.success(`Exported masks to ${response.data.written_path}`); if (response.data.backup_path) { message.info(`Backup created at ${response.data.backup_path}`); @@ -1805,6 +1817,7 @@ function DetectionWorkflow({ sessionId, setSessionId, refreshTrigger }) { if (response.data.backup_path) { message.info(`Backup created at ${response.data.backup_path}`); } + setLastExportPath(response.data.written_path); refreshPersistenceStatus(); } catch (error) { Modal.error({ @@ -1822,6 +1835,44 @@ function DetectionWorkflow({ sessionId, setSessionId, refreshTrigger }) { }); }; + const handleStageForRetraining = async () => { + const correctedMaskPath = lastExportPath || persistence?.last_export_path; + if (!correctedMaskPath) { + message.warning("Export corrected masks before staging retraining."); + return; + } + if (!workflowContext?.workflow?.id) { + message.warning("Workflow state is not available yet."); + return; + } + + try { + await workflowContext.updateWorkflow({ + stage: "retraining_staged", + corrected_mask_path: correctedMaskPath, + }); + await workflowContext.appendEvent({ + actor: "user", + event_type: "retraining.staged", + stage: "retraining_staged", + summary: "Staged corrected masks for retraining.", + payload: { + corrected_mask_path: correctedMaskPath, + ehtool_session_id: sessionId, + source: "proofreading_export", + }, + }); + if (appContext?.trainingState?.setInputLabel) { + appContext.trainingState.setInputLabel(correctedMaskPath); + } + message.success("Corrected masks staged for retraining."); + } catch (error) { + message.error( + getErrorMessage(error, "Failed to stage corrected masks for retraining"), + ); + } + }; + if (!sessionId) { return (
@@ -1834,6 +1885,7 @@ function DetectionWorkflow({ sessionId, setSessionId, refreshTrigger }) { sliderZ ?? viewState.zIndex, axisTotal || totalLayers, ); + const exportedMaskPath = lastExportPath || persistence?.last_export_path || ""; const reviewControls = activeInstance && instanceMode !== "none" ? ( @@ -1929,6 +1981,15 @@ function DetectionWorkflow({ sessionId, setSessionId, refreshTrigger }) { + {exportedMaskPath && ( + + )} Date: Mon, 13 Apr 2026 16:55:35 -0400 Subject: [PATCH 5/6] feat(chat-ui): add timeline filters and proposal cards (cherry picked from commit ecc24938038812241b2dfe4f774252b2a39dd700) --- .../src/__tests__/agentProposalCards.test.js | 78 ++++++++ .../__tests__/workflowTimelineFilters.test.js | 34 ++++ client/src/api.js | 36 ++++ client/src/components/WorkflowTimeline.js | 173 ++++++++++++++---- .../src/components/WorkflowTimeline.test.js | 90 +++++++++ .../src/components/chat/AgentProposalCard.js | 60 ++++++ client/src/contexts/WorkflowContext.js | 44 +++++ client/src/contexts/WorkflowContext.test.js | 23 +++ .../contexts/workflow/proposalCardConfig.js | 92 ++++++++++ .../src/contexts/workflow/timelineFilters.js | 42 +++++ 10 files changed, 635 insertions(+), 37 deletions(-) create mode 100644 client/src/__tests__/agentProposalCards.test.js create mode 100644 client/src/__tests__/workflowTimelineFilters.test.js create mode 100644 client/src/components/chat/AgentProposalCard.js create mode 100644 client/src/contexts/workflow/proposalCardConfig.js create mode 100644 client/src/contexts/workflow/timelineFilters.js diff --git a/client/src/__tests__/agentProposalCards.test.js b/client/src/__tests__/agentProposalCards.test.js new file mode 100644 index 0000000..afd415e --- /dev/null +++ b/client/src/__tests__/agentProposalCards.test.js @@ -0,0 +1,78 @@ +import React from "react"; +import { fireEvent, render, screen } from "@testing-library/react"; + +import AgentProposalCard from "../components/chat/AgentProposalCard"; + +jest.mock("antd", () => ({ + Button: ({ children, ...props }) => ( + + ), + Space: ({ children }) =>
{children}
, + Tag: ({ children }) => {children}, + Typography: { + Text: ({ children }) => {children}, + }, +})); + +describe("AgentProposalCard", () => { + it("renders hotspot proposal fields and approve/reject actions", () => { + const onApprove = jest.fn(); + const onReject = jest.fn(); + const proposal = { + type: "prioritize_failure_hotspots", + rationale: "Focus annotation where the model fails most often.", + target_dataset: "set-a", + hotspots: ["z:11", "z:12"], + priority_metric: "error_rate", + min_failure_rate: 0.2, + }; + + render( + , + ); + + expect(screen.getByText("Prioritize Failure Hotspots")).toBeTruthy(); + expect(screen.getByText("set-a")).toBeTruthy(); + + fireEvent.click(screen.getByRole("button", { name: "Approve" })); + fireEvent.click(screen.getByRole("button", { name: "Reject" })); + + expect(onApprove).toHaveBeenCalledWith(proposal); + expect(onReject).toHaveBeenCalledWith(proposal); + }); + + it("supports correction-impact cards with compact rationale", () => { + const proposal = { + proposal_type: "preview_correction_impact", + rationale: "A".repeat(200), + target_metric: "f1", + expected_delta: "+0.05", + sample_size: 128, + confidence: "high", + }; + + render(); + + expect(screen.getByText("Preview Correction Impact")).toBeTruthy(); + expect(screen.getByText(/A{157}…/)).toBeTruthy(); + expect(screen.getByText("+0.05")).toBeTruthy(); + }); + + it("renders fallback proposal content", () => { + render( + , + ); + + expect(screen.getByText("Agent Proposal")).toBeTruthy(); + expect(screen.getByText("Keep behavior stable")).toBeTruthy(); + expect(screen.getByText("bar")).toBeTruthy(); + }); +}); diff --git a/client/src/__tests__/workflowTimelineFilters.test.js b/client/src/__tests__/workflowTimelineFilters.test.js new file mode 100644 index 0000000..3b60804 --- /dev/null +++ b/client/src/__tests__/workflowTimelineFilters.test.js @@ -0,0 +1,34 @@ +import { + DEFAULT_TIMELINE_FILTERS, + filterTimelineEvents, +} from "../contexts/workflow/timelineFilters"; + +const EVENTS = [ + { id: "1", actor: "user", event_type: "dataset.loaded" }, + { id: "2", actor: "agent", event_type: "agent.proposal_created" }, + { id: "3", actor: "system", event_type: "inference.completed" }, + { id: "4", actor: "agent", event_type: "agent.proposal_approved" }, +]; + +describe("workflow timeline filters", () => { + it("preserves the full timeline by default", () => { + const visible = filterTimelineEvents(EVENTS, DEFAULT_TIMELINE_FILTERS); + expect(visible).toHaveLength(EVENTS.length); + expect(visible).toBe(EVENTS); + }); + + it("filters by actor and event type combinations", () => { + expect(filterTimelineEvents(EVENTS, { actor: "agent", eventType: "" })).toEqual([ + EVENTS[1], + EVENTS[3], + ]); + + expect( + filterTimelineEvents(EVENTS, { actor: "agent", eventType: "approved" }), + ).toEqual([EVENTS[3]]); + + expect( + filterTimelineEvents(EVENTS, { actor: "all", eventType: "proposal" }), + ).toEqual([EVENTS[1], EVENTS[3]]); + }); +}); diff --git a/client/src/api.js b/client/src/api.js index 18df605..c90e416 100644 --- a/client/src/api.js +++ b/client/src/api.js @@ -500,6 +500,42 @@ export async function listWorkflowEvents(workflowId) { } } +export async function getWorkflowHotspots(workflowId) { + try { + const res = await apiClient.get(`/api/workflows/${workflowId}/hotspots`); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function getWorkflowImpactPreview(workflowId) { + try { + const res = await apiClient.get(`/api/workflows/${workflowId}/impact-preview`); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function getWorkflowMetrics(workflowId) { + try { + const res = await apiClient.get(`/api/workflows/${workflowId}/metrics`); + return res.data; + } catch (error) { + handleError(error); + } +} + +export async function exportWorkflowBundle(workflowId) { + try { + const res = await apiClient.post(`/api/workflows/${workflowId}/export-bundle`); + return res.data; + } catch (error) { + handleError(error); + } +} + export async function appendWorkflowEvent(workflowId, event) { try { const res = await apiClient.post(`/api/workflows/${workflowId}/events`, event); diff --git a/client/src/components/WorkflowTimeline.js b/client/src/components/WorkflowTimeline.js index b667687..f2569f9 100644 --- a/client/src/components/WorkflowTimeline.js +++ b/client/src/components/WorkflowTimeline.js @@ -1,7 +1,13 @@ -import React from "react"; -import { Button, Empty, List, Space, Tag, Typography } from "antd"; -import { CheckOutlined, CloseOutlined } from "@ant-design/icons"; +import React, { useMemo, useState } from "react"; +import { Button, Empty, Input, List, Select, Space, Tag, Typography } from "antd"; import { useWorkflow } from "../contexts/WorkflowContext"; +import AgentProposalCard from "./chat/AgentProposalCard"; +import { + DEFAULT_TIMELINE_FILTERS, + filterTimelineEvents, + normalizeTimelineFilters, + TIMELINE_ACTOR_OPTIONS, +} from "../contexts/workflow/timelineFilters"; const { Text } = Typography; @@ -14,6 +20,12 @@ const STAGE_LABELS = { evaluation: "Evaluation", }; +const SEVERITY_COLORS = { + low: "default", + medium: "orange", + high: "red", +}; + function formatEventTime(value) { if (!value) return ""; const date = new Date(value); @@ -24,14 +36,22 @@ function formatEventTime(value) { function WorkflowTimeline({ limit = 8 }) { const workflowContext = useWorkflow(); const workflow = workflowContext?.workflow; - const events = workflowContext?.events || []; + const events = workflowContext?.events; + const hotspots = workflowContext?.hotspots || []; + const impactPreview = workflowContext?.impactPreview; + const refreshInsights = workflowContext?.refreshInsights; const approveAgentAction = workflowContext?.approveAgentAction; const rejectAgentAction = workflowContext?.rejectAgentAction; + const [filters, setFilters] = useState(DEFAULT_TIMELINE_FILTERS); - if (!workflowContext) return null; - - const visibleEvents = events.slice(-limit).reverse(); + const reversedEvents = useMemo(() => [...(events || [])].reverse(), [events]); + const visibleEvents = useMemo(() => { + return filterTimelineEvents(reversedEvents, filters).slice(0, limit); + }, [reversedEvents, filters, limit]); const stageLabel = STAGE_LABELS[workflow?.stage] || workflow?.stage || "Loading"; + const topHotspot = hotspots[0] || null; + + if (!workflowContext) return null; return (
- - {workflow?.title || "Segmentation Workflow"} - + + + {workflow?.title || "Segmentation Workflow"} + + + + + + {(topHotspot || impactPreview) && ( +
+ {topHotspot && ( + + + Hotspot + + + {topHotspot.severity} + + {topHotspot.summary} + + )} + {impactPreview && ( +
+ + Impact + + + {` ${impactPreview.summary} (confidence: ${impactPreview.confidence})`} + +
+ )} +
+ )} + + + + setFilters((prev) => + normalizeTimelineFilters({ ...prev, eventType: event.target.value }), + ) + } + /> + {visibleEvents.length === 0 ? ( @@ -85,22 +183,9 @@ function WorkflowTimeline({ limit = 8 }) { actions={ isPendingProposal ? [ - , - , + + Needs review + , ] : [] } @@ -117,17 +202,31 @@ function WorkflowTimeline({ limit = 8 }) { } description={ - - - {event.actor} - - - {event.event_type} - - - {formatEventTime(event.created_at)} - - +
+ + + {event.actor} + + + {event.event_type} + + + {formatEventTime(event.created_at)} + + + {isPendingProposal && ( + approveAgentAction?.(event.id)} + onReject={() => rejectAgentAction?.(event.id)} + /> + )} +
} /> diff --git a/client/src/components/WorkflowTimeline.test.js b/client/src/components/WorkflowTimeline.test.js index f3a40c3..4792681 100644 --- a/client/src/components/WorkflowTimeline.test.js +++ b/client/src/components/WorkflowTimeline.test.js @@ -40,8 +40,22 @@ jest.mock("antd", () => { {children} ), + Input: ({ ...props }) => , Empty, List, + Select: ({ options = [], value, onChange, ...props }) => ( + + ), Space: ({ children }) => {children}, Tag: ({ children }) => {children}, Typography: { @@ -77,6 +91,9 @@ function renderTimeline(overrides = {}) { const value = { workflow, events: [], + hotspots: [], + impactPreview: null, + refreshInsights: jest.fn(), approveAgentAction: jest.fn(), rejectAgentAction: jest.fn(), ...overrides, @@ -135,4 +152,77 @@ describe("WorkflowTimeline", () => { expect(approveAgentAction).toHaveBeenCalledWith(7); expect(rejectAgentAction).toHaveBeenCalledWith(7); }); + + it("renders hotspot + impact summary and refresh action", () => { + const refreshInsights = jest.fn(); + renderTimeline({ + refreshInsights, + hotspots: [ + { + rank: 1, + region_key: "z:12", + score: 9.5, + severity: "high", + summary: "z:12 has repeated failures.", + recommended_action: "Open proofreading.", + evidence: {}, + }, + ], + impactPreview: { + confidence: "high", + summary: "Corrections are ready to stage.", + }, + }); + + expect(screen.getByText("z:12 has repeated failures.")).toBeTruthy(); + expect( + screen.getByText( + /Corrections are ready to stage\. \(confidence: high\)/, + ), + ).toBeTruthy(); + + fireEvent.click(screen.getByText("Refresh Insights")); + expect(refreshInsights).toHaveBeenCalledTimes(1); + }); + + it("filters timeline entries by actor and event type", () => { + renderTimeline({ + events: [ + { + id: 1, + actor: "user", + event_type: "dataset.loaded", + stage: "proofreading", + summary: "Loaded dataset.", + approval_status: "not_required", + created_at: "2026-04-12T12:00:00Z", + }, + { + id: 2, + actor: "agent", + event_type: "agent.proposal_created", + stage: "proofreading", + summary: "Stage corrected masks.", + approval_status: "pending", + payload: { + action: "stage_retraining_from_corrections", + params: { corrected_mask_path: "/tmp/corrected.tif" }, + }, + created_at: "2026-04-12T12:02:00Z", + }, + ], + }); + + fireEvent.change(screen.getByLabelText("Actor filter"), { + target: { value: "agent" }, + }); + expect(screen.queryByText("Loaded dataset.")).toBeNull(); + expect(screen.getAllByText("Stage corrected masks.").length).toBeGreaterThan(0); + + fireEvent.change(screen.getByLabelText("Event type filter"), { + target: { value: "proposal" }, + }); + expect(screen.getAllByText("Stage corrected masks.").length).toBeGreaterThan(0); + expect(screen.getByText("Stage Retraining From Corrections")).toBeTruthy(); + }); }); diff --git a/client/src/components/chat/AgentProposalCard.js b/client/src/components/chat/AgentProposalCard.js new file mode 100644 index 0000000..d6f06ab --- /dev/null +++ b/client/src/components/chat/AgentProposalCard.js @@ -0,0 +1,60 @@ +import React from "react"; +import { Button, Space, Tag, Typography } from "antd"; + +import { getProposalCardContent } from "../../contexts/workflow/proposalCardConfig"; + +const { Text } = Typography; + +function AgentProposalCard({ proposal, onApprove, onReject, disabled = false }) { + const content = getProposalCardContent(proposal || {}); + + return ( +
+ + + + {content.title} + + {content.type} + + {content.rationale} + {content.fields?.length > 0 && ( +
+ {content.fields.map((field) => ( + + + {field.label} + + {field.value} + + ))} +
+ )} + + + + +
+
+ ); +} + +export default AgentProposalCard; diff --git a/client/src/contexts/WorkflowContext.js b/client/src/contexts/WorkflowContext.js index bb267f5..fc9f23d 100644 --- a/client/src/contexts/WorkflowContext.js +++ b/client/src/contexts/WorkflowContext.js @@ -11,6 +11,8 @@ import { approveAgentAction as approveAgentActionApi, createAgentAction, getCurrentWorkflow, + getWorkflowHotspots, + getWorkflowImpactPreview, listWorkflowEvents, queryWorkflowAgent, rejectAgentAction as rejectAgentActionApi, @@ -28,6 +30,8 @@ export function WorkflowProvider({ children }) { const appContext = useContext(AppContext); const [workflow, setWorkflow] = useState(null); const [events, setEvents] = useState([]); + const [hotspots, setHotspots] = useState([]); + const [impactPreview, setImpactPreview] = useState(null); const [loading, setLoading] = useState(true); const [lastClientEffects, setLastClientEffects] = useState(null); @@ -53,10 +57,47 @@ export function WorkflowProvider({ children }) { return nextEvents || []; }, [workflow?.id]); + const refreshInsights = useCallback(async () => { + if (!workflow?.id) { + setHotspots([]); + setImpactPreview(null); + return null; + } + try { + const [hotspotData, impactData] = await Promise.all([ + getWorkflowHotspots(workflow.id), + getWorkflowImpactPreview(workflow.id), + ]); + setHotspots(hotspotData?.hotspots || []); + setImpactPreview(impactData || null); + return { + hotspots: hotspotData?.hotspots || [], + impactPreview: impactData || null, + }; + } catch (_error) { + return null; + } + }, [workflow?.id]); + useEffect(() => { refreshWorkflow(); }, [refreshWorkflow]); + useEffect(() => { + if (!workflow?.id) { + setHotspots([]); + setImpactPreview(null); + return; + } + refreshInsights(); + }, [ + workflow?.id, + workflow?.stage, + workflow?.corrected_mask_path, + events.length, + refreshInsights, + ]); + const updateWorkflow = useCallback( async (patch) => { if (!workflow?.id) return null; @@ -147,10 +188,13 @@ export function WorkflowProvider({ children }) { value={{ workflow, events, + hotspots, + impactPreview, loading, lastClientEffects, refreshWorkflow, refreshEvents, + refreshInsights, updateWorkflow, appendEvent, proposeAgentAction, diff --git a/client/src/contexts/WorkflowContext.test.js b/client/src/contexts/WorkflowContext.test.js index 3bac823..5754983 100644 --- a/client/src/contexts/WorkflowContext.test.js +++ b/client/src/contexts/WorkflowContext.test.js @@ -6,6 +6,8 @@ import { WorkflowProvider, useWorkflow } from "./WorkflowContext"; import { approveAgentAction, getCurrentWorkflow, + getWorkflowHotspots, + getWorkflowImpactPreview, listWorkflowEvents, } from "../api"; @@ -14,6 +16,8 @@ jest.mock("../api", () => ({ appendWorkflowEvent: jest.fn(), createAgentAction: jest.fn(), getCurrentWorkflow: jest.fn(), + getWorkflowHotspots: jest.fn(), + getWorkflowImpactPreview: jest.fn(), listWorkflowEvents: jest.fn(), queryWorkflowAgent: jest.fn(), rejectAgentAction: jest.fn(), @@ -32,6 +36,8 @@ function Probe() {
{workflowContext.workflow?.stage || "loading"}
{workflowContext.events.map((event) => event.event_type).join(",")}
+
{workflowContext.hotspots?.[0]?.summary || "no-hotspot"}
+
{workflowContext.impactPreview?.confidence || "no-impact"}