diff --git a/benchmarks/README.md b/benchmarks/README.md index 9f8903620..6ae1d7d03 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -65,6 +65,32 @@ git checkout - && git stash pop asv compare main HEAD ``` +### Dataloader benchmarks + +Dataloader benchmarks live in `benchmarks/benchmark_dataloader.py`. They use a synthetic in-memory `SpatialData` (2048×2048 image, 500 circle regions) and compute two metrics: + +- `time_init` — constructing `ImageTilesDataset` (includes bounding-box pre-computation). +- `time_fetch` — iterating over all 500 tiles once (pure `__getitem__` calls, no `DataLoader` overhead). + +Run both in your current environment: + +```bash +asv run --python=same --show-stderr -b TimeDataloader +``` + +Or a single method: + +```bash +asv run --python=same --show-stderr -b TimeDataloader.time_init +asv run --python=same --show-stderr -b TimeDataloader.time_fetch +``` + +Compare against `main` in one shot: + +```bash +asv continuous --show-stderr -v -b TimeDataloader main HEAD +``` + ### Querying benchmarks Querying using a bounding box without a spatial index is highly impacted by large amounts of points (transcripts), more than table rows (cells). diff --git a/benchmarks/benchmark_dataloader.py b/benchmarks/benchmark_dataloader.py new file mode 100644 index 000000000..474b658ba --- /dev/null +++ b/benchmarks/benchmark_dataloader.py @@ -0,0 +1,75 @@ +# type: ignore +"""Benchmarks for ImageTilesDataset: init time and iteration time.""" + +from __future__ import annotations + +import anndata as ad +import geopandas as gpd +import numpy as np +import pandas as pd +from shapely.geometry import Point + +import spatialdata as sd +from spatialdata.dataloader import ImageTilesDataset +from spatialdata.models import Image2DModel, ShapesModel, TableModel +from spatialdata.transformations import Identity + +_RNG = np.random.default_rng(42) + +_IMAGE_SIZE = 2048 +_N_CIRCLES = 500 +_N_CHANNELS = 3 + +_DATASET_KWARGS = { + "regions_to_images": {"circles": "image"}, + "regions_to_coordinate_systems": {"circles": "global"}, + "table_name": "table", + "return_annotations": "instance_id", +} + + +def _make_sdata() -> sd.SpatialData: + img_data = _RNG.integers(0, 256, size=(_N_CHANNELS, _IMAGE_SIZE, _IMAGE_SIZE), dtype=np.uint8).astype(np.float32) + image = Image2DModel.parse(img_data, dims=["c", "y", "x"], transformations={"global": Identity()}) + + radius = 32.0 + cx = _RNG.uniform(radius, _IMAGE_SIZE - radius, size=_N_CIRCLES) + cy = _RNG.uniform(radius, _IMAGE_SIZE - radius, size=_N_CIRCLES) + geom = gpd.GeoDataFrame({"geometry": [Point(x, y) for x, y in zip(cx, cy, strict=True)]}) + geom["radius"] = radius + circles = ShapesModel.parse(geom, transformations={"global": Identity()}) + + table = ad.AnnData( + _RNG.random((_N_CIRCLES, 10)).astype(np.float32), + obs=pd.DataFrame( + { + "region": pd.Categorical(["circles"] * _N_CIRCLES), + "instance_id": np.arange(_N_CIRCLES, dtype=np.int64), + }, + index=[str(i) for i in range(_N_CIRCLES)], + ), + ) + table = TableModel.parse(table, region="circles", region_key="region", instance_key="instance_id") + + return sd.SpatialData(images={"image": image}, shapes={"circles": circles}, tables={"table": table}) + + +class TimeDataloader: + """Time ImageTilesDataset construction and tile iteration.""" + + def setup(self): + self.sdata = _make_sdata() + self.ds = ImageTilesDataset(sdata=self.sdata, **_DATASET_KWARGS) + + def teardown(self): + del self.ds + del self.sdata + + def time_init(self): + """Time constructing ImageTilesDataset (bounding-box pre-computation).""" + ImageTilesDataset(sdata=self.sdata, **_DATASET_KWARGS) + + def time_fetch(self): + """Time iterating over every tile once.""" + for i in range(len(self.ds)): + _ = self.ds[i] diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index f66679ca2..d21ee63e2 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -91,11 +91,11 @@ def get_bounding_box_corners( return output.squeeze().drop_vars("box") -@nb.jit(parallel=False, nopython=True) +@nb.njit(parallel=False) def _create_slices_and_translation( - min_values: nb.types.Array, - max_values: nb.types.Array, -) -> tuple[nb.types.Array, nb.types.Array]: + min_values: np.ndarray, + max_values: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: n_boxes, n_dims = min_values.shape slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 465fdf665..475c36f4f 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -556,7 +556,7 @@ def _( min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) - # for triggering validation + # for triggering validation (handles both 1-D single-box and 2-D multi-box arrays) _ = BoundingBoxRequest( target_coordinate_system=target_coordinate_system, axes=axes, diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 6a105b681..03879abc8 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -127,7 +127,9 @@ def __init__( from spatialdata import bounding_box_query from spatialdata._core.operations.rasterize import rasterize as rasterize_fn - self._validate(sdata, regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) + self.sdata = sdata + self._rasterize = rasterize + self._validate(regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name) if rasterize_kwargs is not None and len(rasterize_kwargs) > 0 and rasterize is False: @@ -151,14 +153,12 @@ def __init__( def _validate( self, - sdata: SpatialData, regions_to_images: dict[str, str], regions_to_coordinate_systems: dict[str, str], return_annotations: str | list[str] | None, table_name: str | None, ) -> None: """Validate input parameters.""" - self.sdata = sdata if return_annotations is not None and table_name is None: raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.") @@ -173,8 +173,8 @@ def _validate( image_name = regions_to_images[region_name] # get elements - region_elem = sdata[region_name] - image_elem = sdata[image_name] + region_elem = self.sdata[region_name] + image_elem = self.sdata[image_name] # check that the elements are supported if get_model(region_elem) == PointsModel: @@ -199,13 +199,13 @@ def _validate( ) if table_name is not None: - _, region_key, instance_key = get_table_keys(sdata.tables[table_name]) + _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) if get_model(region_elem) in [Labels2DModel, Labels3DModel]: indices = get_element_instances(region_elem).tolist() else: indices = region_elem.index.tolist() - table = sdata.tables[table_name] - if not isinstance(sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype): + table = self.sdata.tables[table_name] + if not isinstance(self.sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype): raise TypeError( f"The `regions_element` column `{region_key}` in the table must be a categorical dtype. " f"Please convert it." @@ -228,8 +228,10 @@ def _preprocess( table_name: str | None, ) -> None: """Preprocess the dataset.""" + from spatialdata import bounding_box_query + if table_name is not None: - _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) + _, region_key, _ = get_table_keys(self.sdata.tables[table_name]) filtered_table = self.sdata.tables[table_name][ self.sdata.tables[table_name].obs[region_key].isin(self.regions) ] # filtered table for the data loader @@ -249,6 +251,18 @@ def _preprocess( tile_scale=tile_scale, tile_dim_in_units=tile_dim_in_units, ) + if not rasterize: + # Pre-compute all per-tile slice selections in a single vectorized call. + # Passing 2-D min/max arrays triggers the multi-box path in bounding_box_query, + # which returns a list of {axis: slice} dicts — one per tile. + tile_coords["selection"] = bounding_box_query( + self.sdata[image_name], + ("x", "y"), + min_coordinate=tile_coords[["minx", "miny"]].values, + max_coordinate=tile_coords[["maxx", "maxy"]].values, + target_coordinate_system=cs, + return_request_only=True, + ) tile_coords_df.append(tile_coords) inst = circles.index.values @@ -276,7 +290,7 @@ def _preprocess( self.dataset_index = pd.concat(index_df).reset_index(drop=True) assert len(self.tiles_coords) == len(self.dataset_index) if table_name: - self.dataset_table = ad.concat(*tables_l) + self.dataset_table = ad.concat(tables_l) assert len(self.tiles_coords) == len(self.dataset_table) dims_ = set(chain(*dims_l)) @@ -356,13 +370,17 @@ def __getitem__(self, idx: int) -> Any | SpatialData: t_coords = self.tiles_coords.iloc[idx] image = self.sdata[row["image"]] - tile = self._crop_image( - image, - axes=tuple(self.dims), - min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, - max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, - target_coordinate_system=row["cs"], - ) + if self._rasterize: + tile = self._crop_image( + image, + axes=tuple(self.dims), + min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, + target_coordinate_system=row["cs"], + ) + else: + # Use pre-computed slice selection (vectorized at init time). + tile = image.sel(t_coords["selection"]) if self.transform is not None: out = self._return(idx, tile) return self.transform(out)