Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 68 additions & 18 deletions src/parcels/interpolators/_uxinterpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr
from dask import is_dask_collection

if TYPE_CHECKING:
from parcels._core.field import Field, VectorField
Expand All @@ -17,9 +19,19 @@ def UxConstantFaceConstantZC(
field: Field,
):
"""Piecewise constant interpolation kernel for face registered data that is vertically centered (on zc points)"""
return field.data.values[
# Broadcast the per-axis indices to a common (npart,) shape (``ti`` may be scalar for time-constant fields)
ti, zi, fi = np.broadcast_arrays(
grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
]
)

tdim, zdim, fdim = field.data.dims
selection_dict = {
tdim: xr.DataArray(ti, dims="points"),
zdim: xr.DataArray(zi, dims="points"),
fdim: xr.DataArray(fi, dims="points"),
}
value = field.data.isel(selection_dict, ignore_grid=True).data
return value.compute() if is_dask_collection(value) else value


def UxConstantFaceLinearZF(
Expand All @@ -31,15 +43,28 @@ def UxConstantFaceLinearZF(
Piecewise constant interpolation (lateral) with linear vertical interpolation kernel for face registered data
that is located at vertical interface levels (on zf points)
"""
ti = grid_positions["T"]["index"]
zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
ti, zi, fi = np.broadcast_arrays(
grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
)
z = particle_positions["z"]

tdim, zdim, fdim = field.data.dims

def _zsample(z_index):
"""Pointwise ``isel`` of the face values at a single vertical interface level."""
selection_dict = {
tdim: xr.DataArray(ti, dims="points"),
zdim: xr.DataArray(z_index, dims="points"),
fdim: xr.DataArray(fi, dims="points"),
}
value = field.data.isel(selection_dict, ignore_grid=True).data
return value.compute() if is_dask_collection(value) else value

# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
# First, do barycentric interpolation in the lateral direction for each interface level
fzk = field.data.values[ti, zi, fi]
fzkp1 = field.data.values[ti, zi + 1, fi]
fzk = _zsample(zi)
fzkp1 = _zsample(zi + 1)

# Then, do piecewise linear interpolation in the vertical direction
zk = field.grid.z.values[zi]
Expand All @@ -57,13 +82,22 @@ def UxLinearNodeConstantZC(
Effectively, it applies barycentric interpolation in the lateral direction
and piecewise constant interpolation in the vertical direction.
"""
ti = grid_positions["T"]["index"]
zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
bcoords = grid_positions["FACE"]["bcoord"]
ti, zi, fi = np.broadcast_arrays(
grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
)
bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes"))
node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values
return np.sum(
field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1
) # Linear interpolation in the vertical direction

tdim, zdim, ndim = field.data.dims
selection_dict = {
tdim: xr.DataArray(ti, dims="points"),
zdim: xr.DataArray(zi, dims="points"),
ndim: xr.DataArray(node_ids, dims=("points", "nodes")),
}

node_data = field.data.isel(selection_dict, ignore_grid=True)
value = (node_data * bcoords).sum("nodes").data # Barycentric interpolation in the lateral direction
return value.compute() if is_dask_collection(value) else value


def UxLinearNodeLinearZF(
Expand All @@ -76,21 +110,37 @@ def UxLinearNodeLinearZF(
Effectively, it applies barycentric interpolation in the lateral direction
and piecewise linear interpolation in the vertical direction.
"""
ti = grid_positions["T"]["index"]
zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
ti, zi, fi = np.broadcast_arrays(
grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
)
z = particle_positions["z"]
bcoords = grid_positions["FACE"]["bcoord"]
bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes"))
node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values

tdim, zdim, ndim = field.data.dims

def _zsample(z_index):
"""Barycentric (lateral) interpolation of the node values at a single vertical interface level."""
selection_dict = {
tdim: xr.DataArray(ti, dims="points"),
zdim: xr.DataArray(z_index, dims="points"),
ndim: xr.DataArray(node_ids, dims=("points", "nodes")),
}
# Reduce over the "nodes" dimension by name so the result is independent of ``isel`` dim order.
node_data = field.data.isel(selection_dict, ignore_grid=True)
return (node_data * bcoords).sum("nodes").data

# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
# First, do barycentric interpolation in the lateral direction for each interface level
fzk = np.sum(field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1)
fzkp1 = np.sum(field.data.values[ti[:, None], zi[:, None] + 1, node_ids] * bcoords, axis=-1)
fzk = _zsample(zi)
fzkp1 = _zsample(zi + 1)

# Then, do piecewise linear interpolation in the vertical direction
zk = field.grid.z.values[zi]
zkp1 = field.grid.z.values[zi + 1]
return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction
value = (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction
return value.compute() if is_dask_collection(value) else value


def Ux_Velocity(
Expand Down
Loading