Skip to content

Commit

Permalink
Merge pull request #22 from liamtoney/parallelize
Browse files Browse the repository at this point in the history
Implement parallelization
  • Loading branch information
liamtoney authored May 10, 2024
2 parents 3031d79 + 7d477af commit 74af93a
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 40 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ To cite a previous version of *infresnel*, go to the
in the creation of GeoTIFF path length difference grids)
* [ipympl](https://matplotlib.org/ipympl/)
* [ipywidgets](https://ipywidgets.readthedocs.io/en/stable/)
* [joblib](https://joblib.readthedocs.io/en/stable/) (for parallel processing)
* [JupyterLab](https://jupyterlab.readthedocs.io/en/latest/) (for running the
interactive `.ipynb` notebooks)
* [Matplotlib](https://matplotlib.org/) (for applying colormaps to GeoTIFFs and for
Expand Down
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

autodoc_mock_imports = [
'colorcet',
'joblib',
'matplotlib',
'numba',
'numpy',
Expand All @@ -33,6 +34,7 @@

# These only need to cover the packages we reference from the docstrings
intersphinx_mapping = {
'joblib': ('https://joblib.readthedocs.io/en/latest/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
'python': ('https://docs.python.org/3/', None),
'xarray': ('https://docs.xarray.dev/en/stable/', None),
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies:
- colorcet
- ipympl
- ipywidgets
- joblib
- jupyterlab
- matplotlib
- numba
Expand Down
84 changes: 56 additions & 28 deletions infresnel/infresnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import pygmt
import xarray as xr
from joblib import Parallel, delayed
from pyproj import Transformer
from rasterio.enums import Resampling
from scipy.interpolate import RectBivariateSpline
Expand Down Expand Up @@ -34,6 +35,7 @@ def calculate_paths(
dem_file=None,
full_output=False,
return_dem=False,
n_jobs=1,
):
"""Calculate elevation profiles, direct paths, and shortest diffracted paths.
Expand All @@ -59,6 +61,8 @@ def calculate_paths(
direct and shortest diffracted path lengths
return_dem (bool): Toggle additionally returning the UTM-projected DEM used to
compute the profiles
n_jobs (int): Number of parallel jobs to run (default is 1, which means no
parallelization) — this argument is passed on to :class:`joblib.Parallel`
Returns:
If `full_output` is `False` — an :class:`~numpy.ndarray` with shape ``(2,
Expand All @@ -82,6 +86,9 @@ def calculate_paths(
rec_lats = np.atleast_1d(rec_lat)
rec_lons = np.atleast_1d(rec_lon)

# Define number of paths
n_paths = rec_lats.size

print('Loading and projecting DEM...')
if dem_file is not None:
# Load user-provided DEM, first checking if it exists
Expand Down Expand Up @@ -159,23 +166,32 @@ def calculate_paths(
# also must be run if the PyGMT-supplied DEM is not fully within SRTM range (kind of
# unlikely edge case). For most PyGMT-supplied DEMs, this check will not end up
# being run — which is good, since it can be SLOW.
compute_paths = np.ones(rec_xs.size).astype(bool) # By default, compute all paths
if check_for_valid_elevations:
print('Checking that DEM contains source and receivers...')
if not _check_valid_elevation_for_coords(
dem_utm, mean_resolution, src_x, src_y
):
raise ValueError('Source is not in DEM! Exiting.')
for i, (x, y) in enumerate(zip(rec_xs, rec_ys)):
if not _check_valid_elevation_for_coords(dem_utm, mean_resolution, x, y):
compute_paths[i] = False # Don't compute this path

# KEY: Parallel computation of valid receiver paths
compute_paths = np.array(
Parallel(n_jobs=n_jobs)(
delayed(_check_valid_elevation_for_coords)(
dem_utm, mean_resolution, rec_x, rec_y
)
for rec_x, rec_y in zip(rec_xs, rec_ys)
)
)

n_invalid_paths = (~compute_paths).sum()
if n_invalid_paths > 0:
print(
f'Done — will skip {n_invalid_paths} invalid path{"" if n_invalid_paths == 1 else "s"}\n'
f'Done — {n_invalid_paths} invalid path{"" if n_invalid_paths == 1 else "s"} will be set to NaN\n'
)
else:
print('Done\n')
else:
compute_paths = np.full(n_paths, True) # Compute all paths

# Fit bivariate spline to DEM (slow for very high resolution DEMs!)
print('Fitting spline to DEM...')
Expand All @@ -191,21 +207,8 @@ def calculate_paths(
spline = RectBivariateSpline(x=x, y=y, z=z.T) # x and y are monotonic increasing
print('Done\n')

# Iterate over all receivers (= source-receiver pairs), calculating paths
if full_output:
output_array = np.empty(rec_xs.size, dtype=object) # For storing Datasets
else:
output_array = np.empty((2, rec_xs.size)) # For storing path lengths
n_valid_paths = compute_paths.sum()
print(f'Computing {n_valid_paths} path{"" if n_valid_paths == 1 else "s"}...')
if n_valid_paths > 1: # Only creating the progress bar if we have more than 1 path
bar = tqdm(
total=n_valid_paths,
bar_format='{percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt} paths ',
)
for i, (rec_x, rec_y, compute_path) in enumerate(
zip(rec_xs, rec_ys, compute_paths)
):
# Define fuction for calculating a single path
def _calculate_single_path(rec_x, rec_y, compute_path):
# If the DEM points were valid, compute the path
if compute_path:
# Determine # of points in profile
Expand Down Expand Up @@ -263,28 +266,50 @@ def calculate_paths(
),
)
ds.rio.write_crs(utm_crs, inplace=True)
output_array[i] = ds
output = ds # KEY: Return a Dataset
else:
# Just include the path lengths
output_array[:, i] = direct_path_len, diff_path_len
output = direct_path_len, diff_path_len # KEY: Return a tuple

if n_valid_paths > 1 and compute_path:
bar.update()
return output

print(f'Computing {n_paths} path{"" if n_paths == 1 else "s"}...')

# Only create the progress bar if we have more than 1 path
iterable = zip(rec_xs, rec_ys, compute_paths)
if n_paths > 1:
iterable = tqdm(
iterable,
total=n_paths,
bar_format='{percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt} paths ',
)

# KEY: Parallel path calculations over all receivers (= source-receiver pairs)
output_array = Parallel(n_jobs=n_jobs)(
delayed(_calculate_single_path)(rec_x, rec_y, compute_path)
for rec_x, rec_y, compute_path in iterable
)

bar.close()
print('Done')

# Determine what to output
if full_output:
output_array = output_array.tolist() # Convert to list of Datasets
if not full_output:
output_array = np.array(output_array).T # Convert list of tuples to array
if return_dem:
return output_array, dem_utm
else:
return output_array


def calculate_paths_grid(
src_lat, src_lon, x_radius, y_radius, spacing, dem_file=None, output_file=None
src_lat,
src_lon,
x_radius,
y_radius,
spacing,
dem_file=None,
output_file=None,
n_jobs=1,
):
"""Calculate paths for a UTM-projected grid surrounding a source location.
Expand All @@ -309,6 +334,8 @@ def calculate_paths_grid(
output_file (str or None): If a string filepath is provided, then an RGB GeoTIFF
file containing the colormapped grid of path length difference values is
exported to this filepath (no export if `None`)
n_jobs (int): Number of parallel jobs to run (default is 1, which means no
parallelization) — this argument is passed on to :class:`joblib.Parallel`
Returns:
tuple: Tuple of the form ``(path_length_differences, dem)`` where
Expand Down Expand Up @@ -358,6 +385,7 @@ def _process_radius(radius):
rec_lon=rec_lon.flatten(),
dem_file=dem_file,
return_dem=True,
n_jobs=n_jobs,
)
toc = time.time()
print(f'\nElapsed time = {toc - tic:.0f} s')
Expand Down
29 changes: 17 additions & 12 deletions notebooks/example_paths_grid.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
'colorcet',
'ipympl',
'ipywidgets',
'joblib',
'jupyterlab',
'matplotlib',
'numba',
Expand Down
1 change: 1 addition & 0 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_calculate_paths_grid(output=False):
y_radius=Y_RADIUS,
spacing=SPACING,
dem_file=DEM_FILE,
n_jobs=-2, # "...using n_jobs=-2 will result in all CPUs but one being used."
)
if output:
return path_length_differences, dem
Expand Down

0 comments on commit 74af93a

Please sign in to comment.