Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions hrds/hrds.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class HRDS():

"""
def __init__(self, baseRaster, rasters=None, distances=None,
buffers=None, minmax=None, saveBuffers=False):
buffers=None, minmax=None, saveBuffers=False, global_data=False):
"""
Set up our hrds object

Expand All @@ -104,6 +104,7 @@ def __init__(self, baseRaster, rasters=None, distances=None,
saveBuffers: boolean to save buffers if needed
"""

self.global_data = global_data
if rasters is None:
# single raster only, check everything else is none
if (distances is not None and
Expand Down Expand Up @@ -132,9 +133,9 @@ def __init__(self, baseRaster, rasters=None, distances=None,
"and I expected: "+str(len(rasters)+1))

if minmax is None:
self.baseRaster = RasterInterpolator(baseRaster)
self.baseRaster = RasterInterpolator(baseRaster, periodic=global_data)
else:
self.baseRaster = RasterInterpolator(baseRaster, minmax[0])
self.baseRaster = RasterInterpolator(baseRaster, minmax[0], periodic=global_data)
self.raster_stack = []
if (rasters is not None):
for i, r in enumerate(rasters):
Expand Down
62 changes: 43 additions & 19 deletions hrds/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Interpolator():
may switch bands and hence have to reload the val data.
"""

def __init__(self, origin, delta, val, mask=None, minmax=None):
def __init__(self, origin, delta, val, mask=None, minmax=None, periodic=False):
"""
Init our Interpolator

Expand All @@ -73,6 +73,11 @@ def __init__(self, origin, delta, val, mask=None, minmax=None):
self.val = val
self.mask = mask
self.minmax = minmax
# Handle periodic input as a tuple/list for [X_periodic, Y_periodic]
if isinstance(periodic, bool):
self.periodic = [periodic, periodic]
else:
self.periodic = periodic

def set_mask(self, mask):
"""
Expand All @@ -96,34 +101,49 @@ def get_val(self, point):
RasterInterpolatorError: Generic error interpolating data at
that point
"""
if len(self.val.shape) != 2:
raise RasterInterpolatorError("Field to interpolate should have 2 dimensions")

shape_i, shape_j = self.val.shape


yhat = ((point[0]+(self.delta[0]/2.0)-self.origin[0])/self.delta[0])
xhat = ((point[1]+(self.delta[1]/2.0)-self.origin[1])/self.delta[1])

j = int(math.floor(yhat))-1
i = int(math.floor(xhat))-1
# this is not caught as an IndexError below, because of wrapping of
# negative indices
if i < 0 or j < 0:
raise CoordinateError("Coordinate out of range", point, i, j)
# Check for out-of-bounds on non-periodic boundaries
if not self.periodic[0]: # X axis (j)
if j < 0 or j >= shape_j - 1:
raise CoordinateError("Coordinate out of range on X axis", point, i, j)
if not self.periodic[1]: # Y axis (i)
if i < 0 or i >= shape_i - 1:
raise CoordinateError("Coordinate out of range on Y axis", point, i, j)

alpha = (xhat) % 1.0
beta = (yhat) % 1.0
neigh_i = i+1
neigh_j = j+1
if neigh_i < 0 or neigh_j < 0:
raise CoordinateError("Coordinate out of range", point, i, j)

if self.periodic[0]: # X axis
j = j % shape_j
neigh_j = (j + 1) % shape_j

if self.periodic[1]: # Y axis
i = i % shape_i
neigh_i = (i + 1) % shape_i

try:
if self.mask is not None:
# case with a land mask - masks not yet implemented!
w00 = (1.0-alpha)*(1.0-beta)*self.mask[i, j]
w10 = alpha*(1.0-beta)*self.mask[i+1, j]
w01 = (1.0-alpha)*beta*self.mask[i, j+1]
w11 = alpha*beta*self.mask[i+1, j+1]
if len(self.val.shape) == 2:
value = w00*self.val[i, j] + w10*self.val[i+1, j] \
+ w01*self.val[i, j+1] + w11*self.val[i+1, j+1]
else:
raise RasterInterpolatorError("Field to interpolate,"
"should have 2 dimensions")
value = w00*self.val[i, j] + w10*self.val[i+1, j] \
+ w01*self.val[i, j+1] + w11*self.val[i+1, j+1]
sumw = w00+w10+w01+w11

if sumw > 0.0:
Expand Down Expand Up @@ -183,7 +203,7 @@ class RasterInterpolator(object):
calls of set_band().

"""
def __init__(self, filename, minmax=None):
def __init__(self, filename, minmax=None, periodic=False):
"""
Init our RasterInterpolator

Expand All @@ -205,6 +225,7 @@ def __init__(self, filename, minmax=None):
self.dx = 0.0
self.nodata = None
self.minmax = minmax
self.periodic = periodic

def get_extent(self):
"""Return list of corner coordinates from a geotransform
Expand Down Expand Up @@ -249,7 +270,7 @@ def set_band(self, band_no=1):
transform = self.ds.GetGeoTransform()
self.dx = [transform[1], -transform[5]]
self.interpolator = Interpolator(origin, self.dx, self.val,
self.mask, self.minmax)
self.mask, self.minmax, self.periodic)

def get_array(self):
"""
Expand Down Expand Up @@ -298,11 +319,14 @@ def point_in(self, point):
Boolean. True if point is in the raster. False otherwise.
"""

# does this point occur in the raster?
# If both axes are periodic, it covers the wrapped universe
periodic_x = self.periodic if isinstance(self.periodic, bool) else self.periodic[0]
periodic_y = self.periodic if isinstance(self.periodic, bool) else self.periodic[1]

llc = np.amin(self.extent, axis=0)+(self.dx[0]/2)
urc = np.amax(self.extent, axis=0)-(self.dx[1]/2)
if ((point[0] <= urc[0] and point[0] >= llc[0]) and
(point[1] <= urc[1] and point[1] >= llc[1])):
return True
else:
return False

in_x = periodic_x or (point[0] <= urc[0] and point[0] >= llc[0])
in_y = periodic_y or (point[1] <= urc[1] and point[1] >= llc[1])

return in_x and in_y
Loading