-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batches to zarr #40
base: main
Are you sure you want to change the base?
Batches to zarr #40
Changes from 6 commits
0e7b538
45e20ec
9083881
7b2341b
66a9ec5
dd4108c
e7dca77
1ce312f
08a9e94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
pytest | ||
coverage | ||
zarr | ||
pytest-cov | ||
adlfs | ||
-r requirements.txt |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,3 +144,54 @@ def _iterate_batch_dims(self, ds): | |
|
||
def _iterate_input_dims(self, ds): | ||
return _iterate_through_dataset(ds, self.input_dims, self.input_overlap) | ||
|
||
def to_zarr(self, path, chunks={'batch': '1Gb'}): | ||
""" | ||
Store batches into a zarr datastore in `path`. To speed up loading of | ||
batches it is recommended that the chunking across batches is set close | ||
to the available RAM on the computere where you are doing ML model | ||
training | ||
""" | ||
batch_datasets = list(self) | ||
# can't call the batch dimension `batch` because Dataset.batch is used | ||
# for the batch acccessor. Instead we'll call it `batch_number` | ||
ds_all = xr.concat(batch_datasets, dim='batch_number').reset_index( | ||
'sample' | ||
) | ||
if 'batch' in chunks: | ||
chunks['batch_number'] = chunks.pop('batch') | ||
|
||
if len(chunks) > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test when |
||
ds_all = ds_all.chunk(chunks) | ||
ds_all.to_zarr(path) | ||
|
||
@staticmethod | ||
def from_zarr(path): | ||
""" | ||
Load a batch generator from the zarr datastore at a given `path` | ||
""" | ||
return StoredBatchesGenerator(path=path) | ||
|
||
|
||
class StoredBatchesGenerator: | ||
""" | ||
Create a generator which mimicks the behaviour of BatchGenerator but loads | ||
the batches from a zarr store that was previously created with | ||
`BatchGenerator.to_zarr` | ||
""" | ||
|
||
def __init__(self, path): | ||
self.ds_batches = xr.open_zarr(path) | ||
self.path = path | ||
|
||
def __iter__(self): | ||
for batch_id in self.ds_batches.batch_number.values: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not exactly why but codecov think something in this for loop is not being covered by the existing tests. Perhaps its the empty iterable ( |
||
ds_batch = self.ds_batches.sel(batch_number=batch_id) | ||
# create a MultiIndex like we had before storing the batches | ||
stacked_coords = [ | ||
d | ||
for d in ds_batch.coords | ||
if d not in ['sample', 'batch_number'] | ||
] | ||
ds_batch = ds_batch.set_index(sample=stacked_coords) | ||
yield ds_batch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import tempfile | ||
|
||
import numpy as np | ||
import xarray as xr | ||
|
||
import xbatcher | ||
|
||
|
||
def test_to_zarr(): | ||
da = xr.DataArray( | ||
np.random.rand(1000, 100, 100), name='foo', dims=['time', 'y', 'x'] | ||
).chunk({'time': 1}) | ||
|
||
bgen = xbatcher.BatchGenerator(da, {'time': 10}, preload_batch=False) | ||
|
||
for ds_batch in bgen: | ||
ds_first_batch = ds_batch | ||
break | ||
|
||
tempdir = tempfile.TemporaryDirectory().name | ||
bgen.to_zarr(tempdir) | ||
|
||
bgen_loaded = xbatcher.BatchGenerator.from_zarr(tempdir) | ||
|
||
for loaded_batch in bgen_loaded: | ||
loaded_first_batch = loaded_batch | ||
break | ||
|
||
# DataArray.equals doesn't work while the DataArray's are still stacked | ||
da_first_batch = ds_first_batch.unstack() | ||
da_loaded_first_batch = loaded_first_batch.unstack() | ||
# For some reason DataArray.equals doesn't work here, but DataArray.broadcast_equals did | ||
assert da_loaded_first_batch.broadcast_equals(da_first_batch) | ||
# I think this should mean that DataArray.equals should work | ||
assert (da_loaded_first_batch - da_first_batch).max() == 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test when
'batch' not in chunks