From 84de5cf71ef8357cfe8d19e012030e3185278dcc Mon Sep 17 00:00:00 2001 From: Jon Hill Date: Fri, 12 Jun 2026 15:43:08 +0100 Subject: [PATCH] Adding experimental periodic interpolation for the main raster for global datasets --- hrds/hrds.py | 7 +++--- hrds/raster.py | 62 ++++++++++++++++++++++++++++++++++---------------- 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/hrds/hrds.py b/hrds/hrds.py index a265b11..50e3fe0 100644 --- a/hrds/hrds.py +++ b/hrds/hrds.py @@ -89,7 +89,7 @@ class HRDS(): """ def __init__(self, baseRaster, rasters=None, distances=None, - buffers=None, minmax=None, saveBuffers=False): + buffers=None, minmax=None, saveBuffers=False, global_data=False): """ Set up our hrds object @@ -104,6 +104,7 @@ def __init__(self, baseRaster, rasters=None, distances=None, saveBuffers: boolean to save buffers if needed """ + self.global_data = global_data if rasters is None: # single raster only, check everything else is none if (distances is not None and @@ -132,9 +133,9 @@ def __init__(self, baseRaster, rasters=None, distances=None, "and I expected: "+str(len(rasters)+1)) if minmax is None: - self.baseRaster = RasterInterpolator(baseRaster) + self.baseRaster = RasterInterpolator(baseRaster, periodic=global_data) else: - self.baseRaster = RasterInterpolator(baseRaster, minmax[0]) + self.baseRaster = RasterInterpolator(baseRaster, minmax[0], periodic=global_data) self.raster_stack = [] if (rasters is not None): for i, r in enumerate(rasters): diff --git a/hrds/raster.py b/hrds/raster.py index d915406..d7299ea 100644 --- a/hrds/raster.py +++ b/hrds/raster.py @@ -54,7 +54,7 @@ class Interpolator(): may switch bands and hence have to reload the val data. """ - def __init__(self, origin, delta, val, mask=None, minmax=None): + def __init__(self, origin, delta, val, mask=None, minmax=None, periodic=False): """ Init our Interpolator @@ -73,6 +73,11 @@ def __init__(self, origin, delta, val, mask=None, minmax=None): self.val = val self.mask = mask self.minmax = minmax + # Handle periodic input as a tuple/list for [X_periodic, Y_periodic] + if isinstance(periodic, bool): + self.periodic = [periodic, periodic] + else: + self.periodic = periodic def set_mask(self, mask): """ @@ -96,21 +101,40 @@ def get_val(self, point): RasterInterpolatorError: Generic error interpolating data at that point """ + if len(self.val.shape) != 2: + raise RasterInterpolatorError("Field to interpolate should have 2 dimensions") + + shape_i, shape_j = self.val.shape + yhat = ((point[0]+(self.delta[0]/2.0)-self.origin[0])/self.delta[0]) xhat = ((point[1]+(self.delta[1]/2.0)-self.origin[1])/self.delta[1]) + j = int(math.floor(yhat))-1 i = int(math.floor(xhat))-1 # this is not caught as an IndexError below, because of wrapping of # negative indices - if i < 0 or j < 0: - raise CoordinateError("Coordinate out of range", point, i, j) + # Check for out-of-bounds on non-periodic boundaries + if not self.periodic[0]: # X axis (j) + if j < 0 or j >= shape_j - 1: + raise CoordinateError("Coordinate out of range on X axis", point, i, j) + if not self.periodic[1]: # Y axis (i) + if i < 0 or i >= shape_i - 1: + raise CoordinateError("Coordinate out of range on Y axis", point, i, j) + alpha = (xhat) % 1.0 beta = (yhat) % 1.0 neigh_i = i+1 neigh_j = j+1 - if neigh_i < 0 or neigh_j < 0: - raise CoordinateError("Coordinate out of range", point, i, j) + + if self.periodic[0]: # X axis + j = j % shape_j + neigh_j = (j + 1) % shape_j + + if self.periodic[1]: # Y axis + i = i % shape_i + neigh_i = (i + 1) % shape_i + try: if self.mask is not None: # case with a land mask - masks not yet implemented! @@ -118,12 +142,8 @@ def get_val(self, point): w10 = alpha*(1.0-beta)*self.mask[i+1, j] w01 = (1.0-alpha)*beta*self.mask[i, j+1] w11 = alpha*beta*self.mask[i+1, j+1] - if len(self.val.shape) == 2: - value = w00*self.val[i, j] + w10*self.val[i+1, j] \ - + w01*self.val[i, j+1] + w11*self.val[i+1, j+1] - else: - raise RasterInterpolatorError("Field to interpolate," - "should have 2 dimensions") + value = w00*self.val[i, j] + w10*self.val[i+1, j] \ + + w01*self.val[i, j+1] + w11*self.val[i+1, j+1] sumw = w00+w10+w01+w11 if sumw > 0.0: @@ -183,7 +203,7 @@ class RasterInterpolator(object): calls of set_band(). """ - def __init__(self, filename, minmax=None): + def __init__(self, filename, minmax=None, periodic=False): """ Init our RasterInterpolator @@ -205,6 +225,7 @@ def __init__(self, filename, minmax=None): self.dx = 0.0 self.nodata = None self.minmax = minmax + self.periodic = periodic def get_extent(self): """Return list of corner coordinates from a geotransform @@ -249,7 +270,7 @@ def set_band(self, band_no=1): transform = self.ds.GetGeoTransform() self.dx = [transform[1], -transform[5]] self.interpolator = Interpolator(origin, self.dx, self.val, - self.mask, self.minmax) + self.mask, self.minmax, self.periodic) def get_array(self): """ @@ -298,11 +319,14 @@ def point_in(self, point): Boolean. True if point is in the raster. False otherwise. """ - # does this point occur in the raster? + # If both axes are periodic, it covers the wrapped universe + periodic_x = self.periodic if isinstance(self.periodic, bool) else self.periodic[0] + periodic_y = self.periodic if isinstance(self.periodic, bool) else self.periodic[1] + llc = np.amin(self.extent, axis=0)+(self.dx[0]/2) urc = np.amax(self.extent, axis=0)-(self.dx[1]/2) - if ((point[0] <= urc[0] and point[0] >= llc[0]) and - (point[1] <= urc[1] and point[1] >= llc[1])): - return True - else: - return False + + in_x = periodic_x or (point[0] <= urc[0] and point[0] >= llc[0]) + in_y = periodic_y or (point[1] <= urc[1] and point[1] >= llc[1]) + + return in_x and in_y