From 946a062a3a7b14a9105cd2823ac812d2951ddddc Mon Sep 17 00:00:00 2001 From: Matthew Love Date: Thu, 25 Jun 2026 19:34:02 -0700 Subject: [PATCH] projected raster support, fetch retry --- CHANGELOG.md | 9 ++ CITATION.cff | 4 +- README.md | 2 +- src/transformez/api.py | 86 +++++++++++++++- src/transformez/cli.py | 7 ++ src/transformez/grid_engine.py | 40 +++++++- src/transformez/transform.py | 175 +++++++++++++++++++-------------- src/transformez/utils.py | 39 ++++++++ 8 files changed, 278 insertions(+), 84 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d35d410..e016b70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.4] - 2026-06-25 + +### Added +- support for projected input rasters in the raster command +- add 'save_shift' to transform_raster api/cli + +### CHANGED +- rety failed downloads, such as FES + ## [0.4.3] - 2026-05-04 ### Added diff --git a/CITATION.cff b/CITATION.cff index 50e8bbd..0eea3be 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -14,6 +14,6 @@ authors: website: https://transformez.readthedocs.io title: "Transformez" -version: 0.4.3 -date-released: 2026-06-04 +version: 0.5.0 +date-released: 2026-06-06 url: "https://github.com/continuous-dems/transformez" diff --git a/README.md b/README.md index 3bea32a..de2cc40 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@

Global vertical datum transformations, simplified.

