diff --git a/docs/api.rst b/docs/api.rst index bb7cb2f9..55d72379 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -65,6 +65,14 @@ DataArray xarray.DataArray.pint.bfill xarray.DataArray.pint.interpolate_na +Checking +-------- + +.. autosummary:: + :toctree: generated/ + + pint_xarray.expects + Testing ------- diff --git a/docs/terminology.rst b/docs/terminology.rst index 4fe6534d..4532063e 100644 --- a/docs/terminology.rst +++ b/docs/terminology.rst @@ -5,6 +5,7 @@ Terminology unit-like A `pint`_ unit definition, as accepted by :py:class:`pint.Unit`. - May be either a :py:class:`str` or a :py:class:`pint.Unit` instance. + May be either a :py:class:`str`, a :py:class:`pint.Unit` + instance, or :py:obj:`None`. .. _pint: https://pint.readthedocs.io/en/stable diff --git a/docs/whats-new.rst b/docs/whats-new.rst index 229b4e0c..beb1a468 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -5,6 +5,8 @@ What's new 0.4 (*unreleased*) ------------------ +- Added the :py:func:`pint_xarray.expects` decorator (:pull:`143`). + By `Tom Nicholas `_ and `Justus Magin `_. 0.3 (27 Jul 2022) ----------------- diff --git a/pint_xarray/__init__.py b/pint_xarray/__init__.py index 3ce42d86..f6276b67 100644 --- a/pint_xarray/__init__.py +++ b/pint_xarray/__init__.py @@ -5,6 +5,7 @@ from . import accessors, formatting, testing # noqa: F401 from .accessors import default_registry as unit_registry from .accessors import setup_registry +from .checking import expects # noqa: F401 try: __version__ = version("pint-xarray") diff --git a/pint_xarray/checking.py b/pint_xarray/checking.py new file mode 100644 index 00000000..cc7b7e23 --- /dev/null +++ b/pint_xarray/checking.py @@ -0,0 +1,255 @@ +import functools +import inspect +from inspect import Parameter + +from pint import Quantity +from xarray import DataArray, Dataset, Variable + +from . import conversion +from .accessors import PintDataArrayAccessor # noqa + + +def detect_missing_params(params, units): + """detect parameters for which no units were specified""" + variable_params = { + Parameter.VAR_POSITIONAL, + Parameter.VAR_KEYWORD, + } + + return { + name + for name, param in params.items() + if name not in units.arguments and param.kind not in variable_params + } + + +def convert_and_strip(obj, units): + if isinstance(obj, (DataArray, Dataset, Variable)): + if not isinstance(units, dict): + units = {None: units} + return conversion.strip_units(conversion.convert_units(obj, units)) + elif isinstance(obj, Quantity): + return obj.m_as(units) + elif units is None: + return obj + else: + raise ValueError(f"unknown type: {type(obj)}") + + +def convert_and_strip_args(args, units): + return [convert_and_strip(obj, units_) for obj, units_ in zip(args, units)] + + +def convert_and_strip_kwargs(kwargs, units): + return {name: convert_and_strip(kwargs[name], units[name]) for name in kwargs} + + +def always_iterable(obj, base_type=(str, bytes)): + """ + If *obj* is iterable, return an iterator over its items, + If *obj* is not iterable, return a one-item iterable containing *obj*, + If *obj* is ``None``, return an empty iterable. + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + Copied from more_itertools. + """ + + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + +def attach_return_units(results, units): + if units is None: + # ignore types and units of return values + return results + elif results is None: + raise TypeError( + "Expected function to return something, but function returned None" + ) + else: + # handle case of function returning only one result by promoting to 1-element tuple + return_units_iterable = tuple(always_iterable(units, base_type=(str, dict))) + results_iterable = tuple(always_iterable(results, base_type=(str, Dataset))) + + # check same number of things were returned as expected + if len(results_iterable) != len(return_units_iterable): + raise TypeError( + f"{len(results_iterable)} return values were received, but {len(return_units_iterable)} " + "return values were expected" + ) + + converted_results = _attach_multiple_units( + results_iterable, return_units_iterable + ) + + if isinstance(results, tuple) or len(converted_results) != 1: + return converted_results + else: + return converted_results[0] + + +def _check_or_convert_to_then_strip(obj, units): + """ + Checks the object is of a valid type (Quantity or DataArray), then attempts to convert it to the specified units, + then strips the units from it. + """ + + if units is None: + # allow for passing through non-numerical arguments + return obj + elif isinstance(obj, Quantity): + converted = obj.to(units) + return converted.magnitude + elif isinstance(obj, (DataArray, Dataset)): + converted = obj.pint.to(units) + return converted.pint.dequantify() + else: + raise TypeError( + "Can only expect units for arguments of type xarray.DataArray," + f" xarray.Dataset, or pint.Quantity, not {type(obj)}" + ) + + +def _attach_units(obj, units): + """Attaches units, but can also create pint.Quantity objects from numpy scalars""" + if isinstance(obj, (DataArray, Dataset)): + return obj.pint.quantify(units) + else: + return Quantity(obj, units=units) + + +def _attach_multiple_units(objects, units): + """Attaches list of units to list of objects elementwise""" + converted_objects = [_attach_units(obj, unit) for obj, unit in zip(objects, units)] + return converted_objects + + +def expects(*args_units, return_units=None, **kwargs_units): + """ + Decorator which ensures the inputs and outputs of the decorated + function are expressed in the expected units. + + Arguments to the decorated function are checked for the specified + units, converting to those units if necessary, and then stripped + of their units before being passed into the undecorated + function. Therefore the undecorated function should expect + unquantified DataArrays, Datasets, or numpy-like arrays, but with + the values expressed in specific units. + + Parameters + ---------- + func : callable + Function to decorate, which accepts zero or more + xarray.DataArrays or numpy-like arrays as inputs, and may + optionally return one or more xarray.DataArrays or numpy-like + arrays. + *args_units : unit-like or mapping of hashable to unit-like, optional + Units to expect for each positional argument given to func. + + The decorator will first check that arguments passed to the + decorated function possess these specific units (or will + attempt to convert the argument to these units), then will + strip the units before passing the magnitude to the wrapped + function. + + A value of None indicates not to check that argument for units + (suitable for flags and other non-data arguments). + return_units : unit-like or list of unit-like or mapping of hashable to unit-like \ + or list of mapping of hashable to unit-like, optional + The expected units of the returned value(s), either as a + single unit or as a list of units. The decorator will attach + these units to the variables returned from the function. + + A value of None indicates not to attach any units to that + return value (suitable for flags and other non-data results). + kwargs_units : mapping of hashable to unit-like, optional + Unit to expect for each keyword argument given to func. + + The decorator will first check that arguments passed to the + decorated function possess these specific units (or will + attempt to convert the argument to these units), then will + strip the units before passing the magnitude to the wrapped + function. + + A value of None indicates not to check that argument for units + (suitable for flags and other non-data arguments). + + Returns + ------- + return_values : Any + Return values of the wrapped function, either a single value + or a tuple of values. These will be given units according to + return_units. + + Raises + ------ + TypeError + If an argument or return value has a specified unit, but is + not an xarray.DataArray or pint.Quantity. Also thrown if any + of the units are not a valid type, or if the number of + arguments or return values does not match the number of units + specified. + + Examples + -------- + + Decorating a function which takes one quantified input, but + returns a non-data value (in this case a boolean). + + >>> @expects("deg C") + ... def above_freezing(temp): + ... return temp > 0 + + Decorating a function which allows any dimensions for the array, but also + accepts an optional `weights` keyword argument, which must be dimensionless. + + >>> @expects(None, weights="dimensionless") + ... def mean(da, weights=None): + ... if weights: + ... return da.weighted(weights=weights).mean() + ... else: + ... return da.mean() + + """ + + def _expects_decorator(func): + + # check same number of arguments were passed as expected + sig = inspect.signature(func) + + params = sig.parameters + + bound_units = sig.bind_partial(*args_units, **kwargs_units) + + missing_params = detect_missing_params(params, bound_units) + if missing_params: + raise TypeError( + "Some parameters of the decorated function are missing units:" + f" {', '.join(sorted(missing_params))}" + ) + + @functools.wraps(func) + def _unit_checking_wrapper(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + + converted_args = convert_and_strip_args(bound.args, bound_units.args) + converted_kwargs = convert_and_strip_kwargs( + bound.kwargs, bound_units.kwargs + ) + + results = func(*converted_args, **converted_kwargs) + + return attach_return_units(results, return_units) + + return _unit_checking_wrapper + + return _expects_decorator diff --git a/pint_xarray/conversion.py b/pint_xarray/conversion.py index b18b8a63..c0860b26 100644 --- a/pint_xarray/conversion.py +++ b/pint_xarray/conversion.py @@ -223,7 +223,9 @@ def convert_units_dataset(obj, units): def convert_units(obj, units): - if not isinstance(obj, (DataArray, Dataset)): + if isinstance(obj, Variable): + return convert_units_variable(obj, units) + elif not isinstance(obj, (DataArray, Dataset)): raise ValueError(f"cannot convert object: {obj!r}: unknown type") if isinstance(obj, DataArray): @@ -299,7 +301,9 @@ def strip_units_dataset(obj): def strip_units(obj): - if not isinstance(obj, (DataArray, Dataset)): + if isinstance(obj, Variable): + return strip_units_variable(obj) + elif not isinstance(obj, (DataArray, Dataset)): raise ValueError("cannot strip units from {obj!r}: unknown type") return call_on_dataset(strip_units_dataset, obj, name=temporary_name) diff --git a/pint_xarray/tests/test_checking.py b/pint_xarray/tests/test_checking.py new file mode 100644 index 00000000..068729d2 --- /dev/null +++ b/pint_xarray/tests/test_checking.py @@ -0,0 +1,198 @@ +import pint +import pytest +import xarray as xr +from pint import UnitRegistry + +from ..checking import expects + +ureg = UnitRegistry() + + +class TestExpects: + def test_single_arg(self): + @expects("degC") + def above_freezing(temp): + return temp > 0 + + f_q = pint.Quantity(20, units="degF") + assert not above_freezing(f_q) + + c_q = pint.Quantity(-2, units="degC") + assert not above_freezing(c_q) + + f_da = xr.DataArray(20).pint.quantify(units="degF") + assert not above_freezing(f_da) + + c_da = xr.DataArray(-2).pint.quantify(units="degC") + assert not above_freezing(c_da) + + def test_single_kwarg(self): + @expects("meters", c="meters / second", return_units="Hz") + def freq(wavelength, c=None): + if c is None: + c = (1 * ureg.speed_of_light).to("meters / second").magnitude + + return c / wavelength + + w_q = pint.Quantity(0.02, units="inches") + c_q = pint.Quantity(1e6, units="feet / second") + f_q = freq(w_q, c=c_q) + assert f_q.units == pint.Unit("hertz") + f_q = freq(w_q) + assert f_q.units == pint.Unit("hertz") + + w_da = xr.DataArray(0.02).pint.quantify(units="inches") + c_da = xr.DataArray(1e6).pint.quantify(units="feet / second") + f_da = freq(w_da, c=c_da) + assert f_da.pint.units == pint.Unit("hertz") + f_da = freq(w_da) + assert f_da.pint.units == pint.Unit("hertz") + + def test_weighted_kwarg(self): + @expects(None, weights="dimensionless", return_units="metres") + def mean(da, weights=None): + if weights is not None: + return da.weighted(weights=weights).mean() + else: + return da.mean() + + d = xr.DataArray([1, 2, 3]).pint.quantify(units="metres") + w = xr.DataArray([0.1, 0.7, 0.2]).pint.quantify(units="dimensionless") + + result = mean(d, weights=w) + expected = xr.DataArray(0.21).pint.quantify("metres") + assert result.pint.units == expected.pint.units + + def test_single_return_value(self): + @expects("kg", "m / s^2", return_units="newtons") + def second_law(m, a): + return m * a + + m_q = pint.Quantity(0.1, units="tons") + a_q = pint.Quantity(10, units="feet / second^2") + expected_q = (m_q * a_q).to("newtons") + result_q = second_law(m_q, a_q) + assert result_q == expected_q + + m_da = xr.DataArray(0.1).pint.quantify(units="tons") + a_da = xr.DataArray(10).pint.quantify(units="feet / second^2") + expected_da = (m_da * a_da).pint.to("newtons") + result_da = second_law(m_da, a_da) + assert result_da == expected_da + + def test_multiple_return_values(self): + @expects("kg", "m / s", return_units=["J", "newton seconds"]) + def energy_and_momentum(m, v): + ke = 0.5 * m * v**2 + p = m * v + return ke, p + + m = pint.Quantity(0.1, units="tons") + v = pint.Quantity(10, units="feet / second") + expected_ke = (0.5 * m * v**2).to("J") + expected_p = (m * v).to("newton seconds") + result_ke, result_p = energy_and_momentum(m, v) + assert result_ke.units == expected_ke.units + assert result_p.units == expected_p.units + + m = xr.DataArray(0.1).pint.quantify(units="tons") + v = xr.DataArray(10).pint.quantify(units="feet / second") + expected_ke = (0.5 * m * v**2).pint.to("J") + expected_p = (m * v).pint.to("newton seconds") + result_ke, result_p = energy_and_momentum(m, v) + assert result_ke.pint.units == expected_ke.pint.units + assert result_p.pint.units == expected_p.pint.units + + def test_dont_check_arg_units(self): + @expects("seconds", None, return_units=None) + def finite_difference(a, type): + return ... + + t = pint.Quantity(0.1, units="seconds") + finite_difference(t, "centered") + + @pytest.mark.parametrize( + "arg_units, return_units", + [("nonsense", "Hertz"), ("seconds", 6), ("seconds", [6])], + ) + def test_invalid_unit_types(self, arg_units, return_units): + @expects(arg_units, return_units=return_units) + def freq(period): + return 1 / period + + q = pint.Quantity(1.0, units="seconds") + + with pytest.raises((TypeError, pint.errors.UndefinedUnitError)): + freq(q) + + def test_unquantified_arrays(self): + @expects("seconds", return_units="Hertz") + def freq(period): + return 1 / period + + da = xr.DataArray(10) + + with pytest.raises( + ValueError, + match="cannot convert a non-quantity", + ): + freq(da) + + def test_wrong_number_of_args(self): + with pytest.raises( + TypeError, + match="Some parameters of the decorated function are missing units", + ): + + @expects("kg", return_units="newtons") + def second_law(m, a): + return m * a + + def test_wrong_number_of_return_values(self): + @expects("kg", "m / s^2", return_units=["newtons", "joules"]) + def second_law(m, a): + return m * a + + m_q = pint.Quantity(0.1, units="tons") + a_q = pint.Quantity(10, units="feet / second^2") + + with pytest.raises(TypeError, match="2 return values were expected"): + second_law(m_q, a_q) + + def test_expected_return_value(self): + @expects("seconds", return_units="Hz") + def freq(period): + return None + + p = pint.Quantity(2, units="seconds") + + with pytest.raises(TypeError, match="function returned None"): + freq(p) + + def test_input_unit_dict(self): + @expects({"m": "kg", "a": "m / s^2"}, return_units="newtons") + def second_law(ds): + return ds["m"] * ds["a"] + + m_da = xr.DataArray(0.1).pint.quantify(units="tons") + a_da = xr.DataArray(10).pint.quantify(units="feet / second^2") + ds = xr.Dataset({"m": m_da, "a": a_da}) + + expected_da = (m_da * a_da).pint.to("newtons") + result_da = second_law(ds) + assert result_da == expected_da + + def test_return_dataset(self): + @expects({"m": "kg", "a": "m / s^2"}, return_units=[{"f": "newtons"}]) + def second_law(ds): + f_da = ds["m"] * ds["a"] + return xr.Dataset({"f": f_da}) + + m_da = xr.DataArray(0.1).pint.quantify(units="tons") + a_da = xr.DataArray(10).pint.quantify(units="feet / second^2") + ds = xr.Dataset({"m": m_da, "a": a_da}) + + expected_da = m_da * a_da + expected_ds = xr.Dataset({"f": expected_da}).pint.to({"f": "newtons"}) + result_ds = second_law(ds) + assert result_ds == expected_ds