diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py index 1d6415ed..59b58d99 100644 --- a/src/spatialdata_plot/pl/_datashader.py +++ b/src/spatialdata_plot/pl/_datashader.py @@ -17,7 +17,7 @@ from matplotlib.colors import Normalize from spatialdata_plot._logging import logger -from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams +from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams, _DsReduction from spatialdata_plot.pl.utils import ( _ax_show_and_transform, _convert_alpha_to_datashader_range, @@ -32,8 +32,6 @@ # Type aliases and constants # --------------------------------------------------------------------------- -_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] - # Sentinel category name used in datashader categorical paths to represent # missing (NaN) values. Must not collide with realistic user category names. _DS_NAN_CATEGORY = "ds_nan" diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 996161ae..ab38b312 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Sequence from copy import deepcopy from pathlib import Path -from typing import Any, Literal, cast +from typing import Any, Literal, cast, get_args import matplotlib import matplotlib.pyplot as plt @@ -29,7 +29,7 @@ from xarray import DataArray, DataTree from spatialdata_plot._accessor import register_spatial_data_accessor -from spatialdata_plot._logging import _log_context +from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl.render import ( _draw_channel_legend, _render_graph, @@ -52,8 +52,10 @@ LegendParams, PointsRenderParams, ShapesRenderParams, + _DsReduction, _FontSize, _FontWeight, + _ImageDsReduction, ) from spatialdata_plot.pl.utils import ( _RENDER_CMD_TO_CS_FLAG, @@ -194,7 +196,7 @@ def render_shapes( shape: Literal["circle", "hex", "visium_hex", "square"] | None = None, colorbar: bool | str | None = "auto", colorbar_params: dict[str, object] | None = None, - datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None, + datashader_reduction: _DsReduction | None = None, transfunc: Callable[[float], float] | None = None, ) -> sd.SpatialData: """ @@ -384,7 +386,7 @@ def render_points( gene_symbols: str | None = None, colorbar: bool | str | None = "auto", colorbar_params: dict[str, object] | None = None, - datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None, + datashader_reduction: _DsReduction | None = None, transfunc: Callable[[float], float] | None = None, ) -> sd.SpatialData: """ @@ -536,6 +538,8 @@ def render_images( colorbar: bool | str | None = "auto", colorbar_params: dict[str, object] | None = None, channels_as_legend: bool = False, + method: Literal["matplotlib", "datashader"] | None = None, + datashader_reduction: _ImageDsReduction | None = None, ) -> sd.SpatialData: """ Render image elements in SpatialData. @@ -616,6 +620,21 @@ def render_images( Ignored for single-channel and RGB(A) images. When multiple ``render_images`` calls use this flag on the same axes, all channel entries are combined into a single legend. + method : str | None, optional + Whether to use ``'matplotlib'`` (default) or ``'datashader'`` for + the downsampling step. When ``'datashader'`` is selected, the + rasterization-to-canvas step uses + :meth:`datashader.Canvas.raster` with ``datashader_reduction`` as the + downsample method (default ``'max'``), and ``imshow`` is rendered + with ``interpolation='nearest'`` so the chosen reduction is not + re-smoothed at display time. Useful for very sparse images + (mostly zeros) where mean aggregation collapses the signal — + ``method='datashader'`` with ``datashader_reduction='max'`` preserves the + rare non-zero pixels (``plt.spy``-style). + datashader_reduction : {"max", "min", "mean", "mode", "first", "last", "var", "std"} | None, optional + Downsample reduction used by the datashader path. Defaults to + ``'max'`` when ``method='datashader'``. Ignored otherwise (a + warning is emitted if set without ``method='datashader'``). Notes ----- @@ -634,6 +653,22 @@ def render_images( """ if grayscale and palette is not None: raise ValueError("Cannot combine grayscale=True with palette.") + + if method is not None and not isinstance(method, str): + raise TypeError("Parameter 'method' must be a string.") + if method is not None and method not in ("matplotlib", "datashader"): + raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.") + _valid_image_reductions = get_args(_ImageDsReduction) + if datashader_reduction is not None and not isinstance(datashader_reduction, str): + raise TypeError("Parameter 'datashader_reduction' must be a string.") + if datashader_reduction is not None and datashader_reduction not in _valid_image_reductions: + raise ValueError( + f"Parameter 'datashader_reduction' must be one of {_valid_image_reductions}, " + f"got {datashader_reduction!r}." + ) + if datashader_reduction is not None and method != "datashader": + logger.warning("Parameter 'datashader_reduction' has no effect unless method='datashader'; ignoring.") + params_dict = _validate_image_render_params( self._sdata, element=element, @@ -699,6 +734,8 @@ def render_images( transfunc=transfunc, grayscale=grayscale, channels_as_legend=channels_as_legend, + method=method, + ds_reduction=datashader_reduction, ) n_steps += 1 diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 5cd6688e..494f8333 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -39,7 +39,6 @@ _ds_aggregate, _ds_shade_categorical, _ds_shade_continuous, - _DsReduction, _render_ds_image, _render_ds_outlines, ) @@ -55,6 +54,7 @@ LegendParams, PointsRenderParams, ShapesRenderParams, + _DsReduction, ) from spatialdata_plot.pl.utils import ( _ax_show_and_transform, @@ -73,6 +73,7 @@ _prepare_cmap_norm, _prepare_transformation, _rasterize_if_necessary, + _rasterize_if_necessary_datashader, _set_color_source_vec, _validate_polygons, ) @@ -1279,7 +1280,24 @@ def _render_images( scale=scale, ) # rasterize spatial image if necessary to speed up performance - if rasterize: + use_datashader = render_params.method == "datashader" + if use_datashader: + downsample_method = render_params.ds_reduction or "max" + logger.info( + f"Using 'datashader' backend with '{downsample_method}' as downsample method. " + "Depending on the reduction, the value range of the plot might change. " + "Set method to 'matplotlib' to disable this behaviour." + ) + img = _rasterize_if_necessary_datashader( + image=img, + dpi=fig_params.fig.dpi, + width=fig_params.fig.get_size_inches()[0], + height=fig_params.fig.get_size_inches()[1], + coordinate_system=coordinate_system, + extent=extent, + downsample_method=downsample_method, + ) + elif rasterize: img = _rasterize_if_necessary( image=img, dpi=fig_params.fig.dpi, @@ -1389,6 +1407,10 @@ def _render_images( "Consider using 'palette' instead." ) + # Force nearest-neighbor at display time when the datashader reduction picked + # a non-mean aggregation; otherwise imshow's default interpolation would smear it. + _interp = "nearest" if use_datashader else None + # Detect RGB(A) images by channel names — skip when user overrides with palette/cmap is_rgb, has_alpha = _is_rgb_image(channels) has_explicit_cmap = ( @@ -1430,7 +1452,7 @@ def _render_images( render_params.alpha, ) - _ax_show_and_transform(stacked, trans_data, ax, **show_kwargs) + _ax_show_and_transform(stacked, trans_data, ax, interpolation=_interp, **show_kwargs) if render_params.channels_as_legend: logger.warning("channels_as_legend is not supported for true RGB images and will be ignored.") return @@ -1457,6 +1479,7 @@ def _render_images( cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm, + interpolation=_interp, ) wants_colorbar = _should_request_colorbar( @@ -1549,6 +1572,7 @@ def _render_images( ax, render_params.alpha, zorder=render_params.zorder, + interpolation=_interp, ) # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors @@ -1613,6 +1637,7 @@ def _render_images( ax, render_params.alpha, zorder=render_params.zorder, + interpolation=_interp, ) # 2C) palette set; also covers `palette + norm=list` since synthesized @@ -1633,6 +1658,7 @@ def _render_images( ax, render_params.alpha, zorder=render_params.zorder, + interpolation=_interp, ) elif palette is None and got_multiple_cmaps: @@ -1654,6 +1680,7 @@ def _render_images( ax, render_params.alpha, zorder=render_params.zorder, + interpolation=_interp, ) # Collect channel legend entries (single point for all multi-channel paths) diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index e7232ec7..ec0459d0 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -12,6 +12,8 @@ _FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"] _FontSize = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"] +_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] +_ImageDsReduction = Literal["max", "min", "mean", "mode", "first", "last", "var", "std"] # replace with # from spatialdata._types import ColorLike @@ -243,7 +245,7 @@ class ShapesRenderParams: table_name: str | None = None table_layer: str | None = None shape: Literal["circle", "hex", "visium_hex", "square"] | None = None - ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None + ds_reduction: _DsReduction | None = None colorbar: bool | str | None = "auto" colorbar_params: dict[str, object] | None = None @@ -265,7 +267,7 @@ class PointsRenderParams: zorder: int = 0 table_name: str | None = None table_layer: str | None = None - ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None + ds_reduction: _DsReduction | None = None colorbar: bool | str | None = "auto" colorbar_params: dict[str, object] | None = None @@ -286,6 +288,8 @@ class ImageRenderParams: transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None grayscale: bool = False channels_as_legend: bool = False + method: Literal["matplotlib", "datashader"] | None = None + ds_reduction: _ImageDsReduction | None = None @dataclass diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 35844c6a..ea2dd421 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -83,6 +83,7 @@ PointsRenderParams, ScalebarParams, ShapesRenderParams, + _DsReduction, _FontSize, _FontWeight, ) @@ -2048,6 +2049,61 @@ def _rasterize_if_necessary( return image +def _rasterize_if_necessary_datashader( + image: DataArray, + dpi: float, + width: float, + height: float, + coordinate_system: str, + extent: dict[str, tuple[float, float]], + downsample_method: str, +) -> DataArray: + """Downsample to canvas resolution with a configurable datashader reduction. + + Used by ``render_images(method='datashader')`` so sparse images (mostly + zeros, rare non-zero pixels) survive the downsample step instead of + being averaged away by the default mean aggregation. + """ + has_c_dim = len(image.shape) == 3 + y_dims, x_dims = (image.shape[1], image.shape[2]) if has_c_dim else image.shape + + target_y_dims = int(dpi * height) + target_x_dims = int(dpi * width) + + if y_dims <= target_y_dims and x_dims <= target_x_dims: + return image + + # spatialdata.rasterize is invoked solely to inherit the output coords and + # spatial transformation; its mean-aggregated values are overwritten below. + # TODO: this wastes a full per-channel resample pass. A future refactor can + # construct the target DataArray + transformation directly once spatialdata + # exposes a public geometry-only helper. + world_x = float(extent["x"][1]) - float(extent["x"][0]) + world_y = float(extent["y"][1]) - float(extent["y"][0]) + target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x) + base = rasterize( + image, + ("y", "x"), + [extent["y"][0], extent["x"][0]], + [extent["y"][1], extent["x"][1]], + coordinate_system, + target_unit_to_pixels=target_unit_to_pixels, + ) + + out_y, out_x = (base.shape[1], base.shape[2]) if has_c_dim else base.shape + # Materialize once: per-chunk reductions across channels would otherwise + # trigger repeated dask graph evaluations on the same source array. + src = image.compute() if hasattr(image.data, "compute") else image + cvs = ds.Canvas( + plot_width=out_x, + plot_height=out_y, + x_range=(float(extent["x"][0]), float(extent["x"][1])), + y_range=(float(extent["y"][0]), float(extent["y"][1])), + ) + base.values = np.asarray(cvs.raster(src, downsample_method=downsample_method).values).astype(base.dtype, copy=False) + return base + + def _multiscale_to_spatial_image( multiscale_image: DataTree, dpi: float, @@ -3385,6 +3441,7 @@ def _ax_show_and_transform( cmap: ListedColormap | LinearSegmentedColormap | None = None, zorder: int = 0, norm: Normalize | None = None, + interpolation: str | None = None, ) -> matplotlib.image.AxesImage: # ``extent`` uses mpl's pixel-grid convention; world placement happens via # ``set_transform(trans_data)`` afterwards. @@ -3396,6 +3453,8 @@ def _ax_show_and_transform( imshow_kwargs["alpha"] = alpha else: imshow_kwargs["cmap"] = cmap + if interpolation is not None: + imshow_kwargs["interpolation"] = interpolation im = ax.imshow(array, **imshow_kwargs) im.set_transform(trans_data) return im @@ -3508,7 +3567,7 @@ def _create_image_from_datashader_result( def _datashader_aggregate_with_function( - reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, + reduction: _DsReduction | None, cvs: Canvas, spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame, col_for_color: str | None, @@ -3572,7 +3631,7 @@ def _datashader_aggregate_with_function( def _datshader_get_how_kw_for_spread( - reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, + reduction: _DsReduction | None, ) -> str: # Get the best input for the how argument of ds.tf.spread(), needed for numerical values reduction = reduction or "sum" diff --git a/tests/_images/Images_method_datashader_preserves_sparse_pixels.png b/tests/_images/Images_method_datashader_preserves_sparse_pixels.png new file mode 100644 index 00000000..3333600d Binary files /dev/null and b/tests/_images/Images_method_datashader_preserves_sparse_pixels.png differ diff --git a/tests/_images/Images_method_datashader_reduction_grid.png b/tests/_images/Images_method_datashader_reduction_grid.png new file mode 100644 index 00000000..f12bd7bc Binary files /dev/null and b/tests/_images/Images_method_datashader_reduction_grid.png differ diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 7c2390ab..cf4ef89b 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -162,6 +162,32 @@ def test_plot_constant_channel_renders_as_midgrey(self): sdata = SpatialData(images={"img": img}) sdata.pl.render_images("img").pl.show(title="constant channel: mid-value (not black)") + def test_plot_method_datashader_preserves_sparse_pixels(self): + # #449: bright pixels in a sparse image must survive the downsample step. + arr = np.zeros((1, 1024, 1024), dtype=np.float32) + rng = np.random.default_rng(0) + arr[0, rng.integers(0, 1024, 50), rng.integers(0, 1024, 50)] = 1.0 + sdata = SpatialData(images={"img": Image2DModel.parse(arr, c_coords=["c1"])}) + fig, axs = plt.subplots(1, 2, figsize=(8, 4)) + sdata.pl.render_images("img").pl.show(ax=axs[0], colorbar=False, title="default (mean)") + sdata.pl.render_images("img", method="datashader", datashader_reduction="max").pl.show( + ax=axs[1], colorbar=False, title="datashader (max)" + ) + + def test_plot_method_datashader_reduction_grid(self): + # Mid-grey background with sparse bright pixels: each reduction yields a + # visibly distinct panel — max preserves spots, min/mode show the + # background only, mean shows a slightly-lifted background. + rng = np.random.default_rng(0) + arr = np.full((1, 1024, 1024), 0.3, dtype=np.float32) + arr[0, rng.integers(0, 1024, 50), rng.integers(0, 1024, 50)] = 1.0 + sdata = SpatialData(images={"img": Image2DModel.parse(arr, c_coords=["c1"])}) + fig, axs = plt.subplots(2, 2, figsize=(8, 8)) + for ax, red in zip(axs.flat, ("max", "min", "mean", "mode"), strict=True): + sdata.pl.render_images("img", method="datashader", datashader_reduction=red).pl.show( + ax=ax, colorbar=False, title=red + ) + # --------------------------------------------------------------------------- # Grayscale + transfunc visual tests @@ -746,3 +772,115 @@ def test_lognorm_with_zeros_suppresses_colorbar_with_warning(): sdata.pl.render_images("img", norm=LogNorm()).pl.show(ax=ax) finally: plt.close(fig) + + +def _render_sparse_image_max(**kwargs) -> float: + arr = np.zeros((1, 1024, 1024), dtype=np.float32) + arr[0, 500, 500] = 1.0 + sdata = SpatialData(images={"img": Image2DModel.parse(arr, c_coords=["c1"])}) + fig, ax = plt.subplots(figsize=(2, 2), dpi=50) + try: + sdata.pl.render_images("img", **kwargs).pl.show(ax=ax) + return float(np.nanmax(ax.get_images()[0].get_array())) + finally: + plt.close(fig) + + +def test_render_images_datashader_preserves_sparse_max(): + # Regression test for #449. + default_max = _render_sparse_image_max() + datashader_max = _render_sparse_image_max(method="datashader", datashader_reduction="max") + assert default_max < 0.1, f"default path should collapse sparse signal, got max={default_max}" + assert datashader_max == pytest.approx(1.0, abs=1e-6), ( + f"datashader should preserve sparse signal at 1.0, got {datashader_max}" + ) + + +class TestRenderImagesDatashader: + """Tests for the method='datashader' code path on render_images (issue #449).""" + + @pytest.fixture(autouse=True) + def _cleanup(self): + yield + plt.close("all") + + def test_method_invalid_type_raises(self, sdata_blobs: SpatialData): + with pytest.raises(TypeError, match="must be a string"): + sdata_blobs.pl.render_images("blobs_image", method=123) # type: ignore[arg-type] + + def test_method_invalid_value_raises(self, sdata_blobs: SpatialData): + with pytest.raises(ValueError, match="matplotlib.*datashader"): + sdata_blobs.pl.render_images("blobs_image", method="bogus") + + def test_datashader_reduction_invalid_type_raises(self, sdata_blobs: SpatialData): + with pytest.raises(TypeError, match="must be a string"): + sdata_blobs.pl.render_images("blobs_image", datashader_reduction=42) # type: ignore[arg-type] + + def test_datashader_reduction_invalid_value_raises(self, sdata_blobs: SpatialData): + with pytest.raises(ValueError, match="datashader_reduction"): + sdata_blobs.pl.render_images("blobs_image", method="datashader", datashader_reduction="bogus") + + def test_datashader_reduction_without_datashader_warns(self, sdata_blobs: SpatialData, caplog): + with logger_warns(caplog, logger, match="datashader_reduction"): + _, ax = plt.subplots() + sdata_blobs.pl.render_images("blobs_image", datashader_reduction="max").pl.show(ax=ax) + + def test_datashader_basic_renders_single_image(self): + arr = np.zeros((1, 512, 512), dtype=np.float32) + arr[0, 100, 100] = 1.0 + sdata = SpatialData(images={"img": Image2DModel.parse(arr, c_coords=["c1"])}) + _, ax = plt.subplots(figsize=(2, 2), dpi=50) + sdata.pl.render_images("img", method="datashader").pl.show(ax=ax) + assert len(ax.get_images()) == 1 + + def test_datashader_multichannel(self): + arr = np.zeros((3, 512, 512), dtype=np.float32) + arr[0, 100, 100] = 1.0 + arr[1, 200, 200] = 1.0 + arr[2, 300, 300] = 1.0 + sdata = SpatialData(images={"img": Image2DModel.parse(arr, c_coords=["c1", "c2", "c3"])}) + _, ax = plt.subplots(figsize=(2, 2), dpi=50) + sdata.pl.render_images("img", method="datashader", datashader_reduction="max").pl.show(ax=ax) + assert len(ax.get_images()) == 1 + + def test_datashader_rgb_passthrough(self): + arr = np.zeros((3, 256, 256), dtype=np.float32) + arr[0] = 0.8 + arr[1] = 0.2 + arr[2] = 0.1 + sdata = SpatialData(images={"img": Image2DModel.parse(arr, c_coords=["r", "g", "b"])}) + _, ax = plt.subplots(figsize=(2, 2), dpi=50) + sdata.pl.render_images("img", method="datashader").pl.show(ax=ax) + assert ax.get_images()[0].get_array().shape[-1] == 3 + + def test_datashader_with_transfunc(self): + arr = np.zeros((1, 512, 512), dtype=np.float32) + arr[0, 100, 100] = 1.0 + sdata = SpatialData(images={"img": Image2DModel.parse(arr, c_coords=["c1"])}) + _, ax = plt.subplots(figsize=(2, 2), dpi=50) + sdata.pl.render_images("img", method="datashader", datashader_reduction="max", transfunc=np.log1p).pl.show( + ax=ax + ) + assert len(ax.get_images()) == 1 + + def test_datashader_with_multiscale(self, sdata_blobs: SpatialData): + _, ax = plt.subplots() + sdata_blobs.pl.render_images("blobs_multiscale_image", method="datashader", datashader_reduction="max").pl.show( + ax=ax + ) + assert len(ax.get_images()) == 1 + + def test_method_matplotlib_matches_default(self): + rng = np.random.default_rng(0) + arr = rng.random((1, 64, 64), dtype=np.float32) + sdata = SpatialData(images={"img": Image2DModel.parse(arr, c_coords=["c1"])}) + + def _render_and_grab(**kwargs): + fig, ax = plt.subplots(figsize=(2, 2), dpi=50) + try: + sdata.pl.render_images("img", **kwargs).pl.show(ax=ax) + return np.asarray(ax.get_images()[0].get_array()) + finally: + plt.close(fig) + + np.testing.assert_array_equal(_render_and_grab(), _render_and_grab(method="matplotlib"))