-
Notifications
You must be signed in to change notification settings - Fork 176
Model-centric refactoring to reduce dataset creation #2646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
VeckoTheGecko
wants to merge
48
commits into
Parcels-code:main
Choose a base branch
from
VeckoTheGecko:restructure
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
f5eb442
Typing
VeckoTheGecko 735e1c0
Restructure to introduce models
VeckoTheGecko e03ccf7
Refactor from_sgrid_conventions to model
VeckoTheGecko f198994
Fix typing
VeckoTheGecko aaf6846
Refactor from_ugrid_conventions to model
VeckoTheGecko 915b0cb
Fix typing
VeckoTheGecko 7787c86
Add FieldSet.models
VeckoTheGecko af1dbf5
Move "time_interval" to model
VeckoTheGecko 035bd3f
Update Model ABC
VeckoTheGecko 9e24a14
Update Field init to take model
VeckoTheGecko b69402a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 915b0b6
Add XGCM adapter
VeckoTheGecko 9685ddf
Remove xgcm constructors
VeckoTheGecko 23bc2d4
Update _transpose_xfield_data_to_tzyx to work with SGRID metadata
VeckoTheGecko 82f2001
Define SGRID data pre-processing
VeckoTheGecko f222d4b
Create grid object within StructuredModel
VeckoTheGecko 108d3b2
Allow for time dimension size 1
VeckoTheGecko 065c96d
Disable assert_all_field_dims_have_axis check
VeckoTheGecko 538477d
New interpolator API
VeckoTheGecko f1799ac
Update interpolators to use new API
VeckoTheGecko 1345f6e
Enable adding of fieldsets
VeckoTheGecko b87fae4
Add assert_compatible_fieldsets
VeckoTheGecko 13644ea
Fix test suite
VeckoTheGecko 5569031
Define how to set interpolators
VeckoTheGecko 54674c6
Fix test suite
VeckoTheGecko f2ef7ce
Merge
VeckoTheGecko 18e76d8
Update test suite
VeckoTheGecko bb8c0a5
Fix test suite
VeckoTheGecko a55fe97
Enable constant field tests
VeckoTheGecko 0b2b147
Disable reprs
VeckoTheGecko 2eaf144
Refactor constant field logic to use dedicated model
VeckoTheGecko cc52d2c
Fix constant field logic
VeckoTheGecko 059f5c5
Enable unstructured tests
VeckoTheGecko d896ad6
Update unstructured grid interpolators
VeckoTheGecko 73077b7
Update unstructured FieldSet ingestion in tests
VeckoTheGecko 941c9b1
Update comments
VeckoTheGecko ad7eb5a
Fix test_time1D_field
VeckoTheGecko f278ed9
Merge remote-tracking branch 'upstream/main' into restructure
VeckoTheGecko da46ba7
Update test after merge with main
VeckoTheGecko 7173dc0
Fixes after merge
VeckoTheGecko cc45cd1
Review feedback
VeckoTheGecko 863328e
Update ValueError messages
VeckoTheGecko 5646f15
Add TODO
VeckoTheGecko 6057468
Review feedback
VeckoTheGecko 9e4e34c
Remove zero interpolators
VeckoTheGecko 9519503
Fix typing error
VeckoTheGecko 4e02d1f
Merge with upstream/main
VeckoTheGecko b025d6d
Rename Model to ModelData
VeckoTheGecko File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,11 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
| from collections.abc import Callable, Sequence | ||
| from collections.abc import Sequence | ||
| from datetime import datetime | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
| import uxarray as ux | ||
| import xarray as xr | ||
|
|
||
| from parcels._core.index_search import GRID_SEARCH_ERROR, LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, _search_time_index | ||
| from parcels._core.particlesetview import ParticleSetView | ||
|
|
@@ -15,16 +14,14 @@ | |
| StatusCode, | ||
| ) | ||
| from parcels._core.utils.string import _assert_str_and_python_varname | ||
| from parcels._core.utils.time import TimeInterval | ||
| from parcels._core.uxgrid import UxGrid | ||
| from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx, assert_all_field_dims_have_axis | ||
| from parcels._python import assert_same_function_signature | ||
| from parcels._reprs import field_repr, vectorfield_repr | ||
| from parcels._core.xgrid import XGrid | ||
| from parcels._typing import VectorType | ||
| from parcels.interpolators import ( | ||
| ZeroInterpolator, | ||
| ZeroInterpolator_Vector, | ||
| ) | ||
| from parcels.interpolators._base import ScalarInterpolator, VectorInterpolator | ||
|
|
||
| if TYPE_CHECKING: | ||
| from parcels._core.model import ModelData | ||
|
|
||
|
|
||
| __all__ = ["Field", "VectorField"] | ||
|
|
||
|
|
@@ -86,69 +83,51 @@ class Field: | |
| def __init__( | ||
| self, | ||
| name: str, | ||
| data: xr.DataArray | ux.UxDataArray, | ||
| grid: UxGrid | XGrid, | ||
| interp_method: Callable, | ||
| model: ModelData, | ||
| ): | ||
| if not isinstance(data, (ux.UxDataArray, xr.DataArray)): | ||
| raise ValueError( | ||
| f"Expected `data` to be a uxarray.UxDataArray or xarray.DataArray object, got {type(data)}." | ||
| ) | ||
| # TODO PR: Enable isinstance check once ModelData is moved to abc.ModelData | ||
| # if not isinstance(model, "ModelData"): | ||
| # raise ValueError( | ||
| # f"Expected `model` to be a parcels ModelData object. Got {type(model)}." | ||
| # ) | ||
|
|
||
| _assert_str_and_python_varname(name) | ||
|
|
||
| if not isinstance(grid, (UxGrid, XGrid)): | ||
| raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.") | ||
|
|
||
| _assert_compatible_combination(data, grid) | ||
|
|
||
| if isinstance(grid, XGrid): | ||
| assert_all_field_dims_have_axis(data, grid.xgcm_grid) | ||
| data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid) | ||
|
|
||
| self.name = name | ||
| self.data = data | ||
| self.grid = grid | ||
|
|
||
| try: | ||
| self.time_interval = _get_time_interval(data) | ||
| except ValueError as e: | ||
| e.add_note( | ||
| f"Error getting time interval for field {name!r}. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects?" | ||
| ) | ||
| raise e | ||
| self.model = model | ||
|
|
||
| try: | ||
| if isinstance(data, ux.UxDataArray): | ||
| _assert_valid_uxdataarray(data) | ||
| # TODO: For unstructured grids, validate that `data.uxgrid` is the same as `grid` | ||
| else: | ||
| pass # TODO v4: Add validation for xr.DataArray objects | ||
| except Exception as e: | ||
| e.add_note(f"Error validating field {name!r}.") | ||
| raise e | ||
| self.igrid = -1 # Default the grid index to -1 | ||
|
|
||
| # Setting the interpolation method dynamically | ||
| assert_same_function_signature(interp_method, ref=ZeroInterpolator, context="Interpolation") | ||
| self._interp_method = interp_method | ||
| @property | ||
| def data(self): | ||
| return self.model.data[self.name] | ||
|
|
||
| self.igrid = -1 # Default the grid index to -1 | ||
| @property | ||
| def grid(self): # TODO PR: Remove in favour of referencing model grid directly | ||
| return self.model.grid | ||
|
|
||
| if self.data.shape[0] > 1: | ||
| if "time" not in self.data.coords: | ||
| raise ValueError("Field data is missing a 'time' coordinate.") | ||
| @property | ||
| def time_interval(self): # TODO PR: Remove in favour of referencing model time_interval directly | ||
| return self.model.time_interval | ||
|
|
||
| def __repr__(self): | ||
| return field_repr(self) | ||
| return f"Field(name={self.name}, model={self.model})" | ||
|
|
||
| @property | ||
| def interp_method(self): | ||
| return self._interp_method | ||
| try: | ||
| return self.model.field_to_interpolator[self.name] | ||
| except KeyError as e: | ||
| raise AttributeError( | ||
| f"{type(self).__name__} doesn't have an interp_method defined for it. Use `.interp_method = ...`" | ||
| ) from e | ||
|
|
||
| @interp_method.setter | ||
| def interp_method(self, method: Callable): | ||
| assert_same_function_signature(method, ref=ZeroInterpolator, context="Interpolation") | ||
|
VeckoTheGecko marked this conversation as resolved.
|
||
| self._interp_method = method | ||
| def interp_method(self, value): | ||
| # Setting the interpolation method dynamically | ||
| if not isinstance(value, ScalarInterpolator): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this distinction between ScalarInterpolators and VectorInterpolators! |
||
| raise ValueError(f"interp_method must be a `ScalarInterpolator` object. Got {type(value)=!r}") | ||
| self.model.field_to_interpolator[self.name] = value | ||
|
|
||
| def _check_velocitysampling(self): | ||
| if self.name in ["U", "V", "W"]: | ||
|
|
@@ -193,7 +172,7 @@ def eval(self, time: datetime, z, y, x, particles=None): | |
|
|
||
| particle_positions, grid_positions = _get_positions(self, time, z, y, x, particles, _ei) | ||
|
|
||
| value = self._interp_method(particle_positions, grid_positions, self) | ||
| value = self.interp_method.interp(particle_positions, grid_positions, self) | ||
|
|
||
| _update_particle_states_interp_value(particles, value) | ||
|
|
||
|
|
@@ -219,7 +198,7 @@ def __init__( | |
| U: Field, # noqa: N803 | ||
| V: Field, # noqa: N803 | ||
| W: Field | None = None, # noqa: N803 | ||
| interp_method: Callable | None = None, | ||
| interp_method: VectorInterpolator | None = None, | ||
| ): | ||
| if interp_method is None: | ||
| raise ValueError("interp_method must be provided for VectorField initialization.") | ||
|
|
@@ -244,19 +223,22 @@ def __init__( | |
| else: | ||
| self.vector_type = "2D" | ||
|
|
||
| assert_same_function_signature(interp_method, ref=ZeroInterpolator_Vector, context="Interpolation") | ||
| if not isinstance(interp_method, VectorInterpolator): | ||
| raise ValueError(f"interp_method must be a `VectorInterpolator` object. Got {type(interp_method)=!r}") | ||
|
|
||
| self._interp_method = interp_method | ||
|
|
||
| def __repr__(self): | ||
| return vectorfield_repr(self) | ||
| # def __repr__(self): | ||
| # return vectorfield_repr(self) | ||
|
|
||
| @property | ||
| def interp_method(self): | ||
| return self._interp_method | ||
|
|
||
| @interp_method.setter | ||
| def interp_method(self, method: Callable): | ||
| assert_same_function_signature(method, ref=ZeroInterpolator_Vector, context="Interpolation") | ||
| def interp_method(self, method: VectorInterpolator): | ||
| if not isinstance(method, VectorInterpolator): | ||
| raise ValueError(f"method must be a `VectorInterpolator` object. Got {type(method)=!r}") | ||
| self._interp_method = method | ||
|
|
||
| def eval(self, time: datetime, z, y, x, particles=None): | ||
|
|
@@ -295,7 +277,7 @@ def eval(self, time: datetime, z, y, x, particles=None): | |
|
|
||
| particle_positions, grid_positions = _get_positions(self.U, time, z, y, x, particles, _ei) | ||
|
|
||
| (u, v, w) = self._interp_method(particle_positions, grid_positions, self) | ||
| (u, v, w) = self._interp_method.interp(particle_positions, grid_positions, self) | ||
|
|
||
| for vel in (u, v, w): | ||
| _update_particle_states_interp_value(particles, vel) | ||
|
|
@@ -375,44 +357,6 @@ def _update_particle_states_interp_value(particles, value): | |
| ) | ||
|
|
||
|
|
||
| def _assert_valid_uxdataarray(data: ux.UxDataArray): | ||
| """Verifies that all the required attributes are present in the xarray.DataArray or | ||
| uxarray.UxDataArray object. | ||
| """ | ||
| # Validate dimensions | ||
| if not ("zf" in data.dims or "zc" in data.dims): | ||
| raise ValueError( | ||
| "Field is missing a 'zf' or 'zc' dimension in the field's metadata. " | ||
| "This attribute is required for xarray.DataArray objects." | ||
| ) | ||
|
|
||
| if "time" not in data.dims: | ||
| raise ValueError( | ||
| "Field is missing a 'time' dimension in the field's metadata. " | ||
| "This attribute is required for xarray.DataArray objects." | ||
| ) | ||
|
|
||
|
|
||
| def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: UxGrid | XGrid): | ||
| if isinstance(data, ux.UxDataArray): | ||
| if not isinstance(grid, UxGrid): | ||
| raise ValueError( | ||
| f"Incompatible data-grid combination. Data is a uxarray.UxDataArray, expected `grid` to be a UxGrid object, got {type(grid)}." | ||
| ) | ||
| elif isinstance(data, xr.DataArray): | ||
| if not isinstance(grid, XGrid): | ||
| raise ValueError( | ||
| f"Incompatible data-grid combination. Data is a xarray.DataArray, expected `grid` to be a parcels Grid object, got {type(grid)}." | ||
| ) | ||
|
|
||
|
|
||
| def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None: | ||
| if data.shape[0] == 1: | ||
| return None | ||
|
|
||
| return TimeInterval(data.time.values[0], data.time.values[-1]) | ||
|
|
||
|
|
||
| def _assert_same_time_interval(fields: Sequence[Field]) -> None: | ||
| if len(fields) == 0: | ||
| return | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this choice of -1? Don't we run the same risk of accidentally wrapping negative indices as in #2629?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was already in
main.AFAICT the caching of ei isn't even working (searching for
.igrid =gives no meaningful results). I already suspected this was the case, hence why I was mentioning before that we should have explicit tests for this as part of our release roadmap.For another issue/pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm.. I'll look into this grid index caching issue today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome - thanks for picking this up @fluidnumericsJoe !