Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expects decorator #143

Open
wants to merge 55 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
044d59a
draft implementation of @expects
TomNicholas Oct 20, 2021
0754f22
sketch of different tests needed
TomNicholas Oct 20, 2021
e879ef9
idea for test
TomNicholas Oct 22, 2021
aad7936
upgrade check then convert function to optionally take magnitude
TomNicholas Oct 28, 2021
e354f4e
removed magnitude option
TomNicholas Nov 29, 2021
7727d8e
works for single return value
TomNicholas Nov 29, 2021
1379779
works for single kwarg
TomNicholas Nov 30, 2021
77f5d02
works for multiple return values
TomNicholas Nov 30, 2021
71f4200
allow passing through arguments unchecked
TomNicholas Nov 30, 2021
a710741
check types of units
TomNicholas Nov 30, 2021
497e97f
remove uneeded option to specify a lack of return value
TomNicholas Nov 30, 2021
00219bc
check number of inputs and return values
TomNicholas Nov 30, 2021
9e92f21
removed nonlocal keyword
TomNicholas Nov 30, 2021
86f7e58
generalised to handle specifying dicts of units
TomNicholas Nov 30, 2021
2141c6c
type hint for func
TomNicholas Nov 30, 2021
a94a6ae
type hint for args_units
TomNicholas Nov 30, 2021
a2cc63f
Merge branch 'expects_decorator' of https://github.com/TomNicholas/pi…
TomNicholas Nov 30, 2021
7103483
numpy-style type hints for all arguments
TomNicholas Nov 30, 2021
59ddf86
whats new
TomNicholas Nov 30, 2021
a5a2493
add to API docs
TomNicholas Nov 30, 2021
e5e84fb
use always_iterable
TomNicholas Dec 14, 2021
3f59414
hashable
TomNicholas Dec 14, 2021
b281674
hashable
TomNicholas Dec 14, 2021
c669105
hashable
TomNicholas Dec 14, 2021
9ac8887
dict comprehension
TomNicholas Dec 14, 2021
0a6447d
list comprehension
TomNicholas Dec 14, 2021
c29f935
unindent if/else
TomNicholas Dec 14, 2021
81913a6
missing parenthesis
TomNicholas Dec 14, 2021
4de6f4d
simplify if/else logic for checking there were actually results
TomNicholas Dec 14, 2021
83e422f
return results immediately if a tuple
TomNicholas Dec 14, 2021
37c3fbc
allow for returning Datasets from wrapped funciton
TomNicholas Dec 14, 2021
9c19af0
Update docs/api.rst
TomNicholas Jan 14, 2022
0b5c7c0
correct indentation of docstring
TomNicholas Jan 14, 2022
0f50305
use inspects to check number of arguments passed to decorated function
TomNicholas Jan 14, 2022
57d341e
reformat the docstring
keewis Jan 15, 2022
8845b77
update the definition of unit-like
keewis Jan 16, 2022
bc41425
simplify if/else statement
TomNicholas Jan 18, 2022
aba2d11
Merge branch 'expects_decorator' of https://github.com/TomNicholas/pi…
TomNicholas Jan 18, 2022
0350308
check units in .to instead
TomNicholas Jan 19, 2022
3a24a73
remove extra xfailed test
TomNicholas Jan 19, 2022
19fd6e0
test raises on unquantified input
TomNicholas Jan 20, 2022
d2d74e4
add example of function which optionally accepts dimensionless weights
TomNicholas Jan 20, 2022
1c4feb4
Merge branch 'main' into expects_decorator
keewis Mar 11, 2022
7a6f2cb
Merge branch 'main' into expects_decorator
keewis Sep 20, 2022
85b982c
rewrite using inspect.Signature's bind and bind_partial
keewis Sep 20, 2022
5ea484b
also allow converting and stripping Variable objects
keewis Sep 21, 2022
b7c71c1
implement the conversion functions
keewis Sep 21, 2022
b39fff3
simplify the return construct
keewis Sep 21, 2022
61e0299
code reorganization
keewis Sep 21, 2022
63d8aeb
black
keewis Sep 21, 2022
32a57b2
fix a test
keewis Sep 21, 2022
91b4826
remove the note about coordinates not being checked [skip-ci]
keewis Sep 21, 2022
a43dd13
reword the error message raised when there's no units for some parame…
keewis Sep 21, 2022
7ea921c
move the changelog to a new section
keewis Sep 21, 2022
b92087a
Merge branch 'main' into expects_decorator
keewis Sep 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pint_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import accessors, formatting, testing # noqa: F401
from .accessors import default_registry as unit_registry
from .accessors import setup_registry
from .checking import expects # noqa: F401

try:
__version__ = version("pint-xarray")
Expand Down
118 changes: 118 additions & 0 deletions pint_xarray/checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import functools

from pint import Quantity
from xarray import DataArray

from .accessors import PintDataArrayAccessor


