From 6ad2acaad02347417233fe3d99b90f82b04863fe Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 21 May 2026 16:42:33 +0200 Subject: [PATCH 1/6] Spec interactive region selection + add pixi interactive env (ipympl v0) Add plans/interactive-selection.md documenting the v0 design for sdata.pl.interactive(...): in-notebook selector widget that draws a region on a spatialdata-plot canvas and persists it back into the SpatialData object as a ShapesModel. Includes resolved Q1-Q4, coordinate- system rules, downsampling strategy, persistence policy, and a 12-task implementation queue. Add a pixi `interactive` dep-group (ipympl, ipywidgets, squidpy) and a new `dev-interactive-py313` environment for prototyping. Register a dedicated `sdata-plot-interactive` kernel-install task to avoid the existing `pixi-dev` kernel name collision. Rewrite the broken [tool.pixi] inline-dotted block to explicit table headers ([tool.pixi.workspace], etc.) so pixi 0.54.2 actually loads the manifest. This commit records the ipympl-based prototype iteration. The notebook prototype (Sandbox.ipynb in lustre, not tracked here) revealed that websocket-streamed PNG frames are too laggy over SSH for full-slide interactive drawing; the next iteration switches to Plotly's client-side draw tools while keeping the same spec and task queue. Co-Authored-By: Claude Opus 4.7 (1M context) --- plans/interactive-selection.md | 186 +++++++++++++++++++++++++++++++++ pyproject.toml | 60 +++++++---- 2 files changed, 226 insertions(+), 20 deletions(-) create mode 100644 plans/interactive-selection.md diff --git a/plans/interactive-selection.md b/plans/interactive-selection.md new file mode 100644 index 00000000..f0e56540 --- /dev/null +++ b/plans/interactive-selection.md @@ -0,0 +1,186 @@ +# Interactive region selection in spatialdata-plot + +Status: spec (v0). Materialized from session handoff on 2026-05-21. + +## Goal + +A minimal, in-notebook (Jupyter / VSCode-Remote-SSH) widget that lets the user +draw a region on a spatialdata-plot canvas and persist it back into the +SpatialData object as a ShapesModel element. Works over an SSH bridge to a +SLURM compute node. No napari, no desktop GUI. + +## Confirmed design decisions + +- Output: persisted ShapesModel written back to the on-disk zarr via + `sdata.write_element`. Survives kernel restarts. +- Selector shapes in v0: rectangle, polygon (click vertices), lasso (freehand). +- Scale handling: auto-downsample on the fly. Pyramid-aware when available; + `dask.coarsen` fallback when not. +- Layers beneath the selector in v0: images only. Selector attaches to the + `Axes` returned by the existing `sdata.pl.render_images().pl.show()` pipeline + — we reuse the existing canvas, no duplicate render path. +- Backend: `%matplotlib widget` (ipympl) + `matplotlib.widgets.{Rectangle, + Polygon,Lasso}Selector`. Pure server-side render, PNG frames over websocket. + No bokeh/datashader. + +## Resolved questions (locked 2026-05-21, task #1) + +- **Q1 — Channel/contrast widgets**: **No live widgets in v0.** `channel=` and + `clims=` remain optional kwargs that forward to `render_images`. No + ipywidgets-driven controls. Widget toolbar deferred to v1. +- **Q2 — Auto-redraw on zoom**: **v1.** v0 renders once at the chosen scale; + `xlim_changed`/`ylim_changed` does not re-pick pyramid level. Static extent + ships sooner. +- **Q3 — Selector kind switching**: **One per call.** `selector=` is fixed at + session construction; no mid-session switching. Switchable kinds deferred to + v1. +- **Q4 — `name=` default**: **Required.** No default; omitting `name=` raises. + Keeps persisted element names intentional and zarr listings legible. + +## Public API sketch + +```python +import spatialdata_plot # registers .pl + +session = sdata.pl.interactive( + element="he_image", + coordinate_system="global", + channel=[0, 1, 2], # optional + clims=(0, 30000), # optional + selector="polygon", # 'rectangle' | 'polygon' | 'lasso' + name="tumor_region", + overwrite=False, + persist=True, + max_render_pixels=2_000_000, +) +session.show() # returns the ipympl Figure +# user draws on canvas, double-click / release to commit +sdata["tumor_region"] # ShapesModel +sub = sdata.query.polygon(sdata, sdata["tumor_region"]) +``` + +## Module layout + +``` +src/spatialdata_plot/pl/interactive/ + __init__.py # exports InteractiveSession + _session.py # InteractiveSession class, public entrypoint + _render.py # thin wrapper around existing render_images + _downsample.py # pyramid-aware scale picker; in-memory coarsen + _selectors.py # RectangleAdapter, PolygonAdapter, LassoAdapter + _commit.py # vertices → CS-correct shapely → ShapesModel + _persist.py # write_element + overwrite/timestamp policy + +tests/test_interactive/ + test_commit.py + test_downsample.py + test_selectors_headless.py +``` + +`sdata.pl.interactive(...)` becomes a method on `PlotAccessor` in +`src/spatialdata_plot/_accessor.py`, returning an `InteractiveSession`. + +## Coordinate-system rules (highest-risk surface) + +1. Session is bound to ONE coordinate system at construction. +2. Render is in that CS; axes coords on the canvas equal coords in the CS + (1:1). +3. On commit, vertices are already in the rendered CS — no transform needed + for the selection itself. +4. The committed ShapesModel is registered with `{cs_name: Identity()}`. +5. Cross-CS selection is the user's job downstream. Not v0. + +Avoids the classic double-applied-transform bug. + +## Downsampling + +`_downsample.pick_scale(image, bbox, max_pixels) -> (level_or_factor, array)` + +- `MultiscaleSpatialImage`: walk scales coarse→fine, pick finest within budget. +- Single-scale: `dask.array.coarsen` with integer factor, warn once. +- Static extent in v0. Auto-redraw on `xlim_changed` is v1. +- Default `max_render_pixels ≈ 2M` (~1500×1500), tuned for ipympl PNG over SSH. + +## Selector adapters + +| kind | matplotlib class | commit trigger | +|-------------|-----------------------|-------------------------------| +| rectangle | `RectangleSelector` | mouse release | +| polygon | `PolygonSelector` | close (double-click / enter) | +| lasso | `LassoSelector` | mouse release | + +Lasso vertices simplified via `shapely.simplify(tolerance=0.5px)` before +persist. + +## Persistence policy + +- `sdata.path` set → `sdata.write_element(name)` on every commit. +- Not zarr-backed → warn once, keep in memory. +- `overwrite=False` default. Collision → rename to `"_"`. +- `session.commits` list tracks names committed this session. + +## Risks (pre-mitigated) + +1. CS mistakes → identity transform + unit tests. +2. Image too large → `max_render_pixels` hard cap with clear error. +3. ipympl flakiness in VSCode → documented fallback to browser-Jupyter via + `ssh -L 8888:localhost:8888 node`. +4. Walltime kill → auto-persist every commit. +5. Lasso 10k vertices → `shapely.simplify`. +6. Concurrent zarr writers → documented, no locking in v0. +7. 3D / z-stacks → refuse with same error as static render (commit 3ebefe1). +8. Auto-zoom redraw not in v0 → static extent ships first. + +## Test strategy + +- Unit: `_commit` (synthetic vertices → ShapesModel correctness). +- Unit: `_downsample` (scale picker correctness on synthetic arrays). +- Headless: `_selectors` via programmatic `_press`/`_onmove`/`_release`. +- NO visual tests in v0. CI does not need a live canvas. +- Manual checklist in PR description for the canvas itself. + +## Dependencies + +`[project.dependencies]`: + +- `ipympl` (NEW) +- `ipywidgets` (NEW or pin existing transitive) +- `shapely` (already transitive via geopandas) +- `geopandas` (already transitive via spatialdata) + +Only `ipympl` is genuinely new. + +## v1 roadmap (after v0 ships) + +1. Auto-downsample on zoom (pyramid-aware redraw on `xlim_changed`). +2. Channel + contrast widget controls in the figure toolbar. +3. Labels overlay (segmentation visible during selection). +4. Multiple selectors per session; switchable kinds. +5. Datashader path for points-heavy elements. + +## Task queue + +1. Resolve spec open questions Q1–Q4 +2. Add ipympl dep + pixi interactive feature +3. Scaffold `pl/interactive` submodule +4. Wire `sdata.pl.interactive` entrypoint +5. Implement `_commit`: vertices → ShapesModel +6. Implement `_persist`: zarr write policy +7. Implement `_downsample`: scale picker + warn +8. Implement `_render`: image render to ax +9. Implement `_selectors`: Rectangle/Polygon/Lasso adapters +10. Wire `InteractiveSession` end-to-end +11. Manual end-to-end test on cluster +12. Document feature in module docstring + README + +## Operating rules + +- Repo CLAUDE.md rules apply: plan-first for multi-file work, no drive-by + refactors, run pixi-defined tasks (lint/format/test) before commits, no + pre-commit / no visual tests locally (CI only). +- Pixi only. No venv/pip. `dev-py313` environment. +- Don't stage with `-A`; stage only what's touched. +- Human drives the actual ipympl canvas; agent cannot see it. Agent can + drive a parallel headless kernel on the same node for non-UI checks. +- If task #1 answers change the spec materially, update this file before + starting #2. diff --git a/pyproject.toml b/pyproject.toml index 7ebf7057..aa557005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,12 @@ doc = [ "sphinxcontrib-katex", "sphinxext-opengraph", ] +interactive = [ + "ipykernel", + "ipympl", + "ipywidgets", + "squidpy", +] [tool.hatch] build.hooks.vcs.version-file = "_version.py" @@ -86,29 +92,43 @@ envs.hatch-test.scripts.cov-report = [ "coverage report", "coverage xml -o cover metadata.allow-direct-references = true version.source = "vcs" -[tool.pixi] -workspace.channels = [ "conda-forge" ] -workspace.platforms = [ "linux-64", "osx-arm64" ] -dependencies.python = ">=3.11" -pypi-dependencies.spatialdata-plot = { path = ".", editable = true } -tasks.format = "ruff format ." -tasks.kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"' -tasks.lab = "jupyter lab" -tasks.lint = "ruff check ." -tasks.pre-commit-install = "pre-commit install" -tasks.pre-commit-run = "pre-commit run --all-files" -tasks.test = "pytest -v --color=yes --tb=short --durations=10" +[tool.pixi.workspace] +channels = [ "conda-forge" ] +platforms = [ "linux-64", "osx-arm64" ] + +[tool.pixi.dependencies] +python = ">=3.11" + +[tool.pixi.pypi-dependencies] +spatialdata-plot = { path = ".", editable = true } + +[tool.pixi.tasks] +format = "ruff format ." +kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"' +kernel-install-interactive = 'python -m ipykernel install --user --name sdata-plot-interactive --display-name "sdata-plot (interactive)"' +lab = "jupyter lab" +lint = "ruff check ." +pre-commit-install = "pre-commit install" +pre-commit-run = "pre-commit run --all-files" +test = "pytest -v --color=yes --tb=short --durations=10" + # for gh-actions -feature.py311.dependencies.python = "3.11.*" -feature.py313.dependencies.python = "3.13.*" +[tool.pixi.feature.py311.dependencies] +python = "3.11.*" + +[tool.pixi.feature.py313.dependencies] +python = "3.13.*" + +[tool.pixi.environments] # 3.13 lane -environments.default = { features = [ "py313" ], solve-group = "py313" } +default = { features = [ "py313" ], solve-group = "py313" } # 3.11 lane (for gh-actions) -environments.dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" } -environments.dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" } -environments.docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" } -environments.docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" } -environments.test-py313 = { features = [ "test", "py313" ], solve-group = "py313" } +dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" } +dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" } +dev-interactive-py313 = { features = [ "dev", "test", "interactive", "py313" ], solve-group = "py313" } +docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" } +docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" } +test-py313 = { features = [ "test", "py313" ], solve-group = "py313" } [tool.ruff] line-length = 120 From a002ffbbd48a13f32f318db4ae8f3d0fca726b51 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 21 May 2026 19:17:59 +0200 Subject: [PATCH 2/6] Add anywidget + plotly to interactive feature for client-side drawing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add anywidget and plotly>=5.20,<6 to the pixi interactive dep-group so the prototype notebook can render a custom HTML5/SVG drawing widget. anywidget is the canonical path for traitlet-based widget sync in VSCode-Remote; plotly is pinned to 5.x because its 6.0 anywidget-backed FigureWidget does not relay client-side relayout events back to Python (so layout.shapes never syncs there). The Sandbox.ipynb prototype itself lives outside this repo (/home/.../lustre/projects/spatialdata-plot/), but its current state implements a working anywidget-based draw canvas: pure client-side SVG drawing (rectangle drag, polygon click-then-Close-polygon, lasso freehand drag), shapes pushed back via the `shapes` traitlet, pixel→CS coordinate mapping that respects matplotlib's origin='upper' image axis, multi-shape commit per Save, and an explicit "Write last to disk" button for persistence. Sandbox.anywidget-v0.ipynb is preserved alongside as a reference snapshot before optimization. Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index aa557005..62f2d326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,9 +62,13 @@ doc = [ "sphinxext-opengraph", ] interactive = [ + "anywidget", "ipykernel", "ipympl", "ipywidgets", + # pinned to 5.x: plotly 6's anywidget-backed FigureWidget doesn't relay + # client-side draw events back to Python, so layout.shapes never syncs. + "plotly>=5.20,<6", "squidpy", ] From 2e218c477cd50aa7b1652f64a4881568bb7d353c Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 21 May 2026 21:28:11 +0200 Subject: [PATCH 3/6] =?UTF-8?q?Add=20sdata.pl.annotate()=20=E2=80=94=20int?= =?UTF-8?q?eractive=20region=20selection=20via=20anywidget?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Productionises the Sandbox.ipynb prototype as a user-facing method on PlotAccessor. Public surface is a single function: sdata.pl.annotate(coordinate_system, element, *, persist=True) -> None Both args are required positional. The function validates that the image element is registered in the given coordinate system, renders it to a PNG, constructs an internal _InteractiveSession with anywidget-driven drawing tools (rectangle / polygon / lasso), and displays the widget. Drawn shapes are written into sdata.shapes[name] on click of the Save button; the optional "Write to disk" button persists via sdata.write_element. Module layout (src/spatialdata_plot/pl/interactive/): - _canvas.py DrawCanvas anywidget class - static/draw_canvas.js ESM module read from disk by anywidget (HMR-friendly) - _render.py render_to_png: sdata.pl → PNG + ax extent - _commit.py pixel-coord shape → CS-coord shapely Polygon → ShapesModel - _persist.py commit_to_memory + persist_to_disk (collision policy) - _session.py _InteractiveSession orchestrating the widget The new optional extra `interactive` (anywidget, ipykernel, ipywidgets) gates this feature behind a clear ImportError when missing: pip install 'spatialdata-plot[interactive]' The prototype iteration explored ipympl (rejected: PNG-over-websocket latency unusable over SSH) and plotly's FigureWidget (rejected: client- side relayout events don't sync back to Python in VSCode-Remote, plus plotly 6's anywidget-backed FigureWidget broke the comm path entirely). The custom anywidget approach was the only architecture that worked reliably over SSH while staying responsive. Drawing UX: - Tools: rect (drag), polygon (click + snap-close), lasso (drag freehand) - Wheel zoom, shift-drag pan, alt-click shape to delete - Ctrl+Z undo, R/P/L tool shortcuts, F fit view, Enter close polygon - Multi-shape bundling: each Save commits all canvas shapes as one ShapesModel with multiple rows under a single name Tests cover the unit surface (pixel→CS conversion, ShapesModel transform registration, render-to-PNG correctness, commit/persist policy, widget smoke). Spec at plans/interactive-selection.md updated to document the architectural pivot from the original ipympl approach. Co-Authored-By: Claude Opus 4.7 (1M context) --- plans/interactive-selection.md | 152 ++++--- pyproject.toml | 21 +- src/spatialdata_plot/pl/basic.py | 68 ++++ .../pl/interactive/__init__.py | 10 + .../pl/interactive/_canvas.py | 49 +++ .../pl/interactive/_commit.py | 66 +++ .../pl/interactive/_persist.py | 38 ++ .../pl/interactive/_render.py | 49 +++ .../pl/interactive/_session.py | 275 +++++++++++++ .../pl/interactive/static/draw_canvas.js | 380 ++++++++++++++++++ tests/test_interactive/__init__.py | 0 tests/test_interactive/test_canvas.py | 44 ++ tests/test_interactive/test_commit.py | 78 ++++ tests/test_interactive/test_persist.py | 57 +++ tests/test_interactive/test_render.py | 39 ++ 15 files changed, 1267 insertions(+), 59 deletions(-) create mode 100644 src/spatialdata_plot/pl/interactive/__init__.py create mode 100644 src/spatialdata_plot/pl/interactive/_canvas.py create mode 100644 src/spatialdata_plot/pl/interactive/_commit.py create mode 100644 src/spatialdata_plot/pl/interactive/_persist.py create mode 100644 src/spatialdata_plot/pl/interactive/_render.py create mode 100644 src/spatialdata_plot/pl/interactive/_session.py create mode 100644 src/spatialdata_plot/pl/interactive/static/draw_canvas.js create mode 100644 tests/test_interactive/__init__.py create mode 100644 tests/test_interactive/test_canvas.py create mode 100644 tests/test_interactive/test_commit.py create mode 100644 tests/test_interactive/test_persist.py create mode 100644 tests/test_interactive/test_render.py diff --git a/plans/interactive-selection.md b/plans/interactive-selection.md index f0e56540..f2f9c150 100644 --- a/plans/interactive-selection.md +++ b/plans/interactive-selection.md @@ -16,13 +16,29 @@ SLURM compute node. No napari, no desktop GUI. - Selector shapes in v0: rectangle, polygon (click vertices), lasso (freehand). - Scale handling: auto-downsample on the fly. Pyramid-aware when available; `dask.coarsen` fallback when not. -- Layers beneath the selector in v0: images only. Selector attaches to the - `Axes` returned by the existing `sdata.pl.render_images().pl.show()` pipeline - — we reuse the existing canvas, no duplicate render path. -- Backend: `%matplotlib widget` (ipympl) + `matplotlib.widgets.{Rectangle, - Polygon,Lasso}Selector`. Pure server-side render, PNG frames over websocket. +- Layers in v0: images only. The image is rendered once via the existing + `sdata.pl.render_images().pl.show()` pipeline into a matplotlib figure, + exported to PNG, and laid under a client-side drawing canvas. +- Backend: **custom anywidget** with HTML5/SVG drawing tools (rectangle, + polygon, freehand-lasso). All drawing happens in the browser; shape + geometry is reported back to Python via traitlet sync. Image is sent + once as a base64 data URL; mouse moves never round-trip the kernel. No bokeh/datashader. +### Why anywidget, not ipympl or plotly + +The original spec called for `%matplotlib widget` (ipympl). The prototype +revealed two showstoppers over SSH: +1. **ipympl streams PNG frames per mouse-move** over websocket — every drag + incurs SSH round-trip latency, making freehand drawing unusable. +2. **plotly's `FigureWidget`** has broken two-way shape sync in + VSCode-Remote-SSH (regardless of plotly 5 vs 6 — different bugs each). + +A small (~250-line) anywidget with traitlet-synced shape geometry was the +only architecture that worked reliably in VSCode-Remote and produced +responsive drawing. The image render still uses sdata-plot's matplotlib +pipeline; we just don't drive interaction through it. + ## Resolved questions (locked 2026-05-21, task #1) - **Q1 — Channel/contrast widgets**: **No live widgets in v0.** `channel=` and @@ -43,42 +59,59 @@ SLURM compute node. No napari, no desktop GUI. import spatialdata_plot # registers .pl session = sdata.pl.interactive( - element="he_image", - coordinate_system="global", - channel=[0, 1, 2], # optional - clims=(0, 30000), # optional - selector="polygon", # 'rectangle' | 'polygon' | 'lasso' - name="tumor_region", - overwrite=False, - persist=True, - max_render_pixels=2_000_000, + coordinate_system=None, # optional pre-selection; None = let user pick in UI + element=None, # optional pre-selection; None = let user pick in UI + persist=True, # show "Write to disk" button (False = memory only) ) -session.show() # returns the ipympl Figure -# user draws on canvas, double-click / release to commit -sdata["tumor_region"] # ShapesModel +session.show() # renders the ipywidgets controls + draw canvas + +# User picks CS + image, clicks Render, draws shapes, names + Saves each set. +# Each Save adds an entry to sdata.shapes (memory). Write to disk persists +# the most recent commit via sdata.write_element. + +sdata["tumor_region"] # ShapesModel sub = sdata.query.polygon(sdata, sdata["tumor_region"]) ``` +Removed kwargs vs original spec: +- `selector=` — UI has a tool toggle (rect/polygon/lasso); no need to bind one + selector at construction (Q3 resolution). +- `name=` — typed in the UI before each Save (Q4 resolution). +- `channel=`, `clims=` — deferred to v1 (Q1 resolution). +- `max_render_pixels=` — render is fixed at `figsize=(7,7), dpi=120` ≈ 840×840 + PNG; pyramid-aware downsampling deferred to v1. +- `overwrite=` — collision handling is automatic: same name → append UTC + timestamp. + ## Module layout ``` src/spatialdata_plot/pl/interactive/ - __init__.py # exports InteractiveSession - _session.py # InteractiveSession class, public entrypoint - _render.py # thin wrapper around existing render_images - _downsample.py # pyramid-aware scale picker; in-memory coarsen - _selectors.py # RectangleAdapter, PolygonAdapter, LassoAdapter - _commit.py # vertices → CS-correct shapely → ShapesModel - _persist.py # write_element + overwrite/timestamp policy + __init__.py # exports interactive, InteractiveSession, DrawCanvas + _session.py # InteractiveSession class — ipywidgets controls + _canvas.py # DrawCanvas anywidget + traitlets + _render.py # render_to_png helper (sdata.pl → PNG + extent) + _commit.py # pixel-shape → CS-correct shapely Polygon → ShapesModel + _persist.py # write_element + collision/timestamp policy + static/ + draw_canvas.js # the ESM module; _esm = Path(...) reads at import tests/test_interactive/ - test_commit.py - test_downsample.py - test_selectors_headless.py + test_commit.py # pixel→CS conversion + ShapesModel correctness + test_render.py # render_to_png returns valid PNG + extent + test_persist.py # collision/timestamp policy + test_canvas.py # smoke: instantiate widget, check traitlet defaults ``` -`sdata.pl.interactive(...)` becomes a method on `PlotAccessor` in -`src/spatialdata_plot/_accessor.py`, returning an `InteractiveSession`. +`sdata.pl.interactive(...)` is a method on `PlotAccessor` in +`src/spatialdata_plot/_accessor.py`. It constructs an `InteractiveSession` +and returns it; `session.show()` displays the controls + draw canvas. + +Dropped from the original spec: +- `_downsample.py` — pyramid-aware downsampling deferred to v1; v0 renders + at a fixed dpi (`figsize=(7,7), dpi=120`). +- `_selectors.py` — matplotlib selectors are replaced by the anywidget; the + three drawing tools (rect/polygon/lasso) live in `static/draw_canvas.js`. ## Coordinate-system rules (highest-risk surface) @@ -92,25 +125,31 @@ tests/test_interactive/ Avoids the classic double-applied-transform bug. -## Downsampling +## Rendering + +`_render.render_to_png(sdata, element, coordinate_system) -> (png_bytes, image_w, image_h, xlim, ylim)` -`_downsample.pick_scale(image, bbox, max_pixels) -> (level_or_factor, array)` +- Uses `sdata.pl.render_images(element=...).pl.show(coordinate_systems=..., ax=...)`. +- Axes fills the figure (`ax.add_axes([0,0,1,1])`, `set_axis_off()`) so PNG pixel + coordinates map exactly to data coordinates via `xlim`/`ylim`. +- Fixed at `figsize=(7,7)` × `dpi=120` ≈ 840×840 PNG for v0. Pyramid-aware + downsampling deferred to v1. +- 3D / z-stacks: refused by `render_images` itself (commit 3ebefe1) — we + propagate that error. -- `MultiscaleSpatialImage`: walk scales coarse→fine, pick finest within budget. -- Single-scale: `dask.array.coarsen` with integer factor, warn once. -- Static extent in v0. Auto-redraw on `xlim_changed` is v1. -- Default `max_render_pixels ≈ 2M` (~1500×1500), tuned for ipympl PNG over SSH. +## Drawing tools (in `static/draw_canvas.js`) -## Selector adapters +| kind | gesture | commit trigger | +|-------------|------------------------------------------------|-----------------------------------------------| +| rectangle | left-drag corner → corner | mouse release | +| polygon | click each vertex | snap-to-first-vertex (within 10 px) or Enter | +| lasso | left-drag freehand | mouse release | -| kind | matplotlib class | commit trigger | -|-------------|-----------------------|-------------------------------| -| rectangle | `RectangleSelector` | mouse release | -| polygon | `PolygonSelector` | close (double-click / enter) | -| lasso | `LassoSelector` | mouse release | +Plus client-side: wheel-zoom, shift-drag-pan, alt-click-shape-to-delete, +hover-highlight, Ctrl+Z undo, Delete clear, R/P/L tool shortcuts, F fit. -Lasso vertices simplified via `shapely.simplify(tolerance=0.5px)` before -persist. +Lasso vertices are simplified server-side via `shapely.simplify(tolerance=0.5)` +in `_commit` before persisting. ## Persistence policy @@ -133,22 +172,27 @@ persist. ## Test strategy -- Unit: `_commit` (synthetic vertices → ShapesModel correctness). -- Unit: `_downsample` (scale picker correctness on synthetic arrays). -- Headless: `_selectors` via programmatic `_press`/`_onmove`/`_release`. -- NO visual tests in v0. CI does not need a live canvas. -- Manual checklist in PR description for the canvas itself. +- Unit: `_commit` (synthetic pixel-coord shapes → CS-coord ShapesModel correctness). +- Unit: `_render` (returns valid PNG bytes + extent matching the axis limits). +- Unit: `_persist` (collision-rename + timestamp policy). +- Smoke: `_canvas` (instantiate `DrawCanvas`, check traitlet defaults). +- NO visual / live-canvas tests in v0 — the JS widget can't be driven from Python. + Manual checklist in PR description covers the canvas behaviour. ## Dependencies -`[project.dependencies]`: +Exposed as `[project.optional-dependencies].interactive` so the feature is +opt-in (`pip install spatialdata-plot[interactive]`). Mirrors the pixi +`interactive` dep-group. -- `ipympl` (NEW) -- `ipywidgets` (NEW or pin existing transitive) -- `shapely` (already transitive via geopandas) -- `geopandas` (already transitive via spatialdata) +- `anywidget` (NEW) — the widget framework. +- `ipywidgets` (NEW or pin existing transitive) — for the controls VBox. +- `ipykernel` — needed by anywidget for comm channel. +- `shapely`, `geopandas` — already transitive via spatialdata. -Only `ipympl` is genuinely new. +`ipympl` and `plotly` are NOT runtime deps of the new architecture (we tried +both and rejected them). They remain in the prototype/pixi feature only for +historical comparison and may be dropped from the interactive feature later. ## v1 roadmap (after v0 ships) diff --git a/pyproject.toml b/pyproject.toml index 62f2d326..a4444830 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,11 @@ dependencies = [ "scikit-learn", "spatialdata>=0.3", ] +optional-dependencies.interactive = [ + "anywidget", + "ipykernel", + "ipywidgets", +] urls.Documentation = "https://spatialdata.scverse.org/projects/plot/en/latest/index.html" urls.Home-page = "https://github.com/scverse/spatialdata-plot.git" urls.Source = "https://github.com/scverse/spatialdata-plot.git" @@ -61,11 +66,11 @@ doc = [ "sphinxcontrib-katex", "sphinxext-opengraph", ] -interactive = [ - "anywidget", - "ipykernel", +interactive-extras = [ + # Prototype-only helpers used by Sandbox.ipynb. The published runtime extra + # is [project.optional-dependencies].interactive above (anywidget/ipykernel/ + # ipywidgets only) — these are kept here for the dev-interactive-py313 env. "ipympl", - "ipywidgets", # pinned to 5.x: plotly 6's anywidget-backed FigureWidget doesn't relay # client-side draw events back to Python, so layout.shapes never syncs. "plotly>=5.20,<6", @@ -106,6 +111,12 @@ python = ">=3.11" [tool.pixi.pypi-dependencies] spatialdata-plot = { path = ".", editable = true } +# When the `interactive` feature is active, install the package with the +# `interactive` PyPI extra (anywidget, ipykernel, ipywidgets) so the pixi +# env mirrors what `pip install spatialdata-plot[interactive]` would give. +[tool.pixi.feature.interactive.pypi-dependencies] +spatialdata-plot = { path = ".", editable = true, extras = [ "interactive" ] } + [tool.pixi.tasks] format = "ruff format ." kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"' @@ -129,7 +140,7 @@ default = { features = [ "py313" ], solve-group = "py313" } # 3.11 lane (for gh-actions) dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" } dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" } -dev-interactive-py313 = { features = [ "dev", "test", "interactive", "py313" ], solve-group = "py313" } +dev-interactive-py313 = { features = [ "dev", "test", "interactive", "interactive-extras", "py313" ], solve-group = "py313" } docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" } docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" } test-py313 = { features = [ "test", "py313" ], solve-group = "py313" } diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 37a80593..2f9c2d9e 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -171,6 +171,74 @@ def _copy( return sdata + def annotate( + self, + coordinate_system: str, + element: str, + *, + persist: bool = True, + ) -> None: + """Draw and save regions interactively on an image element. + + Renders the image element in the given coordinate system as a + client-side drawing canvas (rectangle / polygon / lasso tools). + Drawn shapes are saved into ``sdata.shapes`` under a user-typed name + on click of the *Save* button — each save creates one ShapesModel + with one row per drawn shape, registered with an ``Identity`` + transformation in the chosen coordinate system. + + Requires the ``interactive`` extra: ``pip install 'spatialdata-plot[interactive]'``. + + Parameters + ---------- + coordinate_system : + Coordinate system to render and resolve drawn shapes against. + Drawn polygons are stored with an ``Identity`` transformation + in this CS. + element : + Name of the image element to render. + persist : + If ``True`` (default), show a *Write to disk* button that calls + :meth:`SpatialData.write_element` for the most recent save. + Set to ``False`` to limit the session to in-memory commits. + + Returns + ------- + None + Displays the widget in the current notebook cell. Drawn and + saved shapes appear in ``sdata.shapes``; inspect them there. + + Raises + ------ + ValueError + If ``coordinate_system`` is unknown, ``element`` is unknown, + or ``element`` is not registered in ``coordinate_system``. + ImportError + If the ``interactive`` extra is not installed. + + Examples + -------- + >>> import spatialdata_plot # noqa: F401 registers .pl + >>> sdata.pl.annotate("global", "he_image") + >>> # ... user draws and clicks Save with name "tumor" ... + >>> sdata.shapes["tumor"] + """ + try: + from spatialdata_plot.pl.interactive._session import _InteractiveSession + except ImportError as exc: + raise ImportError( + "sdata.pl.annotate() requires the `interactive` extra. " + "Install with: pip install 'spatialdata-plot[interactive]'" + ) from exc + + session = _InteractiveSession( + self._sdata, + coordinate_system=coordinate_system, + element=element, + persist=persist, + ) + session.show() + @_deprecation_alias(elements="element", version="0.3.0") def render_shapes( self, diff --git a/src/spatialdata_plot/pl/interactive/__init__.py b/src/spatialdata_plot/pl/interactive/__init__.py new file mode 100644 index 00000000..9668d0bc --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/__init__.py @@ -0,0 +1,10 @@ +"""Interactive region selection on a SpatialData image. + +Use via :meth:`spatialdata_plot.pl.basic.PlotAccessor.annotate`: + +>>> import spatialdata_plot # noqa: F401 registers .pl +>>> sdata.pl.annotate("global", "he_image") +""" +from __future__ import annotations + +__all__: list[str] = [] diff --git a/src/spatialdata_plot/pl/interactive/_canvas.py b/src/spatialdata_plot/pl/interactive/_canvas.py new file mode 100644 index 00000000..d78b878d --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_canvas.py @@ -0,0 +1,49 @@ +"""anywidget wrapping a client-side SVG drawing canvas.""" +from __future__ import annotations + +from pathlib import Path + +import anywidget +import traitlets + +_ESM_PATH = Path(__file__).parent / "static" / "draw_canvas.js" + + +class DrawCanvas(anywidget.AnyWidget): + """Client-side SVG drawing surface for interactive region selection. + + The image (PNG data URL) is shown as a CSS-transformed background; an + overlay SVG catches mouse events and emits committed shapes in image- + pixel coordinates via the ``shapes`` traitlet. + + Convert the pixel-coord shapes to data/CS coordinates with + :func:`spatialdata_plot.pl.interactive._commit.pixel_shape_to_polygon`. + + Traitlets + --------- + image_url + ``data:image/png;base64,...`` for the rendered image. + image_width, image_height + Pixel dimensions of the PNG (used to set the SVG ``viewBox``). + tool + ``"rectangle"``, ``"polygon"``, or ``"lasso"``. + shapes + List of ``{"type": "rect"|"polygon", "verts": [[x, y], ...]}`` in + image-pixel coordinates. JS pushes to this on commit. + clear_trigger, close_poly_trigger, undo_trigger, fit_trigger + Integer counters. Increment from Python to invoke the corresponding + JS action; JS observers are stateless w.r.t. the value, only the + change event matters. + """ + + _esm = _ESM_PATH + + image_url = traitlets.Unicode("").tag(sync=True) + image_width = traitlets.Int(720).tag(sync=True) + image_height = traitlets.Int(720).tag(sync=True) + tool = traitlets.Unicode("rectangle").tag(sync=True) + shapes = traitlets.List([]).tag(sync=True) + clear_trigger = traitlets.Int(0).tag(sync=True) + close_poly_trigger = traitlets.Int(0).tag(sync=True) + undo_trigger = traitlets.Int(0).tag(sync=True) + fit_trigger = traitlets.Int(0).tag(sync=True) diff --git a/src/spatialdata_plot/pl/interactive/_commit.py b/src/spatialdata_plot/pl/interactive/_commit.py new file mode 100644 index 00000000..8a0ae949 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_commit.py @@ -0,0 +1,66 @@ +"""Convert canvas pixel-coord shapes into a CS-coord ShapesModel.""" +from __future__ import annotations + +from typing import Any + +import geopandas as gpd +import spatialdata as sd +from shapely.geometry import Polygon +from spatialdata.models import ShapesModel + +# Tolerance for lasso simplification, in CS units. The lasso path is sampled +# at every mouse-move so a freehand loop easily exceeds 1000 vertices; +# 0.5 keeps shape fidelity while collapsing co-linear points. +_LASSO_SIMPLIFY_TOL = 0.5 + + +def pixel_shape_to_polygon( + shape: dict[str, Any], + image_w: int, + image_h: int, + xlim: tuple[float, float], + ylim: tuple[float, float], +) -> Polygon | None: + """Convert a single ``DrawCanvas`` shape entry to a CS-coord shapely Polygon. + + Returns ``None`` if the shape is invalid (no verts, <3 verts after + construction, or empty). + + The matplotlib image axes use ``origin='upper'`` — the *smaller* y-value + of ``ylim`` corresponds to the top of the rendered image (PNG row 0). + """ + verts = shape.get("verts") if isinstance(shape, dict) else None + if not verts: + return None + + xmin, xmax = float(xlim[0]), float(xlim[1]) + y_lo, y_hi = sorted((float(ylim[0]), float(ylim[1]))) + + def px_to_cs(px: float, py: float) -> tuple[float, float]: + return ( + xmin + (px / image_w) * (xmax - xmin), + y_lo + (py / image_h) * (y_hi - y_lo), + ) + + cs_verts = [px_to_cs(v[0], v[1]) for v in verts] + if len(cs_verts) < 3: + return None + poly = Polygon(cs_verts) + if poly.is_empty: + return None + if shape.get("type") == "polygon" and len(cs_verts) > 50: + # Lasso-like (high vertex count) polygons benefit from simplification. + poly = poly.simplify(_LASSO_SIMPLIFY_TOL, preserve_topology=True) + return poly + + +def build_shapes_model( + polygons: list[Polygon], + coordinate_system: str, +) -> Any: + """Wrap shapely polygons in a ShapesModel registered with Identity in ``coordinate_system``.""" + gdf = gpd.GeoDataFrame({"geometry": polygons}) + return ShapesModel.parse( + gdf, + transformations={coordinate_system: sd.transformations.Identity()}, + ) diff --git a/src/spatialdata_plot/pl/interactive/_persist.py b/src/spatialdata_plot/pl/interactive/_persist.py new file mode 100644 index 00000000..5d496187 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_persist.py @@ -0,0 +1,38 @@ +"""Commit a ShapesModel into sdata.shapes (memory) and optionally to zarr.""" +from __future__ import annotations + +import datetime as _dt +from typing import Any + +import spatialdata as sd + + +def commit_to_memory( + sdata: sd.SpatialData, + shapes_model: Any, + name: str, +) -> str: + """Add ``shapes_model`` to ``sdata.shapes`` under ``name``. + + On collision, the existing element is preserved and the new one is + renamed to ``{name}_{utc-iso}``. Returns the final committed name. + """ + target = name + if target in sdata.shapes: + ts = _dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + target = f"{name}_{ts}" + sdata.shapes[target] = shapes_model + return target + + +def persist_to_disk(sdata: sd.SpatialData, name: str) -> None: + """Persist ``sdata.shapes[name]`` to the backing zarr store. + + Raises ValueError if ``sdata`` is not zarr-backed. + """ + if sdata.path is None: + raise ValueError( + "SpatialData is not zarr-backed (sdata.path is None); cannot persist. " + "Write the SpatialData object to a zarr store first, then re-open it." + ) + sdata.write_element(name) diff --git a/src/spatialdata_plot/pl/interactive/_render.py b/src/spatialdata_plot/pl/interactive/_render.py new file mode 100644 index 00000000..8c4ebaee --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_render.py @@ -0,0 +1,49 @@ +"""Render an image element to a PNG suitable for the DrawCanvas background.""" +from __future__ import annotations + +from io import BytesIO + +import matplotlib.pyplot as plt +import spatialdata as sd + +# Render size: 7 in × 120 dpi = 840 px square. Fixed for v0; pyramid-aware +# downsampling is a v1 feature. +_FIGSIZE = (7, 7) +_DPI = 120 +_IMAGE_W = _FIGSIZE[0] * _DPI # 840 +_IMAGE_H = _FIGSIZE[1] * _DPI # 840 + + +def render_to_png( + sdata: sd.SpatialData, + element: str, + coordinate_system: str, +) -> tuple[bytes, int, int, tuple[float, float], tuple[float, float]]: + """Render ``element`` in ``coordinate_system`` to PNG bytes. + + The matplotlib axes fills the figure (``[0, 0, 1, 1]`` with axis off) so + the PNG-pixel ↔ data-coord mapping is exactly ``xlim`` × ``ylim``. + + Returns + ------- + png_bytes + PNG-encoded image. + image_w, image_h + Pixel dimensions of the PNG. + xlim, ylim + ``ax.get_xlim()`` / ``ax.get_ylim()`` at render time. For image axes + (``origin='upper'``) ``ylim`` is reversed — see + :func:`._commit.pixel_shape_to_polygon` for the conversion. + """ + fig = plt.figure(figsize=_FIGSIZE, dpi=_DPI) + ax = fig.add_axes([0, 0, 1, 1]) + sdata.pl.render_images(element=element).pl.show(coordinate_systems=coordinate_system, ax=ax) + xlim = ax.get_xlim() + ylim = ax.get_ylim() + ax.set_axis_off() + for spine in ax.spines.values(): + spine.set_visible(False) + buf = BytesIO() + fig.savefig(buf, format="png", dpi=_DPI, pad_inches=0) + plt.close(fig) + return buf.getvalue(), _IMAGE_W, _IMAGE_H, tuple(xlim), tuple(ylim) diff --git a/src/spatialdata_plot/pl/interactive/_session.py b/src/spatialdata_plot/pl/interactive/_session.py new file mode 100644 index 00000000..15363475 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_session.py @@ -0,0 +1,275 @@ +"""ipywidgets-based session orchestrating the DrawCanvas. + +Internal class — users invoke :meth:`PlotAccessor.annotate`, which constructs +a session and displays it. +""" +from __future__ import annotations + +import base64 +from typing import Any + +import ipywidgets as W +import spatialdata as sd +from IPython.display import display + +from ._canvas import DrawCanvas +from ._commit import build_shapes_model, pixel_shape_to_polygon +from ._persist import commit_to_memory, persist_to_disk +from ._render import render_to_png + +_CSS = """ + +""" + + +def _fmt_banner(msg: str, kind: str = "info") -> str: + cls = { + "info": "sdp-banner sdp-banner-info", + "success": "sdp-banner sdp-banner-success", + "error": "sdp-banner sdp-banner-error", + "hint": "sdp-banner sdp-banner-hint", + }.get(kind, "sdp-banner sdp-banner-info") + return f"
{msg}
" + + +def _validate(sdata: sd.SpatialData, coordinate_system: str, element: str) -> None: + if coordinate_system not in sdata.coordinate_systems: + raise ValueError( + f"Unknown coordinate system {coordinate_system!r}. " + f"Available: {list(sdata.coordinate_systems)}" + ) + if element not in sdata.images: + raise ValueError( + f"Unknown image element {element!r}. Available: {list(sdata.images)}" + ) + transforms = sd.transformations.get_transformation(sdata.images[element], get_all=True) + if coordinate_system not in transforms: + raise ValueError( + f"Image {element!r} is not registered in coordinate system " + f"{coordinate_system!r}. Registered in: {list(transforms)}" + ) + + +class _InteractiveSession: + """Internal session class driving the DrawCanvas widget. + + Constructed by :meth:`PlotAccessor.annotate`. Not part of the public API. + """ + + def __init__( + self, + sdata: sd.SpatialData, + coordinate_system: str, + element: str, + *, + persist: bool = True, + ) -> None: + _validate(sdata, coordinate_system, element) + + self.sdata = sdata + self._cs = coordinate_system + self._element = element + self._persist_enabled = persist + self.canvas: DrawCanvas | None = None + self._image_w: int | None = None + self._image_h: int | None = None + self._xlim: tuple[float, float] | None = None + self._ylim: tuple[float, float] | None = None + self.commits: list[str] = [] + + self._style = W.HTML(value=_CSS) + + # Tool toggle + self.tool_tb = W.ToggleButtons( + options=[("Rect", "rectangle"), ("Polygon", "polygon"), ("Lasso", "lasso")], + value="rectangle", + description="Tool:", + ) + self.tool_tb.observe(self._on_tool_change, names="value") + self.close_poly_btn = W.Button(description="Close polygon", icon="check", tooltip="Enter") + self.close_poly_btn.on_click(self._on_close_polygon) + self.close_poly_btn.disabled = True + self.undo_btn = W.Button(description="Undo", icon="rotate-left", tooltip="Ctrl+Z") + self.undo_btn.on_click(self._on_undo) + self.undo_btn.disabled = True + self.clear_btn = W.Button(description="Clear", icon="trash", tooltip="Delete") + self.clear_btn.on_click(self._on_clear) + self.fit_btn = W.Button(description="Fit view", icon="compress", tooltip="F") + self.fit_btn.on_click(self._on_fit) + self.shape_count_lbl = W.Label(value="0 shape(s) on canvas") + + # Save + self.name_tx = W.Text(value="", placeholder="name…", description="Name:") + self.save_btn = W.Button(description="Save", button_style="success", icon="save") + self.save_btn.on_click(self._on_save) + self.persist_btn = W.Button(description="Write to disk", button_style="warning", icon="hdd-o") + self.persist_btn.on_click(self._on_persist) + self.persist_btn.disabled = True + + # Banner + canvas holder + self.banner = W.HTML(value=_fmt_banner( + f"Annotating {element!r} in coordinate system {coordinate_system!r}. " + "Pick a tool and draw. Click canvas first so keyboard shortcuts work. " + "R/P/L tools · Wheel zoom · Shift+drag pan · " + "Alt+click shape to delete · Ctrl+Z undo · F fit", + "hint", + )) + self.plot_box = W.VBox([]) + + def section(label: str) -> W.HTML: + return W.HTML(value=f"
{label}
") + + save_row_widgets = [self.name_tx, self.save_btn] + if persist: + save_row_widgets.append(self.persist_btn) + + controls_card = W.VBox([ + W.HTML(value=( + f"
Annotate
" + f"
{element!r} · {coordinate_system!r}
" + )), + section("Draw"), + W.HBox([self.tool_tb, self.close_poly_btn, self.undo_btn, self.clear_btn, self.fit_btn]), + W.HBox([self.shape_count_lbl]), + section("Save"), + W.HBox(save_row_widgets), + self.banner, + ]) + controls_card.add_class("sdp-card") + + canvas_card = W.VBox([self.plot_box]) + canvas_card.add_class("sdp-card") + + self.controls = W.VBox([self._style, controls_card, canvas_card]) + + def show(self) -> None: + """Render the image and display the controls + canvas.""" + self._render() + display(self.controls) + + def _set_banner(self, msg: str, kind: str = "info") -> None: + self.banner.value = _fmt_banner(msg, kind) + + # ----- render ----- + + def _render(self) -> None: + png_bytes, image_w, image_h, xlim, ylim = render_to_png( + self.sdata, self._element, self._cs, + ) + data_url = "data:image/png;base64," + base64.b64encode(png_bytes).decode("ascii") + self._image_w, self._image_h = image_w, image_h + self._xlim, self._ylim = xlim, ylim + + self.canvas = DrawCanvas( + image_url=data_url, + image_width=image_w, + image_height=image_h, + tool=self.tool_tb.value, + ) + self.canvas.observe(self._on_shapes_change, names="shapes") + self.plot_box.children = (self.canvas,) + self.shape_count_lbl.value = "0 shape(s) on canvas" + + def _on_shapes_change(self, change: dict[str, Any]) -> None: + shapes = change["new"] or [] + self.shape_count_lbl.value = f"{len(shapes)} shape(s) on canvas" + self.undo_btn.disabled = len(shapes) == 0 + + # ----- tool / clear / undo / fit / close ----- + + def _on_tool_change(self, change: dict[str, Any]) -> None: + if self.canvas is None: + return + self.canvas.tool = change["new"] + self.close_poly_btn.disabled = change["new"] != "polygon" + self._set_banner(f"Tool: {change['new']}", "info") + + def _on_close_polygon(self, _btn: W.Button) -> None: + if self.canvas is None: + return + self.canvas.close_poly_trigger += 1 + + def _on_undo(self, _btn: W.Button) -> None: + if self.canvas is None: + return + self.canvas.undo_trigger += 1 + + def _on_clear(self, _btn: W.Button) -> None: + if self.canvas is None: + return + self.canvas.clear_trigger += 1 + self._set_banner("Canvas cleared.", "info") + + def _on_fit(self, _btn: W.Button) -> None: + if self.canvas is None: + return + self.canvas.fit_trigger += 1 + + # ----- save / persist ----- + + def _on_save(self, _btn: W.Button) -> None: + name = self.name_tx.value.strip() + if not name: + self._set_banner("Name is required.", "error") + return + if self.canvas is None or not self.canvas.shapes: + self._set_banner("No shapes drawn yet.", "error") + return + + polys = [] + for sh in self.canvas.shapes: + p = pixel_shape_to_polygon(sh, self._image_w, self._image_h, self._xlim, self._ylim) + if p is not None: + polys.append(p) + if not polys: + self._set_banner( + f"{len(self.canvas.shapes)} shape(s) on canvas but none parsed as valid polygons.", + "error", + ) + return + + shapes_model = build_shapes_model(polys, self._cs) + target = commit_to_memory(self.sdata, shapes_model, name) + self.commits.append(target) + + self.canvas.clear_trigger += 1 + self.shape_count_lbl.value = "0 shape(s) on canvas" + renamed = target != name + msg = f"Saved {target!r} with {len(polys)} polygon(s)." + if renamed: + msg += " (name collided; renamed)" + self._set_banner(msg, "success") + if self._persist_enabled: + self.persist_btn.disabled = self.sdata.path is None + + def _on_persist(self, _btn: W.Button) -> None: + if not self.commits: + self._set_banner("Nothing saved this session yet.", "error") + return + target = self.commits[-1] + try: + persist_to_disk(self.sdata, target) + except ValueError as exc: + self._set_banner(str(exc), "error") + return + self._set_banner(f"Persisted {target!r} → {self.sdata.path}", "success") diff --git a/src/spatialdata_plot/pl/interactive/static/draw_canvas.js b/src/spatialdata_plot/pl/interactive/static/draw_canvas.js new file mode 100644 index 00000000..e86683b2 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/static/draw_canvas.js @@ -0,0 +1,380 @@ +// anywidget ESM for spatialdata_plot.pl.interactive.DrawCanvas. +// Pure client-side drawing on an SVG overlay above the rendered image PNG. +// Shape geometry (in image-pixel coordinates) is synced back to Python via +// the `shapes` traitlet; conversion to data/CS coords happens server-side. + +function render({ model, el }) { + const W = model.get('image_width'); + const H = model.get('image_height'); + const DISP_MAX = 760; + const aspect = W / H; + const dispW = aspect >= 1 ? DISP_MAX : Math.round(DISP_MAX * aspect); + const dispH = aspect >= 1 ? Math.round(DISP_MAX / aspect) : DISP_MAX; + + const wrap = document.createElement('div'); + wrap.style.cssText = ` + display: inline-block; + background: #18181b; + padding: 6px; + border-radius: 10px; + box-shadow: 0 2px 6px rgba(0,0,0,0.08); + `; + const container = document.createElement('div'); + container.style.cssText = ` + position: relative; + width: ${dispW}px; + height: ${dispH}px; + user-select: none; + background: #000; + border-radius: 6px; + overflow: hidden; + `; + wrap.appendChild(container); + + const img = document.createElement('img'); + img.src = model.get('image_url'); + img.style.cssText = ` + position: absolute; inset: 0; width: 100%; height: 100%; + pointer-events: none; + `; + img.draggable = false; + container.appendChild(img); + + const svgNS = 'http://www.w3.org/2000/svg'; + const svg = document.createElementNS(svgNS, 'svg'); + svg.style.cssText = ` + position: absolute; inset: 0; width: 100%; height: 100%; + cursor: crosshair; touch-action: none; + `; + svg.setAttribute('preserveAspectRatio', 'none'); + container.appendChild(svg); + + el.appendChild(wrap); + + let shapes = []; + let drawing = null; + let pendingPoly = null; + let hoverIndex = -1; + let vbox = { x: 0, y: 0, w: W, h: H }; + const SNAP_PX = 10; + + function applyViewbox() { + const sx = W / vbox.w; + const sy = H / vbox.h; + img.style.transformOrigin = '0 0'; + img.style.transform = `scale(${sx}, ${sy}) translate(${-vbox.x}px, ${-vbox.y}px)`; + svg.setAttribute('viewBox', `${vbox.x} ${vbox.y} ${vbox.w} ${vbox.h}`); + } + applyViewbox(); + + function setShapes(next) { + shapes = next; + model.set('shapes', shapes); + model.save_changes(); + } + + function getXY(e) { + const r = svg.getBoundingClientRect(); + const fx = (e.clientX - r.left) / r.width; + const fy = (e.clientY - r.top) / r.height; + return [vbox.x + fx * vbox.w, vbox.y + fy * vbox.h]; + } + + function vboxScalePerSvgPx() { + return vbox.w / svg.getBoundingClientRect().width; + } + + function makeEl(tag, attrs) { + const n = document.createElementNS(svgNS, tag); + for (const k in attrs) n.setAttribute(k, attrs[k]); + return n; + } + + function shapeNode(s, color, opts) { + opts = opts || {}; + const sw = opts.lw || 2; + const dash = opts.dashed ? '6,4' : ''; + const fillOp = opts.fillOp == null ? 0.15 : opts.fillOp; + if (s.type === 'rect') { + const [x0, y0] = s.verts[0]; + const [x1, y1] = s.verts[2]; + return makeEl('rect', { + x: Math.min(x0, x1), y: Math.min(y0, y1), + width: Math.abs(x1 - x0), height: Math.abs(y1 - y0), + stroke: color, 'stroke-width': sw, + 'vector-effect': 'non-scaling-stroke', + fill: color, 'fill-opacity': fillOp, + 'stroke-dasharray': dash, + }); + } else if (s.type === 'polygon') { + return makeEl('polygon', { + points: s.verts.map(v => v.join(',')).join(' '), + stroke: color, 'stroke-width': sw, + 'vector-effect': 'non-scaling-stroke', + fill: color, 'fill-opacity': fillOp, + 'stroke-dasharray': dash, + }); + } else if (s.type === 'polyline') { + return makeEl('polyline', { + points: s.verts.map(v => v.join(',')).join(' '), + stroke: color, 'stroke-width': sw, + 'vector-effect': 'non-scaling-stroke', + fill: 'none', + 'stroke-dasharray': dash, + }); + } + return null; + } + + function distPx(a, b) { return Math.hypot(a[0] - b[0], a[1] - b[1]); } + + function shouldSnapClosePoly(e) { + if (!pendingPoly || pendingPoly.verts.length < 3) return false; + const r = svg.getBoundingClientRect(); + const fx = pendingPoly.verts[0][0]; + const fy = pendingPoly.verts[0][1]; + const cx = r.left + (fx - vbox.x) / vbox.w * r.width; + const cy = r.top + (fy - vbox.y) / vbox.h * r.height; + return distPx([e.clientX, e.clientY], [cx, cy]) <= SNAP_PX; + } + + function redraw() { + while (svg.firstChild) svg.removeChild(svg.firstChild); + shapes.forEach((s, i) => { + const isHover = i === hoverIndex; + const n = shapeNode(s, isHover ? '#fb923c' : '#22d3ee', + { lw: isHover ? 3 : 2, fillOp: isHover ? 0.25 : 0.15 }); + if (n) { + n.style.cursor = 'pointer'; + n.dataset.idx = String(i); + n.addEventListener('mouseenter', () => { + hoverIndex = i; redraw(); + }); + n.addEventListener('mouseleave', () => { + if (hoverIndex === i) { hoverIndex = -1; redraw(); } + }); + n.addEventListener('click', (ev) => { + if (ev.altKey) { + const next = shapes.slice(); next.splice(i, 1); + hoverIndex = -1; + setShapes(next); + ev.stopPropagation(); + } + }); + svg.appendChild(n); + } + }); + if (drawing) { + const n = shapeNode(drawing, '#ec4899', { dashed: true }); + if (n) { n.style.pointerEvents = 'none'; svg.appendChild(n); } + } + if (pendingPoly && pendingPoly.verts.length > 0) { + const px = vboxScalePerSvgPx(); + const rPx = 5 * px; + pendingPoly.verts.forEach(([x, y], i) => { + const c = makeEl('circle', { + cx: x, cy: y, r: i === 0 ? rPx * 1.3 : rPx, + fill: i === 0 ? '#facc15' : '#ec4899', + stroke: 'white', 'stroke-width': 1.5 * px, + 'vector-effect': 'non-scaling-stroke', + }); + c.style.pointerEvents = 'none'; + svg.appendChild(c); + }); + } + } + + function commitPendingPolygon() { + if (pendingPoly && pendingPoly.verts.length >= 3) { + setShapes([...shapes, { type: 'polygon', verts: pendingPoly.verts }]); + } + pendingPoly = null; + drawing = null; + redraw(); + } + + function zoomAt(clientX, clientY, factor) { + const r = svg.getBoundingClientRect(); + const fx = (clientX - r.left) / r.width; + const fy = (clientY - r.top) / r.height; + const px = vbox.x + fx * vbox.w; + const py = vbox.y + fy * vbox.h; + let newW = vbox.w / factor; + let newH = vbox.h / factor; + const minW = Math.max(5, W * 0.02); + const minH = Math.max(5, H * 0.02); + if (newW < minW) newW = minW; + if (newH < minH) newH = minH; + if (newW > W) { newW = W; newH = H; } + vbox.x = px - fx * newW; + vbox.y = py - fy * newH; + vbox.w = newW; vbox.h = newH; + clampVbox(); + applyViewbox(); + redraw(); + } + function panBy(dxClient, dyClient) { + const r = svg.getBoundingClientRect(); + vbox.x -= dxClient * (vbox.w / r.width); + vbox.y -= dyClient * (vbox.h / r.height); + clampVbox(); + applyViewbox(); + redraw(); + } + function clampVbox() { + if (vbox.x < 0) vbox.x = 0; + if (vbox.y < 0) vbox.y = 0; + if (vbox.x + vbox.w > W) vbox.x = W - vbox.w; + if (vbox.y + vbox.h > H) vbox.y = H - vbox.h; + } + function fitView() { + vbox = { x: 0, y: 0, w: W, h: H }; + applyViewbox(); + redraw(); + } + + let panning = false; + let panStart = null; + + function onWheel(e) { + e.preventDefault(); + const factor = e.deltaY < 0 ? 1.2 : 1 / 1.2; + zoomAt(e.clientX, e.clientY, factor); + } + + function onMouseDown(e) { + if (e.button === 1 || (e.button === 0 && e.shiftKey)) { + panning = true; + panStart = [e.clientX, e.clientY]; + svg.style.cursor = 'grabbing'; + e.preventDefault(); + return; + } + if (e.button !== 0) return; + svg.focus(); + const tool = model.get('tool'); + if (tool === 'polygon' && shouldSnapClosePoly(e)) { + commitPendingPolygon(); + e.preventDefault(); + return; + } + const [x, y] = getXY(e); + if (tool === 'rectangle') { + drawing = { type: 'rect', verts: [[x, y], [x, y], [x, y], [x, y]] }; + redraw(); + } else if (tool === 'lasso') { + drawing = { type: 'polygon', verts: [[x, y]] }; + redraw(); + } else if (tool === 'polygon') { + if (!pendingPoly) pendingPoly = { type: 'polygon', verts: [] }; + pendingPoly.verts.push([x, y]); + drawing = { type: 'polyline', verts: [...pendingPoly.verts] }; + redraw(); + } + e.preventDefault(); + } + + function onMouseMove(e) { + if (panning) { + const dx = e.clientX - panStart[0]; + const dy = e.clientY - panStart[1]; + panStart = [e.clientX, e.clientY]; + panBy(dx, dy); + return; + } + if (!drawing) return; + const tool = model.get('tool'); + const [x, y] = getXY(e); + if (tool === 'rectangle') { + const [x0, y0] = drawing.verts[0]; + drawing.verts = [[x0, y0], [x, y0], [x, y], [x0, y]]; + redraw(); + } else if (tool === 'lasso') { + drawing.verts.push([x, y]); + redraw(); + } + } + + function onMouseUp(e) { + if (panning) { + panning = false; panStart = null; + svg.style.cursor = 'crosshair'; + return; + } + const tool = model.get('tool'); + if (tool === 'rectangle' && drawing) { + const [[x0, y0], , [x1, y1]] = drawing.verts; + if (Math.abs(x1 - x0) >= 2 && Math.abs(y1 - y0) >= 2) { + setShapes([...shapes, { type: 'rect', verts: drawing.verts }]); + } + drawing = null; + redraw(); + } else if (tool === 'lasso' && drawing && drawing.verts.length >= 3) { + setShapes([...shapes, { type: 'polygon', verts: drawing.verts }]); + drawing = null; + redraw(); + } + } + + function onKeyDown(e) { + const tool = model.get('tool'); + if (e.key === 'r' || e.key === 'R') { model.set('tool', 'rectangle'); model.save_changes(); e.preventDefault(); return; } + if (e.key === 'p' || e.key === 'P') { model.set('tool', 'polygon'); model.save_changes(); e.preventDefault(); return; } + if (e.key === 'l' || e.key === 'L') { model.set('tool', 'lasso'); model.save_changes(); e.preventDefault(); return; } + if (e.key === 'f' || e.key === 'F') { fitView(); e.preventDefault(); return; } + if (e.key === 'Enter') { + if (tool === 'polygon' && pendingPoly) commitPendingPolygon(); + e.preventDefault(); + return; + } + if (e.key === 'Escape') { + pendingPoly = null; drawing = null; redraw(); + e.preventDefault(); + return; + } + if ((e.ctrlKey || e.metaKey) && (e.key === 'z' || e.key === 'Z')) { + if (shapes.length > 0) setShapes(shapes.slice(0, -1)); + e.preventDefault(); + return; + } + if (e.key === 'Delete' || e.key === 'Backspace') { + if (shapes.length > 0) setShapes([]); + e.preventDefault(); + return; + } + } + + svg.tabIndex = 0; + svg.addEventListener('wheel', onWheel, { passive: false }); + svg.addEventListener('mousedown', onMouseDown); + svg.addEventListener('mousemove', onMouseMove); + svg.addEventListener('mouseup', onMouseUp); + svg.addEventListener('mouseleave', (e) => { if (!panning) onMouseUp(e); }); + svg.addEventListener('keydown', onKeyDown); + svg.addEventListener('contextmenu', (e) => e.preventDefault()); + + function updateCursor() { + svg.style.cursor = 'crosshair'; + svg.title = `Tool: ${model.get('tool')}. R/P/L: tools. Enter: close poly. Esc: cancel. Ctrl+Z: undo. Alt+click shape: delete. Wheel: zoom. Shift+drag: pan. F: fit.`; + } + updateCursor(); + + model.on('change:tool', () => { + pendingPoly = null; + drawing = null; + updateCursor(); + redraw(); + }); + model.on('change:clear_trigger', () => { + shapes = []; drawing = null; pendingPoly = null; + model.set('shapes', []); model.save_changes(); + redraw(); + }); + model.on('change:close_poly_trigger', () => { commitPendingPolygon(); }); + model.on('change:undo_trigger', () => { + if (shapes.length > 0) setShapes(shapes.slice(0, -1)); + }); + model.on('change:fit_trigger', () => { fitView(); }); +} + +export default { render }; diff --git a/tests/test_interactive/__init__.py b/tests/test_interactive/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_interactive/test_canvas.py b/tests/test_interactive/test_canvas.py new file mode 100644 index 00000000..163e0c99 --- /dev/null +++ b/tests/test_interactive/test_canvas.py @@ -0,0 +1,44 @@ +"""Smoke tests for the DrawCanvas anywidget class.""" +from __future__ import annotations + +from pathlib import Path + + +def test_draw_canvas_imports(): + from spatialdata_plot.pl.interactive import DrawCanvas + + assert DrawCanvas is not None + + +def test_draw_canvas_default_traitlets(): + from spatialdata_plot.pl.interactive import DrawCanvas + + c = DrawCanvas() + assert c.tool == "rectangle" + assert c.shapes == [] + assert c.image_width == 720 + assert c.image_height == 720 + assert c.clear_trigger == 0 + assert c.close_poly_trigger == 0 + assert c.undo_trigger == 0 + assert c.fit_trigger == 0 + + +def test_draw_canvas_esm_file_is_bundled(): + """The ESM module file must ship with the package.""" + from spatialdata_plot.pl.interactive import _canvas + + assert _canvas._ESM_PATH.exists(), f"{_canvas._ESM_PATH} not bundled" + assert _canvas._ESM_PATH.suffix == ".js" + assert _canvas._ESM_PATH.stat().st_size > 0 + + +def test_draw_canvas_traitlet_assignment(): + """Setting traitlets from Python should work (Python → JS sync).""" + from spatialdata_plot.pl.interactive import DrawCanvas + + c = DrawCanvas() + c.tool = "polygon" + assert c.tool == "polygon" + c.clear_trigger += 1 + assert c.clear_trigger == 1 diff --git a/tests/test_interactive/test_commit.py b/tests/test_interactive/test_commit.py new file mode 100644 index 00000000..7e8606ad --- /dev/null +++ b/tests/test_interactive/test_commit.py @@ -0,0 +1,78 @@ +"""Tests for pixel-coord → CS-coord conversion and ShapesModel construction.""" +from __future__ import annotations + +import pytest +from shapely.geometry import Polygon + +from spatialdata_plot.pl.interactive._commit import ( + build_shapes_model, + pixel_shape_to_polygon, +) + + +def test_rect_maps_to_full_cs_extent(): + """A rect spanning the full PNG maps to the full CS extent.""" + shape = {"type": "rect", "verts": [[0, 0], [100, 0], [100, 100], [0, 100]]} + poly = pixel_shape_to_polygon(shape, 100, 100, (0.0, 50.0), (50.0, 0.0)) + assert poly.bounds == (0.0, 0.0, 50.0, 50.0) + + +def test_rect_subregion(): + """A 50% rect maps to the central CS quadrant correctly.""" + shape = {"type": "rect", "verts": [[25, 25], [75, 25], [75, 75], [25, 75]]} + poly = pixel_shape_to_polygon(shape, 100, 100, (0.0, 100.0), (100.0, 0.0)) + assert poly.bounds == (25.0, 25.0, 75.0, 75.0) + + +def test_y_axis_orientation_matplotlib_image(): + """matplotlib image axes have origin='upper'; py=0 maps to the smaller y.""" + # PNG pixel (0, 0) = top-left. With ylim=(100, 0) (matplotlib image style), + # the smaller y (0) is at the top, larger y (100) at the bottom. + shape = {"type": "polygon", "verts": [[0, 0], [10, 0], [10, 10], [0, 10]]} + poly = pixel_shape_to_polygon(shape, 100, 100, (0.0, 100.0), (100.0, 0.0)) + # Pixel (0, 0) (top) → CS y = 0; pixel (0, 10) (10px down) → CS y = 10. + assert poly.bounds == (0.0, 0.0, 10.0, 10.0) + + +def test_y_axis_non_reversed_ylim(): + """sorted() handles ylim either way round — no orientation flip.""" + shape = {"type": "polygon", "verts": [[0, 0], [10, 0], [10, 10], [0, 10]]} + # Hand it ylim in non-reversed order; result should be the same. + poly = pixel_shape_to_polygon(shape, 100, 100, (0.0, 100.0), (0.0, 100.0)) + assert poly.bounds == (0.0, 0.0, 10.0, 10.0) + + +def test_invalid_shapes_return_none(): + """Empty verts, <3 verts, or empty geometry all yield None.""" + assert pixel_shape_to_polygon({"type": "polygon", "verts": []}, 100, 100, (0, 1), (0, 1)) is None + assert pixel_shape_to_polygon({"type": "polygon", "verts": [[1, 1], [2, 2]]}, 100, 100, (0, 1), (0, 1)) is None + assert pixel_shape_to_polygon({}, 100, 100, (0, 1), (0, 1)) is None + + +def test_lasso_simplification_for_high_vertex_count(): + """Polygons with > 50 verts get .simplify() applied; rectangle-shaped ones stay rectangles.""" + n = 200 + # A near-rectangular path with many co-linear noise points along each edge. + verts = ( + [[i, 0] for i in range(n)] + + [[n, j] for j in range(n)] + + [[n - i, n] for i in range(n)] + + [[0, n - j] for j in range(n)] + ) + shape = {"type": "polygon", "verts": verts} + poly = pixel_shape_to_polygon(shape, n, n, (0.0, n), (float(n), 0.0)) + # After simplification, the rectangle should have far fewer than 4*n vertices. + assert len(poly.exterior.coords) < 4 * n + # Bounds preserved. + assert poly.bounds == (0.0, 0.0, float(n), float(n)) + + +def test_build_shapes_model_registers_identity_transform(): + """build_shapes_model wraps polygons with an Identity transform in the given CS.""" + polys = [Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])] + sm = build_shapes_model(polys, "my_cs") + # ShapesModel inherits GeoDataFrame; transformations live in .attrs. + import spatialdata as sd + transforms = sd.transformations.get_transformation(sm, get_all=True) + assert "my_cs" in transforms + assert isinstance(transforms["my_cs"], sd.transformations.Identity) diff --git a/tests/test_interactive/test_persist.py b/tests/test_interactive/test_persist.py new file mode 100644 index 00000000..671def28 --- /dev/null +++ b/tests/test_interactive/test_persist.py @@ -0,0 +1,57 @@ +"""Tests for the in-memory commit + zarr write policy.""" +from __future__ import annotations + +import geopandas as gpd +import pytest +import spatialdata as sd +from shapely.geometry import Polygon +from spatialdata.models import ShapesModel + +from spatialdata_plot.pl.interactive._persist import commit_to_memory, persist_to_disk + + +def _make_sdata(tmp_path=None) -> sd.SpatialData: + sdata = sd.SpatialData() + if tmp_path is not None: + sdata.write(tmp_path / "test.zarr") + sdata = sd.read_zarr(tmp_path / "test.zarr") + return sdata + + +def _make_shape() -> ShapesModel: + gdf = gpd.GeoDataFrame({"geometry": [Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])]}) + return ShapesModel.parse(gdf, transformations={"global": sd.transformations.Identity()}) + + +def test_commit_to_memory_stores_under_name(): + sdata = _make_sdata() + target = commit_to_memory(sdata, _make_shape(), "tumor_region") + assert target == "tumor_region" + assert "tumor_region" in sdata.shapes + + +def test_commit_to_memory_renames_on_collision(): + sdata = _make_sdata() + sdata.shapes["tumor_region"] = _make_shape() + target = commit_to_memory(sdata, _make_shape(), "tumor_region") + # Original preserved, new one gets a timestamp suffix. + assert "tumor_region" in sdata.shapes + assert target.startswith("tumor_region_") + assert target != "tumor_region" + assert target in sdata.shapes + + +def test_persist_raises_when_not_zarr_backed(): + sdata = _make_sdata() + sdata.shapes["foo"] = _make_shape() + with pytest.raises(ValueError, match="not zarr-backed"): + persist_to_disk(sdata, "foo") + + +def test_persist_writes_to_zarr(tmp_path): + sdata = _make_sdata(tmp_path=tmp_path) + sdata.shapes["foo"] = _make_shape() + persist_to_disk(sdata, "foo") + # Re-read from disk and check the element survives. + sdata2 = sd.read_zarr(tmp_path / "test.zarr") + assert "foo" in sdata2.shapes diff --git a/tests/test_interactive/test_render.py b/tests/test_interactive/test_render.py new file mode 100644 index 00000000..93f6b5c2 --- /dev/null +++ b/tests/test_interactive/test_render.py @@ -0,0 +1,39 @@ +"""Smoke test for the matplotlib → PNG render path.""" +from __future__ import annotations + +from io import BytesIO + +import numpy as np +import spatialdata as sd +from PIL import Image + +from spatialdata_plot.pl.interactive._render import _IMAGE_H, _IMAGE_W, render_to_png + + +def _make_sdata_with_image() -> sd.SpatialData: + """Build a tiny SpatialData with a single 2D image in CS 'global'.""" + from spatialdata.models import Image2DModel + + arr = np.random.default_rng(0).integers(0, 255, size=(3, 64, 64), dtype=np.uint8) + img = Image2DModel.parse(arr, dims=("c", "y", "x")) + return sd.SpatialData(images={"img": img}) + + +def test_render_to_png_returns_valid_png(): + sdata = _make_sdata_with_image() + png_bytes, w, h, xlim, ylim = render_to_png(sdata, "img", "global") + # PNG signature + assert png_bytes.startswith(b"\x89PNG\r\n\x1a\n") + # Decode and check dimensions roughly match the configured render size. + decoded = Image.open(BytesIO(png_bytes)) + assert decoded.size == (w, h) == (_IMAGE_W, _IMAGE_H) + + +def test_render_to_png_returns_extent_matching_image(): + sdata = _make_sdata_with_image() + _, _, _, xlim, ylim = render_to_png(sdata, "img", "global") + # For an Image2DModel with shape (c=3, y=64, x=64) and no transformations, + # xlim should cover [0, 64] and ylim [64, 0] (origin='upper'). Allow ±1 + # for matplotlib edge padding. + assert xlim[0] <= 0 and xlim[1] >= 63 + assert ylim[0] >= 63 and ylim[1] <= 0 From 390222de088089b2bf2dae3ff1ddbb4b693b5a31 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 21 May 2026 22:39:52 +0200 Subject: [PATCH 4/6] Address review feedback on sdata.pl.annotate() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Tests: import `DrawCanvas` from `._canvas` (internal class is not re-exported), and gate `test_canvas` with `pytest.importorskip` so CI envs without the `interactive` extra skip rather than fail. - `_persist.py`: replace deprecated `datetime.utcnow()` with timezone-aware `datetime.now(timezone.utc)`. Document on-disk overwrite behaviour (asymmetric with in-memory rename-on-collision). - `_render.py`: wrap render in `try/finally` so figures don't leak if `render_images().show()` or `savefig` raises. - `draw_canvas.js`: Delete/Backspace now removes the most recent shape (matches Ctrl+Z) instead of wiping the whole canvas — the toolbar Clear button covers the wipe case. - `basic.py` docstring: note that the canvas clears on every Save and that the Write-to-disk button overwrites same-named on-disk elements. - Add `tests/test_interactive/test_annotate.py` covering the three validation paths (`unknown CS`, `unknown element`, `element not in CS`) by stubbing `_InteractiveSession.show`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/spatialdata_plot/pl/basic.py | 8 ++- .../pl/interactive/_persist.py | 7 +- .../pl/interactive/_render.py | 22 ++++--- .../pl/interactive/static/draw_canvas.js | 2 +- tests/test_interactive/test_annotate.py | 64 +++++++++++++++++++ tests/test_interactive/test_canvas.py | 11 ++-- 6 files changed, 97 insertions(+), 17 deletions(-) create mode 100644 tests/test_interactive/test_annotate.py diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 2f9c2d9e..ce841011 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -185,7 +185,13 @@ def annotate( Drawn shapes are saved into ``sdata.shapes`` under a user-typed name on click of the *Save* button — each save creates one ShapesModel with one row per drawn shape, registered with an ``Identity`` - transformation in the chosen coordinate system. + transformation in the chosen coordinate system. The canvas is + cleared on every Save so the next set of shapes can be drawn + independently. + + In-memory name collisions are renamed to ``{name}_{UTC-ISO}``. The + on-disk *Write to disk* button calls ``SpatialData.write_element``, + which overwrites an existing on-disk element of the same name. Requires the ``interactive`` extra: ``pip install 'spatialdata-plot[interactive]'``. diff --git a/src/spatialdata_plot/pl/interactive/_persist.py b/src/spatialdata_plot/pl/interactive/_persist.py index 5d496187..ff816c42 100644 --- a/src/spatialdata_plot/pl/interactive/_persist.py +++ b/src/spatialdata_plot/pl/interactive/_persist.py @@ -19,7 +19,7 @@ def commit_to_memory( """ target = name if target in sdata.shapes: - ts = _dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + ts = _dt.datetime.now(_dt.timezone.utc).strftime("%Y%m%dT%H%M%SZ") target = f"{name}_{ts}" sdata.shapes[target] = shapes_model return target @@ -28,6 +28,11 @@ def commit_to_memory( def persist_to_disk(sdata: sd.SpatialData, name: str) -> None: """Persist ``sdata.shapes[name]`` to the backing zarr store. + ``sdata.write_element`` overwrites any existing on-disk element of the + same name. Disk-side collision handling differs from + :func:`commit_to_memory`, which renames in memory — call sites that + care should pass in the name returned by ``commit_to_memory``. + Raises ValueError if ``sdata`` is not zarr-backed. """ if sdata.path is None: diff --git a/src/spatialdata_plot/pl/interactive/_render.py b/src/spatialdata_plot/pl/interactive/_render.py index 8c4ebaee..4084be7f 100644 --- a/src/spatialdata_plot/pl/interactive/_render.py +++ b/src/spatialdata_plot/pl/interactive/_render.py @@ -36,14 +36,16 @@ def render_to_png( :func:`._commit.pixel_shape_to_polygon` for the conversion. """ fig = plt.figure(figsize=_FIGSIZE, dpi=_DPI) - ax = fig.add_axes([0, 0, 1, 1]) - sdata.pl.render_images(element=element).pl.show(coordinate_systems=coordinate_system, ax=ax) - xlim = ax.get_xlim() - ylim = ax.get_ylim() - ax.set_axis_off() - for spine in ax.spines.values(): - spine.set_visible(False) - buf = BytesIO() - fig.savefig(buf, format="png", dpi=_DPI, pad_inches=0) - plt.close(fig) + try: + ax = fig.add_axes([0, 0, 1, 1]) + sdata.pl.render_images(element=element).pl.show(coordinate_systems=coordinate_system, ax=ax) + xlim = ax.get_xlim() + ylim = ax.get_ylim() + ax.set_axis_off() + for spine in ax.spines.values(): + spine.set_visible(False) + buf = BytesIO() + fig.savefig(buf, format="png", dpi=_DPI, pad_inches=0) + finally: + plt.close(fig) return buf.getvalue(), _IMAGE_W, _IMAGE_H, tuple(xlim), tuple(ylim) diff --git a/src/spatialdata_plot/pl/interactive/static/draw_canvas.js b/src/spatialdata_plot/pl/interactive/static/draw_canvas.js index e86683b2..03f494bd 100644 --- a/src/spatialdata_plot/pl/interactive/static/draw_canvas.js +++ b/src/spatialdata_plot/pl/interactive/static/draw_canvas.js @@ -338,7 +338,7 @@ function render({ model, el }) { return; } if (e.key === 'Delete' || e.key === 'Backspace') { - if (shapes.length > 0) setShapes([]); + if (shapes.length > 0) setShapes(shapes.slice(0, -1)); e.preventDefault(); return; } diff --git a/tests/test_interactive/test_annotate.py b/tests/test_interactive/test_annotate.py new file mode 100644 index 00000000..ab435745 --- /dev/null +++ b/tests/test_interactive/test_annotate.py @@ -0,0 +1,64 @@ +"""Tests for the user-facing sdata.pl.annotate() validation paths.""" +from __future__ import annotations + +import numpy as np +import pytest +import spatialdata as sd + +pytest.importorskip("anywidget") +pytest.importorskip("ipywidgets") + +import spatialdata_plot # noqa: F401 registers .pl + + +def _make_sdata_with_image() -> sd.SpatialData: + from spatialdata.models import Image2DModel + + arr = np.random.default_rng(0).integers(0, 255, size=(3, 32, 32), dtype=np.uint8) + img = Image2DModel.parse(arr, dims=("c", "y", "x")) + return sd.SpatialData(images={"img": img}) + + +def test_annotate_unknown_coordinate_system_raises(monkeypatch): + sdata = _make_sdata_with_image() + monkeypatch.setattr( + "spatialdata_plot.pl.interactive._session._InteractiveSession.show", + lambda self: None, + ) + with pytest.raises(ValueError, match="Unknown coordinate system"): + sdata.pl.annotate("does_not_exist", "img") + + +def test_annotate_unknown_element_raises(monkeypatch): + sdata = _make_sdata_with_image() + monkeypatch.setattr( + "spatialdata_plot.pl.interactive._session._InteractiveSession.show", + lambda self: None, + ) + with pytest.raises(ValueError, match="Unknown image element"): + sdata.pl.annotate("global", "no_such_image") + + +def test_annotate_element_not_in_cs_raises(monkeypatch): + from spatialdata.models import Image2DModel + + # 'img' is registered only in 'other_cs'; a second element keeps 'global' + # in sdata.coordinate_systems so we trigger the "not registered in CS" + # branch rather than the "unknown CS" one. + rng = np.random.default_rng(0) + arr = rng.integers(0, 255, size=(3, 32, 32), dtype=np.uint8) + img = Image2DModel.parse( + arr, dims=("c", "y", "x"), transformations={"other_cs": sd.transformations.Identity()}, + ) + anchor = Image2DModel.parse( + rng.integers(0, 255, size=(3, 32, 32), dtype=np.uint8), + dims=("c", "y", "x"), + transformations={"global": sd.transformations.Identity()}, + ) + sdata = sd.SpatialData(images={"img": img, "anchor": anchor}) + monkeypatch.setattr( + "spatialdata_plot.pl.interactive._session._InteractiveSession.show", + lambda self: None, + ) + with pytest.raises(ValueError, match="not registered in coordinate system"): + sdata.pl.annotate("global", "img") diff --git a/tests/test_interactive/test_canvas.py b/tests/test_interactive/test_canvas.py index 163e0c99..010d9eac 100644 --- a/tests/test_interactive/test_canvas.py +++ b/tests/test_interactive/test_canvas.py @@ -1,17 +1,20 @@ """Smoke tests for the DrawCanvas anywidget class.""" from __future__ import annotations -from pathlib import Path +import pytest + +pytest.importorskip("anywidget") +pytest.importorskip("ipywidgets") def test_draw_canvas_imports(): - from spatialdata_plot.pl.interactive import DrawCanvas + from spatialdata_plot.pl.interactive._canvas import DrawCanvas assert DrawCanvas is not None def test_draw_canvas_default_traitlets(): - from spatialdata_plot.pl.interactive import DrawCanvas + from spatialdata_plot.pl.interactive._canvas import DrawCanvas c = DrawCanvas() assert c.tool == "rectangle" @@ -35,7 +38,7 @@ def test_draw_canvas_esm_file_is_bundled(): def test_draw_canvas_traitlet_assignment(): """Setting traitlets from Python should work (Python → JS sync).""" - from spatialdata_plot.pl.interactive import DrawCanvas + from spatialdata_plot.pl.interactive._canvas import DrawCanvas c = DrawCanvas() c.tool = "polygon" From 3b770599664afa63f3bfe08fda7a400832ee757a Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 21 May 2026 23:16:48 +0200 Subject: [PATCH 5/6] Refactor sdata.pl.annotate() from three-agent code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reuse / convention fixes: - Delete `_persist.persist_to_disk` — `SpatialData.write_element` already raises `ValueError` when `path is None`. Inline `write_element(name, overwrite=True)` in `_on_persist` and fix the docstring claim about overwrite semantics. - Match project import convention: `from spatialdata.transformations. operations import get_transformation` and `...transformations import Identity`, replacing `sd.transformations.*` references in `_session`, `_commit`, and tests. - `_validate` uses `sdata[element]` indexing + `get_transformation` to match the established pattern in `pl/utils.py` / `pl/render.py`. Quality: - Introduce frozen `RenderExtent` dataclass returned by `render_to_png`; collapses the 5-tuple return + 4 cached attrs on `_InteractiveSession` into one object. `pixel_shape_to_polygon(shape, extent)` drops 4 args. - `traitlets.Enum(TOOLS, ...)` for `DrawCanvas.tool` so a typo raises. - `BannerKind = Literal["info","success","error","hint"]`; drop the silent `.get(..., default)` fallback so a banner-kind typo raises. - Factor `_trigger_btn(description, icon, trait_name, after=...)` — replaces 4 near-identical `_on_close_polygon` / `_on_undo` / `_on_clear` / `_on_fit` methods. - Split `_on_save` into `_collect_polygons` / `_commit_polygons` / `_reset_canvas_state`; orchestrator stays ~10 lines. - Drop redundant `spine.set_visible(False)` loop after `set_axis_off()`. - Guard `persist_btn` construction behind `persist=True`; `_on_persist` early-returns if disabled. - Strip restate-the-code comments (`_canvas` module docstring, `_render` v0/v1 narration, `_commit` lasso-restating comment). - Underscore truly-private attrs (`_sdata`, `_commits`); keep `canvas` un-underscored since `_validate` callers in tests still need a way in. Efficiency (JS): - Incremental in-progress shape update during rect/lasso drag — keep a stable reference to the in-progress SVG node; mutate its attributes in `onMouseMove` instead of full `redraw()` (60 Hz × O(N) DOM ops → O(1)). - Lasso vert-push gated on viewbox-px ≥ 1 from the last vert; cuts vertex count ~5-10× for typical drags and the kernel-side traitlet payload on commit. - `setShapes` early-returns on `next === shapes` or both-empty; the `clear_trigger` handler routes through `setShapes([])` and skips when there is nothing to clear or cancel. - `zoomAt` / `panBy` / `fitView` snapshot the vbox pre-clamp and skip `applyViewbox` + `redraw` if the clamped vbox is unchanged. - `change:tool` only redraws if there was an in-progress shape to clear. - Dedup: `popLastShape` helper used by Ctrl+Z, Delete/Backspace, and `change:undo_trigger`. `shapeNode` extracts the common stroke/fill attrs and uses a single `pointsAttr` formatter for polygon/polyline. Tests: - Pytest `no_display` fixture replaces three duplicate `monkeypatch. setattr(...)` calls in `test_annotate.py`. - `pixel_shape_to_polygon` tests updated to the `RenderExtent` signature via a small `_extent(...)` helper. - `test_render` reads `extent.image_w` / `extent.xlim` from the dataclass instead of unpacking a 5-tuple. - Drop the two `test_persist` tests for the deleted `persist_to_disk` wrapper; `commit_to_memory` policy tests remain. All 18 interactive tests pass in dev-interactive-py313. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../pl/interactive/_canvas.py | 7 +- .../pl/interactive/_commit.py | 48 ++--- .../pl/interactive/_persist.py | 26 +-- .../pl/interactive/_render.py | 42 ++-- .../pl/interactive/_session.py | 198 +++++++++--------- .../pl/interactive/static/draw_canvas.js | 178 ++++++++++------ tests/test_interactive/test_annotate.py | 39 ++-- tests/test_interactive/test_commit.py | 45 ++-- tests/test_interactive/test_persist.py | 37 +--- tests/test_interactive/test_render.py | 18 +- 10 files changed, 299 insertions(+), 339 deletions(-) diff --git a/src/spatialdata_plot/pl/interactive/_canvas.py b/src/spatialdata_plot/pl/interactive/_canvas.py index d78b878d..ead15676 100644 --- a/src/spatialdata_plot/pl/interactive/_canvas.py +++ b/src/spatialdata_plot/pl/interactive/_canvas.py @@ -1,4 +1,3 @@ -"""anywidget wrapping a client-side SVG drawing canvas.""" from __future__ import annotations from pathlib import Path @@ -8,6 +7,8 @@ _ESM_PATH = Path(__file__).parent / "static" / "draw_canvas.js" +TOOLS = ("rectangle", "polygon", "lasso") + class DrawCanvas(anywidget.AnyWidget): """Client-side SVG drawing surface for interactive region selection. @@ -26,7 +27,7 @@ class DrawCanvas(anywidget.AnyWidget): image_width, image_height Pixel dimensions of the PNG (used to set the SVG ``viewBox``). tool - ``"rectangle"``, ``"polygon"``, or ``"lasso"``. + One of ``TOOLS``. shapes List of ``{"type": "rect"|"polygon", "verts": [[x, y], ...]}`` in image-pixel coordinates. JS pushes to this on commit. @@ -41,7 +42,7 @@ class DrawCanvas(anywidget.AnyWidget): image_url = traitlets.Unicode("").tag(sync=True) image_width = traitlets.Int(720).tag(sync=True) image_height = traitlets.Int(720).tag(sync=True) - tool = traitlets.Unicode("rectangle").tag(sync=True) + tool = traitlets.Enum(TOOLS, default_value="rectangle").tag(sync=True) shapes = traitlets.List([]).tag(sync=True) clear_trigger = traitlets.Int(0).tag(sync=True) close_poly_trigger = traitlets.Int(0).tag(sync=True) diff --git a/src/spatialdata_plot/pl/interactive/_commit.py b/src/spatialdata_plot/pl/interactive/_commit.py index 8a0ae949..c60bba1a 100644 --- a/src/spatialdata_plot/pl/interactive/_commit.py +++ b/src/spatialdata_plot/pl/interactive/_commit.py @@ -4,63 +4,45 @@ from typing import Any import geopandas as gpd -import spatialdata as sd from shapely.geometry import Polygon from spatialdata.models import ShapesModel +from spatialdata.transformations.transformations import Identity + +from ._render import RenderExtent -# Tolerance for lasso simplification, in CS units. The lasso path is sampled -# at every mouse-move so a freehand loop easily exceeds 1000 vertices; -# 0.5 keeps shape fidelity while collapsing co-linear points. _LASSO_SIMPLIFY_TOL = 0.5 +_DENSE_POLYGON_VERTEX_THRESHOLD = 50 -def pixel_shape_to_polygon( - shape: dict[str, Any], - image_w: int, - image_h: int, - xlim: tuple[float, float], - ylim: tuple[float, float], -) -> Polygon | None: +def pixel_shape_to_polygon(shape: dict[str, Any], extent: RenderExtent) -> Polygon | None: """Convert a single ``DrawCanvas`` shape entry to a CS-coord shapely Polygon. Returns ``None`` if the shape is invalid (no verts, <3 verts after construction, or empty). - - The matplotlib image axes use ``origin='upper'`` — the *smaller* y-value - of ``ylim`` corresponds to the top of the rendered image (PNG row 0). """ verts = shape.get("verts") if isinstance(shape, dict) else None if not verts: return None - xmin, xmax = float(xlim[0]), float(xlim[1]) - y_lo, y_hi = sorted((float(ylim[0]), float(ylim[1]))) - - def px_to_cs(px: float, py: float) -> tuple[float, float]: - return ( - xmin + (px / image_w) * (xmax - xmin), - y_lo + (py / image_h) * (y_hi - y_lo), - ) + xmin, xmax = float(extent.xlim[0]), float(extent.xlim[1]) + y_lo, y_hi = sorted((float(extent.ylim[0]), float(extent.ylim[1]))) + w, h = extent.image_w, extent.image_h - cs_verts = [px_to_cs(v[0], v[1]) for v in verts] + cs_verts = [ + (xmin + (v[0] / w) * (xmax - xmin), y_lo + (v[1] / h) * (y_hi - y_lo)) + for v in verts + ] if len(cs_verts) < 3: return None poly = Polygon(cs_verts) if poly.is_empty: return None - if shape.get("type") == "polygon" and len(cs_verts) > 50: - # Lasso-like (high vertex count) polygons benefit from simplification. + if shape.get("type") == "polygon" and len(cs_verts) > _DENSE_POLYGON_VERTEX_THRESHOLD: poly = poly.simplify(_LASSO_SIMPLIFY_TOL, preserve_topology=True) return poly -def build_shapes_model( - polygons: list[Polygon], - coordinate_system: str, -) -> Any: +def build_shapes_model(polygons: list[Polygon], coordinate_system: str) -> Any: """Wrap shapely polygons in a ShapesModel registered with Identity in ``coordinate_system``.""" gdf = gpd.GeoDataFrame({"geometry": polygons}) - return ShapesModel.parse( - gdf, - transformations={coordinate_system: sd.transformations.Identity()}, - ) + return ShapesModel.parse(gdf, transformations={coordinate_system: Identity()}) diff --git a/src/spatialdata_plot/pl/interactive/_persist.py b/src/spatialdata_plot/pl/interactive/_persist.py index ff816c42..1fec0bd4 100644 --- a/src/spatialdata_plot/pl/interactive/_persist.py +++ b/src/spatialdata_plot/pl/interactive/_persist.py @@ -1,4 +1,4 @@ -"""Commit a ShapesModel into sdata.shapes (memory) and optionally to zarr.""" +"""Commit a ShapesModel into sdata.shapes under a collision-safe name.""" from __future__ import annotations import datetime as _dt @@ -7,11 +7,7 @@ import spatialdata as sd -def commit_to_memory( - sdata: sd.SpatialData, - shapes_model: Any, - name: str, -) -> str: +def commit_to_memory(sdata: sd.SpatialData, shapes_model: Any, name: str) -> str: """Add ``shapes_model`` to ``sdata.shapes`` under ``name``. On collision, the existing element is preserved and the new one is @@ -23,21 +19,3 @@ def commit_to_memory( target = f"{name}_{ts}" sdata.shapes[target] = shapes_model return target - - -def persist_to_disk(sdata: sd.SpatialData, name: str) -> None: - """Persist ``sdata.shapes[name]`` to the backing zarr store. - - ``sdata.write_element`` overwrites any existing on-disk element of the - same name. Disk-side collision handling differs from - :func:`commit_to_memory`, which renames in memory — call sites that - care should pass in the name returned by ``commit_to_memory``. - - Raises ValueError if ``sdata`` is not zarr-backed. - """ - if sdata.path is None: - raise ValueError( - "SpatialData is not zarr-backed (sdata.path is None); cannot persist. " - "Write the SpatialData object to a zarr store first, then re-open it." - ) - sdata.write_element(name) diff --git a/src/spatialdata_plot/pl/interactive/_render.py b/src/spatialdata_plot/pl/interactive/_render.py index 4084be7f..490121c3 100644 --- a/src/spatialdata_plot/pl/interactive/_render.py +++ b/src/spatialdata_plot/pl/interactive/_render.py @@ -1,39 +1,42 @@ """Render an image element to a PNG suitable for the DrawCanvas background.""" from __future__ import annotations +from dataclasses import dataclass from io import BytesIO import matplotlib.pyplot as plt import spatialdata as sd -# Render size: 7 in × 120 dpi = 840 px square. Fixed for v0; pyramid-aware -# downsampling is a v1 feature. _FIGSIZE = (7, 7) _DPI = 120 -_IMAGE_W = _FIGSIZE[0] * _DPI # 840 -_IMAGE_H = _FIGSIZE[1] * _DPI # 840 +_IMAGE_W = _FIGSIZE[0] * _DPI +_IMAGE_H = _FIGSIZE[1] * _DPI + + +@dataclass(frozen=True) +class RenderExtent: + """Geometry of a render — PNG pixel dims + CS-coord limits at render time. + + For matplotlib image axes (``origin='upper'``) ``ylim`` is reversed: + the smaller y maps to PNG row 0. ``pixel_shape_to_polygon`` accepts + either orientation. + """ + + image_w: int + image_h: int + xlim: tuple[float, float] + ylim: tuple[float, float] def render_to_png( sdata: sd.SpatialData, element: str, coordinate_system: str, -) -> tuple[bytes, int, int, tuple[float, float], tuple[float, float]]: - """Render ``element`` in ``coordinate_system`` to PNG bytes. +) -> tuple[bytes, RenderExtent]: + """Render ``element`` in ``coordinate_system`` to PNG + its extent. The matplotlib axes fills the figure (``[0, 0, 1, 1]`` with axis off) so the PNG-pixel ↔ data-coord mapping is exactly ``xlim`` × ``ylim``. - - Returns - ------- - png_bytes - PNG-encoded image. - image_w, image_h - Pixel dimensions of the PNG. - xlim, ylim - ``ax.get_xlim()`` / ``ax.get_ylim()`` at render time. For image axes - (``origin='upper'``) ``ylim`` is reversed — see - :func:`._commit.pixel_shape_to_polygon` for the conversion. """ fig = plt.figure(figsize=_FIGSIZE, dpi=_DPI) try: @@ -42,10 +45,9 @@ def render_to_png( xlim = ax.get_xlim() ylim = ax.get_ylim() ax.set_axis_off() - for spine in ax.spines.values(): - spine.set_visible(False) buf = BytesIO() fig.savefig(buf, format="png", dpi=_DPI, pad_inches=0) finally: plt.close(fig) - return buf.getvalue(), _IMAGE_W, _IMAGE_H, tuple(xlim), tuple(ylim) + extent = RenderExtent(_IMAGE_W, _IMAGE_H, tuple(xlim), tuple(ylim)) + return buf.getvalue(), extent diff --git a/src/spatialdata_plot/pl/interactive/_session.py b/src/spatialdata_plot/pl/interactive/_session.py index 15363475..ce6c5c0f 100644 --- a/src/spatialdata_plot/pl/interactive/_session.py +++ b/src/spatialdata_plot/pl/interactive/_session.py @@ -1,21 +1,28 @@ -"""ipywidgets-based session orchestrating the DrawCanvas. - -Internal class — users invoke :meth:`PlotAccessor.annotate`, which constructs -a session and displays it. -""" +"""ipywidgets-based session orchestrating the DrawCanvas. Internal.""" from __future__ import annotations import base64 -from typing import Any +from typing import Any, Literal import ipywidgets as W import spatialdata as sd from IPython.display import display +from shapely.geometry import Polygon +from spatialdata.transformations.operations import get_transformation from ._canvas import DrawCanvas from ._commit import build_shapes_model, pixel_shape_to_polygon -from ._persist import commit_to_memory, persist_to_disk -from ._render import render_to_png +from ._persist import commit_to_memory +from ._render import RenderExtent, render_to_png + +BannerKind = Literal["info", "success", "error", "hint"] + +_BANNER_CLASS = { + "info": "sdp-banner sdp-banner-info", + "success": "sdp-banner sdp-banner-success", + "error": "sdp-banner sdp-banner-error", + "hint": "sdp-banner sdp-banner-hint", +} _CSS = """