diff --git a/examples/13_mps/channel_demo.py b/examples/13_mps/channel_demo.py new file mode 100644 index 00000000..e8dadfdf --- /dev/null +++ b/examples/13_mps/channel_demo.py @@ -0,0 +1,81 @@ +""" +Direct Sampling demo — Strebelle (2002) channelized fluvial TI. +Adapted for GSTools. +""" + +import os +import urllib.request + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import ListedColormap + +from gstools import mps + +# 1. Load TI +TI_URL = ( + "https://raw.githubusercontent.com/GeostatsGuy/" + "GeoDataSets/master/MPS_Training_image_and_Realizations_500.npz" +) +CACHE = "mps_strebelle.npz" +if not os.path.exists(CACHE): + print("Downloading Strebelle TI ...") + urllib.request.urlretrieve(TI_URL, CACHE) + +ti_arr = np.load(CACHE)["array1"].astype(int) # (256, 256) +ti_model = mps.TrainingImage(ti_arr, categorical=True) +print(f"TI shape: {ti_model.shape} sand={ti_arr.mean():.3f}") + +# 2. Conditioning: 100 random hard-data points from the TI +SG_SIZE = 100 +N_COND = 100 +rng = np.random.default_rng(0) +cond_row = rng.integers(0, SG_SIZE, N_COND) +cond_col = rng.integers(0, SG_SIZE, N_COND) +cond_pos = [cond_row.astype(float), cond_col.astype(float)] +cond_val = ti_arr[cond_row, cond_col].astype(float) +print(f"Conditioning: {N_COND} pts sand={cond_val.mean():.3f}") + +# 3. Simulate (n=30, f=1.0, t=0.01) +N_NEIGH = 30 +SCAN_F = 1.0 +THRESH = 0.01 + +ds = mps.DirectSampling( + ti_model, n_neighbors=N_NEIGH, scan_fraction=SCAN_F, threshold=THRESH +) +ds.set_condition(cond_pos, cond_val) + +print("Starting simulation (this may take a moment)...") +x = np.arange(SG_SIZE, dtype=float) +y = np.arange(SG_SIZE, dtype=float) +sg = ds([x, y], seed=42).astype(int) + +honored = (sg[cond_row, cond_col] == cond_val.astype(int)).sum() +print(f"Simulation complete. Conditioning: {honored}/{N_COND} honored") + +# 4. Plot +cmap = ListedColormap(["#c9a96e", "#2b6cb0"]) +fig, ax = plt.subplots(1, 2, figsize=(12, 6)) + +ax[0].imshow(ti_arr[:SG_SIZE, :SG_SIZE], cmap=cmap, origin="upper") +ax[0].set_title("Training Image (Crop)") + +im = ax[1].imshow(sg, cmap=cmap, origin="upper") +ax[1].scatter( + cond_col, + cond_row, + c=cond_val, + cmap=cmap, + edgecolors="k", + s=20, + label="Cond. Points", +) +ax[1].set_title(f"DS Realization (n={N_NEIGH}, f={SCAN_F}, t={THRESH})") +ax[1].legend() + +plt.colorbar(im, ax=ax.ravel().tolist(), ticks=[0.25, 0.75]).set_ticklabels( + ["shale", "sand"] +) +plt.savefig("channel_demo_mps.png", dpi=150, bbox_inches="tight") +print("Saved plot to channel_demo_mps.png") diff --git a/src/gstools/__init__.py b/src/gstools/__init__.py index 4d12007c..67c78bad 100644 --- a/src/gstools/__init__.py +++ b/src/gstools/__init__.py @@ -23,10 +23,21 @@ tools transform normalizer + mps Classes ======= +Multiple Point Statistics +^^^^^^^^^^^^^^^^^^^^^^^^ +Classes for Multiple Point Statistics (MPS) simulations + +.. currentmodule:: gstools.mps + +.. autosummary:: + DirectSampling + TrainingImage + Kriging ^^^^^^^ Swiss-Army-Knife for Kriging. For short cut classes see: :any:`gstools.krige` @@ -139,6 +150,7 @@ covmodel, field, krige, + mps, normalizer, random, tools, @@ -169,6 +181,7 @@ ) from gstools.field import PGS, SRF, CondSRF from gstools.krige import Krige +from gstools.mps import DirectSampling, TrainingImage from gstools.tools import ( DEGREE_SCALE, EARTH_RADIUS, @@ -200,7 +213,7 @@ __all__ = ["__version__"] __all__ += ["covmodel", "field", "variogram", "krige", "random", "tools"] -__all__ += ["transform", "normalizer", "config"] +__all__ += ["transform", "normalizer", "config", "mps"] __all__ += [ "CovModel", "SumModel", @@ -237,6 +250,8 @@ "SRF", "CondSRF", "PGS", + "DirectSampling", + "TrainingImage", "rotated_main_axes", "generate_grid", "generate_st_grid", diff --git a/src/gstools/mps/__init__.py b/src/gstools/mps/__init__.py new file mode 100644 index 00000000..677456af --- /dev/null +++ b/src/gstools/mps/__init__.py @@ -0,0 +1,18 @@ +""" +GStools subpackage for Multiple Point Statistics (MPS). + +.. currentmodule:: gstools.mps + +Multiple Point Statistics +^^^^^^^^^^^^^^^^^^^^^^^^ +.. autosummary:: + :toctree: + + DirectSampling + TrainingImage +""" + +from gstools.mps.direct_sampling import DirectSampling +from gstools.mps.training_image import TrainingImage + +__all__ = ["DirectSampling", "TrainingImage"] diff --git a/src/gstools/mps/direct_sampling.py b/src/gstools/mps/direct_sampling.py new file mode 100644 index 00000000..c0b52f13 --- /dev/null +++ b/src/gstools/mps/direct_sampling.py @@ -0,0 +1,606 @@ +""" +GStools subpackage providing the Direct Sampling MPS simulation class. + +.. currentmodule:: gstools.mps + +The following classes and functions are provided + +.. autosummary:: + DirectSampling +""" + +from concurrent.futures import ThreadPoolExecutor + +import numpy as np + +from gstools import config +from gstools.field.base import Field +from gstools.random.rng import RNG + +__all__ = ["DirectSampling"] + +_VALID_BOUNDARY = ("strict", "partial") + + +def _precompute_offsets(shape, max_offset=None): + """Neighbour offsets from the origin, sorted by Euclidean distance. + + Parameters + ---------- + shape : tuple + Simulation grid shape. + max_offset : int, optional + Maximum offset in any dimension. + Default: ``max(shape)``. + + Returns + ------- + numpy.ndarray, shape (N, dim) + """ + dim = len(shape) + if max_offset is None: + max_offset = max(shape) + rng_vals = np.arange(-max_offset, max_offset + 1) + grid = np.array(np.meshgrid(*[rng_vals] * dim, indexing="ij")) + offsets = grid.reshape(dim, -1).T + offsets = offsets[np.any(offsets != 0, axis=1)] + idx = np.argsort(np.sum(offsets**2, axis=1)) + return offsets[idx] + + +def _build_dag( + path, + n_neighbors, + sim_shape, + offset_arr, + informed_init, + path_pos_map, + max_radius=None, +): + N = len(path) + sim_shape_arr = np.array(sim_shape) + + informed = informed_init.copy() + indegree = np.zeros(N, dtype=np.int32) + out_edges = [[] for _ in range(N)] + + for i, x_i in enumerate(path): + found = 0 + for offset in offset_arr: + if found >= n_neighbors: + break + if ( + max_radius is not None + and np.linalg.norm(offset.astype(float)) > max_radius + ): + break # offset_arr is distance-sorted; all remaining also exceed max_radius + nb = x_i + offset + if np.any(nb < 0) or np.any(nb >= sim_shape_arr): + continue + if not informed[tuple(nb)]: + continue + found += 1 + j = path_pos_map[int(np.ravel_multi_index(tuple(nb), sim_shape))] + if j >= 0: # conditioning nodes have j == -1; skip them + indegree[i] += 1 + out_edges[j].append(i) + informed[tuple(x_i)] = True # mark after processing, not before + + return indegree, out_edges + + +def ds_simulate( + training_image, + sim_shape, + n_neighbors, + threshold, + scan_fraction, + seed, + conditions=None, + cond_weight=1.0, + boundary="strict", + max_radius=None, + num_threads=None, +): + """Direct Sampling univariate simulation (Mariethoz2010, Juda2022). + + Parameters + ---------- + training_image : TrainingImage + Training image; provides ``training_image.distance()`` and + ``training_image.adjust_value()``. + sim_shape : tuple + Simulation grid shape. + n_neighbors : int + Maximum number of neighbours in the data event (Juda2022 §2). + threshold : float + Distance threshold for early acceptance (Juda2022 §2). + ``0.0`` → DSBC mode. + scan_fraction : float + Fraction of the per-node search window to scan (Mariethoz2010 §3 ¶24). + Evaluates at most ``floor(f · |window|)`` candidates per node. + ``1.0`` → full window scan. + seed : int + RNG seed. + conditions : dict, optional + ``{tuple_index: value}`` mapping of conditioning data. + cond_weight : float, optional + Weight δ for conditioning nodes (Mariethoz2010 §3 ¶26). + boundary : str, optional + Search-window strategy: ``"strict"`` (default) or ``"partial"``. + max_radius : float, optional + If set, SG neighbours beyond this Euclidean distance are excluded + from the data event (Mariethoz2010 §3 ¶19). + + Returns + ------- + numpy.ndarray + """ + rng = np.random.default_rng(seed) + ti_data = training_image.data + ti_shape = np.array(ti_data.shape) + sim_shape_arr = np.array(sim_shape) + sg = np.full(sim_shape, np.nan) + is_cond = np.zeros(sim_shape, dtype=bool) + informed = np.zeros(sim_shape, dtype=bool) + + if conditions: + for idx, val in conditions.items(): + sg[idx] = val + is_cond[idx] = True + informed[idx] = True + + n_threads = ( + num_threads if num_threads is not None else (config.NUM_THREADS or 1) + ) + executor = ( + ThreadPoolExecutor(max_workers=n_threads) if n_threads > 1 else None + ) + use_parallel_outer = executor is not None + + max_off_int = int(np.ceil(max_radius)) if max_radius is not None else None + offset_arr = _precompute_offsets(sim_shape, max_off_int) + + path = np.argwhere(np.isnan(sg)) + path = path[rng.permutation(len(path))] + node_seeds = rng.integers(0, 2**63, size=len(path)) + + path_flat = np.ravel_multi_index(path.T, sim_shape) + path_pos_map = np.full(int(np.prod(sim_shape)), -1, dtype=np.intp) + path_pos_map[path_flat] = np.arange(len(path_flat)) + + def _rand_ti(node_rng): + return ti_data[tuple(node_rng.integers(0, s) for s in ti_shape)] + + def _get_neighbors(x_i, informed_in): + cands = x_i + offset_arr + valid = cands[np.all((cands >= 0) & (cands < sim_shape_arr), axis=1)] + valid = valid[informed_in[tuple(valid.T)]] + + curr_idx = path_pos_map[ + int(np.ravel_multi_index(tuple(x_i), sim_shape)) + ] + valid_idx = path_pos_map[np.ravel_multi_index(valid.T, sim_shape)] + valid = valid[(valid_idx < curr_idx) | (valid_idx == -1)] + + if max_radius is not None: + dists = np.linalg.norm((valid - x_i).astype(np.float64), axis=1) + valid = valid[dists <= max_radius] + return valid[:n_neighbors] + + def _simulate_node(x_i, node_rng, sg_in, informed_in): + nbrs = _get_neighbors(x_i, informed_in) + if len(nbrs) == 0: + return _rand_ti(node_rng) + + lags = (nbrs - x_i).astype(np.float64) # (k, dim) + data_event_sim = sg_in[tuple(nbrs.T)] # (k,) + cond_mask = is_cond[tuple(nbrs.T)] # (k,) + lag_norms = np.linalg.norm(lags, axis=1) # (k,) + + best_d, best_v = np.inf, None + + if boundary == "strict": + # Search window Y(L_i) — Juda2022 Eq. 5, Mariethoz2010 §3 ¶19 + win_lo = np.maximum(0, np.ceil(-lags.min(axis=0))).astype(int) + win_hi = np.minimum( + ti_shape - 1, np.floor(ti_shape - 1 - lags.max(axis=0)) + ).astype(int) + if np.any(win_lo > win_hi): + return _rand_ti(node_rng) + + win_shape = tuple(win_hi - win_lo + 1) + win_size = int(np.prod(win_shape)) + max_scan = max(1, int(scan_fraction * win_size)) + start = int(node_rng.integers(0, win_size)) + + best_data_event_ti = None + for k in range(max_scan): + y = win_lo + np.array( + np.unravel_index((start + k) % win_size, win_shape) + ) + ti_coords = np.round(y + lags).astype(int) + data_event_ti = ti_data[tuple(ti_coords.T)] + dist_val = training_image.distance( + data_event_sim, + data_event_ti, + cond_mask, + cond_weight, + lag_norms, + ) + if dist_val < best_d: + best_d, best_v, best_data_event_ti = ( + dist_val, + ti_data[tuple(y)], + data_event_ti, + ) + if dist_val <= threshold: + break + + return training_image.adjust_value( + best_v, data_event_sim, best_data_event_ti + ) + + else: # "partial" — Mariethoz2010 §6.2: global template reduction + # Drop lags permanently outside TI (|h[d]| >= ti_shape[d] for any d). + # These can never be satisfied for any anchor y; dropping them restores + # a valid strict window when rotation/affinity stretches the template. + placeable = np.all(np.abs(lags) < ti_shape, axis=1) + if not np.any(placeable): + return _rand_ti(node_rng) + lags_p = lags[placeable] + de_sg_p = data_event_sim[placeable] + cm_p = cond_mask[placeable] + ln_p = lag_norms[placeable] + + sw_lo = np.maximum(0, np.ceil(-lags_p.min(axis=0))).astype(int) + sw_hi = np.minimum( + ti_shape - 1, np.floor(ti_shape - 1 - lags_p.max(axis=0)) + ).astype(int) + if np.any(sw_lo > sw_hi): + return _rand_ti(node_rng) + + sw_shape = tuple(sw_hi - sw_lo + 1) + sw_size = int(np.prod(sw_shape)) + max_scan = max(1, int(scan_fraction * sw_size)) + start = int(node_rng.integers(0, sw_size)) + + best_data_event_ti_p = None + for k in range(max_scan): + y = sw_lo + np.array( + np.unravel_index((start + k) % sw_size, sw_shape) + ) + ti_coords = np.round(y + lags_p).astype(int) + data_event_ti_p = ti_data[tuple(ti_coords.T)] + dist_val = training_image.distance( + de_sg_p, data_event_ti_p, cm_p, cond_weight, ln_p + ) + if dist_val < best_d: + best_d, best_v, best_data_event_ti_p = ( + dist_val, + ti_data[tuple(y)], + data_event_ti_p, + ) + if dist_val <= threshold: + break + + return training_image.adjust_value( + best_v, de_sg_p, best_data_event_ti_p + ) + + try: + if use_parallel_outer: + indegree, out_edges = _build_dag( + path, + n_neighbors, + sim_shape, + offset_arr, + informed, + path_pos_map, + max_radius, + ) + ready = [i for i in range(len(path)) if indegree[i] == 0] + + while ready: + batch = ready + ready = [] + sg_snap = sg.copy() + informed_snap = informed.copy() + + futures = [ + executor.submit( + _simulate_node, + path[i], + np.random.default_rng(int(node_seeds[i])), + sg_snap, + informed_snap, + ) + for i in batch + ] + for i, fut in zip(batch, futures): + val = fut.result() + x_i_t = tuple(path[i]) + if np.isnan(val): + raise ValueError( + f"Simulation produced NaN at {path[i]}. Check TI data." + ) + sg[x_i_t] = val + informed[x_i_t] = True + for j in out_edges[i]: + indegree[j] -= 1 + if indegree[j] == 0: + ready.append(j) + else: + for i, x_i in enumerate(path): + x_i_t = tuple(x_i) + val = _simulate_node( + x_i, + np.random.default_rng(int(node_seeds[i])), + sg, + informed, + ) + if np.isnan(val): + raise ValueError( + f"Simulation produced NaN at {x_i}. Check TI data." + ) + sg[x_i_t] = val + informed[x_i_t] = True + finally: + if executor is not None: + executor.shutdown(wait=True) + + return sg + + +class DirectSampling(Field): + """Multiple Point Statistics simulation using Direct Sampling. + + Subclasses :class:`gstools.field.base.Field`. Takes a :class:`TrainingImage` + (analogous to :class:`CovModel`) and produces fields on structured grids. + + Parameters + ---------- + ti : TrainingImage + The training image (the MPS model). + n_neighbors : int, optional + Maximum neighbors in data event. Default: 32. + scan_fraction : float, optional + Fraction of the per-node search window to scan. Default: 1. + threshold : float, optional + Distance threshold. 0.0 -> DSBC mode. Default: 0.0. + cond_weight : float, optional + Weight for conditioning nodes in distance. Default: 1.0. + boundary : str, optional + Search-window strategy: ``"strict"`` (default) or ``"partial"``. + max_radius : float, optional + Exclude SG neighbours beyond this Euclidean distance from the + data event. Default: ``None`` (no limit). + seed : int or nan, optional + Master RNG seed. Default: nan. + """ + + default_field_names = ["field"] + + def __init__( + self, + ti, + n_neighbors=32, + scan_fraction=1, + threshold=0.0, + cond_weight=1.0, + boundary="strict", + max_radius=None, + num_threads=None, + seed=np.nan, + ): + if boundary not in _VALID_BOUNDARY: + raise ValueError( + f"DirectSampling: boundary must be one of {_VALID_BOUNDARY!r}, " + f"got {boundary!r}" + ) + if max_radius is not None and float(max_radius) <= 0: + raise ValueError( + f"DirectSampling: max_radius must be a positive float, " + f"got {max_radius!r}" + ) + super().__init__(model=None, dim=ti.ndim, value_type="scalar") + self._ti = ti + self._n_neighbors = int(n_neighbors) + self._scan_fraction = float(scan_fraction) + self._threshold = float(threshold) + self._cond_weight = float(cond_weight) + self._boundary = boundary + self._max_radius = ( + float(max_radius) if max_radius is not None else None + ) + self._num_threads = num_threads + self._cond_pos = None + self._cond_val = None + self.rng = RNG(None if np.isnan(seed) else int(seed)) + + def __call__( + self, + pos=None, + seed=np.nan, + mesh_type="structured", + post_process=True, + store=True, + ): + """Generate the spatial random field via Direct Sampling. + + The field is saved as ``self.field`` and is also returned. + + Parameters + ---------- + pos : :class:`list`, optional + The position tuple, containing main direction and transversal + directions. Only structured grids are supported. + seed : :class:`int`, optional + Seed for the RNG. If ``np.nan``, the current seed is kept. + Default: ``np.nan`` + mesh_type : :class:`str`, optional + Grid type. Must be ``"structured"``. + Default: ``"structured"`` + post_process : :class:`bool`, optional + Whether to apply post-processing transformations (mean, + normalizer, trend) to the field. Default: :any:`True` + store : :class:`bool` or :class:`str`, optional + Whether to store the field (``True``), not store it (``False``), + or store it under a custom name (string). + Default: :any:`True` + + Returns + ------- + field : :class:`numpy.ndarray` + The simulated field. + """ + if mesh_type != "structured": + raise ValueError( + "DirectSampling: only structured grids are supported." + ) + name, save = self.get_store_config(store) + pos, shape = self.pre_pos(pos, mesh_type) + conditions = self._conditions_to_grid(self.pos) + if not np.isnan(seed): + self.rng.seed = int(seed) + iseed = int(self.rng.random.randint(0, 2**31)) + field = ds_simulate( + training_image=self._ti, + sim_shape=shape, + n_neighbors=self._n_neighbors, + threshold=self._threshold, + scan_fraction=self._scan_fraction, + seed=iseed, + conditions=conditions, + cond_weight=self._cond_weight, + boundary=self._boundary, + max_radius=self._max_radius, + num_threads=self._num_threads, + ) + return self.post_field(field, name, post_process, save) + + def _conditions_to_grid(self, axes): + """Smart snapping: Mariethoz 2010 collision rule.""" + if self._cond_pos is None: + return {} + candidates = {} # idx -> (val, dist_sq) + for k in range(self._cond_val.shape[0]): + idx = tuple( + int(np.argmin(np.abs(axes[d] - self._cond_pos[d][k]))) + for d in range(self.dim) + ) + dist_sq = sum( + (axes[d][idx[d]] - self._cond_pos[d][k]) ** 2 + for d in range(self.dim) + ) + if idx not in candidates or dist_sq < candidates[idx][1]: + candidates[idx] = (self._cond_val[k], dist_sq) + return {idx: val for idx, (val, _) in candidates.items()} + + def set_condition(self, cond_pos, cond_val, cond_weight=None): + """Set the conditioning data for the simulation. + + Parameters + ---------- + cond_pos : :class:`list` + The position tuple of the conditioning data ``(x, [y, z])``. + cond_val : :class:`numpy.ndarray` + The values at the conditioning positions. + cond_weight : :class:`float`, optional + Conditioning weight δ. If given, overrides the ``cond_weight`` + set at construction. Default: :any:`None` (keep existing weight) + """ + from gstools.krige.tools import set_condition as _gs_set_condition + + self._cond_pos, self._cond_val = _gs_set_condition( + cond_pos, cond_val, self.dim + ) + if cond_weight is not None: + self._cond_weight = float(cond_weight) + + @property + def ti(self): + """TrainingImage: The training image model.""" + return self._ti + + @property + def n_neighbors(self): + """:class:`int`: Maximum neighbours in the data event.""" + return self._n_neighbors + + @n_neighbors.setter + def n_neighbors(self, value): + self._n_neighbors = int(value) + + @property + def scan_fraction(self): + """:class:`float`: Fraction of the per-node search window to scan.""" + return self._scan_fraction + + @scan_fraction.setter + def scan_fraction(self, value): + self._scan_fraction = float(value) + + @property + def threshold(self): + """:class:`float`: Distance threshold (0.0 → DSBC mode).""" + return self._threshold + + @threshold.setter + def threshold(self, value): + self._threshold = float(value) + + @property + def cond_weight(self): + """:class:`float`: Weight for conditioning nodes in distance.""" + return self._cond_weight + + @cond_weight.setter + def cond_weight(self, value): + self._cond_weight = float(value) + + @property + def boundary(self): + """:class:`str`: Search-window strategy (``"strict"`` or ``"partial"``).""" + return self._boundary + + @boundary.setter + def boundary(self, value): + if value not in _VALID_BOUNDARY: + raise ValueError( + f"DirectSampling: boundary must be one of {_VALID_BOUNDARY!r}, " + f"got {value!r}" + ) + self._boundary = value + + @property + def max_radius(self): + """:class:`float` or :any:`None`: Euclidean cap on SG neighbour selection.""" + return self._max_radius + + @max_radius.setter + def max_radius(self, value): + if value is not None and float(value) <= 0: + raise ValueError( + f"DirectSampling: max_radius must be a positive float, " + f"got {value!r}" + ) + self._max_radius = float(value) if value is not None else None + + @property + def num_threads(self): + """:class:`int` or :any:`None`: Number of threads for outer DAG parallelism.""" + return self._num_threads + + @num_threads.setter + def num_threads(self, value): + self._num_threads = None if value is None else int(value) + + def __repr__(self): + return ( + f"DirectSampling(dim={self.dim}, " + f"n_neighbors={self.n_neighbors}, " + f"scan_fraction={self.scan_fraction}, " + f"threshold={self.threshold}, " + f"boundary={self.boundary!r})" + ) diff --git a/src/gstools/mps/distance.py b/src/gstools/mps/distance.py new file mode 100644 index 00000000..b0c86d10 --- /dev/null +++ b/src/gstools/mps/distance.py @@ -0,0 +1,180 @@ +"""Pure distance functions for MPS pattern comparison. + +No class state — takes arrays and scalars, returns floats. +``TrainingImage.distance()`` uses these internally; other algorithms +can import them directly. +""" + +import numpy as np + +__all__ = [ + "compute_node_weights", + "categorical_dist", + "l1_dist", + "l2_dist", + "lp_dist", + "variation_dist", +] + + +def compute_node_weights( + n, lag_norms, distance_power, cond_mask=None, cond_weight=1.0 +): + """Compute normalized spatial-decay weights for a data event. + + Combines spatial decay (Mariethoz2010 Eq. 5) with conditioning data + multipliers (Mariethoz2010 §3 ¶26). + + Parameters + ---------- + n : int + Number of neighbours in the data event. + lag_norms : array-like or None, shape (n,) + Euclidean norms ``‖h_i‖`` of each lag vector. ``None`` or + ``distance_power == 0`` → uniform spatial weights. + distance_power : float + Exponent δ. ``0.0`` → uniform. + cond_mask : array-like of bool, optional + ``True`` where the neighbour is a conditioning datum. + cond_weight : float, optional + Bonus weight multiplier for conditioning nodes. + + Returns + ------- + numpy.ndarray, shape (n,) + Node weights normalized to sum to 1. + """ + if lag_norms is not None and distance_power != 0.0: + norms = np.asarray(lag_norms, dtype=np.float64) + norms = np.where(norms == 0.0, 1e-10, norms) + raw_w = norms ** (-distance_power) + else: + raw_w = np.ones(n, dtype=np.float64) + + if cond_mask is not None: + raw_w = raw_w.copy() + raw_w[np.asarray(cond_mask, dtype=bool)] *= cond_weight + + return raw_w / raw_w.sum() + + +def categorical_dist(data_event_sim, data_event_ti, node_weights): + """Weighted categorical distance (Mariethoz2010 Eq. 3). + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + + Returns + ------- + float + Distance in [0, 1]. + """ + return float( + np.dot( + node_weights, + (data_event_sim != data_event_ti).astype(np.float64), + ) + ) + + +def l1_dist(data_event_sim, data_event_ti, node_weights, d_max): + """Weighted L1 distance / Manhattan (Mariethoz2010 Eq. 6). + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + d_max : float + Data range for normalization. + + Returns + ------- + float + Distance in [0, 1]. + """ + return float( + np.dot(node_weights, np.abs(data_event_sim - data_event_ti) / d_max) + ) + + +def l2_dist(data_event_sim, data_event_ti, node_weights, d_max): + """Weighted L2 / RMS distance (Mariethoz2010 Eq. 4–5). + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + d_max : float + Data range for normalization. + + Returns + ------- + float + Distance in [0, 1]. + """ + return float( + np.sqrt( + np.dot( + node_weights, + ((data_event_sim - data_event_ti) / d_max) ** 2, + ) + ) + ) + + +def lp_dist(data_event_sim, data_event_ti, node_weights, d_max, p): + """Weighted Lp (Minkowski) distance. + + Warning: Computationally heavier than l1_dist or l2_dist due to + the generic C-level pow() evaluation. Use only when p != 1.0 or 2.0. + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + d_max : float + Data range for normalization. + p : float + The Minkowski exponent (e.g., 1.5, 3.0, 5.0). + + Returns + ------- + float + Distance in [0, 1]. + """ + diffs = np.abs(data_event_sim - data_event_ti) / d_max + return float(np.sum(node_weights * (diffs**p)) ** (1.0 / p)) + + +def variation_dist(data_event_sim, data_event_ti, node_weights, d_max): + """Weighted variation distance (Mariethoz2010 Eq. 9, de-meaned). + + Parameters + ---------- + data_event_sim : numpy.ndarray, shape (n,) + data_event_ti : numpy.ndarray, shape (n,) + node_weights : numpy.ndarray, shape (n,) + Normalized spatial and conditioning weights. + d_max : float + Data range for normalization. + + Returns + ------- + float + Distance in [0, 1]. + """ + diffs = (data_event_sim - data_event_sim.mean()) - ( + data_event_ti - data_event_ti.mean() + ) + # 2*d_max: |diffs_i| ≤ 2*d_max always → each squared term ≤ 1 → d ∈ [0, 1] + return float(np.sqrt(np.dot(node_weights, (diffs / (2 * d_max)) ** 2))) diff --git a/src/gstools/mps/training_image.py b/src/gstools/mps/training_image.py new file mode 100644 index 00000000..03a2684c --- /dev/null +++ b/src/gstools/mps/training_image.py @@ -0,0 +1,212 @@ +""" +GStools subpackage providing the TrainingImage class for MPS simulations. + +.. currentmodule:: gstools.mps + +The following classes and functions are provided + +.. autosummary:: + TrainingImage +""" + +import numpy as np + +from gstools.mps.distance import ( + categorical_dist, + compute_node_weights, + l1_dist, + l2_dist, + lp_dist, + variation_dist, +) + +__all__ = ["TrainingImage"] + + +class TrainingImage: + """Training image for multiple point statistics simulation. + + The MPS analogue of :class:`gstools.CovModel`: encapsulates training + data and the distance function for comparing data events. + + Parameters + ---------- + data : numpy.ndarray + Training image data (n-d array). + categorical : bool, optional + Whether the variable is categorical. Default: ``True``. + distance : str, optional + Distance metric for continuous variables: ``"l1"`` (Juda2022 + Eq. 7, default), ``"l2"`` (Mariethoz2010 Eq. 4–5), or + ``"variation"`` (Mariethoz2010 Eq. 9). Ignored when categorical. + distance_power : float, optional + Exponent δ for spatial-decay weighting of neighbours + (Mariethoz2010 Eq. 3). Applied to **all** distance types. + ``0.0`` → uniform weights (oracle-compatible default). + ``1.0`` → closer neighbours weighted more heavily. + """ + + def __init__( + self, data, categorical=True, distance="l1", distance_power=0.0 + ): + self._data = np.asarray(data) + self._categorical = bool(categorical) + self._distance_type = distance + self._distance_power = float(distance_power) + self._p_norm = None + distance_lower = str(distance).lower() + if distance_lower.startswith("l"): + try: + p_val = float(distance_lower[1:]) + except ValueError: + raise ValueError( + f"TrainingImage: distance starting with 'l' must be followed by " + f"a positive number (e.g. 'l1', 'l2', 'l3.5'). Got {distance!r}" + ) + if p_val <= 0: + raise ValueError( + f"TrainingImage: Lp norm exponent must be > 0, got {p_val}." + ) + self._p_norm = p_val + elif distance_lower == "variation": + self._p_norm = None + else: + raise ValueError( + f"TrainingImage: distance must be 'l
' (e.g. 'l1', 'l2') " + f"or 'variation'. Got {distance!r}" + ) + + if not self._categorical: + dmax = float(self._data.max() - self._data.min()) + self._d_max = dmax if dmax > 0 else 1.0 + else: + self._d_max = None + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def data(self): + """numpy.ndarray: Raw training image data.""" + return self._data + + @property + def ndim(self): + """int: Number of spatial dimensions.""" + return self._data.ndim + + @property + def shape(self): + """tuple: Shape of the training image.""" + return self._data.shape + + @property + def categorical(self): + """bool: Whether the variable is categorical.""" + return self._categorical + + @property + def distance_type(self): + """str: Distance metric (e.g. ``"l1"``, ``"l2"``, ``"l3.5"``, or ``"variation"``).""" + return self._distance_type + + @property + def distance_power(self): + """float: Spatial-decay exponent δ for node weighting.""" + return self._distance_power + + # ------------------------------------------------------------------ + # Distance + # ------------------------------------------------------------------ + + def distance( + self, + data_event_sim, + data_event_ti, + cond_mask=None, + cond_weight=1.0, + lag_norms=None, + ): + """Distance between two data events. + + Applies spatial-decay weights (Mariethoz2010 Eq. 3) to all + distance types when ``distance_power > 0``. + + Parameters + ---------- + data_event_sim : array-like, shape (n,) + Values at SG neighbourhood nodes. + data_event_ti : array-like, shape (n,) + Values at TI neighbourhood nodes. + cond_mask : array-like of bool, optional + True where the neighbour is a conditioning datum. + cond_weight : float, optional + Weight multiplier δ for conditioning nodes + (Mariethoz2010 §3 ¶26). Default: ``1.0``. + lag_norms : array-like, shape (n,), optional + Euclidean norms ``‖h_i‖`` of each lag vector. Required for + spatial-decay weighting (``distance_power > 0``). + + Returns + ------- + float + Distance in [0, 1]. + """ + data_event_sim = np.asarray(data_event_sim, dtype=np.float64) + data_event_ti = np.asarray(data_event_ti, dtype=np.float64) + n = len(data_event_sim) + if n == 0: + return 0.0 + + w = compute_node_weights( + n, lag_norms, self._distance_power, cond_mask, cond_weight + ) + + if self._categorical: + return categorical_dist(data_event_sim, data_event_ti, w) + if self._p_norm == 1.0: + return l1_dist(data_event_sim, data_event_ti, w, self._d_max) + if self._p_norm == 2.0: + return l2_dist(data_event_sim, data_event_ti, w, self._d_max) + if self._p_norm is not None: + return lp_dist( + data_event_sim, data_event_ti, w, self._d_max, self._p_norm + ) + else: # _p_norm is None → distance="variation" + return variation_dist( + data_event_sim, data_event_ti, w, self._d_max + ) + + def adjust_value(self, ti_val, data_event_sim, data_event_ti): + """Adjust matched TI value before assignment to SG. + + For ``distance="variation"``, applies the mean-shift correction + (Mariethoz2010 Eq. 9): Z(x_i) = Z(y) − Z̄(y) + Z̄(x_i). + For all other metrics returns *ti_val* unchanged. + + Parameters + ---------- + ti_val : float + Raw value at the matched TI node. + data_event_sim : array-like + SG data event (used to compute Z̄(x_i)). + data_event_ti : array-like + TI data event (used to compute Z̄(y)). + + Returns + ------- + float + """ + if self._p_norm is not None or self._categorical: + return ti_val + data_event_sim = np.asarray(data_event_sim, dtype=np.float64) + data_event_ti = np.asarray(data_event_ti, dtype=np.float64) + return float(ti_val - data_event_ti.mean() + data_event_sim.mean()) + + def __repr__(self): + return ( + f"TrainingImage(shape={self.shape}, " + f"categorical={self._categorical}, " + f"distance={self._distance_type!r})" + )