Skip to content

Commit

Permalink
Skip receivers w/ invalid DEM data (NaN)
Browse files Browse the repository at this point in the history
  • Loading branch information
liamtoney committed Jul 25, 2023
1 parent bdeca17 commit 5ca2e7f
Showing 1 changed file with 49 additions and 34 deletions.
83 changes: 49 additions & 34 deletions infresnel/infresnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyproj import Transformer
from rasterio.enums import Resampling
from scipy.interpolate import RectBivariateSpline
from tqdm.contrib import tzip
from tqdm.auto import tqdm

from ._georeference import (
_check_valid_elevation_for_coords,
Expand Down Expand Up @@ -122,10 +122,14 @@ def calculate_paths(
print('Checking that DEM contains source and receivers...')
if not _check_valid_elevation_for_coords(dem_utm, mean_resolution, src_x, src_y):
raise ValueError('Source is not in DEM! Exiting.')
compute_paths = np.ones(rec_xs.size).astype(bool) # By default, compute all paths
for i, (x, y) in enumerate(zip(rec_xs, rec_ys)):
if not _check_valid_elevation_for_coords(dem_utm, mean_resolution, x, y):
raise ValueError(f'Receiver (index {i}) is not in DEM! Exiting.')
print('Done\n')
compute_paths[i] = False # Don't compute this path
if (~compute_paths).any():
print(f'Done — will skip {(~compute_paths).sum()} invalid path(s)\n')
else:
print('Done\n')

# Fit bivariate spline to DEM (slow for very high resolution DEMs!)
print('Fitting spline to DEM...')
Expand All @@ -143,41 +147,48 @@ def calculate_paths(

# Iterate over all receivers (= source-receiver pairs), calculating paths
ds_list = []
print(f'Computing {rec_lats.size} path{"" if rec_lats.size == 1 else "s"}...')
if rec_lats.size == 1:
rec_zip = zip(rec_xs, rec_ys) # Don't create progress bar if only 1 path
else:
rec_zip = tzip(
rec_xs,
rec_ys,
n_valid_paths = rec_xs[compute_paths].size
print(f'Computing {n_valid_paths} path{"" if n_valid_paths == 1 else "s"}...')
if n_valid_paths > 1: # Only creating the progress bar if we have more than 1 path
bar = tqdm(
total=n_valid_paths,
bar_format='{percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt} paths ',
)
for rec_x, rec_y in rec_zip:
# Determine # of points in profile
dist = _norm(src_x - rec_x, src_y - rec_y)
n = max(int(np.ceil(dist / target_spacing)), 2) # Ensure at least 2 points!

# Make profile by evaluating spline
xvec = np.linspace(src_x, rec_x, n)
yvec = np.linspace(src_y, rec_y, n)
profile = xr.DataArray(
spline.ev(xvec, yvec),
dims='distance',
coords=dict(x=('distance', xvec, units), y=('distance', yvec, units)),
attrs=units,
)
profile = profile.assign_coords(
distance=_horizontal_distance(profile.x.values, profile.y.values)
)
profile.distance.attrs = units
for rec_x, rec_y, compute_path in zip(rec_xs, rec_ys, compute_paths):
# If the DEM points were valid, compute the path
if compute_path:
# Determine # of points in profile
dist = _norm(src_x - rec_x, src_y - rec_y)
n = max(int(np.ceil(dist / target_spacing)), 2) # Ensure at least 2 points!

# Make profile by evaluating spline
xvec = np.linspace(src_x, rec_x, n)
yvec = np.linspace(src_y, rec_y, n)
profile = xr.DataArray(
spline.ev(xvec, yvec),
dims='distance',
coords=dict(x=('distance', xvec, units), y=('distance', yvec, units)),
attrs=units,
)
profile = profile.assign_coords(
distance=_horizontal_distance(profile.x.values, profile.y.values)
)
profile.distance.attrs = units

# Compute DIRECT path
direct_path = _direct_path(profile.distance.values, profile.values)
direct_path_len = _path_length(profile.distance.values, direct_path)
# Compute DIRECT path
direct_path = _direct_path(profile.distance.values, profile.values)
direct_path_len = _path_length(profile.distance.values, direct_path)

# Compute SHORTEST DIFFRACTED path
diff_path = _shortest_diffracted_path(profile.distance.values, profile.values)
diff_path_len = _path_length(profile.distance.values, diff_path)
# Compute SHORTEST DIFFRACTED path
diff_path = _shortest_diffracted_path(
profile.distance.values, profile.values
)
diff_path_len = _path_length(profile.distance.values, diff_path)

# Just populate everything with NaNs
else:
profile = direct_path = diff_path = [np.nan]
direct_path_len = diff_path_len = np.nan

# Make nice Dataset of all info
ds = xr.Dataset(
Expand All @@ -199,6 +210,10 @@ def calculate_paths(
ds.rio.write_crs(utm_crs, inplace=True)
ds_list.append(ds)

if n_valid_paths > 1 and compute_path:
bar.update()

bar.close()
print('Done')

# Determine what to output
Expand Down

0 comments on commit 5ca2e7f

Please sign in to comment.