- Version + Version License Python PyPI version diff --git a/src/transformez/api.py b/src/transformez/api.py index 47063f2..5d6df26 100644 --- a/src/transformez/api.py +++ b/src/transformez/api.py @@ -113,6 +113,8 @@ def generate_grid( increment: Union[str, float], datum_in: str, datum_out: str, + epoch_in: str = "2010.0", + epoch_out: str = "2010.0", decay_pixels: int = 100, out_fn: Optional[str] = None, cache_dir: Optional[str] = None, @@ -126,6 +128,8 @@ def generate_grid( increment: Resolution (e.g., '3s' or 0.0008333). datum_in: Source datum (e.g., 'mllw', '5703'). datum_out: Target datum (e.g., '4979', '6319'). + epoch_in: Source epoch (e.g., '2010.0') + epoch_out: Target epoch (e.g., '2010.0') decay_pixels: Set the pixel decay in case extrapolation is required. out_fn: If provided, saves the grid to this file (.tif or .gtx). cache_dir: Path to store downloaded grids. @@ -167,6 +171,8 @@ def generate_grid( epsg_out=epsg_out, geoid_in=geoid_in, geoid_out=geoid_out, + epoch_in=epoch_in, + epoch_out=epoch_out, decay_pixels=decay_pixels, cache_dir=cache_dir, use_stations=use_stations, @@ -196,6 +202,7 @@ def transform_raster( z_unit_in: Optional[str] = "auto", z_unit_out: Optional[str] = "auto", use_stations: bool = False, + save_shift: bool = False, verbose: bool = False, ) -> Optional[str]: """Apply a vertical datum transformation directly to an existing raster file. @@ -209,6 +216,8 @@ def transform_raster( cache_dir: Path to store downloaded grids. z_unit_in: Input DEM z units. z_unit_out: Output DEM z units. + use_stations: Generate the shift grid from available tide stations, + safe_shift: Save the generated shift raster to disk. verbose: Enable debug logging. Returns: @@ -216,16 +225,48 @@ def transform_raster( """ import rasterio + from rasterio.warp import transform_bounds, reproject, Resampling + from rasterio.transform import from_bounds + from fetchez.spatial import Region if not os.path.exists(input_raster): logger.error(f"Input raster not found: {input_raster}") return None with rasterio.open(input_raster) as src: - bounds = src.bounds - region_obj = Region(bounds.left, bounds.right, bounds.bottom, bounds.top) + native_crs = src.crs + native_bounds = src.bounds + native_transform = src.transform nx, ny = src.width, src.height + is_projected = native_crs.is_projected if native_crs else False + if is_projected: + logger.info( + f"Projected CRS detected ({native_crs}). Extracting WGS84 envelope..." + ) + w, s, e, n = transform_bounds(native_crs, "EPSG:4326", *native_bounds) + + buffer = 0.05 + region_obj = Region(w - buffer, e + buffer, s - buffer, n + buffer) + logger.info(f"Using WGS84 region: {region_obj}") + + inc_deg = 3.0 / 3600.0 + vt_nx = int((region_obj.xmax - region_obj.xmin) / inc_deg) + vt_ny = int((region_obj.ymax - region_obj.ymin) / inc_deg) + else: + region_obj = Region( + native_bounds.left, + native_bounds.right, + native_bounds.bottom, + native_bounds.top, + ) + vt_nx, vt_ny = nx, ny + + # with rasterio.open(input_raster) as src: + # bounds = src.bounds + # region_obj = Region(bounds.left, bounds.right, bounds.bottom, bounds.top) + # nx, ny = src.width, src.height + epsg_in, geoid_in = _parse_datum(datum_in) epsg_out, geoid_out = _parse_datum(datum_out) @@ -266,6 +307,47 @@ def transform_raster( logger.error("Failed to generate shift array for the raster bounds.") return None + if is_projected: + logger.info("Warping shift grid back to native raster projection...") + wgs_transform = from_bounds( + region_obj.xmin, + region_obj.ymin, + region_obj.xmax, + region_obj.ymax, + vt_nx, + vt_ny, + ) + native_shift_array = np.zeros((ny, nx), dtype=np.float32) + + reproject( + source=shift_array, + destination=native_shift_array, + src_transform=wgs_transform, + src_crs="EPSG:4326", + dst_transform=native_transform, + dst_crs=native_crs, + resampling=Resampling.bilinear, + ) + + shift_array = native_shift_array + + if save_shift: + shift_fn = f"{os.path.splitext(output_raster)[0]}_shiftgrid.tif" + logger.info(f"Saving aligned shift grid to {shift_fn}...") + with rasterio.open( + shift_fn, + "w", + driver="GTiff", + height=ny, + width=nx, + count=1, + dtype=shift_array.dtype, + crs=native_crs, + transform=native_transform, + nodata=-9999.0, + ) as dst: + dst.write(shift_array, 1) + success = GridEngine.apply_vertical_shift( input_raster, shift_array, diff --git a/src/transformez/cli.py b/src/transformez/cli.py index c27efff..0f4211d 100644 --- a/src/transformez/cli.py +++ b/src/transformez/cli.py @@ -302,6 +302,11 @@ def transform_grid( is_flag=True, help="Force RBF interpolation using live tide stations instead of global satellite models.", ) +@click.option( + "--save-shift", + is_flag=True, + help="Save the aligned vertical shift grid to disk alongside the output DEM.", +) def transform_raster( input_file, input_datum, @@ -311,6 +316,7 @@ def transform_raster( out, decay_pixels, use_stations, + save_shift, ): """Apply a vertical datum shift to an existing DEM.""" @@ -326,6 +332,7 @@ def transform_raster( z_unit_in=in_units, z_unit_out=out_units, use_stations=use_stations, + save_shift=save_shift, verbose=True, ) diff --git a/src/transformez/grid_engine.py b/src/transformez/grid_engine.py index b75ad6b..484397f 100644 --- a/src/transformez/grid_engine.py +++ b/src/transformez/grid_engine.py @@ -27,6 +27,12 @@ logger = logging.getLogger(__name__) +class GridCorruptionError(Exception): + """Raised when a fetched grid is corrupted and needs re-downloading.""" + + pass + + def plot_grid(grid_array, region, title="Vertical Shift Preview"): """Plot the transformation grid using Matplotlib.""" @@ -168,6 +174,34 @@ def load_and_interpolate(source_files, target_region, nx, ny, decay_pixels=100): valid_mask = ~np.isnan(temp_buffer) mosaic[valid_mask] = temp_buffer[valid_mask] + except Exception as e: + error_msg = str(e) + + if any( + err in error_msg + for err in [ + "-101", + "HDF error", + "not recognized as a supported file format", + "RasterioIOError", + ] + ): + logger.error(f" CRITICAL: Corrupted grid chunk detected in {fn}!") + + real_path = fn.split(":")[1] if fn.startswith("netcdf:") else fn + if os.path.exists(real_path): + logger.warning( + f"Auto-deleting corrupted cache file to force re-fetch: {real_path}" + ) + try: + os.remove(real_path) + except OSError: + pass + + raise GridCorruptionError(f"Corrupted file deleted: {real_path}") + + logger.exception(f"Failed to reproject {fn}: {e}") + continue # except Exception as e: # error_msg = str(e) @@ -189,9 +223,9 @@ def load_and_interpolate(source_files, target_region, nx, ny, decay_pixels=100): # # For all other normal errors, log and continue as usual # logger.exception(f"Failed to reproject {fn}: {e}") # continue - except Exception as e: - logger.exception(f"Failed to reproject {fn}: {e}") - continue + # except Exception as e: + # logger.exception(f"Failed to reproject {fn}: {e}") + # continue # Fill inland areas (decaying to 0) before we clear the remaining NaNs # mosaic = GridEngine.fill_nans(mosaic, decay_pixels=decay_pixels) diff --git a/src/transformez/transform.py b/src/transformez/transform.py index c102f3f..4c0d965 100644 --- a/src/transformez/transform.py +++ b/src/transformez/transform.py @@ -207,7 +207,9 @@ def fetch_grid_(self, module_name, **kwargs): return valid - def _get_grid(self, provider, name): + def _get_grid(self, provider, name, max_retries=3): + + from .grid_engine import GridCorruptionError if not name: return np.zeros((self.ny, self.nx)) @@ -218,95 +220,116 @@ def _get_grid(self, provider, name): if "geoid=" in name: name = name.split("=")[1] - files = self.fetch_grid(provider, datatype=name, query=name) - if provider == "vdatum": - import rasterio - from datetime import datetime - - def get_vdatum_date(gtx_path): - """Finds and parses the release date from VDatum metadata files.""" + for attempt in range(max_retries): + files = self.fetch_grid(provider, datatype=name, query=name) + if provider == "vdatum": + import rasterio + from datetime import datetime - dir_name = os.path.dirname(gtx_path) - meta_files = [ - f for f in os.listdir(dir_name) if f.endswith((".met", ".inf")) - ] + def get_vdatum_date(gtx_path): + """Finds and parses the release date from VDatum metadata files.""" - if not meta_files: - return datetime(1970, 1, 1) - - meta_path = os.path.join(dir_name, meta_files[0]) - try: - with open(meta_path, "r") as f: - content = f.read().splitlines() + dir_name = os.path.dirname(gtx_path) + meta_files = [ + f for f in os.listdir(dir_name) if f.endswith((".met", ".inf")) + ] - if not content: + if not meta_files: return datetime(1970, 1, 1) - first_line = content[0] - - # Parse the first line "#Mon Jul 08 10:27:07 EDT 2019" - if first_line.startswith("#"): - parts = first_line.replace("#", "").split() - if len(parts) >= 6: - year = int(parts[-1]) - day = int(parts[2]) - month_map = { - "Jan": 1, - "Feb": 2, - "Mar": 3, - "Apr": 4, - "May": 5, - "Jun": 6, - "Jul": 7, - "Aug": 8, - "Sep": 9, - "Oct": 10, - "Nov": 11, - "Dec": 12, - } - month = month_map.get(parts[1][:3].title(), 1) - return datetime(year, month, day) - - # Fallback to scanning for "released_date=" - for line in content: - if "released_date=" in line: - date_str = line.split("=")[1].strip() - m, d, y = map(int, date_str.split("/")) - return datetime(y, m, d) + meta_path = os.path.join(dir_name, meta_files[0]) + try: + with open(meta_path, "r") as f: + content = f.read().splitlines() + + if not content: + return datetime(1970, 1, 1) + + first_line = content[0] + + # Parse the first line "#Mon Jul 08 10:27:07 EDT 2019" + if first_line.startswith("#"): + parts = first_line.replace("#", "").split() + if len(parts) >= 6: + year = int(parts[-1]) + day = int(parts[2]) + month_map = { + "Jan": 1, + "Feb": 2, + "Mar": 3, + "Apr": 4, + "May": 5, + "Jun": 6, + "Jul": 7, + "Aug": 8, + "Sep": 9, + "Oct": 10, + "Nov": 11, + "Dec": 12, + } + month = month_map.get(parts[1][:3].title(), 1) + return datetime(year, month, day) + + # Fallback to scanning for "released_date=" + for line in content: + if "released_date=" in line: + date_str = line.split("=")[1].strip() + m, d, y = map(int, date_str.split("/")) + return datetime(y, m, d) + + except Exception as e: + logger.debug(f"Failed to parse date from {meta_path}: {e}") - except Exception as e: - logger.debug(f"Failed to parse date from {meta_path}: {e}") + return datetime(1970, 1, 1) - return datetime(1970, 1, 1) + def sort_key(filepath): + # Time Sorting (Oldest -> Newest) + date_val = get_vdatum_date(filepath) - def sort_key(filepath): - # Time Sorting (Oldest -> Newest) - date_val = get_vdatum_date(filepath) + try: + with rasterio.open(filepath) as src: + b = src.bounds + area = (b.right - b.left) * (b.top - b.bottom) + except Exception: + area = float("inf") - try: - with rasterio.open(filepath) as src: - b = src.bounds - area = (b.right - b.left) * (b.top - b.bottom) - except Exception: - area = float("inf") + return (date_val.timestamp(), -area) - return (date_val.timestamp(), -area) + files.sort(key=sort_key, reverse=True) - files.sort(key=sort_key, reverse=True) + if not files: + return np.zeros((self.ny, self.nx)) - if not files: - return np.zeros((self.ny, self.nx)) + try: + if provider == "seanoe" or provider == "fes": + var_name = ( + "lat_elevation" if "lat" in name.lower() else "msl_elevation" + ) + nc_path = f"netcdf:{files[0]}:{var_name}" + return GridEngine.load_and_interpolate( + [nc_path], + self.region, + self.nx, + self.ny, + decay_pixels=self.decay_pixels, + ) - if provider == "seanoe" or provider == "fes": - var_name = "lat_elevation" if "lat" in name.lower() else "msl_elevation" - nc_path = f"netcdf:{files[0]}:{var_name}" - return GridEngine.load_and_interpolate( - [nc_path], self.region, self.nx, self.ny, decay_pixels=self.decay_pixels - ) + return GridEngine.load_and_interpolate( + files, self.region, self.nx, self.ny, decay_pixels=self.decay_pixels + ) + except GridCorruptionError: + if attempt < max_retries - 1: + logger.warning( + f"Download corruption detected. Retrying fetch (Attempt {attempt + 2}/{max_retries})..." + ) + continue + else: + logger.error( + "Max retries reached. Could not secure an uncorrupted grid." + ) + return np.zeros((self.ny, self.nx)) - return GridEngine.load_and_interpolate( - files, self.region, self.nx, self.ny, decay_pixels=self.decay_pixels - ) + return np.zeros((self.ny, self.nx)) def _get_htdp_shift(self, epsg_from, epsg_to, epoch_from, epoch_to): """Calculate Frame Shift via HTDP with Fallback.""" diff --git a/src/transformez/utils.py b/src/transformez/utils.py index 905eb3d..dbb5c93 100644 --- a/src/transformez/utils.py +++ b/src/transformez/utils.py @@ -16,6 +16,7 @@ import logging import numpy as np import rasterio +import shutil logger = logging.getLogger(__name__) @@ -95,3 +96,41 @@ def query(self, x, y): results[valid] = self.data[rows[valid], cols[valid]] return results + + +def export_cache(cache_dir=None, output_name="transformez_offline_cache"): + """Packs the local transformez cache into a ZIP file for offline use.""" + + if cache_dir is None: + cache_dir = os.path.join(os.getcwd(), "transformez_cache") + + if not os.path.exists(cache_dir): + logger.error(f"[EXPORT FATAL] Cache directory not found at: {cache_dir}") + logger.error("Run a transformation to populate the cache before exporting.") + return None + + # Determine output path + out_path = os.path.abspath(output_name) + + logger.info("-" * 60) + logger.info(f"Packing offline cache bundle from: {cache_dir}") + logger.info( + "This may take a minute depending on the size of your downloaded grids..." + ) + + try: + # shutil.make_archive(base_name, format, root_dir) + zip_path = shutil.make_archive(out_path, "zip", cache_dir) + + # Get human-readable file size + size_mb = os.path.getsize(zip_path) / (1024 * 1024) + + logger.info( + f"Successfully exported offline cache bundle: {zip_path} ({size_mb:.1f} MB)" + ) + logger.info("-" * 60) + return zip_path + + except Exception as e: + logger.error(f"[EXPORT FATAL] Failed to export cache: {e}") + return None