Skip to content

Commit

Permalink
Merge branch 'main' of github.com:xpublish-community/xpublish-wms
Browse files Browse the repository at this point in the history
  • Loading branch information
mpiannucci committed Jul 27, 2023
2 parents 7f533b7 + 263d2ef commit 64c82da
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 24 deletions.
9 changes: 6 additions & 3 deletions xpublish_wms/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@

class GridType(Enum):
REGULAR = 1
SGRID = 2
NON_DIMENSIONAL = 2
SGRID = 3
UNSUPPORTED = 255

@classmethod
def from_ds(cls, ds: xr.Dataset):
if "grid" in ds.cf.cf_roles["grid_topology"]:
if "grid_topology" in ds.cf.cf_roles:
return cls.SGRID

try:
if "latitude" in ds.cf["latitude"].dims:
if len(ds.cf["latitude"].dims) > 1:
return cls.NON_DIMENSIONAL
elif "latitude" in ds.cf["latitude"].dims:
return cls.REGULAR
except Exception:
return cls.UNSUPPORTED
Expand Down
67 changes: 46 additions & 21 deletions xpublish_wms/wms/get_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@

import cachey
import cf_xarray # noqa
import dask.array as dask_array
import datashader as dsh
import datashader.transfer_functions as tf
import matplotlib.cm as cm
import numpy as np
import pandas as pd
import rioxarray # noqa
import xarray as xr
from fastapi.responses import StreamingResponse

from xpublish_wms.grid import GridType
from xpublish_wms.utils import to_lnglat
from xpublish_wms.utils import to_mercator

logger = logging.getLogger("uvicorn")

Expand Down Expand Up @@ -255,30 +257,53 @@ def render_quad_grid(
) -> Union[bool, dict]:
"""
Render the data array into an image buffer when the dataset is using a
regular or staggered (ala ROMS) grid
2d grid
:param da:
:return:
"""

projection_start = time.time()
if self.crs == "EPSG:3857":
bbox_lng, bbox_lat = to_lnglat.transform(
[self.bbox[0], self.bbox[2]],
[self.bbox[1], self.bbox[3]],
)
bbox_ll = [*bbox_lng, *bbox_lat]
if (
self.grid_type == GridType.NON_DIMENSIONAL
or self.grid_type == GridType.SGRID
):
x, y = to_mercator.transform(da.cf["longitude"], da.cf["latitude"])
x_chunks = (
da.cf["longitude"].chunks if da.cf["longitude"].chunks else x.shape
)
y_chunks = (
da.cf["latitude"].chunks if da.cf["latitude"].chunks else y.shape
)

da = da.assign_coords(
{
"x": (
da.cf["longitude"].dims,
dask_array.from_array(x, chunks=x_chunks),
),
"y": (
da.cf["latitude"].dims,
dask_array.from_array(y, chunks=y_chunks),
),
},
)
elif self.grid_type == GridType.REGULAR:
da = da.rio.reproject("EPSG:3857")
else:
bbox_ll = [self.bbox[0], self.bbox[2], self.bbox[1], self.bbox[3]]
da = da.assign_coords({"x": da.cf["longitude"], "y": da.cf["latitude"]})

logger.debug(f"Projection time: {time.time() - projection_start}")

if minmax_only:
da = da.persist()
x = np.array(da.cf["longitude"].values)
y = np.array(da.cf["latitude"].values)
x = np.array(da.x.values)
y = np.array(da.y.values)
data = np.array(da.values)
inds = np.where(
(x >= (bbox_ll[0] - 0.18))
& (x <= (bbox_ll[1] + 0.18))
& (y >= (bbox_ll[2] - 0.18))
& (y <= (bbox_ll[3] + 0.18)),
(x >= (self.bbox[0] - 0.18))
& (x <= (self.bbox[2] + 0.18))
& (y >= (self.bbox[1] - 0.18))
& (y <= (self.bbox[3] + 0.18)),
)
# x_sel = x[inds].flatten()
# y_sel = y[inds].flatten()
Expand All @@ -295,23 +320,23 @@ def render_quad_grid(

start_dask = time.time()
da.persist()
da.cf["latitude"].persist()
da.cf["longitude"].persist()
da.y.persist()
da.x.persist()
logger.debug(f"dask compute: {time.time() - start_dask}")

start_shade = time.time()
cvs = dsh.Canvas(
plot_height=self.height,
plot_width=self.width,
x_range=(bbox_ll[0], bbox_ll[1]),
y_range=(bbox_ll[2], bbox_ll[3]),
x_range=(self.bbox[0], self.bbox[2]),
y_range=(self.bbox[1], self.bbox[3]),
)

shaded = tf.shade(
cvs.quadmesh(
da,
x=da.cf.coords["longitude"].name,
y=da.cf.coords["latitude"].name,
x="x",
y="y",
),
cmap=cm.get_cmap(self.palettename),
how="linear",
Expand Down

0 comments on commit 64c82da

Please sign in to comment.