From f1127b02152b5a782ac9fd6e5bc45eff8de17275 Mon Sep 17 00:00:00 2001 From: Rajeev Jain Date: Wed, 13 May 2026 11:09:50 -0500 Subject: [PATCH] add zonal_anomaly with shared face-band weight kernel --- uxarray/core/dataarray.py | 65 +++++++++++ uxarray/core/zonal.py | 222 ++++++++++++++++++++++++++------------ 2 files changed, 216 insertions(+), 71 deletions(-) diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index b75199cc7..8ae2315a1 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -24,6 +24,7 @@ from uxarray.core.zonal import ( _compute_conservative_zonal_mean_bands, _compute_non_conservative_zonal_mean, + _compute_zonal_anomaly, ) from uxarray.cross_sections import UxDataArrayCrossSectionAccessor from uxarray.formatting_html import array_repr @@ -767,6 +768,70 @@ def zonal_average(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs) """Alias of zonal_mean; prefer `zonal_mean` for primary API.""" return self.zonal_mean(lat=lat, conservative=conservative, **kwargs) + def zonal_anomaly(self, lat=(-90, 90, 10), conservative: bool = False): + """Compute the zonal anomaly: each face value minus the mean of its latitude band. + + Returns a new ``UxDataArray`` with the same dimensions as the input, + where each face holds its original value minus the zonal mean of the + latitude band it belongs to. + + Parameters + ---------- + lat : tuple or array-like, default=(-90, 90, 10) + Latitude band specification: + - tuple (start, end, step): band edges via np.linspace(start, end, n) + - array-like: explicit band edges in degrees + conservative : bool, default=False + If True, uses area-weighted band means and blends across bands for + faces that straddle a band boundary, reusing the face-band weight + matrix computed for zonal_mean so no geometry is duplicated. + If False, assigns each face to a band by its centroid latitude. + + Returns + ------- + UxDataArray + Same dimensions as input with per-face band mean subtracted. + + Examples + -------- + >>> uxds["var"].zonal_anomaly() + >>> uxds["var"].zonal_anomaly(lat=(-60, 60, 5), conservative=True) + """ + if not self._face_centered(): + raise ValueError( + "Zonal anomaly is only supported for face-centered data variables." + ) + + if isinstance(lat, tuple): + start, end, step = lat + if step <= 0: + raise ValueError("Step size must be positive.") + num_points = int(round((end - start) / step)) + 1 + edges = np.linspace(start, end, num_points) + edges = np.clip(edges, -90, 90) + elif isinstance(lat, (list, np.ndarray)): + edges = np.asarray(lat, dtype=float) + else: + raise ValueError( + "Invalid value for 'lat'. Must be a tuple (start, end, step) or array-like band edges." + ) + + if edges.ndim != 1 or edges.size < 2: + raise ValueError("Band edges must be 1D with at least two values.") + + res = _compute_zonal_anomaly(self, edges, conservative=conservative) + + return UxDataArray( + res, + dims=self.dims, + coords=self.coords, + name=self.name + "_zonal_anomaly" + if self.name is not None + else "zonal_anomaly", + attrs={"zonal_anomaly": True, "conservative": conservative}, + uxgrid=self.uxgrid, + ) + def azimuthal_mean( self, center_coord, diff --git a/uxarray/core/zonal.py b/uxarray/core/zonal.py index 173bb3f30..8569213fa 100644 --- a/uxarray/core/zonal.py +++ b/uxarray/core/zonal.py @@ -225,31 +225,25 @@ def _compute_band_overlap_area( return area -def _compute_conservative_zonal_mean_bands(uxda, bands): - """ - Compute conservative zonal mean over latitude bands. +def _compute_face_band_weights(uxgrid, bands): + """Compute overlap area between every face and every latitude band. - Uses get_faces_between_latitudes to optimize computation by avoiding - overlap area calculations for fully contained faces. + Shared geometry kernel used by both zonal_mean and zonal_anomaly so the + expensive intersection calculations are never duplicated. Parameters ---------- - uxda : UxDataArray - The data array to compute zonal means for + uxgrid : Grid bands : array-like - Latitude band edges in degrees + Latitude band edges in degrees, shape (n_bands + 1,) Returns ------- - result : array - Zonal means for each band + W : ndarray, shape (n_face, n_bands) + W[f, b] is the overlap area between face f and band b. + Fully-contained faces carry their full face area; partially-overlapping + faces carry the exact intersection area. """ - import dask.array as da - - uxgrid = uxda.uxgrid - face_axis = uxda.get_axis_num("n_face") - - # Pre-compute face properties faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array( uxgrid.face_node_connectivity.values, uxgrid.n_face, @@ -263,24 +257,12 @@ def _compute_conservative_zonal_mean_bands(uxda, bands): face_areas = uxgrid.face_areas.values bands = np.asarray(bands, dtype=float) - if bands.ndim != 1 or bands.size < 2: - raise ValueError("bands must be 1D with at least two edges") - nb = bands.size - 1 - - # Initialize result array - shape = list(uxda.shape) - shape[face_axis] = nb - if isinstance(uxda.data, da.Array): - result = da.zeros(shape, dtype=uxda.dtype) - else: - result = np.zeros(shape, dtype=uxda.dtype) + W = np.zeros((uxgrid.n_face, nb), dtype=float) for bi in range(nb): lat0 = float(np.clip(bands[bi], -90.0, 90.0)) lat1 = float(np.clip(bands[bi + 1], -90.0, 90.0)) - - # Ensure lat0 <= lat1 if lat0 > lat1: lat0, lat1 = lat1, lat0 @@ -288,55 +270,153 @@ def _compute_conservative_zonal_mean_bands(uxda, bands): z1 = np.sin(np.deg2rad(lat1)) zmin, zmax = (z0, z1) if z0 <= z1 else (z1, z0) - # Step 1: Get fully contained faces - fully_contained_faces = uxgrid.get_faces_between_latitudes((lat0, lat1)) - - # Step 2: Get all overlapping faces (including partial) + fully_contained = uxgrid.get_faces_between_latitudes((lat0, lat1)) mask = ~((face_bounds_lat[:, 1] < lat0) | (face_bounds_lat[:, 0] > lat1)) - all_overlapping_faces = np.nonzero(mask)[0] + all_overlapping = np.nonzero(mask)[0] - if all_overlapping_faces.size == 0: - # No faces in this band - idx = [slice(None)] * result.ndim - idx[face_axis] = bi - result[tuple(idx)] = np.nan + if all_overlapping.size == 0: continue - # Step 3: Partition faces into fully contained vs partially overlapping - is_fully_contained = np.isin(all_overlapping_faces, fully_contained_faces) - partially_overlapping_faces = all_overlapping_faces[~is_fully_contained] - - # Step 4: Compute weights - all_weights = np.zeros(all_overlapping_faces.size, dtype=float) - - # For fully contained faces, use their full area - if fully_contained_faces.size > 0: - fully_contained_indices = np.where(is_fully_contained)[0] - all_weights[fully_contained_indices] = face_areas[fully_contained_faces] - - # For partially overlapping faces, compute fractional area - if partially_overlapping_faces.size > 0: - partial_indices = np.where(~is_fully_contained)[0] - for i, face_idx in enumerate(partially_overlapping_faces): - nedge = n_nodes_per_face[face_idx] - face_edges = faces_edge_nodes_xyz[face_idx, :nedge] - overlap_area = _compute_band_overlap_area(face_edges, zmin, zmax) - all_weights[partial_indices[i]] = overlap_area - - # Step 5: Compute weighted average - data_slice = uxda.isel(n_face=all_overlapping_faces, ignore_grid=True).data - total_weight = all_weights.sum() - - if total_weight == 0.0: - weighted = np.nan * data_slice[..., 0] - else: - w_shape = [1] * data_slice.ndim - w_shape[face_axis] = all_weights.size - w_reshaped = all_weights.reshape(w_shape) - weighted = (data_slice * w_reshaped).sum(axis=face_axis) / total_weight + is_fully_contained = np.isin(all_overlapping, fully_contained) + + fc = all_overlapping[is_fully_contained] + W[fc, bi] = face_areas[fc] + + for f in all_overlapping[~is_fully_contained]: + nedge = n_nodes_per_face[f] + W[f, bi] = _compute_band_overlap_area( + faces_edge_nodes_xyz[f, :nedge], zmin, zmax + ) + + return W + + +def _compute_conservative_zonal_mean_bands(uxda, bands): + """Compute conservative zonal mean over latitude bands. + + Parameters + ---------- + uxda : UxDataArray + bands : array-like + Latitude band edges in degrees + + Returns + ------- + result : array + Zonal means for each band, with n_face axis replaced by n_bands + """ + import dask.array as da + + bands = np.asarray(bands, dtype=float) + if bands.ndim != 1 or bands.size < 2: + raise ValueError("bands must be 1D with at least two edges") + + W = _compute_face_band_weights(uxda.uxgrid, bands) # (n_face, n_bands) + nb = W.shape[1] + face_axis = uxda.get_axis_num("n_face") + + shape = list(uxda.shape) + shape[face_axis] = nb + if isinstance(uxda.data, da.Array): + result = da.full(shape, np.nan, dtype=float) + else: + result = np.full(shape, np.nan, dtype=float) + + for bi in range(nb): + overlapping = np.nonzero(W[:, bi] > 0)[0] + if overlapping.size == 0: + continue + + w = W[overlapping, bi] + total = w.sum() + if total == 0.0: + continue + + data_slice = uxda.isel(n_face=overlapping, ignore_grid=True).data + w_shape = [1] * data_slice.ndim + w_shape[face_axis] = w.size + weighted = (data_slice * w.reshape(w_shape)).sum(axis=face_axis) / total idx = [slice(None)] * result.ndim idx[face_axis] = bi result[tuple(idx)] = weighted return result + + +def _compute_zonal_anomaly(uxda, bands, conservative=False): + """Compute zonal anomaly: each face value minus the mean of its latitude band. + + Parameters + ---------- + uxda : UxDataArray + bands : array-like + Latitude band edges in degrees + conservative : bool + If True, uses area-weighted band means and blends across bands for + faces that straddle a boundary, reusing the same weight matrix as + zonal_mean so geometry is computed only once. + If False, assigns each face to a band by centroid latitude. + + Returns + ------- + ndarray + Same shape as uxda, with the per-face band mean subtracted. + """ + bands = np.asarray(bands, dtype=float) + face_axis = uxda.get_axis_num("n_face") + n_face = uxda.uxgrid.n_face + nb = bands.size - 1 + + if conservative: + # Single geometry pass shared with zonal_mean + W = _compute_face_band_weights(uxda.uxgrid, bands) # (n_face, n_bands) + + # Band means + band_means = np.full(nb, np.nan) + for bi in range(nb): + overlapping = np.nonzero(W[:, bi] > 0)[0] + if overlapping.size == 0: + continue + w = W[overlapping, bi] + total = w.sum() + if total > 0: + vals = uxda.isel(n_face=overlapping, ignore_grid=True).values + band_means[bi] = (w * vals).sum() / total + + # Map band means back to faces; straddling faces get area-weighted blend + face_totals = W.sum(axis=1) + valid = face_totals > 0 + face_means = np.where( + valid, + np.where( + valid, + ( + W * np.where(np.isnan(band_means), 0.0, band_means)[np.newaxis, :] + ).sum(axis=1) + / np.where(valid, face_totals, 1.0), + np.nan, + ), + np.nan, + ) + else: + # Centroid-based: fast, no intersection geometry needed + face_lats = uxda.uxgrid.face_lat.values + band_indices = np.clip(np.digitize(face_lats, bands) - 1, 0, nb - 1) + + band_means = np.full(nb, np.nan) + for bi in range(nb): + mask = band_indices == bi + if mask.any(): + band_means[bi] = float( + uxda.isel( + n_face=np.nonzero(mask)[0], ignore_grid=True + ).values.mean() + ) + + face_means = band_means[band_indices] + + # Broadcast face_means to match uxda shape (face axis may not be last) + shape = [1] * uxda.ndim + shape[face_axis] = n_face + return uxda.values - face_means.reshape(shape)