diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index eb39824d..1c9b97eb 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -411,9 +411,11 @@ def _render_shapes( _check_obs_var_shadow(sdata, element, col_for_color, render_params.table_name) + # filter_tables=False: join_spatialelement_table below overwrites the table, + # so the cs-level sparse copy is wasted work. sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, - filter_tables=bool(render_params.table_name), + filter_tables=False, ) table_name = render_params.table_name @@ -425,11 +427,19 @@ def _render_shapes( # Workaround for upstream spatialdata bug (scverse/spatialdata#1099): # join_spatialelement_table calls table.obs.reset_index() which fails when - # the obs index name matches an existing column (e.g. "EntityID" in Merfish data). - # Temporarily drop the conflicting index name for the join, then restore it. + # the obs index name matches an existing column (e.g. "EntityID" in Merfish + # data). When that collision is present, the obs index may also be a + # non-RangeIndex of int dtype, which AnnData's `_normalize_index` rejects + # when the join indexes back into the table. Temporarily swap to a clean + # RangeIndex / drop the conflicting name; restore on exit. _obs = sdata[table_name].obs _saved_index_name = _obs.index.name - if _saved_index_name is not None and _saved_index_name in _obs.columns: + _saved_index: pd.Index | None = None + _name_collides = _saved_index_name is not None and _saved_index_name in _obs.columns + if _name_collides and not isinstance(_obs.index, pd.RangeIndex): + _saved_index = _obs.index + _obs.index = pd.RangeIndex(len(_obs)) + elif _name_collides: _obs.index.name = None try: @@ -437,6 +447,8 @@ def _render_shapes( sdata, spatial_element_names=element, table_name=table_name, how="inner" ) finally: + if _saved_index is not None: + _obs.index = _saved_index _obs.index.name = _saved_index_name sdata_filt[element] = shapes = element_dict[element] joined_table.uns["spatialdata_attrs"]["region"] = ( @@ -1748,9 +1760,11 @@ def _render_labels( _check_obs_var_shadow(sdata, element, col_for_color, table_name) + # filter_tables=False: match_table_to_element below already filters per + # element, so the cs-level sparse copy is wasted work. sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, - filter_tables=bool(table_name), + filter_tables=False, ) label = sdata_filt.labels[element]