Skip to content

Commit

Permalink
Auto rechunk to enable blockwise reduction
Browse files Browse the repository at this point in the history
Done when
1. `method` is None
2. Grouping and reducing by a 1D array

We gate this on fractional change in number of chunks and change in size
of largest chunk.

Closes #359
  • Loading branch information
dcherian committed Aug 2, 2024
1 parent f8f34b9 commit b3ac2c2
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@
# _simple_combine.
DUMMY_AXIS = -2

# Thresholds below which we will automatically rechunk to blockwise if it makes sense
# 1. Fractional change in number of chunks after rechunking
BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD = 0.25
# 2. Fractional change in max chunk size after rechunking
BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD = 0.15

logger = logging.getLogger("flox")


Expand Down Expand Up @@ -230,6 +236,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
Δl = abs(c - l)
if c == 0 or newchunkidx[-1] > l:
continue
f = f.item() # noqa
l = l.item() # noqa
if Δf < Δl and f > newchunkidx[-1]:
newchunkidx.append(f)
else:
Expand Down Expand Up @@ -654,9 +662,15 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
labels = factorize_((labels,), axes=())[0]
chunks = array.chunks[axis]
newchunks = _get_optimal_chunks_for_groups(chunks, labels)

if newchunks == chunks:
return array
else:

Δn = abs(len(newchunks) - len(chunks))
if (Δn / len(chunks) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD) and (
abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
):
# Less than 25% change in number of chunks, let's do it
return array.rechunk({axis: newchunks})


Expand Down Expand Up @@ -2468,6 +2482,11 @@ def groupby_reduce(
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)

if method is None and nax == 1 and not any_by_dask and by_.ndim == 1 and _issorted(by_):
# Let's try rechunking for sorted 1D by.
(single_axis,) = axis_
array = rechunk_for_blockwise(array, single_axis, by_)

if _is_first_last_reduction(func):
if has_dask and nax != 1:
raise ValueError(
Expand Down

0 comments on commit b3ac2c2

Please sign in to comment.