Skip to content

Commit

Permalink
Working selector class before api improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mpiannucci committed Jul 30, 2024
1 parent abb551a commit 3f0b2d9
Show file tree
Hide file tree
Showing 6 changed files with 734 additions and 726 deletions.
1,214 changes: 534 additions & 680 deletions examples/fvcom.ipynb

Large diffs are not rendered by default.

48 changes: 42 additions & 6 deletions xarray_subset_grid/grids/regular_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,51 @@
import xarray as xr

from xarray_subset_grid.grid import Grid
from xarray_subset_grid.selector import Selector
from xarray_subset_grid.utils import (
normalize_bbox_x_coords,
normalize_polygon_x_coords,
ray_tracing_numpy,
)


class RegularGridPolygonSelector(Selector):
"""Selector for regular lat/lng grids"""

polygon: list[tuple[float, float]] | np.ndarray
_polygon_mask: xr.DataArray

def __init__(self, polygon: list[tuple[float, float]] | np.ndarray, mask: xr.DataArray):
super().__init__()
self.polygon_mask = mask

def select(self, ds: xr.Dataset) -> xr.Dataset:
"""Perform the selection on the dataset"""
ds_subset = ds.cf.isel(
lon=self._polygon_mask,
lat=self._polygon_mask,
)
return ds_subset


class RegularGridBBoxSelector(Selector):
"""Selector for regular lat/lng grids"""

bbox: tuple[float, float, float, float]
_longitude_selection: slice
_latitude_selection: slice

def __init__(self, bbox: tuple[float, float, float, float]):
super().__init__()
self.bbox = bbox
self._longitude_selection = slice(bbox[0], bbox[2])
self._latitude_selection = slice(bbox[1], bbox[3])

def select(self, ds: xr.Dataset) -> xr.Dataset:
"""Perform the selection on the dataset"""
ds.cf.sel(lon=self._longitude_selection, lat=self._latitude_selection)


class RegularGrid(Grid):
"""Grid implementation for regular lat/lng grids"""

Expand Down Expand Up @@ -72,11 +110,8 @@ def subset_polygon(
polygon = normalize_polygon_x_coords(x, polygon)
polygon_mask = ray_tracing_numpy(x, lat.flat, polygon).reshape(lon.shape)

ds_subset = ds.cf.isel(
lon=polygon_mask,
lat=polygon_mask,
)
return ds_subset
selector = RegularGridPolygonSelector(polygon, polygon_mask)
return selector.select(ds)

def subset_bbox(self, ds: xr.Dataset, bbox: tuple[float, float, float, float]) -> xr.Dataset:
"""Subset the dataset to the bounding box
Expand All @@ -85,4 +120,5 @@ def subset_bbox(self, ds: xr.Dataset, bbox: tuple[float, float, float, float]) -
:return: The subsetted dataset
"""
bbox = normalize_bbox_x_coords(ds.cf["longitude"].values, bbox)
return ds.cf.sel(lon=slice(bbox[0], bbox[2]), lat=slice(bbox[1], bbox[3]))
selector = RegularGridBBoxSelector(bbox)
return selector.select(ds)
49 changes: 35 additions & 14 deletions xarray_subset_grid/grids/regular_grid_2d.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
import numpy as np
import xarray as xr

from xarray_subset_grid.grid import Grid
from xarray_subset_grid.selector import Selector
from xarray_subset_grid.utils import compute_2d_subset_mask


class RegularGrid2dSelector(Selector):
polygon: list[tuple[float, float]] | np.ndarray
_subset_mask: xr.DataArray

def __init__(self, polygon: list[tuple[float, float]] | np.ndarray, subset_mask: xr.DataArray):
super().__init__()
self.polygon = polygon
self._subset_mask = subset_mask

def select(self, ds: xr.Dataset) -> xr.Dataset:
# First, we need to add the mask as a variable in the dataset
# so that we can use it to mask and drop via xr.where, which requires that
# the mask and data have the same shape and both are DataArrays with matching
# dimensions
ds_subset = ds.assign(subset_mask=self._subset_mask)

# Now we can use the mask to subset the data
ds_subset = ds_subset.where(ds_subset.subset_mask, drop=True).drop_encoding()
ds_subset.drop_vars("subset_mask")
return ds_subset


class RegularGrid2d(Grid):
"""Grid implementation for 2D regular grids"""

@staticmethod
def recognize(ds):
def recognize(ds) -> bool:
"""Recognize if the dataset matches the given grid"""
lat = ds.cf.coordinates.get("latitude", None)
lon = ds.cf.coordinates.get("longitude", None)
Expand All @@ -20,11 +46,11 @@ def recognize(ds):
return lat_dim == lon_dim and ndim == 2

@property
def name(self):
def name(self) -> str:
"""Name of the grid type"""
return "regular_grid_2d"

def grid_vars(self, ds):
def grid_vars(self, ds: xr.Dataset) -> set[str]:
"""Set of grid variables
These variables are used to define the grid and thus should be kept
Expand All @@ -34,7 +60,7 @@ def grid_vars(self, ds):
lon = ds.cf.coordinates["longitude"][0]
return {lat, lon}

def data_vars(self, ds):
def data_vars(self, ds: xr.Dataset) -> set[str]:
"""Set of data variables
These variables exist on the grid and are available to used for
Expand All @@ -51,7 +77,9 @@ def data_vars(self, ds):
and "longitude" in var.cf.coordinates
}

