diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index 2bef93a5..c2b718e8 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -30,8 +30,6 @@ "nanmean": {np.int_: np.float64}, "nanvar": {np.int_: np.float64}, "nanstd": {np.int_: np.float64}, - "nanfirst": {np.datetime64: np.int64, np.timedelta64: np.int64}, - "nanlast": {np.datetime64: np.int64, np.timedelta64: np.int64}, } @@ -53,7 +51,7 @@ def _numbagg_wrapper( if cast_to: for from_, to_ in cast_to.items(): if np.issubdtype(array.dtype, from_): - array = array.astype(to_, copy=False) + array = array.astype(to_) func_ = getattr(numbagg.grouped, f"group_{func}") diff --git a/flox/core.py b/flox/core.py index 7f15ae2f..91903ded 100644 --- a/flox/core.py +++ b/flox/core.py @@ -45,10 +45,6 @@ ) from .cache import memoize from .xrutils import ( - _contains_cftime_datetimes, - _datetime_nanmin, - _to_pytimedelta, - datetime_to_numeric, is_chunked_array, is_duck_array, is_duck_cubed_array, @@ -2477,8 +2473,7 @@ def groupby_reduce( has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_) - is_first_last = _is_first_last_reduction(func) - if is_first_last: + if _is_first_last_reduction(func): if has_dask and nax != 1: raise ValueError( "For dask arrays: first, last, nanfirst, nanlast reductions are " @@ -2491,24 +2486,6 @@ def groupby_reduce( "along a single axis or when reducing across all dimensions of `by`." ) - # Flox's count works with non-numeric and its faster than converting. - is_npdatetime = array.dtype.kind in "Mm" - is_cftime = _contains_cftime_datetimes(array) - requires_numeric = ( - (func not in ["count", "any", "all"] and not is_first_last) - or (func == "count" and engine != "flox") - or (is_first_last and is_cftime) - ) - if requires_numeric: - if is_npdatetime: - offset = _datetime_nanmin(array) - # xarray always uses np.datetime64[ns] for np.datetime64 data - dtype = "timedelta64[ns]" - array = datetime_to_numeric(array, offset) - elif is_cftime: - offset = array.min() - array = datetime_to_numeric(array, offset, datetime_unit="us") - if nax == 1 and by_.ndim > 1 and expected_ is None: # When we reduce along all axes, we are guaranteed to see all # groups in the final combine stage, so everything works. @@ -2694,14 +2671,6 @@ def groupby_reduce( if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)): result = result.astype(bool) - - # Output of count has an int dtype. - if requires_numeric and func != "count": - if is_npdatetime: - return result.astype(dtype) + offset - elif is_cftime: - return _to_pytimedelta(result, unit="us") + offset - return (result, *groups) diff --git a/flox/xarray.py b/flox/xarray.py index fbeeedba..1562acc8 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -7,6 +7,7 @@ import pandas as pd import xarray as xr from packaging.version import Version +from xarray.core.duck_array_ops import _datetime_nanmin from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func from .core import ( @@ -17,6 +18,7 @@ ) from .core import rechunk_for_blockwise as rechunk_array_for_blockwise from .core import rechunk_for_cohorts as rechunk_array_for_cohorts +from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric if TYPE_CHECKING: from xarray.core.types import T_DataArray, T_Dataset @@ -364,6 +366,22 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): if "nan" not in func and func not in ["all", "any", "count"]: func = f"nan{func}" + # Flox's count works with non-numeric and its faster than converting. + requires_numeric = func not in ["count", "any", "all"] or ( + func == "count" and kwargs["engine"] != "flox" + ) + if requires_numeric: + is_npdatetime = array.dtype.kind in "Mm" + is_cftime = _contains_cftime_datetimes(array) + if is_npdatetime: + offset = _datetime_nanmin(array) + # xarray always uses np.datetime64[ns] for np.datetime64 data + dtype = "timedelta64[ns]" + array = datetime_to_numeric(array, offset) + elif is_cftime: + offset = array.min() + array = datetime_to_numeric(array, offset, datetime_unit="us") + result, *groups = groupby_reduce(array, *by, func=func, **kwargs) # Transpose the new quantile dimension to the end. This is ugly. @@ -377,6 +395,13 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): # output dim order: (*broadcast_dims, *group_dims, quantile_dim) result = np.moveaxis(result, 0, -1) + # Output of count has an int dtype. + if requires_numeric and func != "count": + if is_npdatetime: + return result.astype(dtype) + offset + elif is_cftime: + return _to_pytimedelta(result, unit="us") + offset + return result # These data variables do not have any of the core dimension, diff --git a/flox/xrutils.py b/flox/xrutils.py index 9ae0ae02..ba8a5672 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -345,28 +345,6 @@ def _contains_cftime_datetimes(array) -> bool: return False -def _datetime_nanmin(array): - """nanmin() function for datetime64. - - Caveats that this function deals with: - - - In numpy < 1.18, min() on datetime64 incorrectly ignores NaT - - numpy nanmin() don't work on datetime64 (all versions at the moment of writing) - - dask min() does not work on datetime64 (all versions at the moment of writing) - """ - from .xrdtypes import is_datetime_like - - dtype = array.dtype - assert is_datetime_like(dtype) - # (NaT).astype(float) does not produce NaN... - array = np.where(pd.isnull(array), np.nan, array.astype(float)) - array = min(array, skipna=True) - if isinstance(array, float): - array = np.array(array) - # ...but (NaN).astype("M8") does produce NaT - return array.astype(dtype) - - def _select_along_axis(values, idx, axis): other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) sl = other_ind[:axis] + (idx,) + other_ind[axis:] diff --git a/tests/test_core.py b/tests/test_core.py index 055c641c..164f87b3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2006,19 +2006,3 @@ def test_blockwise_avoid_rechunk(): actual, groups = groupby_reduce(array, by, func="first") assert_equal(groups, ["", "0", "1"]) assert_equal(actual, np.array([0, 0, 0], dtype=np.int64)) - - -@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"]) -def test_datetime_timedelta_first_last(engine, func): - import flox - - idx = 0 if "first" in func else -1 - - dt = pd.date_range("2001-01-01", freq="d", periods=5).values - by = np.ones(dt.shape, dtype=int) - actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine) - assert_equal(actual, dt[[idx]]) - - dt = dt - dt[0] - actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine) - assert_equal(actual, dt[[idx]])