-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move detection (2d/3d filtering, structure splitting) to PyTorch (#440)
* Move detection to PyTorch. Add initial cuda detection. Convert plane filters to torch. Add batch sizes. Add batching. Fix threshold values for float vs uint32. Save alternate way of tiling. Move all filters to pytorch. Don't raise exception, return instead. Conv filter on the fastest dim. Turn ON infrence mode. Refactor out the detection settings. Handle input data conversions. Add Wrappers for threads. Add support for single plane batches. Add dtype for detection. To pass soma value without knowing dtype, set largest dtype. Switch splitting to torch. Return the filtered plains... Remove multiprocessing. Use correct z axis. Use as large a batch as possible for splitting. Add back multiprocessing for splitting. Ensure volume is valid. Make tiling optional. Cleanup and docs. Limit cores to most prevent contension. Fix division by zero. We only need one version of ball filter now. Parallelize again 2d filtering for CPU. Add kornia as dependency. Use better multiprocessing for cpu plane filters. Allow using scipy for plane filter. Pass buffers ahead to process to prevent torch queue buffer issues. Queue must come from same ctx as process. Reduce max cores. Don't pin memory on cpu. Fix tests and add more. Add more tests and fixes. More tests and include int inputs. Fix coverage multiprocessing. More tests. Add more tests. More docs/tests. Use modules for 2d filters so we can use reflect padding and add 2d filter tests. With correct thresholds, detection dtype can be input dtype size. Add test for tiles generated during 2d filtering. Add testing for 3d filtering. Clean up filtering to detection conversion. Brainmapper passes in str not float. Fix numba cast warning. Pad 2d filter data enough to not require padding for each filter. Add threading test/docs. We must process a plane at a time for parity with previous algo. Add test comparing generated cells. Ensure full parity with scipy for 2dfiltering. Fix numba warning. Don't count values at threshold in planes - brings 2d filtering to scipy parity. Include 3d filter top/bottom padded planes in progress bar. Move more into jit scipt. Get test data from pooch and use it in benchmarks. Add test for splitting underflow. * Don't number of threads go below 4. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use smaller bright brain. * Fix pooch hash for xml. * Make test order insensitive. * Increase timeout to 2 hours. * Apply suggestions from code review Co-authored-by: Alessandro Felder <[email protected]> * Apply suggestions from code review Co-authored-by: Alessandro Felder <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Implement PR code feedback. * Not all queues have a qsize func. * Exclude torch v2.4 due to it's Windows issue. * Todo already addressed and checked in tests. * reproduce windows error * preserve excluding torch 2.4.0 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessandro Felder <[email protected]>
- Loading branch information
1 parent
b32cc2a
commit 4691987
Showing
33 changed files
with
4,755 additions
and
1,117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from pathlib import Path | ||
|
||
import pooch | ||
import torch | ||
from torch.profiler import ProfilerActivity, profile | ||
from torch.utils.benchmark import Compare, Timer | ||
|
||
from cellfinder.core.tools.IO import fetch_pooch_directory | ||
|
||
|
||
def get_test_data_path(path): | ||
""" | ||
Create a test data registry for BrainGlobe. | ||
Returns: | ||
pooch.Pooch: The test data registry object. | ||
""" | ||
registry = pooch.create( | ||
path=pooch.os_cache("brainglobe_test_data"), | ||
base_url="https://gin.g-node.org/BrainGlobe/test-data/raw/master/cellfinder/", | ||
env="BRAINGLOBE_TEST_DATA_DIR", | ||
) | ||
|
||
registry.load_registry( | ||
Path(__file__).parent.parent / "tests" / "data" / "pooch_registry.txt" | ||
) | ||
|
||
return fetch_pooch_directory(registry, path) | ||
|
||
|
||
def time_filters(repeat, run, run_args, label): | ||
timer = Timer( | ||
stmt="run(*args)", | ||
globals={"run": run, "args": run_args}, | ||
label=label, | ||
num_threads=4, | ||
description="", # must be not None due to pytorch bug | ||
) | ||
return timer.timeit(number=repeat) | ||
|
||
|
||
def compare_results(*results): | ||
# prints the results of all the timed tests | ||
compare = Compare(results) | ||
compare.trim_significant_figures() | ||
compare.colorize() | ||
compare.print() | ||
|
||
|
||
def profile_cpu(repeat, run, run_args): | ||
with profile( | ||
activities=[ProfilerActivity.CPU], | ||
record_shapes=True, | ||
profile_memory=True, | ||
with_stack=True, | ||
with_modules=True, | ||
) as prof: | ||
for _ in range(repeat): | ||
run(*run_args) | ||
|
||
print( | ||
prof.key_averages(group_by_stack_n=1).table( | ||
sort_by="self_cpu_time_total", row_limit=20 | ||
) | ||
) | ||
|
||
|
||
def profile_cuda(repeat, run, run_args): | ||
with profile( | ||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | ||
record_shapes=True, | ||
profile_memory=True, | ||
with_stack=True, | ||
with_modules=True, | ||
) as prof: | ||
for _ in range(repeat): | ||
run(*run_args) | ||
# make sure it's fully done filtering | ||
torch.cuda.synchronize("cuda") | ||
|
||
print( | ||
prof.key_averages(group_by_stack_n=1).table( | ||
sort_by="self_cuda_time_total", row_limit=20 | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,131 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.append(os.path.dirname(__file__)) | ||
|
||
import numpy as np | ||
from pyinstrument import Profiler | ||
import torch | ||
from benchmark_tools import ( | ||
compare_results, | ||
get_test_data_path, | ||
profile_cpu, | ||
profile_cuda, | ||
time_filters, | ||
) | ||
from brainglobe_utils.IO.image.load import read_with_dask | ||
|
||
from cellfinder.core.detect.filters.plane import TileProcessor | ||
from cellfinder.core.detect.filters.setup_filters import setup_tile_filtering | ||
from cellfinder.core.detect.filters.setup_filters import DetectionSettings | ||
|
||
# Use random 16-bit integer data for signal plane | ||
shape = (10000, 10000) | ||
|
||
signal_array_plane = np.random.randint( | ||
low=0, high=65536, size=shape, dtype=np.uint16 | ||
) | ||
def setup_filter( | ||
signal_path, | ||
batch_size=1, | ||
num_z=None, | ||
torch_device="cpu", | ||
dtype=np.uint16, | ||
use_scipy=False, | ||
): | ||
signal_array = read_with_dask(signal_path) | ||
num_z = num_z or len(signal_array) | ||
signal_array = np.asarray(signal_array[:num_z]).astype(dtype) | ||
shape = signal_array.shape | ||
|
||
settings = DetectionSettings( | ||
plane_original_np_dtype=dtype, | ||
plane_shape=shape[1:], | ||
torch_device=torch_device, | ||
voxel_sizes=(5.06, 4.5, 4.5), | ||
soma_diameter_um=30, | ||
ball_xy_size_um=6, | ||
ball_z_size_um=15, | ||
) | ||
signal_array = settings.filter_data_converter_func(signal_array) | ||
signal_array = torch.from_numpy(signal_array).to(torch_device) | ||
|
||
tile_processor = TileProcessor( | ||
plane_shape=shape[1:], | ||
clipping_value=settings.clipping_value, | ||
threshold_value=settings.threshold_value, | ||
soma_diameter=settings.soma_diameter, | ||
log_sigma_size=settings.log_sigma_size, | ||
n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh, | ||
torch_device=torch_device, | ||
dtype=settings.filtering_dtype.__name__, | ||
use_scipy=use_scipy, | ||
) | ||
|
||
return tile_processor, signal_array, batch_size | ||
|
||
|
||
def run_filter(tile_processor, signal_array, batch_size): | ||
for i in range(0, len(signal_array), batch_size): | ||
tile_processor.get_tile_mask(signal_array[i : i + batch_size]) | ||
|
||
|
||
clipping_value, threshold_value = setup_tile_filtering(signal_array_plane) | ||
tile_processor = TileProcessor( | ||
clipping_value=clipping_value, | ||
threshold_value=threshold_value, | ||
soma_diameter=16, | ||
log_sigma_size=0.2, | ||
n_sds_above_mean_thresh=10, | ||
) | ||
if __name__ == "__main__": | ||
profiler = Profiler() | ||
profiler.start() | ||
plane, tiles = tile_processor.get_tile_mask(signal_array_plane) | ||
profiler.stop() | ||
profiler.print(show_all=True) | ||
with torch.inference_mode(True): | ||
n = 5 | ||
batch_size = 2 | ||
signal_path = get_test_data_path("bright_brain/signal") | ||
|
||
compare_results( | ||
time_filters( | ||
n, | ||
run_filter, | ||
setup_filter( | ||
signal_path, | ||
batch_size=batch_size, | ||
torch_device="cpu", | ||
use_scipy=False, | ||
), | ||
"cpu-no_scipy", | ||
), | ||
time_filters( | ||
n, | ||
run_filter, | ||
setup_filter( | ||
signal_path, | ||
batch_size=batch_size, | ||
torch_device="cpu", | ||
use_scipy=True, | ||
), | ||
"cpu-scipy", | ||
), | ||
time_filters( | ||
n, | ||
run_filter, | ||
setup_filter( | ||
signal_path, batch_size=batch_size, torch_device="cuda" | ||
), | ||
"cuda", | ||
), | ||
) | ||
|
||
profile_cpu( | ||
n, | ||
run_filter, | ||
setup_filter( | ||
signal_path, | ||
batch_size=batch_size, | ||
torch_device="cpu", | ||
use_scipy=False, | ||
), | ||
) | ||
profile_cpu( | ||
n, | ||
run_filter, | ||
setup_filter( | ||
signal_path, | ||
batch_size=batch_size, | ||
torch_device="cpu", | ||
use_scipy=True, | ||
), | ||
) | ||
profile_cuda( | ||
n, | ||
run_filter, | ||
setup_filter( | ||
signal_path, batch_size=batch_size, torch_device="cuda" | ||
), | ||
) |
Oops, something went wrong.