Skip to content

Commit

Permalink
Begin testing
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 3, 2023
1 parent 9d35bb2 commit 0eff0a7
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 12 deletions.
4 changes: 3 additions & 1 deletion flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def _prepare_for_flox(group_idx, array):
if issorted:
ordered_array = array
else:
perm = group_idx.argsort(kind="stable")
kind = "stable" if isinstance(group_idx, np.ndarray) else None

perm = np.argsort(group_idx, kind=kind)
group_idx = group_idx[..., perm]
ordered_array = array[..., perm]
return group_idx, ordered_array
Expand Down
2 changes: 2 additions & 0 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ def factorize_(
else:
assert sort
groups, idx = np.unique(flat, return_inverse=True)
idx[np.isnan(flat)] = -1
groups = groups[~np.isnan(groups)]

found_groups.append(groups)
factorized.append(idx.reshape(groupvar.shape))
Expand Down
2 changes: 0 additions & 2 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def _select_along_axis(values, idx, axis):
def nanfirst(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
values = np.asarray(values)
axis = normalize_axis_index(axis, values.ndim)
idx_first = np.argmax(~pd.isnull(values), axis=axis)
result = _select_along_axis(values, idx_first, axis)
Expand All @@ -307,7 +306,6 @@ def nanfirst(values, axis, keepdims=False):
def nanlast(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
values = np.asarray(values)
axis = normalize_axis_index(axis, values.ndim)
rev = (slice(None),) * axis + (slice(None, None, -1),)
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
Expand Down
13 changes: 13 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
except ImportError:
xr_types = () # type: ignore

try:
import cupy as cp

cp_types = (cp.ndarray,)
except ImportError:
cp_types = () # type: ignore


def _importorskip(modname, minversion=None):
try:
Expand Down Expand Up @@ -88,6 +95,12 @@ def assert_equal(a, b, tolerance=None):
if isinstance(b, list):
b = np.array(b)

if isinstance(a, cp_types):
a = a.get()

if isinstance(b, cp_types):
b = b.get()

if isinstance(a, pd_types) or isinstance(b, pd_types):
pd.testing.assert_index_equal(a, b)
return
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,18 @@ def engine(request):
except ImportError:
pytest.xfail()
return request.param


@pytest.fixture(scope="module", params=["numpy", "cupy"])
def array_module(request):
if request.param == "cupy":
try:
import cupy # noqa

return cupy
except ImportError:
pytest.xfail()
elif request.param == "numpy":
import numpy

return numpy
40 changes: 31 additions & 9 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,31 +178,53 @@ def test_groupby_reduce(
assert_equal(expected_result, result)


def gen_array_by(size, func):
by = np.ones(size[-1])
rng = np.random.default_rng(12345)
def maybe_skip_cupy(array_module, func, engine):
if array_module is np:
return

import cupy

assert array_module is cupy

if engine == "numba":
pytest.skip()

if engine == "numpy" and ("prod" in func or "first" in func or "last" in func):
pytest.xfail()
elif engine == "flox" and not (
"sum" in func or "mean" in func or "std" in func or "var" in func
):
pytest.xfail()


def gen_array_by(size, func, array_module):
xp = array_module
by = xp.ones(size[-1])
rng = xp.random.default_rng(12345)
array = rng.random(size)
if "nan" in func and "nanarg" not in func:
array[[1, 4, 5], ...] = np.nan
array[[1, 4, 5], ...] = xp.nan
elif "nanarg" in func and len(size) > 1:
array[[1, 4, 5], 1] = np.nan
array[[1, 4, 5], 1] = xp.nan
if func in ["any", "all"]:
array = array > 0.5
return array, by


@pytest.mark.parametrize("chunks", [None, -1, 3, 4])
@pytest.mark.parametrize("nby", [1, 2, 3])
@pytest.mark.parametrize("size", ((12,), (12, 9)))
@pytest.mark.parametrize("add_nan_by", [True, False])
@pytest.mark.parametrize("chunks", [None, -1, 3, 4])
@pytest.mark.parametrize("func", ALL_FUNCS)
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
@pytest.mark.parametrize("add_nan_by", [True, False])
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine, array_module):
if chunks is not None and not has_dask:
pytest.skip()
if "arg" in func and engine == "flox":
pytest.skip()

array, by = gen_array_by(size, func)
maybe_skip_cupy(array_module, func, engine)

array, by = gen_array_by(size, func, array_module)
if chunks:
array = dask.array.from_array(array, chunks=chunks)
by = (by,) * nby
Expand Down

0 comments on commit 0eff0a7

Please sign in to comment.