Skip to content

Commit

Permalink
Move detection (2d/3d filtering, structure splitting) to PyTorch (#440)
Browse files Browse the repository at this point in the history
* Move detection to PyTorch.

Add initial cuda detection.

Convert plane filters to torch.

Add batch sizes.

Add batching.

Fix threshold values for float vs uint32.

Save alternate way of tiling.

Move all filters to pytorch.

Don't raise exception, return instead.

Conv filter on the fastest dim.

Turn ON infrence mode.

Refactor out the detection settings.

Handle input data conversions.

Add Wrappers for threads.

Add support for single plane batches.

Add dtype for detection.

To pass soma value without knowing dtype, set largest dtype.

Switch splitting to torch.

Return the filtered plains...

Remove multiprocessing.

Use correct z axis.

Use as large a batch as possible for splitting.

Add back multiprocessing for splitting.

Ensure volume is valid.

Make tiling optional.

Cleanup and docs.

Limit cores to most prevent contension.

Fix division by zero.

We only need one version of ball filter now.

Parallelize again 2d filtering for CPU.

Add kornia as dependency.

Use better multiprocessing for cpu plane filters.

Allow using scipy for plane filter.

Pass buffers ahead to process to prevent torch queue buffer issues.

Queue must come from same ctx as process.

Reduce max cores.

Don't pin memory on cpu.

Fix tests and add more.

Add more tests and fixes.

More tests and include int inputs.

Fix coverage multiprocessing.

More tests.

Add more tests.

More docs/tests.

Use modules for 2d filters so we can use reflect padding and add 2d filter tests.

With correct thresholds, detection dtype can be input dtype size.

Add test for tiles generated during 2d filtering.

Add testing for 3d filtering.

Clean up filtering to detection conversion.

Brainmapper passes in str not float.

Fix numba cast warning.

Pad 2d filter data enough to not require padding for each filter.

Add threading test/docs.

We must process a plane at a time for parity with previous algo.

Add test comparing generated cells.

Ensure full parity with scipy for 2dfiltering.

Fix numba warning.

Don't count values at threshold in planes - brings 2d filtering to scipy parity.

Include 3d filter top/bottom padded planes in progress bar.

Move more into jit scipt.

Get test data from pooch and use it in benchmarks.

Add test for splitting underflow.

* Don't number of threads go below 4.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use smaller bright brain.

* Fix pooch hash for xml.

* Make test order insensitive.

* Increase timeout to 2 hours.

* Apply suggestions from code review

Co-authored-by: Alessandro Felder <[email protected]>

* Apply suggestions from code review

Co-authored-by: Alessandro Felder <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Implement PR code feedback.

* Not all queues have a qsize func.

* Exclude torch v2.4 due to it's Windows issue.

* Todo already addressed and checked in tests.

* reproduce windows error

* preserve excluding torch 2.4.0

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Alessandro Felder <[email protected]>
  • Loading branch information
3 people authored Oct 31, 2024
1 parent b32cc2a commit 4691987
Show file tree
Hide file tree
Showing 33 changed files with 4,755 additions and 1,117 deletions.
2 changes: 2 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ flags:
numba:
paths:
- cellfinder/core/detect/filters/plane/tile_walker.py
- cellfinder/core/detect/filters/plane/classical_filter.py
- cellfinder/core/detect/filters/plane/plane_filter.py
- cellfinder/core/detect/filters/volume/ball_filter.py
- cellfinder/core/detect/filters/volume/structure_detection.py
carryforward: true
30 changes: 26 additions & 4 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ jobs:
test:
needs: [linting, manifest]
name: Run package tests
timeout-minutes: 60
timeout-minutes: 120
runs-on: ${{ matrix.os }}
env:
KERAS_BACKEND: torch
CELLFINDER_TEST_DEVICE: cpu
# pooch cache dir
BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"

strategy:
matrix:
# Run all supported Python versions on linux
Expand All @@ -56,6 +59,14 @@ jobs:
python-version: "3.12"

steps:
- uses: actions/checkout@v4
- name: Cache pooch data
uses: actions/cache@v4
with:
path: "~/.pooch_cache"
# hash on conftest in case url changes
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pooch_registry.txt') }}
# Cache the tensorflow model so we don't have to remake it every time
- name: Cache brainglobe directory
uses: actions/cache@v3
with:
Expand All @@ -78,19 +89,30 @@ jobs:
test_numba_disabled:
needs: [linting, manifest]
name: Run tests with numba disabled
timeout-minutes: 60
timeout-minutes: 120
runs-on: ubuntu-latest
env:
NUMBA_DISABLE_JIT: "1"
NUMBA_DISABLE_JIT: "1"
PYTORCH_JIT: "0"
# pooch cache dir
BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"

steps:
- uses: actions/checkout@v4
- name: Cache brainglobe directory
uses: actions/cache@v3
with:
path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually
~/.brainglobe
!~/.brainglobe/atlas.tar.gz
key: brainglobe

- name: Cache pooch data
uses: actions/cache@v4
with:
path: "~/.pooch_cache"
key: ${{ runner.os }}-3.10-${{ hashFiles('**/pooch_registry.txt') }}

# Setup pyqt libraries
- name: Setup qtpy libraries
uses: tlambert03/setup-qt-libs@v1
Expand All @@ -108,7 +130,7 @@ jobs:
test_brainmapper_cli:
needs: [linting, manifest]
name: Run brainmapper tests to check for breakages
timeout-minutes: 60
timeout-minutes: 120
runs-on: ubuntu-latest
env:
KERAS_BACKEND: torch
Expand Down
86 changes: 86 additions & 0 deletions benchmarks/benchmark_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from pathlib import Path

import pooch
import torch
from torch.profiler import ProfilerActivity, profile
from torch.utils.benchmark import Compare, Timer

from cellfinder.core.tools.IO import fetch_pooch_directory


def get_test_data_path(path):
"""
Create a test data registry for BrainGlobe.
Returns:
pooch.Pooch: The test data registry object.
"""
registry = pooch.create(
path=pooch.os_cache("brainglobe_test_data"),
base_url="https://gin.g-node.org/BrainGlobe/test-data/raw/master/cellfinder/",
env="BRAINGLOBE_TEST_DATA_DIR",
)

registry.load_registry(
Path(__file__).parent.parent / "tests" / "data" / "pooch_registry.txt"
)

return fetch_pooch_directory(registry, path)


def time_filters(repeat, run, run_args, label):
timer = Timer(
stmt="run(*args)",
globals={"run": run, "args": run_args},
label=label,
num_threads=4,
description="", # must be not None due to pytorch bug
)
return timer.timeit(number=repeat)


def compare_results(*results):
# prints the results of all the timed tests
compare = Compare(results)
compare.trim_significant_figures()
compare.colorize()
compare.print()


def profile_cpu(repeat, run, run_args):
with profile(
activities=[ProfilerActivity.CPU],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_modules=True,
) as prof:
for _ in range(repeat):
run(*run_args)

print(
prof.key_averages(group_by_stack_n=1).table(
sort_by="self_cpu_time_total", row_limit=20
)
)


def profile_cuda(repeat, run, run_args):
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_modules=True,
) as prof:
for _ in range(repeat):
run(*run_args)
# make sure it's fully done filtering
torch.cuda.synchronize("cuda")