def subset_polygon(self, ds, polygon):
def subset_polygon(
self, ds: xr.Dataset, polygon: list[tuple[float, float]] | np.ndarray
) -> xr.Dataset:
"""Subset the dataset to the grid
:param ds: The dataset to subset
:param polygon: The polygon to subset to
Expand All @@ -61,12 +89,5 @@ def subset_polygon(self, ds, polygon):
lon = ds.cf["longitude"]
subset_mask = compute_2d_subset_mask(lat=lat, lon=lon, polygon=polygon)

# First, we need to add the mask as a variable in the dataset
# so that we can use it to mask and drop via xr.where, which requires that
# the mask and data have the same shape and both are DataArrays with matching
# dimensions
ds_subset = ds.assign(subset_mask=subset_mask)

# Now we can use the mask to subset the data
ds_subset = ds_subset.where(ds_subset.subset_mask, drop=True).drop_encoding()
return ds_subset
selector = RegularGrid2dSelector(polygon=polygon, subset_mask=subset_mask)
return selector.select(ds)
63 changes: 44 additions & 19 deletions xarray_subset_grid/grids/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,49 @@
import xarray as xr

from xarray_subset_grid.grid import Grid
from xarray_subset_grid.selector import Selector
from xarray_subset_grid.utils import compute_2d_subset_mask


class SGridSelector(Selector):
polygon: list[tuple[float, float]] | np.ndarray

_grid_topology_key: str
_grid_topology: xr.DataArray
_subset_masks: list[tuple[list[str], xr.DataArray]]

def __init__(
self,
polygon: list[tuple[float, float]] | np.ndarray,
subset_masks: list[tuple[list[str], xr.DataArray]],
):
super().__init__()
self.polygon = polygon
self._subset_masks = subset_masks

def select(self, ds: xr.Dataset) -> xr.Dataset:
ds_out = []
for mask in self._subset_masks:
# First, we need to add the mask as a variable in the dataset
# so that we can use it to mask and drop via xr.where, which requires that
# the mask and data have the same shape and both are DataArrays with matching
# dimensions
ds_subset = ds.assign(subset_mask=mask[1])

# Now we can use the mask to subset the data
ds_subset = ds_subset[mask[0]].where(ds_subset.subset_mask, drop=True).drop_encoding()
ds_subset = ds_subset.drop_vars("subset_mask")

# Add the subsetted dataset to the list for merging
ds_out.append(ds_subset)

# Merge the subsetted datasets
ds_out = xr.merge(ds_out)

ds_out = ds_out.assign({self._grid_topology_key: self._grid_topology})
return ds_out


class SGrid(Grid):
"""Grid implementation for SGRID datasets"""

Expand Down Expand Up @@ -66,8 +106,7 @@ def subset_polygon(
grid_topology = ds[grid_topology_key]
dims = _get_sgrid_dim_coord_names(grid_topology)

ds_out = []

subset_masks: list[tuple[list[str], xr.DataArray]] = []
for dim, coord in dims:
# Get the variables that have the dimensions
unique_dims = set(dim)
Expand All @@ -92,24 +131,10 @@ def subset_polygon(

subset_mask = compute_2d_subset_mask(lat=lat, lon=lon, polygon=polygon)

# First, we need to add the mask as a variable in the dataset
# so that we can use it to mask and drop via xr.where, which requires that
# the mask and data have the same shape and both are DataArrays with matching
# dimensions
ds_subset = ds.assign(subset_mask=subset_mask)

# Now we can use the mask to subset the data
ds_subset = ds_subset[vars].where(ds_subset.subset_mask, drop=True).drop_encoding()
subset_masks.append((vars, subset_mask))

# Add the subsetted dataset to the list for merging
ds_out.append(ds_subset)

# Merge the subsetted datasets
ds_out = xr.merge(ds_out)

ds_out = ds_out.assign({grid_topology_key: grid_topology})

return ds_out
selector = SGridSelector(polygon=polygon, subset_masks=subset_masks)
return selector.select(ds)


def _get_sgrid_dim_coord_names(
Expand Down
79 changes: 72 additions & 7 deletions xarray_subset_grid/grids/ugrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import xarray as xr

from xarray_subset_grid.grid import Grid
from xarray_subset_grid.selector import Selector
from xarray_subset_grid.utils import (
normalize_polygon_x_coords,
ray_tracing_numpy,
Expand All @@ -22,6 +23,62 @@
)


class UGridSelector(Selector):
polygon: list[tuple[float, float]] | np.ndarray

_node_dimension: str
_selected_nodes: np.ndarray

_face_dimension: str
_selected_elements: np.ndarray

_face_node_connectivity_key: str
_face_node_connectivity: np.ndarray

_face_face_connectivity_key: str | None
_face_face_connectivity: np.ndarray | None

def __init__(
self,
polygon: list[tuple[float, float]] | np.ndarray,
node_dimension: str,
selected_nodes: np.ndarray,
face_dimension: str,
selected_elements: np.ndarray,
face_node_connectivity_key: str,
face_node_connectivity: np.ndarray,
face_face_connectivity_key: str | None = None,
face_face_connectivity: np.ndarray | None = None,
):
super().__init__()
self.polygon = polygon
self._node_dimension = node_dimension
self._selected_nodes = selected_nodes
self._face_dimension = face_dimension
self._selected_elements = selected_elements
self._face_node_connectivity_key = face_node_connectivity_key
self._face_node_connectivity = face_node_connectivity
self._face_face_connectivity_key = face_face_connectivity_key
self._face_face_connectivity = face_face_connectivity

def select(self, ds: xr.Dataset) -> xr.Dataset:
# Subset using xarrays select indexing, and overwrite the face_node_connectivity
# and face_face_connectivity (if available) with the new indices
ds_subset = ds.sel(
{
self._node_dimension: self._selected_nodes,
self._face_dimension: self._selected_elements,
}
)
ds_subset[self._face_node_connectivity_key][:] = self._face_node_connectivity
if (
self._face_face_connectivity is not None
and self._face_face_connectivity_key is not None
):
ds_subset[self._face_face_connectivity_key][:] = self._face_face_connectivity
return ds_subset


class UGrid(Grid):
"""Grid implementation for UGRID datasets
Expand Down Expand Up @@ -202,13 +259,21 @@ def subset_polygon(
if transpose_face_face_connectivity:
face_face_new = face_face_new.T

# Subset using xarrays select indexing, and overwrite the face_node_connectivity
# and face_face_connectivity (if available) with the new indices
ds_subset = ds.sel({node_dimension: selected_nodes, face_dimension: selected_elements})
ds_subset[mesh.face_node_connectivity][:] = face_node_new
if has_face_face_connectivity:
ds_subset[mesh.face_face_connectivity][:] = face_face_new
return ds_subset
selector = UGridSelector(
polygon=polygon,
node_dimension=node_dimension,
selected_nodes=selected_nodes,
face_dimension=face_dimension,
selected_elements=selected_elements,
face_node_connectivity_key=mesh.face_node_connectivity,
face_node_connectivity=face_node_new,
face_face_connectivity_key=mesh.face_face_connectivity
if has_face_face_connectivity
else None,
face_face_connectivity=face_face_new if has_face_face_connectivity else None,
)

return selector.select(ds)


def assign_ugrid_topology(
Expand Down
7 changes: 7 additions & 0 deletions xarray_subset_grid/selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import xarray as xr


class Selector:
def select(self, ds: xr.Dataset) -> xr.Dataset:
"""Perform the selection on the dataset"""
return ds

0 comments on commit 3f0b2d9

Please sign in to comment.