diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py index 1d6415ed..8b9415e4 100644 --- a/src/spatialdata_plot/pl/_datashader.py +++ b/src/spatialdata_plot/pl/_datashader.py @@ -227,6 +227,7 @@ def _ds_shade_continuous( na_color_hex: str, spread_px: int | None = None, ds_reduction: _DsReduction | None = None, + how: str = "linear", ) -> tuple[Any, Any | None, tuple[Any, Any] | None]: """Shade a continuous datashader aggregate, optionally applying spread and NaN coloring. @@ -255,6 +256,7 @@ def _ds_shade_continuous( min_alpha=_convert_alpha_to_datashader_range(alpha), span=color_span, clip=norm.clip, + how=how, ) shaded = _apply_user_alpha(shaded, alpha) @@ -278,6 +280,8 @@ def _ds_shade_categorical( color_vector: Any, alpha: float, spread_px: int | None = None, + how: str = "linear", + density: bool = False, ) -> Any: """Shade a categorical or no-color datashader aggregate.""" ds_cmap = None @@ -286,12 +290,20 @@ def _ds_shade_categorical( if isinstance(ds_cmap, str) and ds_cmap[0] == "#": ds_cmap = _hex_no_alpha(ds_cmap) + # The default min_alpha (~254) is a near-full-opacity floor — right for scatter + # plots, but it collapses the count-driven alpha range and makes categorical + # density read as a flat hue cloud. Drop the floor under density so per-pixel + # alpha can actually encode count. A small non-zero floor (~15%) keeps the + # sparse edges visible under density_how="linear" instead of vanishing. + min_alpha = 40.0 if density else _convert_alpha_to_datashader_range(alpha) + agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg shaded = _datashader_map_aggregate_to_color( agg_to_shade, cmap=ds_cmap, color_key=color_key, - min_alpha=_convert_alpha_to_datashader_range(alpha), + min_alpha=min_alpha, + how=how, ) return _apply_user_alpha(shaded, alpha) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 996161ae..37a80593 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -385,6 +385,8 @@ def render_points( colorbar: bool | str | None = "auto", colorbar_params: dict[str, object] | None = None, datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None, + density: bool = False, + density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear", transfunc: Callable[[float], float] | None = None, ) -> sd.SpatialData: """ @@ -455,6 +457,19 @@ def render_points( in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter. datashader_reduction : Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, optional Reduction method for datashader when coloring by continuous values. When ``None``, defaults to ``"sum"``. + density : bool, default False + Render the points as a 2-D count density via datashader instead of plotting individual markers. + When ``True``, ``method`` is forced to ``"datashader"`` (passing ``method="matplotlib"`` raises). + Density supports ``color=None`` (plain density) or a categorical ``color`` column (per-category + density via :func:`datashader.by`). A continuous ``color`` column or a literal color value + (e.g. ``"red"``) raises an error. Under ``density=True`` the following parameters are ignored + (with a warning if explicitly set): ``size``, ``transfunc``, ``norm.vmin/vmax``, and + ``datashader_reduction``. + density_how : Literal["linear", "log", "cbrt", "eq_hist"], default "linear" + How datashader maps aggregated counts to color intensity. ``"linear"`` (default) keeps the + colorbar axis as a count; ``"log"`` and ``"cbrt"`` compress dynamic range; ``"eq_hist"`` + equalizes the histogram (rank-based, surfaces the most structure but the colorbar axis is + no longer a count). Ignored when ``density=False``. transfunc : Callable[[float], float] | None, optional Optional transformation applied to the continuous color vector before normalization and colormap mapping. @@ -462,6 +477,18 @@ def render_points( ------- sd.SpatialData A copy of the SpatialData object with the rendering parameters stored in its plotting tree. + + Examples + -------- + Plain density of all transcripts: + + >>> sdata.pl.render_points("transcripts", density=True).pl.show() + + Per-gene density with a categorical palette: + + >>> sdata.pl.render_points( + ... "transcripts", color="gene", groups=["Gad1", "Slc17a7"], palette="tab20", density=True + ... ).pl.show() """ params_dict = _validate_points_render_params( self._sdata, @@ -480,6 +507,10 @@ def render_points( colorbar=colorbar, colorbar_params=colorbar_params, gene_symbols=gene_symbols, + density=density, + density_how=density_how, + transfunc=transfunc, + method=method, ) if method is not None: @@ -488,6 +519,9 @@ def render_points( if method not in ["matplotlib", "datashader"]: raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.") + if density and method is None: + method = "datashader" + sdata = self._copy() sdata = _verify_plotting_tree(sdata) n_steps = len(sdata.plotting_tree.keys()) @@ -515,6 +549,8 @@ def render_points( ds_reduction=param_values["ds_reduction"], colorbar=param_values["colorbar"], colorbar_params=param_values["colorbar_params"], + density=density, + density_how=density_how, ) n_steps += 1 diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index eec55481..eb39824d 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -170,6 +170,44 @@ def _warn_groups_ignored_continuous( ) +def _is_categorical_like_dtype(dtype: Any) -> bool: + return ( + isinstance(dtype, pd.CategoricalDtype) + or pd.api.types.is_object_dtype(dtype) + or pd.api.types.is_string_dtype(dtype) + ) + + +def _reject_continuous_color_under_density( + sdata_filt: sd.SpatialData, + element: str, + col_for_color: str | None, + color_source_vector: Any, + color_vector: Any, +) -> None: + """Raise before any materialization if density+continuous-color was requested. + + ``color_source_vector`` is only populated by ``_set_color_source_vec`` for the categorical + branch, so a non-None value is sufficient to accept the call. Otherwise we read the dtype + from the dask source (points element column) or the pre-computed color vector — neither + forces a ``.compute()``. + """ + if col_for_color is None or color_source_vector is not None: + return + points_columns = sdata_filt.points[element].columns + if col_for_color in points_columns: + dtype = sdata_filt.points[element][col_for_color].dtype + else: + dtype = getattr(color_vector, "dtype", None) + if dtype is None or _is_categorical_like_dtype(dtype): + return + raise ValueError( + f"density=True is only supported with no color or a categorical color column; " + f"got continuous column {col_for_color!r}. To color a density plot by a continuous " + f"variable, set density=False and use method='datashader' with datashader_reduction=." + ) + + def _warn_missing_groups( groups: str | list[str], color_source_vector: pd.Categorical, @@ -950,7 +988,10 @@ def _render_points( method = render_params.method - if method is None: + if render_params.density: + method = "datashader" + _reject_continuous_color_under_density(sdata_filt, element, col_for_color, color_source_vector, color_vector) + elif method is None: method = "datashader" if n_points > 10000 else "matplotlib" _default_reduction: _DsReduction = "sum" @@ -960,7 +1001,11 @@ def _render_points( # NOTE: s in matplotlib is in units of points**2 # use dpi/100 as a factor for cases where dpi!=100 - px = int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100))) + # Under density, spreading would smear the count signal across pixels and + # distort apparent density at sparse edges, so disable it unconditionally. + px: int | None = ( + None if render_params.density else int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100))) + ) # Apply transformations and materialize to pandas immediately so # datashader aggregates without dask scheduler overhead. See #379. @@ -1045,14 +1090,22 @@ def _render_points( ): color_vector = np.asarray([_hex_no_alpha(c) for c in color_vector]) + shade_how = render_params.density_how if render_params.density else "linear" + # Plain density (no color column) must use the user-facing cmap as a sequential + # gradient over counts; the categorical path collapses to a single color and only + # modulates alpha, which renders as a flat hue regardless of density. + plain_density = render_params.density and col_for_color is None + nan_shaded = None - if color_by_categorical or col_for_color is None: + if not plain_density and (color_by_categorical or col_for_color is None): shaded = _ds_shade_categorical( agg, color_key, color_vector, render_params.alpha, spread_px=px, + how=shade_how, + density=render_params.density, ) else: shaded, nan_shaded, reduction_bounds = _ds_shade_continuous( @@ -1066,6 +1119,7 @@ def _render_points( na_color_hex, spread_px=px, ds_reduction=render_params.ds_reduction, + how=shade_how, ) _render_ds_image( diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index e7232ec7..a0414d6b 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -268,6 +268,8 @@ class PointsRenderParams: ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None colorbar: bool | str | None = "auto" colorbar_params: dict[str, object] | None = None + density: bool = False + density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear" @dataclass diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index c6ae3350..a64b0fd2 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2,8 +2,9 @@ import math import os +import warnings from collections import OrderedDict -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from copy import copy from functools import partial from pathlib import Path @@ -2820,7 +2821,17 @@ def _validate_points_render_params( colorbar: bool | str | None, colorbar_params: dict[str, object] | None, gene_symbols: str | None = None, + density: bool = False, + density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear", + transfunc: Callable[[float], float] | None = None, + method: str | None = None, ) -> dict[str, dict[str, Any]]: + if not isinstance(density, bool): + raise TypeError("Parameter 'density' must be a bool.") + allowed_how = ("linear", "log", "cbrt", "eq_hist") + if density_how not in allowed_how: + raise ValueError(f"Parameter 'density_how' must be one of {allowed_how}; got {density_how!r}.") + param_dict: dict[str, Any] = { "sdata": sdata, "element": element, @@ -2840,6 +2851,47 @@ def _validate_points_render_params( } param_dict = _type_check_params(param_dict, "points") + if density: + if method == "matplotlib": + raise ValueError( + "density=True requires the datashader backend; got method='matplotlib'. " + "Either drop method= or set method='datashader'." + ) + # Literal color (resolved into param_dict["color"] as a Color instance, with + # col_for_color set to None) is ambiguous with density: it could mean a + # single-hue cmap or a one-entry palette. Force the user to choose. + if param_dict["color"] is not None and param_dict["col_for_color"] is None: + raise ValueError( + "density=True with a literal color is ambiguous. Pass cmap= to recolor the " + "density, or palette= to assign a categorical color, but not color=." + ) + # Warn-and-ignore: these parameters do not interact meaningfully with a + # count-based density and are silently dropped to keep the API consistent. + if size != 1.0: + warnings.warn( + "size is ignored when density=True; spreading would distort the count signal.", + UserWarning, + stacklevel=3, + ) + if transfunc is not None: + warnings.warn( + "transfunc is ignored when density=True (no continuous color vector to transform).", + UserWarning, + stacklevel=3, + ) + if isinstance(norm, Normalize) and (norm.vmin is not None or norm.vmax is not None): + warnings.warn( + "norm.vmin/vmax are ignored when density=True; use density_how= to control intensity mapping.", + UserWarning, + stacklevel=3, + ) + if ds_reduction is not None: + warnings.warn( + "datashader_reduction is ignored when density=True; counts are forced.", + UserWarning, + stacklevel=3, + ) + element_params: dict[str, dict[str, Any]] = {} for el in param_dict["element"]: # ensure that the element exists in the SpatialData object @@ -3715,11 +3767,17 @@ def _datashader_map_aggregate_to_color( min_alpha: float = 40, span: None | list[float] = None, clip: bool = True, + how: str = "linear", ) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]: """ds.tf.shade() part, ensuring correct clipping behavior. If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results. This ensures the correct clipping behavior, because else datashader would always automatically clip. + + ``how`` controls the count-to-color mapping passed to :func:`datashader.transfer_functions.shade` + (``"linear"`` by default; ``"log"``/``"cbrt"``/``"eq_hist"`` compress dynamic range). The split-shade + branch used for ``norm.clip=False`` always uses ``"linear"`` since per-segment shading would otherwise + interact poorly with rank-based mappings. """ if not clip and isinstance(cmap, Colormap) and span is not None: # in case we use datashader together with a Normalize object where clip=False @@ -3768,7 +3826,7 @@ def _datashader_map_aggregate_to_color( color_key=color_key, min_alpha=min_alpha, span=span, - how="linear", + how=how, ) return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span) diff --git a/tests/_images/Points_density_categorical.png b/tests/_images/Points_density_categorical.png new file mode 100644 index 00000000..76ce49e2 Binary files /dev/null and b/tests/_images/Points_density_categorical.png differ diff --git a/tests/_images/Points_density_how_eq_hist.png b/tests/_images/Points_density_how_eq_hist.png new file mode 100644 index 00000000..0807af6d Binary files /dev/null and b/tests/_images/Points_density_how_eq_hist.png differ diff --git a/tests/_images/Points_density_plain.png b/tests/_images/Points_density_plain.png new file mode 100644 index 00000000..efd526af Binary files /dev/null and b/tests/_images/Points_density_plain.png differ diff --git a/tests/conftest.py b/tests/conftest.py index b70dc567..36b255f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,6 +80,34 @@ def sdata_blobs() -> SpatialData: return blobs() +@pytest.fixture() +def sdata_dense_points() -> SpatialData: + """Dense (~20k) multi-cluster points dataset for density-rendering visual tests. + + The blobs fixture is too sparse (~200 points across 500x500) for density to render + meaningfully without spreading; this fixture provides a Gaussian-cluster cloud with + a categorical ``gene`` column so the per-category density branch is exercised too. + """ + rng = get_standard_RNG() + n_per_cluster = 20000 + centers = [(120, 120), (380, 150), (250, 380)] + genes = ["gene_a", "gene_b", "gene_c"] + xs, ys, gs = [], [], [] + for (cx, cy), gene in zip(centers, genes, strict=True): + xs.append(rng.normal(loc=cx, scale=18, size=n_per_cluster)) + ys.append(rng.normal(loc=cy, scale=18, size=n_per_cluster)) + gs.extend([gene] * n_per_cluster) + df = pd.DataFrame( + { + "x": np.concatenate(xs).clip(0, 500), + "y": np.concatenate(ys).clip(0, 500), + "gene": pd.Categorical(gs, categories=genes), + } + ) + points = PointsModel.parse(df) + return SpatialData(points={"dense_points": points}) + + @pytest.fixture() def sdata_blobs_str() -> SpatialData: return blobs(n_channels=5, c_coords=["c1", "c2", "c3", "c4", "c5"]) diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 67dd71c9..7c0a1286 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -172,6 +172,15 @@ def test_plot_datashader_continuous_color(self, sdata_blobs: SpatialData): method="datashader", ).pl.show() + def test_plot_density_plain(self, sdata_dense_points: SpatialData): + sdata_dense_points.pl.render_points("dense_points", density=True).pl.show() + + def test_plot_density_categorical(self, sdata_dense_points: SpatialData): + sdata_dense_points.pl.render_points("dense_points", color="gene", density=True).pl.show() + + def test_plot_density_how_eq_hist(self, sdata_dense_points: SpatialData): + sdata_dense_points.pl.render_points("dense_points", density=True, density_how="eq_hist").pl.show() + def test_plot_points_categorical_color_column_matplotlib(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_points("blobs_points", color="genes", method="matplotlib").pl.show() @@ -1178,3 +1187,44 @@ def test_datashader_canvas_from_empty_dataframe_does_not_crash(): assert plot_width == 0 and plot_height == 0 finally: plt.close(fig) + + +# --------------------------------------------------------------------------- +# Density mode (unit tests; visual tests live in the TestPoints class above) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"color": "instance_id"}, "density=True is only supported"), + ({"color": "red"}, "literal color is ambiguous"), + ({"method": "matplotlib"}, "datashader backend"), + ({"density_how": "magic"}, "density_how"), + ], +) +def test_density_rejects_invalid_combinations(sdata_blobs: SpatialData, kwargs, match): + with pytest.raises(ValueError, match=match): + sdata_blobs.pl.render_points("blobs_points", density=True, **kwargs).pl.show() + plt.close("all") + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"size": 5.0}, "size is ignored"), + ({"transfunc": lambda x: x}, "transfunc is ignored"), + ({"norm": Normalize(vmin=0, vmax=1)}, "norm.vmin/vmax are ignored"), + ({"datashader_reduction": "mean"}, "datashader_reduction is ignored"), + ], +) +def test_density_warns_on_ignored_params(sdata_blobs: SpatialData, kwargs, match): + with pytest.warns(UserWarning, match=match): + sdata_blobs.pl.render_points("blobs_points", density=True, **kwargs) + + +def test_density_defaults_silent_and_force_datashader(sdata_blobs: SpatialData, recwarn): + out = sdata_blobs.pl.render_points("blobs_points", density=True) + last = list(out.plotting_tree.values())[-1] + assert (last.density, last.density_how, last.method) == (True, "linear", "datashader") + assert not any("ignored when density=True" in str(w.message) for w in recwarn.list)