Skip to content

Commit

Permalink
make dm-tree optional
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Sep 11, 2024
1 parent 5a928fb commit 9de6c2b
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 11 deletions.
25 changes: 18 additions & 7 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import numpy as np
import tree
import xarray as xr

try:
import tree
except ImportError:
tree = None

try:
import ujson as json
except ImportError:
Expand Down Expand Up @@ -89,6 +93,9 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
input_tree.
"""
# pylint: disable=protected-access
if tree is None:
raise ImportError("Missing optional dependency 'dm-tree'. Use pip or conda to install it")

if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
isinstance(shallow_tree, tree.collections_abc.Mapping)
or tree._is_namedtuple(shallow_tree)
Expand Down Expand Up @@ -299,7 +306,7 @@ def numpy_to_data_array(
return xr.DataArray(ary, coords=coords, dims=dims)


def pytree_to_dataset(
def dict_to_dataset(
data,
*,
attrs=None,
Expand All @@ -312,6 +319,8 @@ def pytree_to_dataset(
):
"""Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
ArviZ itself supports conversion of flat dictionaries.
Suport for pytrees requires ``dm-tree`` which is an optional dependency.
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
this inclues at least dictionaries and tuple types.
Expand Down Expand Up @@ -386,10 +395,12 @@ def pytree_to_dataset(
"""
if dims is None:
dims = {}
try:
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
except TypeError: # probably unsortable keys -- the function will still work if
pass # it is an honest dictionary.

if tree is not None:
try:
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
except TypeError: # probably unsortable keys -- the function will still work if
pass # it is an honest dictionary.

data_vars = {
key: numpy_to_data_array(
Expand All @@ -406,7 +417,7 @@ def pytree_to_dataset(
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))


dict_to_dataset = pytree_to_dataset
pytree_to_dataset = dict_to_dataset


def make_attrs(attrs=None, library=None):
Expand Down
9 changes: 6 additions & 3 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""High level conversion functions."""

import numpy as np
import tree
import xarray as xr
try:
from tree import is_nested
except ImportError:
is_nested = lambda obj: False

from .base import dict_to_dataset
from .inference_data import InferenceData
Expand Down Expand Up @@ -107,7 +110,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
dataset = obj.to_dataset()
elif isinstance(obj, dict):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
elif is_nested(obj) and not isinstance(obj, (list, tuple)):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif isinstance(obj, np.ndarray):
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
Expand All @@ -122,7 +125,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
"xarray dataarray",
"xarray dataset",
"dict",
"pytree",
"pytree (if 'dm-tree' is installed)",
"netcdf filename",
"numpy array",
"pystan fit",
Expand Down
5 changes: 5 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
models,
)

# Check if dm-tree is installed
dm_tree_installed = importlib.util.find_spec("tree") is not None # pylint: disable=invalid-name
skip_tests = (not dm_tree_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)


@pytest.fixture(autouse=True)
def no_remote_data(monkeypatch, tmpdir):
Expand Down Expand Up @@ -1076,6 +1080,7 @@ def test_dict_to_dataset():
assert set(dataset.b.coords) == {"chain", "draw", "c"}


@pytest.mark.skipif(skip_tests, reason="test requires dm-tree which is not installed")
def test_nested_dict_to_dataset():
datadict = {
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
Expand Down
1 change: 1 addition & 0 deletions requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ sphinx_design
sphinx-codeautolink>=0.9.0
jupyter-sphinx
sphinxcontrib-youtube
dm-tree>=0.1.8
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ ujson
dask[distributed]
zarr>=2.5.0,<3
xarray-datatree
dm-tree>=0.1.8
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ numpy>=1.23.0
scipy>=1.9.0
packaging
pandas>=1.5.0
dm-tree>=0.1.8
xarray>=2022.6.0
h5netcdf>=1.0.2
typing_extensions>=4.1.0
Expand Down

0 comments on commit 9de6c2b

Please sign in to comment.