diff --git a/ci/environment.yml b/ci/environment.yml index 0510e4b20..d847e56a5 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -11,6 +11,7 @@ dependencies: - numpy>=1.20 - lxml # for mypy coverage report - matplotlib + - pint - pip - pytest - pytest-cov diff --git a/flox/aggregations.py b/flox/aggregations.py index 3d2a085a6..c8eda3349 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -64,8 +64,6 @@ def generic_aggregate( f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead." ) - group_idx = np.asarray(group_idx, like=array) - with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") result = method( @@ -131,6 +129,7 @@ def __init__( dtypes=None, final_dtype: DTypeLike | None = None, reduction_type="reduce", + units_func: Callable | None = None, ): """ Blueprint for computing grouped aggregations. @@ -173,6 +172,8 @@ def __init__( per reduction in ``chunk`` as a tuple. final_dtype : DType, optional DType for output. By default, uses dtype of array being reduced. + units_func : callable + function whose output will be used to infer units. """ self.name = name # preprocess before blockwise @@ -206,6 +207,8 @@ def __init__( # The following are set by _initialize_aggregation self.finalize_kwargs: dict[Any, Any] = {} self.min_count: int | None = None + self.units_func: Callable = units_func + self.units = None def _normalize_dtype_fill_value(self, value, name): value = _atleast_1d(value) @@ -254,17 +257,44 @@ def __repr__(self) -> str: final_dtype=np.intp, ) + +def identity(x): + return x + + +def square(x): + return x**2 + + +def raise_units_error(x): + raise ValueError( + "Units cannot supported for prod in general. " + "We can only attach units when there are " + "equal number of members in each group. " + "Please strip units and then reattach units " + "to the output manually." + ) + + # note that the fill values are the result of np.func([np.nan, np.nan]) # final_fill_value is used for groups that don't exist. This is usually np.nan -sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0) -nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0) -prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1) +sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0, units_func=identity) +nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0, units_func=identity) +prod = Aggregation( + "prod", + chunk="prod", + combine="prod", + fill_value=1, + final_fill_value=1, + units_func=raise_units_error, +) nanprod = Aggregation( "nanprod", chunk="nanprod", combine="prod", fill_value=1, final_fill_value=dtypes.NA, + units_func=raise_units_error, ) @@ -281,6 +311,7 @@ def _mean_finalize(sum_, count): fill_value=(0, 0), dtypes=(None, np.intp), final_dtype=np.floating, + units_func=identity, ) nanmean = Aggregation( "nanmean", @@ -290,6 +321,7 @@ def _mean_finalize(sum_, count): fill_value=(0, 0), dtypes=(None, np.intp), final_dtype=np.floating, + units_func=identity, ) @@ -315,6 +347,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0): final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, + units_func=square, ) nanvar = Aggregation( "nanvar", @@ -325,6 +358,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0): final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, + units_func=square, ) std = Aggregation( "std", @@ -335,6 +369,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0): final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, + units_func=identity, ) nanstd = Aggregation( "nanstd", @@ -345,13 +380,18 @@ def _std_finalize(sumsq, sum_, count, ddof=0): final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, + units_func=identity, ) -min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF) -nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan) -max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF) -nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan) +min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, units_func=identity) +nanmin = Aggregation( + "nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan, units_func=identity +) +max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, units_func=identity) +nanmax = Aggregation( + "nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan, units_func=identity +) def argreduce_preprocess(array, axis): @@ -439,10 +479,14 @@ def _pick_second(*x): final_dtype=np.intp, ) -first = Aggregation("first", chunk=None, combine=None, fill_value=0) -last = Aggregation("last", chunk=None, combine=None, fill_value=0) -nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan) -nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan) +first = Aggregation("first", chunk=None, combine=None, fill_value=0, units_func=identity) +last = Aggregation("last", chunk=None, combine=None, fill_value=0, units_func=identity) +nanfirst = Aggregation( + "nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan, units_func=identity +) +nanlast = Aggregation( + "nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan, units_func=identity +) all_ = Aggregation( "all", @@ -502,6 +546,7 @@ def _initialize_aggregation( dtype, array_dtype, fill_value, + array_units, min_count: int | None, finalize_kwargs: dict[Any, Any] | None, ) -> Aggregation: @@ -572,4 +617,8 @@ def _initialize_aggregation( agg.dtype["intermediate"] += (np.intp,) agg.dtype["numpy"] += (np.intp,) + if array_units is not None and agg.units_func is not None: + import pint + + agg.units = agg.units_func(pint.Quantity([1], units=array_units)) return agg diff --git a/flox/core.py b/flox/core.py index d69ac51ce..cc4fb4985 100644 --- a/flox/core.py +++ b/flox/core.py @@ -35,6 +35,7 @@ generic_aggregate, ) from .cache import memoize +from .pint_compat import _reattach_units, _strip_units from .xrutils import is_duck_array, is_duck_dask_array, isnull if TYPE_CHECKING: @@ -1799,6 +1800,8 @@ def groupby_reduce( by_is_dask = tuple(is_duck_dask_array(b) for b in bys) any_by_dask = any(by_is_dask) + array, bys, units = _strip_units(array, *bys) + if method in ["split-reduce", "cohorts"] and any_by_dask: raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") @@ -1904,7 +1907,9 @@ def groupby_reduce( fill_value = np.nan kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine) - agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs) + agg = _initialize_aggregation( + func, dtype, array.dtype, fill_value, units[0], min_count, finalize_kwargs + ) groups: tuple[np.ndarray | DaskArray, ...] if not has_dask: @@ -1964,4 +1969,7 @@ def groupby_reduce( if _is_minmax_reduction(func) and is_bool_array: result = result.astype(bool) + + units[0] = agg.units + result, *groups = _reattach_units(result, *groups, units=units) return (result, *groups) # type: ignore[return-value] # Unpack not in mypy yet diff --git a/flox/pint_compat.py b/flox/pint_compat.py new file mode 100644 index 000000000..eddf58654 --- /dev/null +++ b/flox/pint_compat.py @@ -0,0 +1,25 @@ +def _strip_units(*arrays): + try: + import pint + + pint_quantity = (pint.Quantity,) + + except ImportError: + pint_quantity = () + + bare = tuple(array.magnitude if isinstance(array, pint_quantity) else array for array in arrays) + units = [array.units if isinstance(array, pint_quantity) else None for array in arrays] + + return bare[0], bare[1:], units + + +def _reattach_units(*arrays, units): + try: + import pint + + return tuple( + pint.Quantity(array, unit) if unit is not None else array + for array, unit in zip(arrays, units) + ) + except ImportError: + return arrays diff --git a/tests/__init__.py b/tests/__init__.py index 4c04a0fc8..b92e68e57 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -24,6 +24,13 @@ except ImportError: xr_types = () # type: ignore +try: + import pint + + pint_types = pint.Quantity +except ImportError: + pint_types = () # type: ignore + def _importorskip(modname, minversion=None): try: @@ -46,6 +53,7 @@ def LooseVersion(vstring): has_dask, requires_dask = _importorskip("dask") +has_pint, requires_pint = _importorskip("pint") has_xarray, requires_xarray = _importorskip("xarray") @@ -95,6 +103,14 @@ def assert_equal(a, b, tolerance=None): xr.testing.assert_identical(a, b) return + if has_pint and isinstance(a, pint_types) or isinstance(b, pint_types): + assert isinstance(a, pint_types) + assert isinstance(b, pint_types) + assert a.units == b.units + + a = a.magnitude + b = b.magnitude + if tolerance is None and ( np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64) ): diff --git a/tests/test_core.py b/tests/test_core.py index e83a69da5..88eaa13cd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -30,6 +30,7 @@ has_dask, raise_if_dask_computes, requires_dask, + requires_pint, ) labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) @@ -1339,6 +1340,45 @@ def test_negative_index_factorize_race_condition(): [dask.compute(out, scheduler="threads") for _ in range(5)] +@requires_pint +@pytest.mark.parametrize("func", ALL_FUNCS) +@pytest.mark.parametrize("chunk", [True, False]) +def test_pint(chunk, func, engine): + import pint + + if func in ["prod", "nanprod"]: + pytest.skip() + + if chunk: + d = dask.array.array([1, 2, 3]) + else: + d = np.array([1, 2, 3]) + q = pint.Quantity(d, units="m") + + actual, _ = groupby_reduce(q, [0, 0, 1], func=func) + expected, _ = groupby_reduce(q.magnitude, [0, 0, 1], func=func) + + units = None if func in ["count", "all", "any"] or "arg" in func else getattr(np, func)(q).units + if units is not None: + expected = pint.Quantity(expected, units=units) + assert_equal(expected, actual) + + +@requires_pint +@pytest.mark.parametrize("chunk", [True, False]) +def test_pint_prod_error(chunk): + import pint + + if chunk: + d = dask.array.array([1, 2, 3]) + else: + d = np.array([1, 2, 3]) + q = pint.Quantity(d, units="m") + + with pytest.raises(ValueError): + groupby_reduce(q, [0, 0, 1], func="prod") + + @pytest.mark.parametrize("sort", [True, False]) def test_expected_index_conversion_passthrough_range_index(sort): index = pd.RangeIndex(100)