Skip to content

Commit

Permalink
Avoid rechunking when preferred_method="blockwise" (#394)
Browse files Browse the repository at this point in the history
* Avoid rechunking when preferred_method="blockwise"

* Add test

* fix
  • Loading branch information
dcherian authored Sep 16, 2024
1 parent 7421cb1 commit d65181c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
DaskArray
Rechunked array
"""
# TODO: this should be unnecessary?
labels = factorize_((labels,), axes=())[0]
chunks = array.chunks[axis]
newchunks = _get_optimal_chunks_for_groups(chunks, labels)
Expand Down Expand Up @@ -2623,7 +2624,8 @@ def groupby_reduce(

partial_agg = partial(dask_groupby_agg, **kwargs)

if method == "blockwise" and by_.ndim == 1:
# if preferred method is already blockwise, no need to rechunk
if preferred_method != "blockwise" and method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)

result, groups = partial_agg(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,3 +1997,12 @@ def test_agg_dtypes(func, engine):
)
expected = _get_array_func(func)(counts, dtype="uint8")
assert actual.dtype == np.uint8 == expected.dtype


@requires_dask
def test_blockwise_avoid_rechunk():
array = dask.array.zeros((6,), chunks=(2, 4), dtype=np.int64)
by = np.array(["1", "1", "0", "", "0", ""], dtype="<U1")
actual, groups = groupby_reduce(array, by, func="first")
assert_equal(groups, ["", "0", "1"])
assert_equal(actual, [0, 0, 0])

0 comments on commit d65181c

Please sign in to comment.