diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py index 8b9415e4..3b6d9bcf 100644 --- a/src/spatialdata_plot/pl/_datashader.py +++ b/src/spatialdata_plot/pl/_datashader.py @@ -5,6 +5,7 @@ from __future__ import annotations +from copy import copy from typing import Any, Literal import dask.dataframe as dd @@ -19,6 +20,7 @@ from spatialdata_plot._logging import logger from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams from spatialdata_plot.pl.utils import ( + _DS_REDUCTION_FUNCS, _ax_show_and_transform, _convert_alpha_to_datashader_range, _create_image_from_datashader_result, @@ -26,6 +28,7 @@ _datashader_map_aggregate_to_color, _datshader_get_how_kw_for_spread, _hex_no_alpha, + _make_continuous_mappable, ) # --------------------------------------------------------------------------- @@ -38,6 +41,11 @@ # missing (NaN) values. Must not collide with realistic user category names. _DS_NAN_CATEGORY = "ds_nan" +# Private column name under which the outline color vector is attached to the +# datashader rasterizer element. Must not collide with a real user column; +# the leading/trailing dunders are deliberate. +_OUTLINE_INTERNAL_COL = "__sdp_outline_col__" + # --------------------------------------------------------------------------- # Low-level helpers # --------------------------------------------------------------------------- @@ -344,8 +352,16 @@ def _render_ds_outlines( factor: float, x_min: float = 0.0, y_min: float = 0.0, + outline_color_vector: Any | None = None, + outline_color_source_vector: pd.Series | None = None, ) -> None: - """Aggregate, shade, and render shape outlines (outer and inner) with datashader.""" + """Aggregate, shade, and render shape outlines (outer and inner) with datashader. + + When ``outline_color_vector`` is provided, the outer outline is colored per-shape + via ``ds.by`` (categorical) or a numeric reduction (continuous) instead of a + single literal color. The two-outline form is rejected at validation, so this + only affects the outer outline. + """ ds_lw_factor = fig_params.fig.dpi / 72 assert len(render_params.outline_alpha) == 2 # noqa: S101 @@ -358,6 +374,24 @@ def _render_ds_outlines( alpha = render_params.outline_alpha[idx] if alpha <= 0: continue + if idx == 0 and outline_color_vector is not None: + _render_ds_outline_by_column( + cvs=cvs, + transformed_element=transformed_element, + outline_color_vector=outline_color_vector, + outline_color_source_vector=outline_color_source_vector, + cmap_params=render_params.cmap_params, + ds_reduction=render_params.ds_reduction, + line_width=linewidth * ds_lw_factor, + alpha=alpha, + fig_params=fig_params, + ax=ax, + factor=factor, + x_min=x_min, + y_min=y_min, + zorder=render_params.zorder, + ) + continue agg_outline = cvs.line( transformed_element, geometry="geometry", @@ -375,6 +409,95 @@ def _render_ds_outlines( _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder) +def _render_ds_outline_by_column( + cvs: Any, + transformed_element: Any, + outline_color_vector: Any | None, + outline_color_source_vector: pd.Series | None, + cmap_params: Any, + ds_reduction: _DsReduction | None, + line_width: float, + alpha: float, + fig_params: FigParams, + ax: matplotlib.axes.SubplotBase, + factor: float, + x_min: float, + y_min: float, + zorder: int, +) -> None: + """Aggregate + shade an outline colored by an obs column via datashader. + + Two-outline form is not supported for column-driven outline coloring, + so this only renders the outer outline. + """ + color_by_categorical = outline_color_source_vector is not None + na_color_hex = _hex_no_alpha(cmap_params.na_color.get_hex()) + + # Attach the outline vector under a private column name so a fill column with the + # same key never gets overwritten. Assign positionally (via a Series indexed to the + # element) — `.assign(col=series)` aligns by index, which silently inserts NaN when + # the element's index is non-contiguous (e.g. after an inner-join). The NaNs would + # then be lifted to the `ds_nan` sentinel and one polygon's outline would render as + # `na_color` instead of its real category. + transformed_element = transformed_element.copy() + if color_by_categorical: + cat = pd.Categorical(outline_color_source_vector) + attach_cat = _inject_ds_nan_sentinel(pd.Series(cat)) + transformed_element[_OUTLINE_INTERNAL_COL] = pd.Categorical( + attach_cat.to_numpy(), categories=attach_cat.cat.categories + ) + else: + transformed_element[_OUTLINE_INTERNAL_COL] = np.asarray(outline_color_vector) + + if color_by_categorical: + agg_outline = cvs.line( + transformed_element, + geometry="geometry", + agg=ds.by(_OUTLINE_INTERNAL_COL, ds.count()), + line_width=line_width, + ) + color_key = _build_datashader_color_key( + _coerce_categorical_source(transformed_element[_OUTLINE_INTERNAL_COL]), + outline_color_vector, + na_color_hex, + ) + shaded = ds.tf.shade( + agg_outline, + color_key=color_key, + min_alpha=_convert_alpha_to_datashader_range(alpha), + how="linear", + ) + else: + reduction_name = ds_reduction if ds_reduction is not None else "max" + try: + reduction_function = _DS_REDUCTION_FUNCS[reduction_name](column=_OUTLINE_INTERNAL_COL) + except KeyError as e: + raise ValueError( + f"Reduction '{reduction_name}' is not supported. Use one of: {', '.join(_DS_REDUCTION_FUNCS.keys())}." + ) from e + agg_outline = cvs.line( + transformed_element, + geometry="geometry", + agg=reduction_function, + line_width=line_width, + ) + # Apply the user-provided norm (vmin/vmax) the same way the fill path does so + # an explicit Normalize takes effect for the outline cmap. + norm = copy(cmap_params.norm) + agg_outline, color_span = _apply_ds_norm(agg_outline, norm) + shaded = ds.tf.shade( + agg_outline, + cmap=cmap_params.cmap, + span=color_span, + min_alpha=_convert_alpha_to_datashader_range(alpha), + how="linear", + ) + + shaded = _apply_user_alpha(shaded, alpha) + rgba, trans = _create_image_from_datashader_result(shaded, factor, ax, x_min, y_min) + _ax_show_and_transform(rgba, trans, ax, zorder=zorder) + + def _build_ds_colorbar( reduction_bounds: tuple[Any, Any] | None, norm: Normalize, @@ -388,12 +511,4 @@ def _build_ds_colorbar( return None vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax - if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: - assert norm.vmin is not None - assert norm.vmax is not None - vmin = norm.vmin - 0.5 - vmax = norm.vmin + 0.5 - return ScalarMappable( - norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), - cmap=cmap, - ) + return _make_continuous_mappable(vmin, vmax, cmap) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 37a80593..4d32ff8b 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -241,13 +241,18 @@ def render_shapes( Width of the border. If 2 values are given (tuple), 2 borders are shown with these widths (outer & inner). If `outline_color` and/or `outline_alpha` are used to indicate that one/two outlines should be drawn, the default outline widths 1.5 and 0.5 are used for outer/only and inner outline respectively. - outline_color : ColorLike | tuple[ColorLike], optional + outline_color : ColorLike | tuple[ColorLike] | str, optional Color of the border. Can either be a named color ("red"), a hex representation ("#000000") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). If the hex representation includes alpha, e.g. "#000000ff", and `outline_alpha` is not given, this value controls the opacity of the outline. If 2 values are given (tuple), 2 borders are shown with these colors (outer & inner). If `outline_width` and/or `outline_alpha` are used to indicate that one/two outlines should be drawn, the default outline colors "#000000" and "#ffffff are used for outer/only and inner outline respectively. + A string that is not a recognized color is interpreted as a column key (in `obs` of the annotating table + or in the element's own dataframe), mirroring how ``color`` is parsed. The outline is then colored + per-shape using the same ``palette`` / ``cmap`` / ``na_color`` as the fill. When both ``color`` and + ``outline_color`` resolve to columns, two stacked legends are drawn. Column-based outline coloring is + only supported for a single outline (not the 2-tuple form). outline_alpha : float | int | tuple[float | int, float | int] | None, optional Alpha value for the outline of shapes. Invisible by default, meaning outline_alpha=0.0 if both outline_color and outline_width are not specified. Else, outlines are rendered with the alpha implied by outline_color, or @@ -344,6 +349,8 @@ def render_shapes( element=element, color=param_values["color"], col_for_color=param_values["col_for_color"], + col_for_outline_color=param_values["col_for_outline_color"], + outline_table_name=param_values["outline_table_name"], groups=param_values["groups"], scale=param_values["scale"], outline_params=outline_params, @@ -810,11 +817,14 @@ def render_labels( fill_alpha : float | int | None, optional Alpha value for the fill of the labels. By default, it is set to 0.4 or, if a color is given that implies an alpha, that value is used for `fill_alpha`. - outline_color : ColorLike | None + outline_color : ColorLike | str | None Color of the outline of the labels. Can either be a named color ("red"), a hex representation ("#000000") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). If ``None``, the outline inherits from the ``color`` parameter when it is a literal color, or uses data-driven per-label colors when ``color`` refers to a column. + A string that is not a recognized color is interpreted as a column key (in `obs` of the annotating + table), mirroring how ``color`` is parsed. The outline is then colored per-label using the same + ``palette`` / ``cmap`` / ``na_color`` as the fill. scale : str | None Influences the resolution of the rendering. Possibilities for setting this parameter: 1) None (default). The image is rasterized to fit the canvas size. For multiscale images, the best scale @@ -880,6 +890,8 @@ def render_labels( element=element, color=param_values["color"], col_for_color=param_values["col_for_color"], + col_for_outline_color=param_values["col_for_outline_color"], + outline_table_name=param_values["outline_table_name"], groups=param_values["groups"], contour_px=param_values["contour_px"], cmap_params=cmap_params, @@ -1046,6 +1058,8 @@ def show( na_in_legend: bool = True, colorbar: bool = True, colorbar_params: dict[str, object] | None = None, + legend_title: str | None = None, + outline_legend_title: str | None = None, wspace: float | None = None, hspace: float = 0.25, ncols: int = 4, @@ -1090,6 +1104,12 @@ def show( colorbar_params : dict[str, object] | None Global overrides passed to colorbars for all axes. Accepts the same keys as per-layer ``colorbar_params`` (e.g., ``loc``, ``width``, ``pad``, ``label``). + legend_title : str | None + Title for the fill categorical legend. When both fill and outline are colored by an obs column, the + two legends default to ``"fill"`` / ``"outline"`` to disambiguate; pass an explicit string to override + the fill title. Set to ``None`` (default) to keep the auto-title behavior. + outline_legend_title : str | None + Title for the outline categorical legend. Mirrors ``legend_title`` for the outline channel. wspace : float | None Horizontal spacing between panels (passed to :class:`matplotlib.gridspec.GridSpec`). hspace : float, default 0.25 @@ -1326,6 +1346,8 @@ def show( legend_fontoutline=legend_fontoutline, na_in_legend=na_in_legend, colorbar=colorbar, + legend_title=legend_title, + outline_legend_title=outline_legend_title, ) def _draw_colorbar( diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 1c9b97eb..b4db9ed5 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -4,7 +4,7 @@ from collections import abc from collections.abc import Sequence from copy import copy -from typing import Any, Literal +from typing import Any, Literal, cast import dask import dask.dataframe as dd @@ -24,7 +24,7 @@ from matplotlib.colors import ListedColormap, Normalize from scanpy._settings import settings as sc_settings from scanpy.plotting._tools.scatterplots import _add_categorical_legend -from spatialdata import get_extent, get_values, join_spatialelement_table +from spatialdata import get_extent, get_values from spatialdata._core.query.relational_query import match_table_to_element from spatialdata.models import PointsModel, ShapesModel, get_table_keys from spatialdata.transformations import set_transformation @@ -57,8 +57,11 @@ ShapesRenderParams, ) from spatialdata_plot.pl.utils import ( + _align_outline_vector_to_length, + _apply_mask_to_outline_vectors, _ax_show_and_transform, _check_obs_var_shadow, + _color_vector_to_rgba, _convert_shapes, _datashader_canvas_from_dataframe, _decorate_axs, @@ -67,6 +70,8 @@ _get_extent_and_range_for_datashader_canvas, _get_linear_colormap, _hex_no_alpha, + _join_table_for_element, + _make_continuous_mappable, _map_color_seg, _maybe_set_colors, _mpl_ax_contains_elements, @@ -334,12 +339,21 @@ def _add_legend_and_colorbar( colorbar: bool | str | None, colorbar_params: dict[str, object] | None, colorbar_requests: list[ColorbarSpec] | None, + outline_col_for_color: str | None = None, + outline_color_source_vector: pd.Series | None = None, + outline_color_vector: Any | None = None, + outline_cmap_params: CmapParams | None = None, ) -> None: """Add legend and colorbar decorations if the color vector warrants them.""" - if not _want_decorations(color_vector, na_color): + fill_has_decorations = _want_decorations(color_vector, na_color) and col_for_color is not None + outline_has_decorations = outline_col_for_color is not None and ( + outline_color_source_vector is not None or outline_color_vector is not None + ) + + if not fill_has_decorations and not outline_has_decorations: return - if palette is None: + if palette is None and fill_has_decorations: palette = _make_palette(color_source_vector, color_vector) if color_source_vector is not None and hasattr(color_source_vector, "remove_unused_categories"): @@ -351,29 +365,192 @@ def _add_legend_and_colorbar( is_continuous=col_for_color is not None and color_source_vector is None, ) - _decorate_axs( - ax=ax, - cax=cax, - fig_params=fig_params, - adata=adata, - value_to_plot=col_for_color, - color_source_vector=color_source_vector, - color_vector=color_vector, - palette=palette, - alpha=alpha, - na_color=na_color, - legend_fontsize=legend_params.legend_fontsize, - legend_fontweight=legend_params.legend_fontweight, - legend_loc=legend_params.legend_loc, - legend_fontoutline=legend_params.legend_fontoutline, - na_in_legend=legend_params.na_in_legend, - colorbar=wants_colorbar and legend_params.colorbar, - colorbar_params=colorbar_params, - colorbar_requests=colorbar_requests, - colorbar_label=_resolve_colorbar_label( - colorbar_params, - col_for_color if isinstance(col_for_color, str) else None, - ), + if fill_has_decorations: + # Auto-title the fill legend only when an outline legend will also be drawn. + outline_legend_will_render = outline_has_decorations and outline_color_source_vector is not None + if legend_params.legend_title is not None: + fill_title: str | None = legend_params.legend_title or None + elif outline_legend_will_render and color_source_vector is not None: + fill_title = "fill" + else: + fill_title = None + _decorate_axs( + ax=ax, + cax=cax, + fig_params=fig_params, + adata=adata, + value_to_plot=col_for_color, + color_source_vector=color_source_vector, + color_vector=color_vector, + palette=palette, + alpha=alpha, + na_color=na_color, + legend_fontsize=legend_params.legend_fontsize, + legend_fontweight=legend_params.legend_fontweight, + legend_loc=legend_params.legend_loc, + legend_fontoutline=legend_params.legend_fontoutline, + na_in_legend=legend_params.na_in_legend, + colorbar=wants_colorbar and legend_params.colorbar, + colorbar_params=colorbar_params, + colorbar_requests=colorbar_requests, + colorbar_label=_resolve_colorbar_label( + colorbar_params, + col_for_color if isinstance(col_for_color, str) else None, + ), + legend_title=fill_title, + ) + + if outline_has_decorations and outline_cmap_params is not None: + _decorate_outline( + ax=ax, + fig_params=fig_params, + outline_col=cast(str, outline_col_for_color), + outline_color_source_vector=outline_color_source_vector, + outline_color_vector=outline_color_vector, + cmap_params=outline_cmap_params, + colorbar_params=colorbar_params, + colorbar_requests=colorbar_requests, + legend_params=legend_params, + fill_has_legend=fill_has_decorations and color_source_vector is not None, + alpha=alpha, + ) + + +def _decorate_outline( + ax: matplotlib.axes.SubplotBase, + fig_params: FigParams, + outline_col: str, + outline_color_source_vector: pd.Series | None, + outline_color_vector: Any, + cmap_params: CmapParams, + colorbar_params: dict[str, object] | None, + colorbar_requests: list[ColorbarSpec] | None, + legend_params: LegendParams, + fill_has_legend: bool, + alpha: float, +) -> None: + """Dispatch a categorical legend or continuous colorbar for an outline column.""" + if outline_color_source_vector is not None: + _add_outline_legend( + ax=ax, + fig_params=fig_params, + outline_col=outline_col, + outline_color_source_vector=outline_color_source_vector, + outline_color_vector=outline_color_vector, + fill_has_legend=fill_has_legend, + legend_params=legend_params, + ) + elif colorbar_requests is not None and legend_params.colorbar and outline_color_vector is not None: + _append_outline_colorbar( + colorbar_requests=colorbar_requests, + ax=ax, + outline_color_vector=outline_color_vector, + cmap_params=cmap_params, + colorbar_params=colorbar_params, + outline_col=outline_col, + alpha=alpha, + ) + + +def _append_outline_colorbar( + colorbar_requests: list[ColorbarSpec], + ax: matplotlib.axes.SubplotBase, + outline_color_vector: Any, + cmap_params: CmapParams, + colorbar_params: dict[str, object] | None, + outline_col: str, + alpha: float, +) -> None: + """Append a `ColorbarSpec` for a continuous outline column. + + No-op when ``outline_color_vector`` has no finite values. Honors user-supplied + `vmin`/`vmax` on ``cmap_params.norm``; falls back to data range. Mirrors the + `vmin == vmax` ±0.5 expansion used by the fill colorbar. + """ + arr = pd.to_numeric(pd.Series(np.asarray(outline_color_vector)), errors="coerce").to_numpy() + finite = np.isfinite(arr) + if not finite.any(): + return + norm = cmap_params.norm + vmin = norm.vmin if norm.vmin is not None else float(np.nanmin(arr[finite])) + vmax = norm.vmax if norm.vmax is not None else float(np.nanmax(arr[finite])) + colorbar_requests.append( + ColorbarSpec( + ax=ax, + mappable=_make_continuous_mappable(vmin, vmax, cmap_params.cmap), + params=colorbar_params, + label=outline_col, + alpha=alpha, + ) + ) + + +def _add_outline_legend( + ax: matplotlib.axes.SubplotBase, + fig_params: FigParams, + outline_col: str, + outline_color_source_vector: pd.Series, + outline_color_vector: Any, + fill_has_legend: bool, + legend_params: LegendParams, +) -> None: + """Add a second legend for outline-by-column, auto-positioned below the fill legend. + + Uses the rendered fill legend's window extent to anchor the outline legend just + below it in axes-fraction coordinates. Falls back to anchoring at the bottom-right + of the axes when the measurement is unavailable. + """ + cats = outline_color_source_vector.remove_unused_categories().unique() + cats = cats[~cats.isnull()] + mapping_df = pd.DataFrame( + {"cats": outline_color_source_vector.remove_unused_categories(), "color": outline_color_vector} + ) + color_map = mapping_df.drop_duplicates("cats").set_index("cats")["color"].to_dict() + + outline_handles = [ax.scatter([], [], c=color_map[c], label=str(c)) for c in cats] + + anchor_y: float | None = None + if fill_has_legend: + fill_legend = ax.get_legend() + if fill_legend is not None: + # Reposition the fill legend to the top of the right margin so the two + # stack contiguously. Scanpy's default `bbox_to_anchor=(1, 0.5)` centers + # the fill legend vertically, which looks unbalanced once a second legend + # is added below. + fill_legend.set_bbox_to_anchor((1.02, 1.0)) + if hasattr(fill_legend, "set_loc"): + fill_legend.set_loc("upper left") + ax.add_artist(fill_legend) # keep fill legend on the axes + # Force layout so get_window_extent returns the real (not stale) bbox. + fig_params.fig.canvas.draw() + bbox_axes = fill_legend.get_window_extent().transformed(ax.transAxes.inverted()) + anchor_y = float(bbox_axes.y0) - 0.02 + + # If the measured extent is degenerate (no fill legend, or its bbox sits at/below + # the axes' bottom edge), fall back to an opposite-anchor layout that still avoids + # overlap regardless of legend height. + if anchor_y is not None and anchor_y > 0: + loc = "upper left" + anchor = (1.02, anchor_y) + else: + loc = "lower left" if fill_has_legend else "center left" + anchor = (1.02, 0.0) if fill_has_legend else (1.0, 0.5) + + # Auto-title only when a fill legend is also present (so the user can tell which is which). + # User-provided `outline_legend_title` always wins; pass empty string to suppress. + if legend_params.outline_legend_title is not None: + title = legend_params.outline_legend_title or None + else: + title = "outline" if fill_has_legend else None + + ax.legend( + handles=outline_handles, + title=title, + frameon=False, + loc=loc, + bbox_to_anchor=anchor, + fontsize=legend_params.legend_fontsize, + ncol=(1 if len(outline_handles) <= 14 else 2 if len(outline_handles) <= 30 else 3), ) @@ -424,36 +601,8 @@ def _render_shapes( shapes = sdata_filt[element] else: _check_instance_ids_overlap(sdata_filt, table_name, element, sdata_filt[element].index) - - # Workaround for upstream spatialdata bug (scverse/spatialdata#1099): - # join_spatialelement_table calls table.obs.reset_index() which fails when - # the obs index name matches an existing column (e.g. "EntityID" in Merfish - # data). When that collision is present, the obs index may also be a - # non-RangeIndex of int dtype, which AnnData's `_normalize_index` rejects - # when the join indexes back into the table. Temporarily swap to a clean - # RangeIndex / drop the conflicting name; restore on exit. - _obs = sdata[table_name].obs - _saved_index_name = _obs.index.name - _saved_index: pd.Index | None = None - _name_collides = _saved_index_name is not None and _saved_index_name in _obs.columns - if _name_collides and not isinstance(_obs.index, pd.RangeIndex): - _saved_index = _obs.index - _obs.index = pd.RangeIndex(len(_obs)) - elif _name_collides: - _obs.index.name = None - - try: - element_dict, joined_table = join_spatialelement_table( - sdata, spatial_element_names=element, table_name=table_name, how="inner" - ) - finally: - if _saved_index is not None: - _obs.index = _saved_index - _obs.index.name = _saved_index_name - sdata_filt[element] = shapes = element_dict[element] - joined_table.uns["spatialdata_attrs"]["region"] = ( - joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist() - ) + joined_element, joined_table = _join_table_for_element(sdata, element, table_name) + sdata_filt[element] = shapes = joined_element sdata_filt[table_name] = table = joined_table shapes = sdata_filt[element] @@ -479,6 +628,54 @@ def _render_shapes( values_are_categorical = color_source_vector is not None + col_for_outline_color = render_params.col_for_outline_color + outline_table_name = render_params.outline_table_name + outline_color_source_vector: pd.Series | None = None + outline_color_vector: Any = None + if col_for_outline_color is not None: + # When the outline column lives in a table that hasn't been joined yet + # (no fill table, or a different table than fill's), inner-join it onto + # the element so the lookup is aligned and the element row count matches + # the outline vector length. + if outline_table_name is not None and outline_table_name != table_name: + joined_outline_element, joined_outline_table = _join_table_for_element( + sdata_filt, element, outline_table_name + ) + sdata_filt[outline_table_name] = joined_outline_table + # If no fill join happened, replace the element with the outline-joined version + # so the per-shape outline vector length matches the rendered shapes. + if table_name is None: + sdata_filt[element] = shapes = joined_outline_element + outline_color_source_vector, outline_color_vector, _ = _set_color_source_vec( + sdata=sdata_filt, + element=sdata_filt[element], + element_name=element, + value_to_plot=col_for_outline_color, + groups=None, + palette=render_params.palette, + na_color=render_params.cmap_params.na_color, + cmap_params=render_params.cmap_params, + table_name=outline_table_name, + table_layer=table_layer, + coordinate_system=coordinate_system, + ) + # Cross-table case: if fill and outline tables differ and the outline table does + # not annotate every row of the (fill-joined) element, the vector length will + # differ from the rendered element row count. Warn + align so per-shape lookup stays + # well-defined. + _n_shapes = len(sdata_filt[element]) + if outline_color_vector is not None and len(outline_color_vector) != _n_shapes: + logger.warning( + f"Outline column '{col_for_outline_color}' does not fully annotate " + f"element '{element}' under its fill-joined alignment " + f"({len(outline_color_vector)} of {_n_shapes} rows). Missing rows will use na_color." + ) + outline_color_vector, outline_color_source_vector = _align_outline_vector_to_length( + outline_color_vector, + outline_color_source_vector, + _n_shapes, + ) + _warn_groups_ignored_continuous(groups, color_source_vector, col_for_color) if groups is not None and color_source_vector is not None: @@ -495,6 +692,10 @@ def _render_shapes( if len(shapes) == 0: return sdata_filt[element] = shapes + if outline_color_vector is not None: + outline_color_vector, outline_color_source_vector = _apply_mask_to_outline_vectors( + outline_color_vector, outline_color_source_vector, keep + ) # color_source_vector is None when the values aren't categorical if not values_are_categorical and render_params.transfunc is not None: @@ -686,6 +887,8 @@ def _render_shapes( factor, x_min=x_ext[0], y_min=y_ext[0], + outline_color_vector=outline_color_vector, + outline_color_source_vector=outline_color_source_vector, ) _cax = _render_ds_image( @@ -702,7 +905,31 @@ def _render_shapes( elif method == "matplotlib": # render outlines separately to ensure they are always underneath the shape - if render_params.outline_alpha[0] > 0 and isinstance(render_params.outline_params.outer_outline_color, Color): + if col_for_outline_color is not None and render_params.outline_alpha[0] > 0: + outline_rgba = _color_vector_to_rgba( + outline_color_vector, + outline_color_source_vector, + render_params.cmap_params, + n_rows=len(shapes), + ) + _cax = _get_collection_shape( + shapes=shapes, + s=render_params.scale, + c=np.array(["white"]), # hack, will be invisible bc fill_alpha=0 + render_params=render_params, + rasterized=sc_settings._vector_friendly, + cmap=None, + norm=None, + fill_alpha=0.0, + outline_alpha=render_params.outline_alpha[0], + outline_color=outline_rgba, + linewidth=render_params.outline_params.outer_outline_linewidth, + zorder=render_params.zorder, + ) + ax.add_collection(_cax) + for path in _cax.get_paths(): + path.vertices = trans.transform(path.vertices) + elif render_params.outline_alpha[0] > 0 and isinstance(render_params.outline_params.outer_outline_color, Color): _cax = _get_collection_shape( shapes=shapes, s=render_params.scale, @@ -798,6 +1025,10 @@ def _render_shapes( colorbar=render_params.colorbar, colorbar_params=render_params.colorbar_params, colorbar_requests=colorbar_requests, + outline_col_for_color=col_for_outline_color, + outline_color_source_vector=outline_color_source_vector, + outline_color_vector=outline_color_vector, + outline_cmap_params=render_params.cmap_params, ) @@ -1758,6 +1989,13 @@ def _render_labels( groups = render_params.groups scale = render_params.scale + # When fill is a literal (no `color=` column) but outline points to an obs column, + # promote the outline table to be the "active" table for instance_id derivation so + # the outline color vector aligns to label IDs by the table's instance_key rather + # than by positional index. + if table_name is None and render_params.outline_table_name is not None: + table_name = render_params.outline_table_name + _check_obs_var_shadow(sdata, element, col_for_color, table_name) # filter_tables=False: match_table_to_element below already filters per @@ -1843,19 +2081,55 @@ def _render_labels( coordinate_system=coordinate_system, ) + # Outline color lookup must run BEFORE any masking so the returned vector aligns to + # the original instance_id. The same masks applied to fill below are then applied + # to the outline vectors to keep lengths consistent. + col_for_outline_color = render_params.col_for_outline_color + outline_table_name = render_params.outline_table_name + outline_color_source_vector: pd.Series | None = None + outline_color_vector: Any = None + if col_for_outline_color is not None: + outline_color_source_vector, outline_color_vector, _ = _set_color_source_vec( + sdata=sdata_filt, + element=label, + element_name=element, + value_to_plot=col_for_outline_color, + groups=None, + palette=palette, + na_color=render_params.cmap_params.na_color, + cmap_params=render_params.cmap_params, + table_name=outline_table_name, + table_layer=table_layer, + render_type="labels", + coordinate_system=coordinate_system, + ) + # Align to instance_id so the rasterize/groups masks (computed against + # instance_id) can be applied without IndexError when the outline table + # annotates a subset of the labels. + outline_color_vector, outline_color_source_vector = _align_outline_vector_to_length( + outline_color_vector, + outline_color_source_vector, + len(instance_id), + ) + # rasterize could have removed labels from label # only problematic if color is specified - if rasterize and col_for_color is not None: + if rasterize and (col_for_color is not None or col_for_outline_color is not None): labels_in_rasterized_image = np.unique(label.values) mask = np.isin(instance_id, labels_in_rasterized_image) instance_id = instance_id[mask] - color_vector = color_vector[mask] - if isinstance(color_vector.dtype, pd.CategoricalDtype): - color_vector = color_vector.remove_unused_categories() - assert color_source_vector is not None - color_source_vector = color_source_vector[mask] - else: - assert color_source_vector is None + if col_for_color is not None: + color_vector = color_vector[mask] + if isinstance(color_vector.dtype, pd.CategoricalDtype): + color_vector = color_vector.remove_unused_categories() + assert color_source_vector is not None # noqa: S101 + color_source_vector = color_source_vector[mask] + else: + assert color_source_vector is None # noqa: S101 + if outline_color_vector is not None: + outline_color_vector, outline_color_source_vector = _apply_mask_to_outline_vectors( + outline_color_vector, outline_color_source_vector, mask + ) _warn_groups_ignored_continuous(groups, color_source_vector, col_for_color) @@ -1881,6 +2155,10 @@ def _render_labels( color_vector = color_vector[keep_vec] if isinstance(color_vector.dtype, pd.CategoricalDtype): color_vector = color_vector.remove_unused_categories() + if outline_color_vector is not None: + outline_color_vector, outline_color_source_vector = _apply_mask_to_outline_vectors( + outline_color_vector, outline_color_source_vector, keep_vec + ) # color_source_vector is None when the values aren't categorical if color_source_vector is None and render_params.transfunc is not None: @@ -1902,6 +2180,8 @@ def _draw_labels( seg_boundaries=seg_boundaries, na_color=na_color, outline_color=outline_color, + outline_color_vector=outline_color_vector if seg_boundaries else None, + outline_color_source_vector=outline_color_source_vector if seg_boundaries else None, ) cax = ax.imshow( @@ -1975,6 +2255,14 @@ def _draw_labels( is_continuous=col_for_color is not None and color_source_vector is None and not categorical, ) + # Auto-title the fill legend only when an outline legend will also be drawn. + outline_legend_will_render = col_for_outline_color is not None and outline_color_source_vector is not None + if legend_params.legend_title is not None: + fill_title: str | None = legend_params.legend_title or None + elif outline_legend_will_render and color_source_vector is not None: + fill_title = "fill" + else: + fill_title = None _ = _decorate_axs( ax=ax, cax=cax, @@ -1998,8 +2286,24 @@ def _draw_labels( render_params.colorbar_params, col_for_color if isinstance(col_for_color, str) else None, ), + legend_title=fill_title, ) + if col_for_outline_color is not None: + _decorate_outline( + ax=ax, + fig_params=fig_params, + outline_col=col_for_outline_color, + outline_color_source_vector=outline_color_source_vector, + outline_color_vector=outline_color_vector, + cmap_params=render_params.cmap_params, + colorbar_params=render_params.colorbar_params, + colorbar_requests=colorbar_requests, + legend_params=legend_params, + fill_has_legend=col_for_color is not None and color_source_vector is not None, + alpha=alpha_to_decorate_ax, + ) + def _normalise_to_range(values: np.ndarray, lo: float, hi: float) -> np.ndarray: """Min-max normalise a 1-D array into ``[lo, hi]``. Constant input → midpoint.""" diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index a0414d6b..1d18a725 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -188,6 +188,11 @@ class LegendParams: legend_fontoutline: int | None = None na_in_legend: bool = True colorbar: bool = True + # Optional explicit titles for the fill / outline categorical legends. When unset, + # both legends are untitled unless both fill and outline are colored by an obs + # column, in which case they default to "fill" / "outline" to disambiguate. + legend_title: str | None = None + outline_legend_title: str | None = None @dataclass @@ -232,6 +237,8 @@ class ShapesRenderParams: element: str color: Color | None = None col_for_color: str | None = None + col_for_outline_color: str | None = None + outline_table_name: str | None = None groups: str | list[str] | None = None palette: ListedColormap | dict[str, str] | list[str] | None = None outline_alpha: tuple[float, float] = (1.0, 1.0) @@ -298,6 +305,8 @@ class LabelsRenderParams: element: str color: Color | None = None col_for_color: str | None = None + col_for_outline_color: str | None = None + outline_table_name: str | None = None groups: str | list[str] | None = None contour_px: int | None = None palette: ListedColormap | dict[str, str] | list[str] | None = None diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a64b0fd2..66317748 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -30,6 +30,7 @@ from geopandas import GeoDataFrame from matplotlib import colors, patheffects, rcParams from matplotlib.axes import Axes +from matplotlib.cm import ScalarMappable from matplotlib.collections import PatchCollection from matplotlib.colors import ( ColorConverter, @@ -61,6 +62,7 @@ get_element_annotators, get_extent, get_values, + join_spatialelement_table, rasterize, ) from spatialdata._core.query.relational_query import _locate_value @@ -431,6 +433,170 @@ def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor pathpatch.get_path().vertices = scaled_vertices +def _join_table_for_element( + sdata: sd.SpatialData, + element: str, + table_name: str, +) -> tuple[Any, AnnData]: + """Inner-join ``element`` with its annotating ``table_name``. + + Wraps the workaround for scverse/spatialdata#1099: ``join_spatialelement_table`` + calls ``table.obs.reset_index()`` which fails when the obs index name matches + an existing column (e.g. "EntityID" in Merfish data). When that collision is + present, the obs index may also be a non-RangeIndex of int dtype, which + AnnData's ``_normalize_index`` rejects when the join indexes back into the + table. Temporarily swap to a clean RangeIndex / drop the conflicting name; + restore on exit. + + Also patches ``joined_table.uns["spatialdata_attrs"]["region"]`` to the + actual unique regions after the join so downstream lookups see consistent + metadata. + """ + _obs = sdata[table_name].obs + _saved_index_name = _obs.index.name + _saved_index: pd.Index | None = None + _name_collides = _saved_index_name is not None and _saved_index_name in _obs.columns + if _name_collides and not isinstance(_obs.index, pd.RangeIndex): + _saved_index = _obs.index + _obs.index = pd.RangeIndex(len(_obs)) + elif _name_collides: + _obs.index.name = None + + try: + element_dict, joined_table = join_spatialelement_table( + sdata, spatial_element_names=element, table_name=table_name, how="inner" + ) + finally: + if _saved_index is not None: + _obs.index = _saved_index + _obs.index.name = _saved_index_name + + joined_table.uns["spatialdata_attrs"]["region"] = ( + joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist() + ) + return element_dict[element], joined_table + + +def _make_continuous_mappable(vmin: float, vmax: float, cmap: Any) -> ScalarMappable: + """Build a ``ScalarMappable`` for a continuous colorbar, with a ±0.5 fallback when ``vmin == vmax``.""" + if vmin == vmax: + vmin, vmax = vmin - 0.5, vmax + 0.5 + return ScalarMappable(norm=Normalize(vmin=vmin, vmax=vmax), cmap=cmap) + + +def _apply_mask_to_outline_vectors( + outline_color_vector: Any, + outline_color_source_vector: pd.Series | None, + mask: Any, +) -> tuple[Any, pd.Series | None]: + """Apply a boolean ``keep`` mask to outline color vector(s). + + Used to keep outline data aligned with the fill data after a ``groups`` + or rasterize-based filter is applied to the rendered element. + """ + arr = np.asarray(mask) + if outline_color_source_vector is not None: + outline_color_source_vector = outline_color_source_vector[arr] + return outline_color_vector[arr], outline_color_source_vector + + +def _align_outline_vector_to_length( + outline_color_vector: Any, + outline_color_source_vector: pd.Series | None, + n: int, +) -> tuple[Any, pd.Series | None]: + """Pad or truncate the outline color vector(s) to length ``n``. + + Used when the outline column annotates a different row count than the rendered + element (cross-table case, or rasterize-induced label drop). Missing entries + are padded with NaN so downstream code maps them to ``na_color``. + """ + if outline_color_vector is None or len(outline_color_vector) == n: + return outline_color_vector, outline_color_source_vector + if len(outline_color_vector) > n: + if outline_color_source_vector is not None: + outline_color_source_vector = outline_color_source_vector[:n] + return outline_color_vector[:n], outline_color_source_vector + pad = n - len(outline_color_vector) + if outline_color_source_vector is not None: + # Categorical: downstream picks one hex per category from rows that *have* a + # category. NaN-padded rows contribute no category, so the per-row hex pad is + # immaterial; pad with NaN to skip the allocation. + padded_vec = np.concatenate([np.asarray(outline_color_vector), np.full(pad, np.nan, dtype=object)]) + outline_color_source_vector = pd.Categorical( + list(outline_color_source_vector) + [None] * pad, + categories=outline_color_source_vector.categories, + ) + else: + # Continuous: numeric vector, pad with NaN so cmap maps padded rows to na_color. + padded_vec = np.concatenate([np.asarray(outline_color_vector, dtype=float), np.full(pad, np.nan)]) + return padded_vec, outline_color_source_vector + + +def _color_vector_to_rgba( + color_vector: Any | None, + color_source_vector: pd.Series | None, + cmap_params: CmapParams, + n_rows: int, +) -> np.ndarray: + """Convert a fill/outline `color_vector` (categorical hex strings or continuous numerics) to (N, 4) RGBA. + + Mirrors the per-row mapping done inside :func:`_get_collection_shape` so that + callers can pre-materialize an outline-color array. NaN/non-finite entries are + painted with ``cmap_params.na_color``. + """ + na_rgba = colors.to_rgba(cmap_params.na_color.get_hex_with_alpha()) + if color_vector is None: + rgba = np.empty((n_rows, 4), dtype=float) + rgba[:] = na_rgba + return rgba + + if color_source_vector is not None: + # Categorical: color_vector contains hex strings aligned to color_source_vector + return np.asarray(ColorConverter().to_rgba_array(list(color_vector))) + + arr = np.asarray(color_vector) + if arr.ndim == 2 and arr.shape[1] in (3, 4) and np.issubdtype(arr.dtype, np.number): + return np.asarray(ColorConverter().to_rgba_array(arr)) + + rgba = np.empty((len(arr), 4), dtype=float) + rgba[:] = na_rgba + if np.issubdtype(arr.dtype, np.number): + finite_mask = np.isfinite(arr) + if finite_mask.any(): + norm = cmap_params.norm + if norm.vmin is None or norm.vmax is None: + vmin = float(np.nanmin(arr[finite_mask])) + vmax = float(np.nanmax(arr[finite_mask])) + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: + vmin, vmax = 0.0, 1.0 + used_norm = Normalize(vmin=vmin, vmax=vmax, clip=False) + else: + used_norm = norm + rgba[finite_mask] = cmap_params.cmap(used_norm(arr[finite_mask])) + return rgba + + # Object dtype: mix of numerics and color-like specs (apply cmap to the numeric subset only) + series = pd.Series(arr, copy=False) + num = pd.to_numeric(series, errors="coerce").to_numpy() + is_num = np.isfinite(num) + if is_num.any(): + norm = cmap_params.norm + if norm.vmin is None or norm.vmax is None: + vmin = float(np.nanmin(num[is_num])) + vmax = float(np.nanmax(num[is_num])) + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: + vmin, vmax = 0.0, 1.0 + used_norm = Normalize(vmin=vmin, vmax=vmax, clip=False) + else: + used_norm = norm + rgba[is_num] = cmap_params.cmap(used_norm(num[is_num])) + color_mask = (~is_num) & series.notna().to_numpy() + if color_mask.any(): + rgba[color_mask] = ColorConverter().to_rgba_array(series[color_mask].tolist()) + return rgba + + def _get_collection_shape( shapes: list[GeoDataFrame], c: Any, @@ -439,7 +605,7 @@ def _get_collection_shape( render_params: ShapesRenderParams, fill_alpha: None | float = None, outline_alpha: None | float = None, - outline_color: None | str | list[float] = "white", + outline_color: None | str | list[float] | np.ndarray = "white", linewidth: float = 0.0, **kwargs: Any, ) -> PatchCollection: @@ -451,6 +617,11 @@ def _get_collection_shape( - a single color or a list of color specs. Only NaNs are painted with na_color; finite values are mapped via norm+cmap. + + .. note:: + When ``outline_color`` is passed as an ``(N, 4)`` RGBA array of dtype ``float``, + its alpha channel is mutated in place to apply ``outline_alpha``. Pass a copy + if you need to retain the original buffer. """ cmap = kwargs["cmap"] @@ -534,7 +705,13 @@ def _as_rgba_array(x: Any) -> np.ndarray: # Outline handling if outline_alpha and outline_alpha > 0.0: - outline_c_array = _as_rgba_array(outline_color) + outline_arr = np.asarray(outline_color) if not isinstance(outline_color, str) else None + if outline_arr is not None and outline_arr.ndim == 2 and outline_arr.shape == (len(shapes), 4): + # Per-shape RGBA array. Mutate in place when already float so we don't allocate twice + # on the hot path; otherwise upcast to a fresh float buffer. + outline_c_array = outline_arr if outline_arr.dtype == float else outline_arr.astype(float) + else: + outline_c_array = _as_rgba_array(outline_color) outline_c_array[..., -1] = outline_alpha outline_c = outline_c_array.tolist() else: @@ -1315,6 +1492,8 @@ def _map_color_seg( seg_erosionpx: int | None = None, seg_boundaries: bool = False, outline_color: Color | None = None, + outline_color_vector: ArrayLike | pd.Series[CategoricalDtype] | None = None, + outline_color_source_vector: pd.Series[CategoricalDtype] | None = None, ) -> ArrayLike: cell_id = np.array(cell_id) @@ -1358,6 +1537,53 @@ def _map_color_seg( if seg_erosionpx is not None: val_im[val_im == erosion(val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx)))] = 0 + if seg_boundaries and outline_color_vector is not None: + # Column-driven outline: build per-label colors from the outline vector and overlay + # on the eroded ring. Two cases (mirroring _set_color_source_vec's return contract): + # - categorical: outline_color_source_vector is the source Categorical; outline_color_vector + # holds hex strings aligned to cells. + # - continuous: outline_color_source_vector is None; outline_color_vector is numeric. + if outline_color_source_vector is not None: + cat = pd.Categorical(outline_color_source_vector) + cat_codes = cat.codes + outline_val_im: ArrayLike = map_array(seg.copy(), cell_id, cat_codes + 1) + color_arr = np.asarray(outline_color_vector, dtype=object) + # Pick the first per-cell hex for each category in one vectorized pass + # (avoids `K × O(N)` Python loops on large label sets). + cat_colors: list[Any] = [na_color.get_hex_with_alpha()] * len(cat.categories) + unique_codes, first_indices = np.unique(cat_codes, return_index=True) + for code, idx in zip(unique_codes, first_indices, strict=True): + if code >= 0: + cat_colors[code] = color_arr[idx] + outline_cols = colors.to_rgba_array(cat_colors) + else: + # Continuous: numeric values normalized via cmap + ov = ( + outline_color_vector.to_numpy() + if isinstance(outline_color_vector, pd.Series) + else np.asarray(outline_color_vector) + ) + normed = ov.copy().astype(float) + finite = ~np.isnan(normed) + if finite.any(): + normed[finite] = cmap_params.norm(normed[finite]) + outline_cols = cmap_params.cmap(normed) + outline_val_im = map_array(seg.copy(), cell_id, cell_id) + if seg_erosionpx is not None: + outline_val_im[ + outline_val_im == erosion(outline_val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx))) + ] = 0 + outline_seg_im = label2rgb( + label=outline_val_im, + colors=outline_cols, + bg_label=0, + bg_color=(1, 1, 1), + image_alpha=0, + ) + outline_mask = val_im > 0 + alpha_channel = outline_mask.astype(float) + return np.dstack((outline_seg_im, alpha_channel)) + if seg_boundaries and outline_color is not None: # Uniform outline color requested: skip label2rgb, build RGBA directly outline_rgba = colors.to_rgba(outline_color.get_hex_with_alpha()) @@ -1751,6 +1977,7 @@ def _decorate_axs( colorbar_params: dict[str, object] | None = None, colorbar_requests: list[ColorbarSpec] | None = None, colorbar_label: str | None = None, + legend_title: str | None = None, ) -> Axes: if value_to_plot is not None: # if only dots were plotted without an associated value @@ -1786,6 +2013,10 @@ def _decorate_axs( na_in_legend=na_in_legend, multi_panel=fig_params.axs is not None, ) + # scanpy's helper doesn't accept a title; set it post-hoc so the user can + # disambiguate fill vs outline when both legends are drawn. + if legend_title is not None and (legend := ax.get_legend()) is not None: + legend.set_title(legend_title) elif colorbar and colorbar_requests is not None and cax is not None: colorbar_requests.append( ColorbarSpec( @@ -2486,9 +2717,11 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") outline_color = param_dict.get("outline_color") + if "outline_color" in param_dict and element_type in {"shapes", "labels"}: + param_dict["col_for_outline_color"] = None if outline_color: if not isinstance(outline_color, str | tuple | list): - raise TypeError("Parameter 'color' must be a string or a tuple/list of floats or colors.") + raise TypeError("Parameter 'outline_color' must be a string or a tuple/list of floats or colors.") if isinstance(outline_color, tuple | list): if len(outline_color) < 1: raise ValueError("Empty tuple is not supported as input for outline_color!") @@ -2505,6 +2738,18 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st f"Tuple/List of length {len(outline_color)} was passed for outline_color. Valid options would be: " "tuple of 2 colors (for 2 outlines) or an RGB(A) array, aka a list/tuple of 3-4 floats." ) + elif isinstance(outline_color, str) and element_type in {"shapes", "labels"}: + if _is_color_like(outline_color): + _check_color_column_collision(param_dict["sdata"], param_dict["element"], outline_color, element_type) + param_dict["outline_color"] = Color(outline_color) + else: + if isinstance(param_dict.get("outline_width"), tuple): + raise ValueError( + "Coloring outlines by a column is not supported with two outlines. " + "Pass a scalar `outline_width` or a literal color for `outline_color`." + ) + param_dict["col_for_outline_color"] = outline_color + param_dict["outline_color"] = None else: param_dict["outline_color"] = Color(outline_color) @@ -2797,6 +3042,20 @@ def _validate_label_render_params( element_params[el]["table_name"] = table_name element_params[el]["col_for_color"] = col_for_color + element_params[el]["col_for_outline_color"] = None + element_params[el]["outline_table_name"] = None + if (col_for_outline_color := param_dict.get("col_for_outline_color")) is not None: + col_for_outline_color, outline_table_name = _validate_col_for_column_table( + sdata, + el, + col_for_outline_color, + param_dict["table_name"], + labels=True, + gene_symbols=gene_symbols, + ) + element_params[el]["col_for_outline_color"] = col_for_outline_color + element_params[el]["outline_table_name"] = outline_table_name + _gate_palette_and_groups(element_params[el], param_dict) element_params[el]["colorbar"] = param_dict["colorbar"] element_params[el]["colorbar_params"] = param_dict["colorbar_params"] @@ -3000,6 +3259,16 @@ def _validate_shape_render_params( element_params[el]["table_name"] = table_name element_params[el]["col_for_color"] = col_for_color + element_params[el]["col_for_outline_color"] = None + element_params[el]["outline_table_name"] = None + col_for_outline_color = param_dict.get("col_for_outline_color") + if col_for_outline_color is not None: + col_for_outline_color, outline_table_name = _validate_col_for_column_table( + sdata, el, col_for_outline_color, param_dict["table_name"], gene_symbols=gene_symbols + ) + element_params[el]["col_for_outline_color"] = col_for_outline_color + element_params[el]["outline_table_name"] = outline_table_name + _gate_palette_and_groups(element_params[el], param_dict) element_params[el]["method"] = param_dict["method"] element_params[el]["ds_reduction"] = param_dict["ds_reduction"] @@ -3591,6 +3860,18 @@ def _create_image_from_datashader_result( return rgba_image, trans_data +_DS_REDUCTION_FUNCS: dict[str, Any] = { + "sum": ds.sum, + "mean": ds.mean, + "any": ds.any, + "count": ds.count, + "std": ds.std, + "var": ds.var, + "max": ds.max, + "min": ds.min, +} + + def _datashader_aggregate_with_function( reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, cvs: Canvas, @@ -3615,22 +3896,11 @@ def _datashader_aggregate_with_function( if reduction is None: reduction = "sum" - reduction_function_map = { - "sum": ds.sum, - "mean": ds.mean, - "any": ds.any, - "count": ds.count, - "std": ds.std, - "var": ds.var, - "max": ds.max, - "min": ds.min, - } - try: - reduction_function = reduction_function_map[reduction](column=col_for_color) + reduction_function = _DS_REDUCTION_FUNCS[reduction](column=col_for_color) except KeyError as e: raise ValueError( - f"Reduction '{reduction}' is not supported. Please use one of: {', '.join(reduction_function_map.keys())}." + f"Reduction '{reduction}' is not supported. Please use one of: {', '.join(_DS_REDUCTION_FUNCS.keys())}." ) from e element_function_map = { diff --git a/tests/_images/Labels_outline_color_by_categorical_obs_labels.png b/tests/_images/Labels_outline_color_by_categorical_obs_labels.png new file mode 100644 index 00000000..c3666607 Binary files /dev/null and b/tests/_images/Labels_outline_color_by_categorical_obs_labels.png differ diff --git a/tests/_images/Shapes_fill_and_outline_both_obs_columns.png b/tests/_images/Shapes_fill_and_outline_both_obs_columns.png new file mode 100644 index 00000000..9d9f81ac Binary files /dev/null and b/tests/_images/Shapes_fill_and_outline_both_obs_columns.png differ diff --git a/tests/_images/Shapes_outline_color_by_categorical_obs.png b/tests/_images/Shapes_outline_color_by_categorical_obs.png new file mode 100644 index 00000000..cd61b9af Binary files /dev/null and b/tests/_images/Shapes_outline_color_by_categorical_obs.png differ diff --git a/tests/_images/Shapes_outline_color_by_categorical_obs_datashader.png b/tests/_images/Shapes_outline_color_by_categorical_obs_datashader.png new file mode 100644 index 00000000..96bee1b4 Binary files /dev/null and b/tests/_images/Shapes_outline_color_by_categorical_obs_datashader.png differ diff --git a/tests/_images/Shapes_outline_color_by_continuous_obs.png b/tests/_images/Shapes_outline_color_by_continuous_obs.png new file mode 100644 index 00000000..461569ab Binary files /dev/null and b/tests/_images/Shapes_outline_color_by_continuous_obs.png differ diff --git a/tests/_images/Shapes_outline_color_by_continuous_obs_datashader.png b/tests/_images/Shapes_outline_color_by_continuous_obs_datashader.png new file mode 100644 index 00000000..92d7a023 Binary files /dev/null and b/tests/_images/Shapes_outline_color_by_continuous_obs_datashader.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 4c76b1a1..2024bd24 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -28,6 +28,16 @@ # ".png" is appended to , no need to set it +def _annotate_labels_with_outline_columns(sdata: SpatialData) -> SpatialData: + """Patch the shared blobs fixture so its table annotates ``blobs_labels`` with categorical columns.""" + sdata["table"].obs["region"] = pd.Categorical(["blobs_labels"] * sdata["table"].n_obs) + sdata["table"].uns["spatialdata_attrs"]["region"] = "blobs_labels" + n = sdata["table"].n_obs + sdata["table"].obs["cluster"] = pd.Categorical((["c1", "c2"] * ((n + 1) // 2))[:n]) + sdata["table"].obs["stage"] = pd.Categorical((["s1", "s2"] * ((n + 1) // 2))[:n]) + return sdata + + class TestLabels(PlotTester, metaclass=PlotTesterMeta): def test_plot_can_render_labels(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels(element="blobs_labels").pl.show() @@ -120,6 +130,12 @@ def test_plot_outline_uses_data_driven_colors(self, sdata_blobs: SpatialData): "blobs_labels", color="channel_0_sum", outline_alpha=1, fill_alpha=0, contour_px=10 ).pl.show() + def test_plot_outline_color_by_categorical_obs_labels(self, sdata_blobs: SpatialData): + sdata_blobs = _annotate_labels_with_outline_columns(sdata_blobs) + sdata_blobs.pl.render_labels( + "blobs_labels", fill_alpha=0, outline_alpha=1, outline_color="cluster", contour_px=10 + ).pl.show() + def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show() @@ -565,3 +581,17 @@ def test_render_labels_raises_on_3d(scale_factors): sdata.pl.render_labels("lbl3d").pl.show(ax=ax) finally: plt.close(fig) + + +def test_labels_outline_color_groups_filter_aligns(sdata_blobs: SpatialData): + """When `groups` filters the fill labels, the outline vector must be masked alongside it.""" + sdata_blobs = _annotate_labels_with_outline_columns(sdata_blobs) + fig, ax = plt.subplots() + sdata_blobs.pl.render_labels( + "blobs_labels", + color="cluster", + groups=["c1"], + outline_alpha=1.0, + outline_color="stage", + ).pl.show(ax=ax) + plt.close(fig) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index f77f0ed8..1254bbcb 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -33,6 +33,20 @@ # ".png" is appended to , no need to set it +def _annotate_polygons_with_outline_columns(sdata: SpatialData) -> SpatialData: + """Patch the shared blobs fixture so its table annotates ``blobs_polygons``. + + Adds two categorical columns and one continuous column for the outline-color tests. + """ + sdata["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata["table"].n_obs) + sdata["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + n = sdata["table"].n_obs + sdata["table"].obs["cluster"] = pd.Categorical((["c1", "c2"] * ((n + 1) // 2))[:n]) + sdata["table"].obs["stage"] = pd.Categorical((["s1", "s2"] * ((n + 1) // 2))[:n]) + sdata["table"].obs["value"] = np.linspace(0.0, 1.0, n) + return sdata + + class TestShapes(PlotTester, metaclass=PlotTesterMeta): def test_plot_can_render_circles(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_circles").pl.show() @@ -62,6 +76,46 @@ def test_plot_can_render_polygons_with_rgba_colored_outline(self, sdata_blobs: S element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0) ).pl.show() + def test_plot_outline_color_by_categorical_obs(self, sdata_blobs: SpatialData): + sdata_blobs = _annotate_polygons_with_outline_columns(sdata_blobs) + sdata_blobs.pl.render_shapes( + "blobs_polygons", color="white", outline_alpha=1, outline_width=3, outline_color="cluster" + ).pl.show() + + def test_plot_outline_color_by_continuous_obs(self, sdata_blobs: SpatialData): + sdata_blobs = _annotate_polygons_with_outline_columns(sdata_blobs) + sdata_blobs.pl.render_shapes( + "blobs_polygons", color="white", outline_alpha=1, outline_width=3, outline_color="value" + ).pl.show() + + def test_plot_outline_color_by_categorical_obs_datashader(self, sdata_blobs: SpatialData): + sdata_blobs = _annotate_polygons_with_outline_columns(sdata_blobs) + sdata_blobs.pl.render_shapes( + "blobs_polygons", + color="white", + outline_alpha=1, + outline_width=3, + outline_color="cluster", + method="datashader", + ).pl.show() + + def test_plot_outline_color_by_continuous_obs_datashader(self, sdata_blobs: SpatialData): + sdata_blobs = _annotate_polygons_with_outline_columns(sdata_blobs) + sdata_blobs.pl.render_shapes( + "blobs_polygons", + color="white", + outline_alpha=1, + outline_width=3, + outline_color="value", + method="datashader", + ).pl.show() + + def test_plot_fill_and_outline_both_obs_columns(self, sdata_blobs: SpatialData): + sdata_blobs = _annotate_polygons_with_outline_columns(sdata_blobs) + sdata_blobs.pl.render_shapes( + "blobs_polygons", color="cluster", outline_alpha=1, outline_width=3, outline_color="stage" + ).pl.show() + def test_plot_can_render_empty_geometry(self, sdata_blobs: SpatialData): sdata_blobs.shapes["blobs_circles"].at[0, "geometry"] = gpd.points_from_xy([None], [None])[0] sdata_blobs.pl.render_shapes().pl.show() @@ -1473,3 +1527,66 @@ def test_render_shapes_datashader_under_bbox_query_does_not_crash(): cropped_sdata.pl.render_shapes("shapes", method="datashader").pl.show(ax=ax) finally: plt.close(fig) + + +def test_outline_color_column_with_two_outlines_raises(sdata_blobs: SpatialData): + sdata_blobs = _annotate_polygons_with_outline_columns(sdata_blobs) + with pytest.raises(ValueError, match="not supported with two outlines"): + sdata_blobs.pl.render_shapes("blobs_polygons", outline_width=(2.0, 0.5), outline_color="cluster") + + +def test_outline_color_column_groups_filter_aligns(sdata_blobs: SpatialData): + """When `groups` filters the fill, the outline vector must be masked alongside it.""" + sdata_blobs = _annotate_polygons_with_outline_columns(sdata_blobs) + fig, ax = plt.subplots() + # This used to raise IndexError when outline vector wasn't filtered with the fill mask + sdata_blobs.pl.render_shapes( + "blobs_polygons", + color="cluster", + groups=["c1"], + outline_width=2, + outline_color="stage", + ).pl.show(ax=ax) + plt.close(fig) + + +def test_outline_color_column_collision_raises(sdata_blobs: SpatialData): + """If `outline_color` is a string that is both a matplotlib color and an obs column, raise.""" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + n = sdata_blobs["table"].n_obs + # Add an obs column whose name shadows a real color + sdata_blobs["table"].obs["red"] = pd.Categorical((["a", "b"] * ((n + 1) // 2))[:n]) + with pytest.raises(ValueError, match=r"ambiguous|matplotlib color name AND a column"): + sdata_blobs.pl.render_shapes("blobs_polygons", outline_width=2, outline_color="red", outline_alpha=1.0) + + +def test_outline_color_cross_table(sdata_blobs: SpatialData): + """Fill column on table A, outline column on a separate table B.""" + # Patch original table to annotate blobs_polygons with a fill column. + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["cluster"] = pd.Categorical((["c1", "c2"] * ((n + 1) // 2))[:n]) + # Build a second table that ALSO annotates blobs_polygons but with a different column. + adata2 = AnnData(get_standard_RNG().normal(size=(n, 2))) + adata2.var = pd.DataFrame({}, index=["g1", "g2"]) + adata2.obs = pd.DataFrame( + { + "instance_id": list(range(n)), + "region": pd.Categorical(["blobs_polygons"] * n), + "stage": pd.Categorical((["s1", "s2"] * ((n + 1) // 2))[:n]), + } + ) + sdata_blobs["table_outline"] = TableModel.parse( + adata=adata2, region_key="region", instance_key="instance_id", region="blobs_polygons" + ) + fig, ax = plt.subplots() + # Don't pin table_name — let validation auto-resolve each column to its annotating table. + sdata_blobs.pl.render_shapes( + "blobs_polygons", + color="cluster", + outline_width=2, + outline_color="stage", + ).pl.show(ax=ax) + plt.close(fig)