print(
prof.key_averages(group_by_stack_n=1).table(
sort_by="self_cuda_time_total", row_limit=20
)
)
144 changes: 124 additions & 20 deletions benchmarks/filter_2d.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,131 @@
import os
import sys

sys.path.append(os.path.dirname(__file__))

import numpy as np
from pyinstrument import Profiler
import torch
from benchmark_tools import (
compare_results,
get_test_data_path,
profile_cpu,
profile_cuda,
time_filters,
)
from brainglobe_utils.IO.image.load import read_with_dask

from cellfinder.core.detect.filters.plane import TileProcessor
from cellfinder.core.detect.filters.setup_filters import setup_tile_filtering
from cellfinder.core.detect.filters.setup_filters import DetectionSettings

# Use random 16-bit integer data for signal plane
shape = (10000, 10000)

signal_array_plane = np.random.randint(
low=0, high=65536, size=shape, dtype=np.uint16
)
def setup_filter(
signal_path,
batch_size=1,
num_z=None,
torch_device="cpu",
dtype=np.uint16,
use_scipy=False,
):
signal_array = read_with_dask(signal_path)
num_z = num_z or len(signal_array)
signal_array = np.asarray(signal_array[:num_z]).astype(dtype)
shape = signal_array.shape

settings = DetectionSettings(
plane_original_np_dtype=dtype,
plane_shape=shape[1:],
torch_device=torch_device,
voxel_sizes=(5.06, 4.5, 4.5),
soma_diameter_um=30,
ball_xy_size_um=6,
ball_z_size_um=15,
)
signal_array = settings.filter_data_converter_func(signal_array)
signal_array = torch.from_numpy(signal_array).to(torch_device)

tile_processor = TileProcessor(
plane_shape=shape[1:],
clipping_value=settings.clipping_value,
threshold_value=settings.threshold_value,
soma_diameter=settings.soma_diameter,
log_sigma_size=settings.log_sigma_size,
n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh,
torch_device=torch_device,
dtype=settings.filtering_dtype.__name__,
use_scipy=use_scipy,
)

return tile_processor, signal_array, batch_size


def run_filter(tile_processor, signal_array, batch_size):
for i in range(0, len(signal_array), batch_size):
tile_processor.get_tile_mask(signal_array[i : i + batch_size])


clipping_value, threshold_value = setup_tile_filtering(signal_array_plane)
tile_processor = TileProcessor(
clipping_value=clipping_value,
threshold_value=threshold_value,
soma_diameter=16,
log_sigma_size=0.2,
n_sds_above_mean_thresh=10,
)
if __name__ == "__main__":
profiler = Profiler()
profiler.start()
plane, tiles = tile_processor.get_tile_mask(signal_array_plane)
profiler.stop()
profiler.print(show_all=True)
with torch.inference_mode(True):
n = 5
batch_size = 2
signal_path = get_test_data_path("bright_brain/signal")

compare_results(
time_filters(
n,
run_filter,
setup_filter(
signal_path,
batch_size=batch_size,
torch_device="cpu",
use_scipy=False,
),
"cpu-no_scipy",
),
time_filters(
n,
run_filter,
setup_filter(
signal_path,
batch_size=batch_size,
torch_device="cpu",
use_scipy=True,
),
"cpu-scipy",
),
time_filters(
n,
run_filter,
setup_filter(
signal_path, batch_size=batch_size, torch_device="cuda"
),
"cuda",
),
)

profile_cpu(
n,
run_filter,
setup_filter(
signal_path,
batch_size=batch_size,
torch_device="cpu",
use_scipy=False,
),
)
profile_cpu(
n,
run_filter,
setup_filter(
signal_path,
batch_size=batch_size,
torch_device="cpu",
use_scipy=True,
),
)
profile_cuda(
n,
run_filter,
setup_filter(
signal_path, batch_size=batch_size, torch_device="cuda"
),
)
Loading

0 comments on commit 4691987

Please sign in to comment.