Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 125 additions & 10 deletions src/spatialdata_plot/pl/_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

from copy import copy
from typing import Any, Literal

import dask.dataframe as dd
Expand All @@ -19,13 +20,15 @@
from spatialdata_plot._logging import logger
from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams
from spatialdata_plot.pl.utils import (
_DS_REDUCTION_FUNCS,
_ax_show_and_transform,
_convert_alpha_to_datashader_range,
_create_image_from_datashader_result,
_datashader_aggregate_with_function,
_datashader_map_aggregate_to_color,
_datshader_get_how_kw_for_spread,
_hex_no_alpha,
_make_continuous_mappable,
)

# ---------------------------------------------------------------------------
Expand All @@ -38,6 +41,11 @@
# missing (NaN) values. Must not collide with realistic user category names.
_DS_NAN_CATEGORY = "ds_nan"

# Private column name under which the outline color vector is attached to the
# datashader rasterizer element. Must not collide with a real user column;
# the leading/trailing dunders are deliberate.
_OUTLINE_INTERNAL_COL = "__sdp_outline_col__"

# ---------------------------------------------------------------------------
# Low-level helpers
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -344,8 +352,16 @@ def _render_ds_outlines(
factor: float,
x_min: float = 0.0,
y_min: float = 0.0,
outline_color_vector: Any | None = None,
outline_color_source_vector: pd.Series | None = None,
) -> None:
"""Aggregate, shade, and render shape outlines (outer and inner) with datashader."""
"""Aggregate, shade, and render shape outlines (outer and inner) with datashader.

When ``outline_color_vector`` is provided, the outer outline is colored per-shape
via ``ds.by`` (categorical) or a numeric reduction (continuous) instead of a
single literal color. The two-outline form is rejected at validation, so this
only affects the outer outline.
"""
ds_lw_factor = fig_params.fig.dpi / 72
assert len(render_params.outline_alpha) == 2 # noqa: S101