def expects(*args_units, return_units=None, **kwargs_units):
"""
Decorator which checks the inputs and outputs of the decorated function have certain units.

Arguments

Note that the coordinates of input DataArrays are not checked, only the data.
So if your decorated function uses coordinates and you wish to check their units,
you should pass the coordinates of interest as separate arguments.
keewis marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
func: function
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
Decorated function, which accepts zero or more xarray.DataArrays or pint.Quantitys as inputs,
and may optionally return one or more xarray.DataArrays or pint.Quantitys.
args_units : Union[str, pint.Unit, None]
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
Unit to expect for each positional argument given to func.

A value of None indicates not to check that argument for units (suitable for flags
and other non-data arguments).
return_units : Union[Union[str, pint.Unit, None, False], Sequence[Union[str, pint.Unit, None]], Optional
The expected units of the returned value(s), either as a single unit or as an iterable of units.

A value of None indicates not to check that return value for units (suitable for flags and other
non-data arguments). Passing False means that no return value is expected from the function at all,
and an error will be raised if a return value is found.
kwargs_units : Dict[str, Union[str, pint.Unit, None]], Optional
Unit to expect for each keyword argument given to func.

A value of None indicates not to check that argument for units (suitable for flags
and other non-data arguments).

Returns
-------
return_values
Return values of the wrapped function, either a single value or a tuple of values.

Raises
------
TypeError
If an argument or return value has a specified unit, but is not an xarray.DataArray.


Examples
--------

Decorating a function which takes one quantified input, but returns a non-data value (in this case a boolean).

>>> @expects('deg C')
... def above_freezing(temp):
... return temp > 0
...


TODO: example where we check units of an optional weighted kwarg
"""

# TODO: Check args_units, kwargs_units, and return_units types
# TODO: Check number of arguments line up

def _expects_decorator(func):

@functools.wraps(func)
def _unit_checking_wrapper(*args, **kwargs):

converted_args = []
for arg, arg_unit in zip(args, args_units):
converted_arg = _check_then_convert_to(arg, arg_unit)
converted_args.append(converted_arg)

converted_kwargs = {}
for key, val in kwargs.items():
kwarg_unit = kwargs_units[key]
converted_kwargs[key] = _check_then_convert_to(val, kwarg_unit)

results = func(*converted_args, **converted_kwargs)

if results is not None:
if return_units is False:
raise ValueError("Did not expect function to return anything")
elif return_units is not None:
# TODO check something was actually returned
# TODO check same number of things were returned as expected
# TODO handle single return value vs tuple of return values

converted_results = []
for return_unit, return_value in zip(return_units, results):
converted_result = _check_then_convert_to(return_value, return_unit)
converted_results.append(converted_result)

return tuple(converted_results)
else:
return results
else:
if return_units:
raise ValueError("Expected function to return something")

return _unit_checking_wrapper

return _expects_decorator


def _check_then_convert_to(obj, units):
if isinstance(obj, Quantity):
return obj.to(units)
elif isinstance(obj, DataArray):
return obj.pint.to(units)
else:
raise TypeError("Can only expect units for arguments of type xarray.DataArray or pint.Quantity,"
f"not {type(obj)}")
97 changes: 97 additions & 0 deletions pint_xarray/tests/test_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest
import pint
import xarray as xr

from pint import UnitRegistry

from ..checking import expects

ureg = UnitRegistry()


class TestExpects:
def test_single_arg(self):

@expects('degC')
def above_freezing(temp : pint.Quantity):
return temp.magnitude > 0

f_q = pint.Quantity(20, units='degF')
assert above_freezing(f_q) == False

c_q = pint.Quantity(-2, units='degC')
assert above_freezing(c_q) == False

@expects('degC')
def above_freezing(temp : xr.DataArray):
return temp.pint.magnitude > 0

f_da = xr.DataArray(20).pint.quantify(units='degF')
assert above_freezing(f_da) == False

c_da = xr.DataArray(-2).pint.quantify(units='degC')
assert above_freezing(c_da) == False

def test_single_kwarg(self):

@expects('meters', c='meters / second')
def freq(wavelength, c=None):
if c is None:
c = ureg.speed_of_light

return c / wavelength

def test_single_return_value(self):

@expects('Hz')
def period(freq):
return 1 / freq

f = pint.Quantity(10, units='Hz')

# test conversion
T = period(f)
assert f.units == 'seconds'

# test wrong dimensions for conversion
...

@pytest.mark.xfail
def test_multiple_return_values(self):
raise NotImplementedError

@pytest.mark.xfail
def test_mixed_args_kwargs_return_values(self):
raise NotImplementedError

@pytest.mark.xfail
def test_invalid_input_types(self):
raise NotImplementedError

@pytest.mark.xfail
def test_invalid_return_types(self):
raise NotImplementedError

@pytest.mark.xfail
def test_unquantified_arrays(self):
raise NotImplementedError

@pytest.mark.xfail
def test_wrong_number_of_args(self):
raise NotImplementedError

@pytest.mark.xfail
def test_nonexistent_kwarg(self):
raise NotImplementedError

@pytest.mark.xfail
def test_unexpected_return_value(self):
raise NotImplementedError

@pytest.mark.xfail
def test_expected_return_value(self):
raise NotImplementedError

@pytest.mark.xfail
def test_wrong_number_of_return_values(self):
raise NotImplementedError