Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 2, 2024
1 parent a5f208a commit 9fdae45
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,16 +365,6 @@ def find_group_cohorts(
if not is_duck_array(labels):
labels = np.asarray(labels)

if is_duck_dask_array(labels):
import dask

((bitmask, nlabels, ilabels),) = dask.compute(
dask.delayed(_compute_label_chunk_bitmask)(labels, chunks)
)
else:
bitmask, nlabels, ilabels = _compute_label_chunk_bitmask(labels, chunks)

shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)

# assumes that `labels` are factorized
Expand All @@ -387,8 +377,14 @@ def find_group_cohorts(
if nchunks == 1:
return "blockwise", {(0,): list(range(nlabels))}

labels = np.broadcast_to(labels, shape[-labels.ndim :])
bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)
if is_duck_dask_array(labels):
import dask

((bitmask, nlabels, ilabels),) = dask.compute(
dask.delayed(_compute_label_chunk_bitmask)(labels, chunks, nlabels)
)
else:
bitmask, nlabels, ilabels = _compute_label_chunk_bitmask(labels, chunks, nlabels)

CHUNK_AXIS, LABEL_AXIS = 0, 1
chunks_per_label = bitmask.sum(axis=CHUNK_AXIS)
Expand Down

0 comments on commit 9fdae45

Please sign in to comment.