Skip to content

Commit

Permalink
ROMS updates (#38)
Browse files Browse the repository at this point in the history
* Apply roms masking

* Run black

* Fix GFI for roms

* Maks out GFI requests

* Update typing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mpiannucci and pre-commit-ci[bot] authored Oct 31, 2023
1 parent c265cd1 commit e843618
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 47 deletions.
59 changes: 33 additions & 26 deletions xpublish_wms/grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, Union

import cartopy.geodesic
import cf_xarray # noqa
Expand Down Expand Up @@ -91,6 +91,13 @@ def select_by_elevation(
da = da.cf.sel({"vertical": elevation}, method="nearest")
return da

def mask(
self,
da: Union[xr.DataArray, xr.Dataset],
) -> Union[xr.DataArray, xr.Dataset]:
"""Mask the given data array"""
return da

@abstractmethod
def project(self, da: xr.DataArray, crs: str) -> Any:
"""Project the given data array from this dataset and grid to the given crs"""
Expand Down Expand Up @@ -148,7 +155,6 @@ def __init__(self, ds: xr.Dataset):

@staticmethod
def recognize(ds: xr.Dataset) -> bool:
print(ds.cf.cf_roles)
return "grid_topology" in ds.cf.cf_roles

@property
Expand All @@ -163,6 +169,18 @@ def render_method(self) -> RenderMethod:
def crs(self) -> str:
return "EPSG:4326"

def mask(
self,
da: Union[xr.DataArray, xr.Dataset],
) -> Union[xr.DataArray, xr.Dataset]:
mask = self.ds[f'mask_{da.cf["latitude"].name.split("_")[1]}']
mask = mask.cf.isel(time=0).squeeze(drop=True).cf.drop_vars("time")
mask[:-1, :] = mask[:-1, :].where(mask[1:, :] == 1, 0)
mask[:, :-1] = mask[:, :-1].where(mask[:, 1:] == 1, 0)
mask[1:, :] = mask[1:, :].where(mask[:-1, :] == 1, 0)
mask[:, 1:] = mask[:, 1:].where(mask[:, :-1] == 1, 0)
return da.where(mask == 1)

def project(self, da: xr.DataArray, crs: str) -> xr.DataArray:
if crs == "EPSG:4326":
da = da.assign_coords({"x": da.cf["longitude"], "y": da.cf["latitude"]})
Expand All @@ -187,6 +205,7 @@ def project(self, da: xr.DataArray, crs: str) -> xr.DataArray:
)

da = da.unify_chunks()

return da

def sel_lat_lng(
Expand All @@ -196,7 +215,7 @@ def sel_lat_lng(
lat,
parameters,
) -> Tuple[xr.Dataset, list, list]:
topology = self.ds.cf["grid_topology"]
topology = self.ds.grid

merged_ds = None
x_axis = None
Expand All @@ -207,6 +226,7 @@ def sel_lat_lng(
lng_coord, lat_coord = topology.attrs[f"{grid_location}_coordinates"].split(
" ",
)

new_selected_ds = sel2d(
subset,
lons=subset.cf[lng_coord],
Expand Down Expand Up @@ -438,7 +458,7 @@ def grid_factory(ds: xr.Dataset) -> Optional[Grid]:
return None


@xr.register_dataset_accessor("grid")
@xr.register_dataset_accessor("gridded")
class GridDatasetAccessor:
_ds: xr.Dataset
_grid: Optional[Grid]
Expand Down Expand Up @@ -515,6 +535,15 @@ def select_by_elevation(
else:
return self._grid.select_by_elevation(da, elevation)

def mask(
self,
da: Union[xr.DataArray, xr.Dataset],
) -> Union[xr.DataArray, xr.Dataset]:
if self._grid is None:
return None
else:
return self._grid.mask(da)

def project(self, da: xr.DataArray, crs: str) -> xr.DataArray:
if self._grid is None:
return None
Expand All @@ -540,28 +569,6 @@ def sel_lat_lng(
return self._grid.sel_lat_lng(subset, lng, lat, parameters)


class GridType(Enum):
REGULAR = 1
NON_DIMENSIONAL = 2
SGRID = 3
UNSUPPORTED = 255

@classmethod
def from_ds(cls, ds: xr.Dataset):
if "grid_topology" in ds.cf.cf_roles:
return cls.SGRID

try:
if len(ds.cf["latitude"].dims) > 1:
return cls.NON_DIMENSIONAL
elif "latitude" in ds.cf["latitude"].dims:
return cls.REGULAR
except Exception:
return cls.UNSUPPORTED

return cls.UNSUPPORTED


def argsel2d(lons, lats, lon0, lat0):
"""Find the indices of coordinate pair closest to another point.
Adapted from https://github.com/xoceanmodel/xroms/blob/main/xroms/utilities.py which is failing to run for some reason
Expand Down
8 changes: 4 additions & 4 deletions xpublish_wms/wms/get_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def get_capabilities(ds: xr.Dataset, request: Request, query_params: dict) -> Re
# can have a different bbox
da.cf["latitude"].persist()
da.cf["longitude"].persist()
bbox = ds.grid.bbox(da)
bbox = ds.gridded.bbox(da)
bounds = {
crs_tag: "EPSG:4326",
"minx": f"{bbox[0]}",
Expand Down Expand Up @@ -220,14 +220,14 @@ def get_capabilities(ds: xr.Dataset, request: Request, query_params: dict) -> Re
# TODO: Add ISO duration specifier
time_dimension_element.text = f"{','.join(times)}"

if ds.grid.has_elevation(da):
elevations_values = ds.grid.elevations(da).persist()
if ds.gridded.has_elevation(da):
elevations_values = ds.gridded.elevations(da).persist()
default_elevation_index = np.abs(elevations_values).argmin().values
default_elevation = elevations_values[default_elevation_index].values.round(
5,
)
elevations = [f"{e}" for e in elevations_values.values.round(5)]
elevation_units = ds.grid.elevation_units(da)
elevation_units = ds.gridded.elevation_units(da)
elevation_dimension_element = ET.SubElement(
layer,
"Dimension",
Expand Down
8 changes: 5 additions & 3 deletions xpublish_wms/wms/get_feature_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_feature_info(ds: xr.Dataset, query: dict) -> Response:
else:
elevation = []
has_vertical_axis = [
ds.grid.has_elevation(ds[parameter]) for parameter in parameters
ds.gridded.has_elevation(ds[parameter]) for parameter in parameters
]
any_has_vertical_axis = True in has_vertical_axis

Expand Down Expand Up @@ -169,14 +169,16 @@ def get_feature_info(ds: xr.Dataset, query: dict) -> Response:
selected_ds = selected_ds.cf.sel(vertical=0, method="nearest")

try:
selected_ds, x_axis, y_axis = ds.grid.sel_lat_lng(
# Apply masking if necessary
selected_ds = ds.gridded.mask(selected_ds)
selected_ds, x_axis, y_axis = ds.gridded.sel_lat_lng(
subset=selected_ds,
lng=x_coord[x],
lat=y_coord[y],
parameters=parameters,
)
except ValueError:
raise HTTPException(500, f"Unsupported grid type: {ds.grid.name}")
raise HTTPException(500, f"Unsupported grid type: {ds.gridded.name}")

# When none of the parameters have data, drop it
time_coord_name = selected_ds.cf.coordinates["time"][0]
Expand Down
16 changes: 7 additions & 9 deletions xpublish_wms/wms/get_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import xarray as xr
from fastapi.responses import StreamingResponse

from xpublish_wms.grid import GridType, RenderMethod
from xpublish_wms.grid import RenderMethod

logger = logging.getLogger("uvicorn")

Expand All @@ -41,7 +41,6 @@ class GetMap:
has_elevation: bool

# Grid
grid_type: GridType
crs: str
bbox = List[float]
width: int
Expand Down Expand Up @@ -114,8 +113,6 @@ def ensure_query_types(self, ds: xr.Dataset, query: dict):
:param query:
:return:
"""
self.grid_type = GridType.from_ds(ds)

# Data selection
self.parameter = query["layers"]
self.time_str = query.get("time", None)
Expand Down Expand Up @@ -202,7 +199,7 @@ def select_elevation(self, ds: xr.Dataset, da: xr.DataArray) -> xr.DataArray:
:param da:
:return:
"""
da = ds.grid.select_by_elevation(da, self.elevation)
da = ds.gridded.select_by_elevation(da, self.elevation)
print(da.shape)

return da
Expand Down Expand Up @@ -243,7 +240,8 @@ def render(
# TODO: FVCOM and other grids
# return self.render_quad_grid(da, buffer, minmax_only)
projection_start = time.time()
da = ds.grid.project(da, self.crs)
da = ds.gridded.mask(da)
da = ds.gridded.project(da, self.crs)
logger.debug(f"Projection time: {time.time() - projection_start}")

if minmax_only:
Expand Down Expand Up @@ -284,14 +282,14 @@ def render(
y_range=(self.bbox[1], self.bbox[3]),
)

if ds.grid.render_method == RenderMethod.Quad:
if ds.gridded.render_method == RenderMethod.Quad:
mesh = cvs.quadmesh(
da,
x="x",
y="y",
)
elif ds.grid.render_method == RenderMethod.Triangle:
triangles = ds.grid.tessellate(da)
elif ds.gridded.render_method == RenderMethod.Triangle:
triangles = ds.gridded.tessellate(da)
verts = pd.DataFrame({"x": da.x, "y": da.y, "z": da})
tris = pd.DataFrame(triangles.astype(int), columns=["v0", "v1", "v2"])

Expand Down
10 changes: 5 additions & 5 deletions xpublish_wms/wms/get_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def get_layer_details(ds: xr.Dataset, layer_name: str) -> dict:
da = ds[layer_name]
units = da.attrs.get("units", "")
supported_styles = "raster" # TODO: more styles
bbox = ds.grid.bbox(da)
if ds.grid.has_elevation(da):
elevation = ds.grid.elevations(da).values.round(5).tolist()
elevation_positive = ds.grid.elevation_positive_direction(da)
elevation_units = ds.grid.elevation_units(da)
bbox = ds.gridded.bbox(da)
if ds.gridded.has_elevation(da):
elevation = ds.gridded.elevations(da).values.round(5).tolist()
elevation_positive = ds.gridded.elevation_positive_direction(da)
elevation_units = ds.gridded.elevation_units(da)
else:
elevation = None
elevation_positive = None
Expand Down

0 comments on commit e843618

Please sign in to comment.