diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index c1bafeba..5cd6688e 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -4,7 +4,7 @@ from collections import abc from collections.abc import Sequence from copy import copy -from typing import Any +from typing import Any, Literal import dask import dask.dataframe as dd @@ -80,6 +80,24 @@ _Normalize = Normalize | abc.Sequence[Normalize] +def _get_top_data_array(element: xr.DataArray | DataTree) -> xr.DataArray: + if isinstance(element, DataTree): + return next(iter(next(iter(element.values())).data_vars.values())) + return element + + +def _guard_2d_only(element: xr.DataArray | DataTree, element_name: str, kind: Literal["images", "labels"]) -> None: + top = _get_top_data_array(element) + if "z" in top.dims: + z_size = top.sizes["z"] + raise ValueError( + f"render_{kind} does not support 3D {kind}. Element '{element_name}' has a 'z' dimension " + f"with {z_size} slices. Select a 2D slice before plotting:\n" + f" sdata['{element_name}'].isel(z=0)\n" + "or use sd.bounding_box_query() to extract a 2D region." + ) + + def _want_decorations(color_vector: Any, na_color: Color) -> bool: """Return whether legend/colorbar decorations should be shown. @@ -1247,6 +1265,7 @@ def _render_images( palette = render_params.palette img = sdata_filt[render_params.element] + _guard_2d_only(img, render_params.element, "images") extent = get_extent(img, coordinate_system=coordinate_system) scale = render_params.scale @@ -1674,6 +1693,7 @@ def _render_labels( ) label = sdata_filt.labels[element] + _guard_2d_only(label, element, "labels") extent = get_extent(label, coordinate_system=coordinate_system) # get best scale out of multiscale label diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 7f871257..7c2390ab 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -7,7 +7,7 @@ from matplotlib.colors import LogNorm, Normalize from spatial_image import to_spatial_image from spatialdata import SpatialData -from spatialdata.models import Image2DModel +from spatialdata.models import Image2DModel, Image3DModel import spatialdata_plot # noqa: F401 from spatialdata_plot._logging import logger, logger_warns @@ -720,6 +720,21 @@ def test_channels_as_legend_coexists_with_other_elements(self, sdata_blobs: Spat plt.close("all") +@pytest.mark.parametrize("scale_factors", [None, [2]]) +def test_render_images_raises_on_3d(scale_factors): + # Regression test for #608: 3D images must raise a clear ValueError, not crash + # deep in matplotlib with "Invalid shape" / opaque numpy errors. + img = np.random.default_rng(0).random((2, 4, 16, 16), dtype=np.float32) + image3d = Image3DModel.parse(img, dims=["c", "z", "y", "x"], c_coords=["DAPI", "GFP"], scale_factors=scale_factors) + sdata = SpatialData(images={"img3d": image3d}) + fig, ax = plt.subplots() + try: + with pytest.raises(ValueError, match=r"render_images does not support 3D.*img3d.*z.*4"): + sdata.pl.render_images("img3d").pl.show(ax=ax) + finally: + plt.close(fig) + + def test_lognorm_with_zeros_suppresses_colorbar_with_warning(): # regression test for #604: LogNorm + non-positive data must not raise an opaque # matplotlib ValueError; instead suppress the colorbar with an actionable UserWarning. diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 931ac24c..4c76b1a1 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -9,7 +9,7 @@ from matplotlib.colors import Normalize from spatial_image import to_spatial_image from spatialdata import SpatialData, deepcopy, get_element_instances -from spatialdata.models import Labels2DModel, TableModel +from spatialdata.models import Labels2DModel, Labels3DModel, TableModel import spatialdata_plot # noqa: F401 from spatialdata_plot._logging import logger, logger_warns @@ -550,3 +550,18 @@ def test_render_labels_disjoint_instance_ids_clear_error(): sdata.pl.render_labels("lbl", color="cat", table_name="t").pl.show(ax=ax) finally: plt.close(fig) + + +@pytest.mark.parametrize("scale_factors", [None, [2]]) +def test_render_labels_raises_on_3d(scale_factors): + # Regression test for #608: 3D labels must raise a clear ValueError, not crash + # deep in numpy with an opaque concatenation error. + arr = np.random.default_rng(0).integers(0, 5, size=(4, 16, 16), dtype=np.int32) + labels3d = Labels3DModel.parse(arr, dims=["z", "y", "x"], scale_factors=scale_factors) + sdata = SpatialData(labels={"lbl3d": labels3d}) + fig, ax = plt.subplots() + try: + with pytest.raises(ValueError, match=r"render_labels does not support 3D.*lbl3d.*z.*4"): + sdata.pl.render_labels("lbl3d").pl.show(ax=ax) + finally: + plt.close(fig)