diff --git a/CHANGELOG.md b/CHANGELOG.md index d35d410..e016b70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.4] - 2026-06-25 + +### Added +- support for projected input rasters in the raster command +- add 'save_shift' to transform_raster api/cli + +### CHANGED +- rety failed downloads, such as FES + ## [0.4.3] - 2026-05-04 ### Added diff --git a/CITATION.cff b/CITATION.cff index 50e8bbd..0eea3be 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -14,6 +14,6 @@ authors: website: https://transformez.readthedocs.io title: "Transformez" -version: 0.4.3 -date-released: 2026-06-04 +version: 0.5.0 +date-released: 2026-06-06 url: "https://github.com/continuous-dems/transformez" diff --git a/README.md b/README.md index 3bea32a..de2cc40 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@
Global vertical datum transformations, simplified.
-
+
diff --git a/src/transformez/api.py b/src/transformez/api.py
index 47063f2..5d6df26 100644
--- a/src/transformez/api.py
+++ b/src/transformez/api.py
@@ -113,6 +113,8 @@ def generate_grid(
increment: Union[str, float],
datum_in: str,
datum_out: str,
+ epoch_in: str = "2010.0",
+ epoch_out: str = "2010.0",
decay_pixels: int = 100,
out_fn: Optional[str] = None,
cache_dir: Optional[str] = None,
@@ -126,6 +128,8 @@ def generate_grid(
increment: Resolution (e.g., '3s' or 0.0008333).
datum_in: Source datum (e.g., 'mllw', '5703').
datum_out: Target datum (e.g., '4979', '6319').
+ epoch_in: Source epoch (e.g., '2010.0')
+ epoch_out: Target epoch (e.g., '2010.0')
decay_pixels: Set the pixel decay in case extrapolation is required.
out_fn: If provided, saves the grid to this file (.tif or .gtx).
cache_dir: Path to store downloaded grids.
@@ -167,6 +171,8 @@ def generate_grid(
epsg_out=epsg_out,
geoid_in=geoid_in,
geoid_out=geoid_out,
+ epoch_in=epoch_in,
+ epoch_out=epoch_out,
decay_pixels=decay_pixels,
cache_dir=cache_dir,
use_stations=use_stations,
@@ -196,6 +202,7 @@ def transform_raster(
z_unit_in: Optional[str] = "auto",
z_unit_out: Optional[str] = "auto",
use_stations: bool = False,
+ save_shift: bool = False,
verbose: bool = False,
) -> Optional[str]:
"""Apply a vertical datum transformation directly to an existing raster file.
@@ -209,6 +216,8 @@ def transform_raster(
cache_dir: Path to store downloaded grids.
z_unit_in: Input DEM z units.
z_unit_out: Output DEM z units.
+ use_stations: Generate the shift grid from available tide stations,
+ safe_shift: Save the generated shift raster to disk.
verbose: Enable debug logging.
Returns:
@@ -216,16 +225,48 @@ def transform_raster(
"""
import rasterio
+ from rasterio.warp import transform_bounds, reproject, Resampling
+ from rasterio.transform import from_bounds
+ from fetchez.spatial import Region
if not os.path.exists(input_raster):
logger.error(f"Input raster not found: {input_raster}")
return None
with rasterio.open(input_raster) as src:
- bounds = src.bounds
- region_obj = Region(bounds.left, bounds.right, bounds.bottom, bounds.top)
+ native_crs = src.crs
+ native_bounds = src.bounds
+ native_transform = src.transform
nx, ny = src.width, src.height
+ is_projected = native_crs.is_projected if native_crs else False
+ if is_projected:
+ logger.info(
+ f"Projected CRS detected ({native_crs}). Extracting WGS84 envelope..."
+ )
+ w, s, e, n = transform_bounds(native_crs, "EPSG:4326", *native_bounds)
+
+ buffer = 0.05
+ region_obj = Region(w - buffer, e + buffer, s - buffer, n + buffer)
+ logger.info(f"Using WGS84 region: {region_obj}")
+
+ inc_deg = 3.0 / 3600.0
+ vt_nx = int((region_obj.xmax - region_obj.xmin) / inc_deg)
+ vt_ny = int((region_obj.ymax - region_obj.ymin) / inc_deg)
+ else:
+ region_obj = Region(
+ native_bounds.left,
+ native_bounds.right,
+ native_bounds.bottom,
+ native_bounds.top,
+ )
+ vt_nx, vt_ny = nx, ny
+
+ # with rasterio.open(input_raster) as src:
+ # bounds = src.bounds
+ # region_obj = Region(bounds.left, bounds.right, bounds.bottom, bounds.top)
+ # nx, ny = src.width, src.height
+
epsg_in, geoid_in = _parse_datum(datum_in)
epsg_out, geoid_out = _parse_datum(datum_out)
@@ -266,6 +307,47 @@ def transform_raster(
logger.error("Failed to generate shift array for the raster bounds.")
return None
+ if is_projected:
+ logger.info("Warping shift grid back to native raster projection...")
+ wgs_transform = from_bounds(
+ region_obj.xmin,
+ region_obj.ymin,
+ region_obj.xmax,
+ region_obj.ymax,
+ vt_nx,
+ vt_ny,
+ )
+ native_shift_array = np.zeros((ny, nx), dtype=np.float32)
+
+ reproject(
+ source=shift_array,
+ destination=native_shift_array,
+ src_transform=wgs_transform,
+ src_crs="EPSG:4326",
+ dst_transform=native_transform,
+ dst_crs=native_crs,
+ resampling=Resampling.bilinear,
+ )
+
+ shift_array = native_shift_array
+
+ if save_shift:
+ shift_fn = f"{os.path.splitext(output_raster)[0]}_shiftgrid.tif"
+ logger.info(f"Saving aligned shift grid to {shift_fn}...")
+ with rasterio.open(
+ shift_fn,
+ "w",
+ driver="GTiff",
+ height=ny,
+ width=nx,
+ count=1,
+ dtype=shift_array.dtype,
+ crs=native_crs,
+ transform=native_transform,
+ nodata=-9999.0,
+ ) as dst:
+ dst.write(shift_array, 1)
+
success = GridEngine.apply_vertical_shift(
input_raster,
shift_array,
diff --git a/src/transformez/cli.py b/src/transformez/cli.py
index c27efff..0f4211d 100644
--- a/src/transformez/cli.py
+++ b/src/transformez/cli.py
@@ -302,6 +302,11 @@ def transform_grid(
is_flag=True,
help="Force RBF interpolation using live tide stations instead of global satellite models.",
)
+@click.option(
+ "--save-shift",
+ is_flag=True,
+ help="Save the aligned vertical shift grid to disk alongside the output DEM.",
+)
def transform_raster(
input_file,
input_datum,
@@ -311,6 +316,7 @@ def transform_raster(
out,
decay_pixels,
use_stations,
+ save_shift,
):
"""Apply a vertical datum shift to an existing DEM."""
@@ -326,6 +332,7 @@ def transform_raster(
z_unit_in=in_units,
z_unit_out=out_units,
use_stations=use_stations,
+ save_shift=save_shift,
verbose=True,
)
diff --git a/src/transformez/grid_engine.py b/src/transformez/grid_engine.py
index b75ad6b..484397f 100644
--- a/src/transformez/grid_engine.py
+++ b/src/transformez/grid_engine.py
@@ -27,6 +27,12 @@
logger = logging.getLogger(__name__)
+class GridCorruptionError(Exception):
+ """Raised when a fetched grid is corrupted and needs re-downloading."""
+
+ pass
+
+
def plot_grid(grid_array, region, title="Vertical Shift Preview"):
"""Plot the transformation grid using Matplotlib."""
@@ -168,6 +174,34 @@ def load_and_interpolate(source_files, target_region, nx, ny, decay_pixels=100):
valid_mask = ~np.isnan(temp_buffer)
mosaic[valid_mask] = temp_buffer[valid_mask]
+ except Exception as e:
+ error_msg = str(e)
+
+ if any(
+ err in error_msg
+ for err in [
+ "-101",
+ "HDF error",
+ "not recognized as a supported file format",
+ "RasterioIOError",
+ ]
+ ):
+ logger.error(f" CRITICAL: Corrupted grid chunk detected in {fn}!")
+
+ real_path = fn.split(":")[1] if fn.startswith("netcdf:") else fn
+ if os.path.exists(real_path):
+ logger.warning(
+ f"Auto-deleting corrupted cache file to force re-fetch: {real_path}"
+ )
+ try:
+ os.remove(real_path)
+ except OSError:
+ pass
+
+ raise GridCorruptionError(f"Corrupted file deleted: {real_path}")
+
+ logger.exception(f"Failed to reproject {fn}: {e}")
+ continue
# except Exception as e:
# error_msg = str(e)
@@ -189,9 +223,9 @@ def load_and_interpolate(source_files, target_region, nx, ny, decay_pixels=100):
# # For all other normal errors, log and continue as usual
# logger.exception(f"Failed to reproject {fn}: {e}")
# continue
- except Exception as e:
- logger.exception(f"Failed to reproject {fn}: {e}")
- continue
+ # except Exception as e:
+ # logger.exception(f"Failed to reproject {fn}: {e}")
+ # continue
# Fill inland areas (decaying to 0) before we clear the remaining NaNs
# mosaic = GridEngine.fill_nans(mosaic, decay_pixels=decay_pixels)
diff --git a/src/transformez/transform.py b/src/transformez/transform.py
index c102f3f..4c0d965 100644
--- a/src/transformez/transform.py
+++ b/src/transformez/transform.py
@@ -207,7 +207,9 @@ def fetch_grid_(self, module_name, **kwargs):
return valid
- def _get_grid(self, provider, name):
+ def _get_grid(self, provider, name, max_retries=3):
+
+ from .grid_engine import GridCorruptionError
if not name:
return np.zeros((self.ny, self.nx))
@@ -218,95 +220,116 @@ def _get_grid(self, provider, name):
if "geoid=" in name:
name = name.split("=")[1]
- files = self.fetch_grid(provider, datatype=name, query=name)
- if provider == "vdatum":
- import rasterio
- from datetime import datetime
-
- def get_vdatum_date(gtx_path):
- """Finds and parses the release date from VDatum metadata files."""
+ for attempt in range(max_retries):
+ files = self.fetch_grid(provider, datatype=name, query=name)
+ if provider == "vdatum":
+ import rasterio
+ from datetime import datetime
- dir_name = os.path.dirname(gtx_path)
- meta_files = [
- f for f in os.listdir(dir_name) if f.endswith((".met", ".inf"))
- ]
+ def get_vdatum_date(gtx_path):
+ """Finds and parses the release date from VDatum metadata files."""
- if not meta_files:
- return datetime(1970, 1, 1)
-
- meta_path = os.path.join(dir_name, meta_files[0])
- try:
- with open(meta_path, "r") as f:
- content = f.read().splitlines()
+ dir_name = os.path.dirname(gtx_path)
+ meta_files = [
+ f for f in os.listdir(dir_name) if f.endswith((".met", ".inf"))
+ ]
- if not content:
+ if not meta_files:
return datetime(1970, 1, 1)
- first_line = content[0]
-
- # Parse the first line "#Mon Jul 08 10:27:07 EDT 2019"
- if first_line.startswith("#"):
- parts = first_line.replace("#", "").split()
- if len(parts) >= 6:
- year = int(parts[-1])
- day = int(parts[2])
- month_map = {
- "Jan": 1,
- "Feb": 2,
- "Mar": 3,
- "Apr": 4,
- "May": 5,
- "Jun": 6,
- "Jul": 7,
- "Aug": 8,
- "Sep": 9,
- "Oct": 10,
- "Nov": 11,
- "Dec": 12,
- }
- month = month_map.get(parts[1][:3].title(), 1)
- return datetime(year, month, day)
-
- # Fallback to scanning for "released_date="
- for line in content:
- if "released_date=" in line:
- date_str = line.split("=")[1].strip()
- m, d, y = map(int, date_str.split("/"))
- return datetime(y, m, d)
+ meta_path = os.path.join(dir_name, meta_files[0])
+ try:
+ with open(meta_path, "r") as f:
+ content = f.read().splitlines()
+
+ if not content:
+ return datetime(1970, 1, 1)
+
+ first_line = content[0]
+
+ # Parse the first line "#Mon Jul 08 10:27:07 EDT 2019"
+ if first_line.startswith("#"):
+ parts = first_line.replace("#", "").split()
+ if len(parts) >= 6:
+ year = int(parts[-1])
+ day = int(parts[2])
+ month_map = {
+ "Jan": 1,
+ "Feb": 2,
+ "Mar": 3,
+ "Apr": 4,
+ "May": 5,
+ "Jun": 6,
+ "Jul": 7,
+ "Aug": 8,
+ "Sep": 9,
+ "Oct": 10,
+ "Nov": 11,
+ "Dec": 12,
+ }
+ month = month_map.get(parts[1][:3].title(), 1)
+ return datetime(year, month, day)
+
+ # Fallback to scanning for "released_date="
+ for line in content:
+ if "released_date=" in line:
+ date_str = line.split("=")[1].strip()
+ m, d, y = map(int, date_str.split("/"))
+ return datetime(y, m, d)
+
+ except Exception as e:
+ logger.debug(f"Failed to parse date from {meta_path}: {e}")
- except Exception as e:
- logger.debug(f"Failed to parse date from {meta_path}: {e}")
+ return datetime(1970, 1, 1)
- return datetime(1970, 1, 1)
+ def sort_key(filepath):
+ # Time Sorting (Oldest -> Newest)
+ date_val = get_vdatum_date(filepath)
- def sort_key(filepath):
- # Time Sorting (Oldest -> Newest)
- date_val = get_vdatum_date(filepath)
+ try:
+ with rasterio.open(filepath) as src:
+ b = src.bounds
+ area = (b.right - b.left) * (b.top - b.bottom)
+ except Exception:
+ area = float("inf")
- try:
- with rasterio.open(filepath) as src:
- b = src.bounds
- area = (b.right - b.left) * (b.top - b.bottom)
- except Exception:
- area = float("inf")
+ return (date_val.timestamp(), -area)
- return (date_val.timestamp(), -area)
+ files.sort(key=sort_key, reverse=True)
- files.sort(key=sort_key, reverse=True)
+ if not files:
+ return np.zeros((self.ny, self.nx))
- if not files:
- return np.zeros((self.ny, self.nx))
+ try:
+ if provider == "seanoe" or provider == "fes":
+ var_name = (
+ "lat_elevation" if "lat" in name.lower() else "msl_elevation"
+ )
+ nc_path = f"netcdf:{files[0]}:{var_name}"
+ return GridEngine.load_and_interpolate(
+ [nc_path],
+ self.region,
+ self.nx,
+ self.ny,
+ decay_pixels=self.decay_pixels,
+ )
- if provider == "seanoe" or provider == "fes":
- var_name = "lat_elevation" if "lat" in name.lower() else "msl_elevation"
- nc_path = f"netcdf:{files[0]}:{var_name}"
- return GridEngine.load_and_interpolate(
- [nc_path], self.region, self.nx, self.ny, decay_pixels=self.decay_pixels
- )
+ return GridEngine.load_and_interpolate(
+ files, self.region, self.nx, self.ny, decay_pixels=self.decay_pixels
+ )
+ except GridCorruptionError:
+ if attempt < max_retries - 1:
+ logger.warning(
+ f"Download corruption detected. Retrying fetch (Attempt {attempt + 2}/{max_retries})..."
+ )
+ continue
+ else:
+ logger.error(
+ "Max retries reached. Could not secure an uncorrupted grid."
+ )
+ return np.zeros((self.ny, self.nx))
- return GridEngine.load_and_interpolate(
- files, self.region, self.nx, self.ny, decay_pixels=self.decay_pixels
- )
+ return np.zeros((self.ny, self.nx))
def _get_htdp_shift(self, epsg_from, epsg_to, epoch_from, epoch_to):
"""Calculate Frame Shift via HTDP with Fallback."""
diff --git a/src/transformez/utils.py b/src/transformez/utils.py
index 905eb3d..dbb5c93 100644
--- a/src/transformez/utils.py
+++ b/src/transformez/utils.py
@@ -16,6 +16,7 @@
import logging
import numpy as np
import rasterio
+import shutil
logger = logging.getLogger(__name__)
@@ -95,3 +96,41 @@ def query(self, x, y):
results[valid] = self.data[rows[valid], cols[valid]]
return results
+
+
+def export_cache(cache_dir=None, output_name="transformez_offline_cache"):
+ """Packs the local transformez cache into a ZIP file for offline use."""
+
+ if cache_dir is None:
+ cache_dir = os.path.join(os.getcwd(), "transformez_cache")
+
+ if not os.path.exists(cache_dir):
+ logger.error(f"[EXPORT FATAL] Cache directory not found at: {cache_dir}")
+ logger.error("Run a transformation to populate the cache before exporting.")
+ return None
+
+ # Determine output path
+ out_path = os.path.abspath(output_name)
+
+ logger.info("-" * 60)
+ logger.info(f"Packing offline cache bundle from: {cache_dir}")
+ logger.info(
+ "This may take a minute depending on the size of your downloaded grids..."
+ )
+
+ try:
+ # shutil.make_archive(base_name, format, root_dir)
+ zip_path = shutil.make_archive(out_path, "zip", cache_dir)
+
+ # Get human-readable file size
+ size_mb = os.path.getsize(zip_path) / (1024 * 1024)
+
+ logger.info(
+ f"Successfully exported offline cache bundle: {zip_path} ({size_mb:.1f} MB)"
+ )
+ logger.info("-" * 60)
+ return zip_path
+
+ except Exception as e:
+ logger.error(f"[EXPORT FATAL] Failed to export cache: {e}")
+ return None