Expand All @@ -358,6 +374,24 @@ def _render_ds_outlines(
alpha = render_params.outline_alpha[idx]
if alpha <= 0:
continue
if idx == 0 and outline_color_vector is not None:
_render_ds_outline_by_column(
cvs=cvs,
transformed_element=transformed_element,
outline_color_vector=outline_color_vector,
outline_color_source_vector=outline_color_source_vector,
cmap_params=render_params.cmap_params,
ds_reduction=render_params.ds_reduction,
line_width=linewidth * ds_lw_factor,
alpha=alpha,
fig_params=fig_params,
ax=ax,
factor=factor,
x_min=x_min,
y_min=y_min,
zorder=render_params.zorder,
)
continue
agg_outline = cvs.line(
transformed_element,
geometry="geometry",
Expand All @@ -375,6 +409,95 @@ def _render_ds_outlines(
_ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder)


def _render_ds_outline_by_column(
cvs: Any,
transformed_element: Any,
outline_color_vector: Any | None,
outline_color_source_vector: pd.Series | None,
cmap_params: Any,
ds_reduction: _DsReduction | None,
line_width: float,
alpha: float,
fig_params: FigParams,
ax: matplotlib.axes.SubplotBase,
factor: float,
x_min: float,
y_min: float,
zorder: int,
) -> None:
"""Aggregate + shade an outline colored by an obs column via datashader.

Two-outline form is not supported for column-driven outline coloring,
so this only renders the outer outline.
"""
color_by_categorical = outline_color_source_vector is not None
na_color_hex = _hex_no_alpha(cmap_params.na_color.get_hex())

# Attach the outline vector under a private column name so a fill column with the
# same key never gets overwritten. Assign positionally (via a Series indexed to the
# element) — `.assign(col=series)` aligns by index, which silently inserts NaN when
# the element's index is non-contiguous (e.g. after an inner-join). The NaNs would
# then be lifted to the `ds_nan` sentinel and one polygon's outline would render as
# `na_color` instead of its real category.
transformed_element = transformed_element.copy()
if color_by_categorical:
cat = pd.Categorical(outline_color_source_vector)
attach_cat = _inject_ds_nan_sentinel(pd.Series(cat))
transformed_element[_OUTLINE_INTERNAL_COL] = pd.Categorical(
attach_cat.to_numpy(), categories=attach_cat.cat.categories
)
else:
transformed_element[_OUTLINE_INTERNAL_COL] = np.asarray(outline_color_vector)

if color_by_categorical:
agg_outline = cvs.line(
transformed_element,
geometry="geometry",
agg=ds.by(_OUTLINE_INTERNAL_COL, ds.count()),
line_width=line_width,
)
color_key = _build_datashader_color_key(
_coerce_categorical_source(transformed_element[_OUTLINE_INTERNAL_COL]),
outline_color_vector,
na_color_hex,
)
shaded = ds.tf.shade(
agg_outline,
color_key=color_key,
min_alpha=_convert_alpha_to_datashader_range(alpha),
how="linear",
)
else:
reduction_name = ds_reduction if ds_reduction is not None else "max"
try:
reduction_function = _DS_REDUCTION_FUNCS[reduction_name](column=_OUTLINE_INTERNAL_COL)
except KeyError as e:
raise ValueError(
f"Reduction '{reduction_name}' is not supported. Use one of: {', '.join(_DS_REDUCTION_FUNCS.keys())}."
) from e
agg_outline = cvs.line(
transformed_element,
geometry="geometry",
agg=reduction_function,
line_width=line_width,
)
# Apply the user-provided norm (vmin/vmax) the same way the fill path does so
# an explicit Normalize takes effect for the outline cmap.
norm = copy(cmap_params.norm)
agg_outline, color_span = _apply_ds_norm(agg_outline, norm)
shaded = ds.tf.shade(
agg_outline,
cmap=cmap_params.cmap,
span=color_span,
min_alpha=_convert_alpha_to_datashader_range(alpha),
how="linear",
)

shaded = _apply_user_alpha(shaded, alpha)
rgba, trans = _create_image_from_datashader_result(shaded, factor, ax, x_min, y_min)
_ax_show_and_transform(rgba, trans, ax, zorder=zorder)


def _build_ds_colorbar(
reduction_bounds: tuple[Any, Any] | None,
norm: Normalize,
Expand All @@ -388,12 +511,4 @@ def _build_ds_colorbar(
return None
vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin
vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
assert norm.vmin is not None
assert norm.vmax is not None
vmin = norm.vmin - 0.5
vmax = norm.vmin + 0.5
return ScalarMappable(
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
cmap=cmap,
)
return _make_continuous_mappable(vmin, vmax, cmap)
26 changes: 24 additions & 2 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,18 @@ def render_shapes(
Width of the border. If 2 values are given (tuple), 2 borders are shown with these widths (outer & inner).
If `outline_color` and/or `outline_alpha` are used to indicate that one/two outlines should be drawn, the
default outline widths 1.5 and 0.5 are used for outer/only and inner outline respectively.
outline_color : ColorLike | tuple[ColorLike], optional
outline_color : ColorLike | tuple[ColorLike] | str, optional
Color of the border. Can either be a named color ("red"), a hex representation ("#000000") or a list of
floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). If the hex representation includes alpha, e.g.
"#000000ff", and `outline_alpha` is not given, this value controls the opacity of the outline. If 2 values
are given (tuple), 2 borders are shown with these colors (outer & inner). If `outline_width` and/or
`outline_alpha` are used to indicate that one/two outlines should be drawn, the default outline colors
"#000000" and "#ffffff are used for outer/only and inner outline respectively.
A string that is not a recognized color is interpreted as a column key (in `obs` of the annotating table
or in the element's own dataframe), mirroring how ``color`` is parsed. The outline is then colored
per-shape using the same ``palette`` / ``cmap`` / ``na_color`` as the fill. When both ``color`` and
``outline_color`` resolve to columns, two stacked legends are drawn. Column-based outline coloring is
only supported for a single outline (not the 2-tuple form).
outline_alpha : float | int | tuple[float | int, float | int] | None, optional
Alpha value for the outline of shapes. Invisible by default, meaning outline_alpha=0.0 if both outline_color
and outline_width are not specified. Else, outlines are rendered with the alpha implied by outline_color, or
Expand Down Expand Up @@ -344,6 +349,8 @@ def render_shapes(
element=element,
color=param_values["color"],
col_for_color=param_values["col_for_color"],
col_for_outline_color=param_values["col_for_outline_color"],
outline_table_name=param_values["outline_table_name"],
groups=param_values["groups"],
scale=param_values["scale"],
outline_params=outline_params,
Expand Down Expand Up @@ -810,11 +817,14 @@ def render_labels(
fill_alpha : float | int | None, optional
Alpha value for the fill of the labels. By default, it is set to 0.4 or, if a color is given that implies
an alpha, that value is used for `fill_alpha`.
outline_color : ColorLike | None
outline_color : ColorLike | str | None
Color of the outline of the labels. Can either be a named color ("red"), a hex representation
("#000000") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). If ``None``,
the outline inherits from the ``color`` parameter when it is a literal color, or uses data-driven
per-label colors when ``color`` refers to a column.
A string that is not a recognized color is interpreted as a column key (in `obs` of the annotating
table), mirroring how ``color`` is parsed. The outline is then colored per-label using the same
``palette`` / ``cmap`` / ``na_color`` as the fill.
scale : str | None
Influences the resolution of the rendering. Possibilities for setting this parameter:
1) None (default). The image is rasterized to fit the canvas size. For multiscale images, the best scale
Expand Down Expand Up @@ -880,6 +890,8 @@ def render_labels(
element=element,
color=param_values["color"],
col_for_color=param_values["col_for_color"],
col_for_outline_color=param_values["col_for_outline_color"],
outline_table_name=param_values["outline_table_name"],
groups=param_values["groups"],
contour_px=param_values["contour_px"],
cmap_params=cmap_params,
Expand Down Expand Up @@ -1046,6 +1058,8 @@ def show(
na_in_legend: bool = True,
colorbar: bool = True,
colorbar_params: dict[str, object] | None = None,
legend_title: str | None = None,
outline_legend_title: str | None = None,
wspace: float | None = None,
hspace: float = 0.25,
ncols: int = 4,
Expand Down Expand Up @@ -1090,6 +1104,12 @@ def show(
colorbar_params : dict[str, object] | None
Global overrides passed to colorbars for all axes. Accepts the same keys as per-layer ``colorbar_params``
(e.g., ``loc``, ``width``, ``pad``, ``label``).
legend_title : str | None
Title for the fill categorical legend. When both fill and outline are colored by an obs column, the
two legends default to ``"fill"`` / ``"outline"`` to disambiguate; pass an explicit string to override
the fill title. Set to ``None`` (default) to keep the auto-title behavior.
outline_legend_title : str | None
Title for the outline categorical legend. Mirrors ``legend_title`` for the outline channel.
wspace : float | None
Horizontal spacing between panels (passed to :class:`matplotlib.gridspec.GridSpec`).
hspace : float, default 0.25
Expand Down Expand Up @@ -1326,6 +1346,8 @@ def show(
legend_fontoutline=legend_fontoutline,
na_in_legend=na_in_legend,
colorbar=colorbar,
legend_title=legend_title,
outline_legend_title=outline_legend_title,
)

def _draw_colorbar(
Expand Down
Loading
Loading