Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow method="cohorts" when grouping by dask array #294

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 52 additions & 16 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def chunk_unique(labels, slicer, nlabels, label_is_present=None):
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
cols_array = np.concatenate(cols)

return make_bitmask(rows_array, cols_array)
return make_bitmask(rows_array, cols_array), nlabels, ilabels


# @memoize
Expand Down Expand Up @@ -362,10 +362,9 @@ def find_group_cohorts(
cohorts: dict_values
Iterable of cohorts
"""
# To do this, we must have values in memory so casting to numpy should be safe
labels = np.asarray(labels)
if not is_duck_array(labels):
labels = np.asarray(labels)

shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)

# assumes that `labels` are factorized
Expand All @@ -378,8 +377,14 @@ def find_group_cohorts(
if nchunks == 1:
return "blockwise", {(0,): list(range(nlabels))}

labels = np.broadcast_to(labels, shape[-labels.ndim :])
bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)
if is_duck_dask_array(labels):
import dask

((bitmask, nlabels, ilabels),) = dask.compute(
dask.delayed(_compute_label_chunk_bitmask)(labels, chunks, nlabels)
)
else:
bitmask, nlabels, ilabels = _compute_label_chunk_bitmask(labels, chunks, nlabels)

CHUNK_AXIS, LABEL_AXIS = 0, 1
chunks_per_label = bitmask.sum(axis=CHUNK_AXIS)
Expand Down Expand Up @@ -726,6 +731,26 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
return offset, size


def fast_isin(ar1, ar2, invert):
rev_idx, ar1 = pd.factorize(ar1, sort=False)

ar = np.concatenate((ar1, ar2))
# We need this to be a stable sort, so always use 'mergesort'
# here. The values from the first array should always come before
# the values from the second array.
order = ar.argsort(kind="mergesort")
sar = ar[order]
if invert:
bool_ar = sar[1:] != sar[:-1]
else:
bool_ar = sar[1:] == sar[:-1]
flag = np.concatenate((bool_ar, [invert]))
ret = np.empty(ar.shape, dtype=bool)
ret[order] = flag

return ret[rev_idx]


@overload
def factorize_(
by: T_Bys,
Expand Down Expand Up @@ -821,8 +846,18 @@ def factorize_(
if expect is not None and reindex:
sorter = np.argsort(expect)
groups = expect[(sorter,)] if sort else expect

idx = np.searchsorted(expect, flat, sorter=sorter)
mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
mask = fast_isin(flat, expect, invert=True)
if not np.issubdtype(flat.dtype, np.integer):
mask |= isnull(flat)
mask |= idx == len(expect)

# idx = np.full(flat.shape, -1)
# result = np.searchsorted(expect.values, flat[~mask], sorter=sorter)
# idx[~mask] = result
# idx = np.searchsorted(expect.values, flat, sorter=sorter)
# idx[mask] = -1
if not sort:
# idx is the index in to the sorted array.
# if we didn't want sorting, unsort it back
Expand Down Expand Up @@ -2125,11 +2160,10 @@ def _factorize_multiple(
for by_, expect in zip(by, expected_groups):
if expect is None:
if is_duck_dask_array(by_):
raise ValueError(
"Please provide expected_groups when grouping by a dask array."
)

found_group = pd.unique(by_.reshape(-1))
# could be remote dataset, execute remotely in that case
found_group = np.unique(by_.reshape(-1)).compute()
else:
found_group = pd.unique(by_.reshape(-1))
else:
found_group = expect.to_numpy()

Expand Down Expand Up @@ -2409,9 +2443,6 @@ def groupby_reduce(
"Try engine='numpy' or engine='numba' instead."
)

if method == "cohorts" and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

reindex = _validate_reindex(
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
)
Expand Down Expand Up @@ -2439,10 +2470,15 @@ def groupby_reduce(

# Don't factorize early only when
# grouping by dask arrays, and not having expected_groups
# except for cohorts
factorize_early = not (
# can't do it if we are grouping by dask array but don't have expected_groups
any(is_dask and ex_ is None for is_dask, ex_ in zip(by_is_dask, expected_groups))
any(
is_dask and ex_ is None and method != "cohorts"
for is_dask, ex_ in zip(by_is_dask, expected_groups)
)
)

expected_: pd.RangeIndex | None
if factorize_early:
bys, final_groups, grp_shape = _factorize_multiple(
Expand Down
41 changes: 29 additions & 12 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,15 +862,13 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
array = [1, 1, 1, 1, 1, 1]
labels = [0.2, 1.5, 1.9, 2, 3, 20]

if method == "cohorts" and chunk_labels:
pytest.xfail()

if chunks:
array = dask.array.from_array(array, chunks=chunks)
if chunk_labels:
labels = dask.array.from_array(labels, chunks=chunks)

with raise_if_dask_computes():
max_computes = 1 if method == "cohorts" else 0
with raise_if_dask_computes(max_computes):
actual, *groups = groupby_reduce(
array, labels, func="count", fill_value=0, engine=engine, method=method, **kwargs
)
Expand Down Expand Up @@ -1063,27 +1061,33 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):


@requires_dask
@pytest.mark.parametrize("func", ALL_FUNCS)
@pytest.mark.parametrize("func", ["sum"])
@pytest.mark.parametrize("axis", (-1, None))
@pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce"])
def test_cohorts_nd_by(func, method, axis, engine):
@pytest.mark.parametrize("by_is_dask", [True, False])
def test_cohorts_nd_by(by_is_dask, func, method, axis):
engine = "numpy"
if (
("arg" in func and (axis is None or engine in ["flox", "numbagg"]))
or (method != "blockwise" and func in BLOCKWISE_FUNCS)
or (axis is None and ("first" in func or "last" in func))
):
pytest.skip()
if axis is not None and method != "map-reduce":
pytest.xfail()
pytest.skip()
if by_is_dask and method == "blockwise":
pytest.skip()

o = dask.array.ones((3,), chunks=-1)
o2 = dask.array.ones((2, 3), chunks=-1)

array = dask.array.block([[o, 2 * o], [3 * o2, 4 * o2]])
by = array.compute().astype(np.int64)
by = array.astype(np.int64)
by[0, 1] = 30
by[2, 1] = 40
by[0, 4] = 31
if not by_is_dask:
by = by.compute()
array = np.broadcast_to(array, (2, 3) + array.shape)

if func in ["any", "all"]:
Expand All @@ -1092,17 +1096,30 @@ def test_cohorts_nd_by(func, method, axis, engine):
fill_value = -123

kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value)
if by_is_dask and axis is not None and method == "map-reduce":
kwargs["expected_groups"] = pd.Index([1, 2, 3, 4, 30, 31, 40])

if "quantile" in func:
kwargs["finalize_kwargs"] = {"q": DEFAULT_QUANTILE}
actual, groups = groupby_reduce(array, by, **kwargs)
expected, sorted_groups = groupby_reduce(array.compute(), by, **kwargs)
assert_equal(groups, sorted_groups)
assert_equal(actual, expected)

actual, groups = groupby_reduce(array, by, sort=False, **kwargs)
assert_equal(groups, np.array([1, 30, 2, 31, 3, 4, 40], dtype=np.int64))
reindexed = reindex_(actual, groups, pd.Index(sorted_groups))
assert_equal(reindexed, expected)
if isinstance(by, dask.array.Array):
cache.clear()
actual_cohorts = find_group_cohorts(by, array.chunks[-by.ndim :])
cache.clear()
expected_cohorts = find_group_cohorts(by.compute(), array.chunks[-by.ndim :])
assert actual_cohorts == expected_cohorts
# assert cache.nbytes

if not isinstance(by, dask.array.Array):
# Always sorting groups with cohorts and dask array
actual, groups = groupby_reduce(array, by, sort=False, **kwargs)
assert_equal(groups, np.array([1, 30, 2, 31, 3, 4, 40], dtype=np.int64))
reindexed = reindex_(actual, groups, pd.Index(sorted_groups))
assert_equal(reindexed, expected)


@pytest.mark.parametrize("func", ["sum", "count"])
Expand Down
Loading