Skip to content

Commit

Permalink
Perf improvements (#25)
Browse files Browse the repository at this point in the history
* perf impvmts

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mpiannucci and pre-commit-ci[bot] authored Jul 26, 2023
1 parent 3d207ea commit bf8a8fc
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion xpublish_wms/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class GridType(Enum):

@classmethod
def from_ds(cls, ds: xr.Dataset):
if f"{ds.cf}".startswith("SGRID"):
if "grid" in ds.cf.cf_roles["grid_topology"]:
return cls.SGRID

try:
Expand Down
4 changes: 2 additions & 2 deletions xpublish_wms/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def dataset_router(self, deps: Dependencies) -> APIRouter:
)

@router.get("/")
async def wms_root(
def wms_root(
request: Request,
dataset: xr.Dataset = Depends(deps.dataset),
cache: cachey.Cache = Depends(deps.cache),
):
return await wms_handler(request, dataset, cache)
return wms_handler(request, dataset, cache)

return router
10 changes: 5 additions & 5 deletions xpublish_wms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ def lnglat_to_cartesian(longitude, latitude):
to_mercator = Transformer.from_crs(4326, 3857, always_xy=True)


def da_bbox(da: xr.DataArray) -> Tuple[float, float, float, float]:
def da_bbox(lat: xr.DataArray, lon: xr.DataArray) -> Tuple[float, float, float, float]:
"""
Return the bounding box of the dataarray
:param ds:
:return:
"""
bbox = [
da.cf.coords["longitude"].min().values.item(),
da.cf.coords["latitude"].min().values.item(),
da.cf.coords["longitude"].max().values.item(),
da.cf.coords["latitude"].max().values.item(),
lon.min().values.item(),
lat.min().values.item(),
lon.max().values.item(),
lat.max().values.item(),
]

return bbox
2 changes: 1 addition & 1 deletion xpublish_wms/wms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
logger = logging.getLogger("uvicorn")


async def wms_handler(
def wms_handler(
request: Request,
dataset: xr.Dataset = Depends(get_dataset),
cache: cachey.Cache = Depends(get_cache),
Expand Down
12 changes: 10 additions & 2 deletions xpublish_wms/wms/get_capabilities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import xml.etree.ElementTree as ET
from typing import List

Expand Down Expand Up @@ -139,6 +140,8 @@ def get_capabilities(ds: xr.Dataset, request: Request, query_params: dict) -> Re
create_text_element(layer_tag, crs_tag, "EPSG:3857")
create_text_element(layer_tag, crs_tag, "CRS:84")

current_date = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)

for var in ds.data_vars:
da = ds[var]

Expand All @@ -148,7 +151,9 @@ def get_capabilities(ds: xr.Dataset, request: Request, query_params: dict) -> Re

# TODO: Cache this based on variable names fetched. for now we assume every dataarray
# can have a different bbox
bbox = da_bbox(da)
lat = da.cf["latitude"].persist()
lon = da.cf["longitude"].persist()
bbox = da_bbox(lat, lon)
bounds = {
crs_tag: "EPSG:4326",
"minx": f"{bbox[0]}",
Expand Down Expand Up @@ -199,14 +204,17 @@ def get_capabilities(ds: xr.Dataset, request: Request, query_params: dict) -> Re

if "time" in da.cf.coords:
times = format_timestamp(da.cf["time"])
default_time = format_timestamp(
da.cf["time"].cf.sel(time=current_date, method="nearest"),
).item()

time_dimension_element = ET.SubElement(
layer,
"Dimension",
attrib={
"name": "time",
"units": "ISO8601",
"default": times[-1],
"default": default_time,
},
)
# TODO: Add ISO duration specifier
Expand Down
29 changes: 21 additions & 8 deletions xpublish_wms/wms/get_map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import logging
import time
from datetime import datetime
from typing import List, Union

Expand Down Expand Up @@ -118,6 +119,7 @@ def ensure_query_types(self, ds: xr.Dataset, query: dict):
# Data selection
self.parameter = query["layers"]
self.time_str = query.get("time", None)

if self.time_str:
self.time = pd.to_datetime(self.time_str).tz_localize(None)
else:
Expand Down Expand Up @@ -257,6 +259,7 @@ def render_quad_grid(
:param da:
:return:
"""

if self.crs == "EPSG:3857":
bbox_lng, bbox_lat = to_lnglat.transform(
[self.bbox[0], self.bbox[2]],
Expand All @@ -267,6 +270,7 @@ def render_quad_grid(
bbox_ll = [self.bbox[0], self.bbox[2], self.bbox[1], self.bbox[3]]

if minmax_only:
da = da.persist()
x = np.array(da.cf["longitude"].values)
y = np.array(da.cf["latitude"].values)
data = np.array(da.values)
Expand All @@ -284,19 +288,26 @@ def render_quad_grid(
"max": float(np.nanmax(data_sel)),
}

if not self.autoscale:
vmin, vmax = self.colorscalerange
else:
vmin, vmax = [None, None]

start_dask = time.time()
da.persist()
da.cf["latitude"].persist()
da.cf["longitude"].persist()
logger.debug(f"dask compute: {time.time() - start_dask}")

start_shade = time.time()
cvs = dsh.Canvas(
plot_height=self.height,
plot_width=self.width,
x_range=(bbox_ll[0], bbox_ll[1]),
y_range=(bbox_ll[2], bbox_ll[3]),
)

if not self.autoscale:
vmin, vmax = self.colorscalerange
else:
vmin, vmax = [None, None]

im = tf.shade(
shaded = tf.shade(
cvs.quadmesh(
da,
x=da.cf.coords["longitude"].name,
Expand All @@ -305,7 +316,9 @@ def render_quad_grid(
cmap=cm.get_cmap(self.palettename),
how="linear",
span=(vmin, vmax),
).to_pil()
im.save(buffer, format="PNG")
)
logger.debug(f"Shade time: {time.time() - start_shade}")

im = shaded.to_pil()
im.save(buffer, format="PNG")
return True

0 comments on commit bf8a8fc

Please sign in to comment.