diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index e1d28d3b..3d99cc77 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -20,7 +20,7 @@ from geopandas import GeoDataFrame from matplotlib.axes import Axes from matplotlib.backend_bases import RendererBase -from matplotlib.colors import Colormap, Normalize +from matplotlib.colors import Colormap, LogNorm, Normalize from matplotlib.figure import Figure from mpl_toolkits.axes_grid1.inset_locator import inset_axes from spatialdata import get_extent @@ -1299,6 +1299,18 @@ def _draw_colorbar( base_offsets_axes: dict[str, float], trackers_axes: dict[str, float], ) -> None: + norm = spec.mappable.norm + if isinstance(norm, LogNorm): + vmin, vmax = norm.vmin, norm.vmax + if vmin is None or vmax is None or vmin <= 0 or vmin >= vmax: + warnings.warn( + "Data contains zeros or non-positive values; colorbar suppressed for `LogNorm`. " + "Pass `colorbar=False` to silence this warning, or clip the data to positive values.", + UserWarning, + stacklevel=2, + ) + return + base_layout = { "location": CBAR_DEFAULT_LOCATION, "fraction": CBAR_DEFAULT_FRACTION, diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index fb723c2f..13fe675c 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -4,7 +4,7 @@ import numpy as np import pytest import scanpy as sc -from matplotlib.colors import Normalize +from matplotlib.colors import LogNorm, Normalize from spatial_image import to_spatial_image from spatialdata import SpatialData from spatialdata.models import Image2DModel @@ -682,3 +682,16 @@ def test_channels_as_legend_coexists_with_other_elements(self, sdata_blobs: Spat assert "0" in labels assert "1" in labels plt.close("all") + + +def test_lognorm_with_zeros_suppresses_colorbar_with_warning(): + # regression test for #604: LogNorm + non-positive data must not raise an opaque + # matplotlib ValueError; instead suppress the colorbar with an actionable UserWarning. + img = np.zeros((1, 5, 5), dtype=np.float32) + sdata = SpatialData(images={"img": Image2DModel.parse(img, c_coords=["DAPI"])}) + fig, ax = plt.subplots() + try: + with pytest.warns(UserWarning, match="LogNorm"): + sdata.pl.render_images("img", norm=LogNorm()).pl.show(ax=ax) + finally: + plt.close(fig)