Skip to content

Commit

Permalink
collect blend test
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jul 5, 2024
1 parent d83845a commit 7a0180e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 70 deletions.
18 changes: 7 additions & 11 deletions nsrdb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,6 @@ def blend(ctx, config, verbose=False, pipeline_step=None, collect=False):
"""Blend files from separate domains (e.g. east / west) into a single
domain."""

func = Collector.collect_dir if collect else Blender.run_full
mod_name = ModuleName.COLLECT_BLEND if collect else ModuleName.BLEND

config = BaseCLI.from_config_preflight(
Expand All @@ -711,30 +710,27 @@ def blend(ctx, config, verbose=False, pipeline_step=None, collect=False):
pipeline_step=pipeline_step,
)

file_tags = config.get(
'file_tag', ['_'.join(k.split('_')[1:-1]) for k in NSRDB.OUTS]
)
file_tags = file_tags if isinstance(file_tags, list) else [file_tags]

if collect:
BaseCLI.kickoff_single(
BaseCLI.kickoff_job(
ctx=ctx,
module_name=mod_name,
func=func,
func=Collector.collect_dir,
config=config,
verbose=verbose,
pipeline_step=pipeline_step,
)

else:
file_tags = config.get(
'file_tag', ['_'.join(k.split('_')[1:-1]) for k in NSRDB.OUTS]
)
file_tags = file_tags if isinstance(file_tags, list) else [file_tags]
for file_tag in file_tags:
log_id = file_tag
config['job_name'] = f'{ctx.obj["RUN_NAME"]}_{log_id}'
config['file_tag'] = file_tag
BaseCLI.kickoff_job(
ctx=ctx,
module_name=mod_name,
func=func,
func=Blender.run_full,
config=config,
log_id=log_id,
)
Expand Down
75 changes: 43 additions & 32 deletions nsrdb/file_handlers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, collect_dir, dset):
Dataset/var name that is searched for in file names in collect_dir.
"""

self.collect_dir = collect_dir
self.flist = self.get_flist(collect_dir, dset)

if not any(self.flist):
Expand Down Expand Up @@ -109,7 +110,20 @@ def verify_flist(flist, d, var):
)

@staticmethod
def get_flist(d, var):
def filter_flist(flist, collect_dir, dset):
"""Filter file list so that only remaining files have given dset."""
filt_list = []
for fn in flist:
fp = os.path.join(collect_dir, fn)
with Outputs(fp, mode='r') as fobj:
if dset in fobj.dsets:
filt_list.append(fn)

logger.debug(f'Found files for "{dset}": {filt_list}')
return filt_list

@staticmethod
def get_flist(d, dset):
"""Get a date-sorted .h5 file list for a given var.
Filename requirements:
Expand All @@ -121,7 +135,7 @@ def get_flist(d, var):
----------
d : str
Directory to get file list from.
var : str
dset : str
Variable name that is searched for in files in d.
Returns
Expand All @@ -131,20 +145,10 @@ def get_flist(d, var):
Sorted by integer before the first underscore in the filename.
"""

flist = []
temp = os.listdir(d)
temp = [f for f in temp if f.endswith('.h5') and var in f]

for fn in temp:
fp = os.path.join(d, fn)
with Outputs(fp, mode='r') as fobj:
if var in fobj.dsets:
flist.append(fn)

flist = sorted(flist, key=lambda x: int(x.split('_')[0]))
logger.debug('Found files for "{}": {}'.format(var, flist))

return flist
temp = [f for f in temp if f.endswith('.h5') and dset in f]
flist = Collector.filter_flist(temp, collect_dir=d, dset=dset)
return sorted(flist, key=lambda x: int(x.split('_')[0]))

@staticmethod
def get_slices(final_time_index, final_meta, new_time_index, new_meta):
Expand Down Expand Up @@ -308,7 +312,7 @@ def _special_attrs(dset, dset_attrs):

@staticmethod
def _get_collection_attrs(
flist, collect_dir, dset, sites=None, sort=True, sort_key=None
flist, collect_dir, sites=None, sort=True, sort_key=None
):
"""Get important dataset attributes from a file list to be collected.
Expand Down Expand Up @@ -340,8 +344,6 @@ def _get_collection_attrs(
collected
shape : tuple
Output (collected) dataset shape
dtype : str
Dataset output (collected on disk) dataset data type.
"""

if sort:
Expand Down Expand Up @@ -373,11 +375,7 @@ def _get_collection_attrs(

shape = (len(time_index), len(meta))

fp0 = os.path.join(collect_dir, flist[0])
with Outputs(fp0, mode='r') as fin:
dtype = fin.get_dset_properties(dset)[1]

return time_index, meta, shape, dtype
return time_index, meta, shape

@staticmethod
def _init_collected_h5(f_out, time_index, meta):
Expand Down Expand Up @@ -434,8 +432,9 @@ def _ensure_dset_in_output(f_out, dset, var_meta=None, data=None):
dset, f.shape, dtype, chunks=chunks, attrs=attrs, data=data
)

@staticmethod
@classmethod
def collect_flist(
cls,
flist,
collect_dir,
f_out,
Expand All @@ -448,6 +447,8 @@ def collect_flist(
):
"""Collect a dataset from a file list with data pre-init.
Note
----
Collects data that can be chunked in both space and time.
Parameters
Expand Down Expand Up @@ -475,8 +476,11 @@ def collect_flist(
None uses all available.
"""

time_index, meta, shape, _ = Collector._get_collection_attrs(
flist, collect_dir, dset, sites=sites, sort=sort, sort_key=sort_key
flist = cls.filter_flist(
flist=flist, collect_dir=collect_dir, dset=dset
)
time_index, meta, shape = Collector._get_collection_attrs(
flist, collect_dir, sites=sites, sort=sort, sort_key=sort_key
)

attrs, _, final_dtype = VarFactory.get_dset_attrs(
Expand Down Expand Up @@ -628,8 +632,8 @@ def collect_flist_lowmem(
)

if not os.path.exists(f_out):
time_index, meta, _, _ = Collector._get_collection_attrs(
flist, collect_dir, dset, sort=sort, sort_key=sort_key
time_index, meta, _ = Collector._get_collection_attrs(
flist, collect_dir, sort=sort, sort_key=sort_key
)

Collector._init_collected_h5(f_out, time_index, meta)
Expand Down Expand Up @@ -744,8 +748,8 @@ def collect_daily(
raise ValueError(e)

if not os.path.exists(f_out):
time_index, meta, _, _ = collector._get_collection_attrs(
collector.flist, collect_dir, dset, sites=sites
time_index, meta, _ = collector._get_collection_attrs(
collector.flist, collect_dir, sites=sites
)
collector._init_collected_h5(f_out, time_index, meta)

Expand Down Expand Up @@ -898,7 +902,10 @@ def collect_dir(
and os.path.join(collect_dir, fn) != fout
]
flist = sorted(
flist, key=lambda x: int(x.replace('.h5', '').split('_')[-1])
flist,
key=lambda x: int(x.replace('.h5', '').split('_')[-1])
if x.replace('.h5', '').split('_')[-1].isdigit()
else x,
)

logger.info(f'Collecting chunks from {len(flist)} files to: {fout}')
Expand All @@ -909,5 +916,9 @@ def collect_dir(

for dset in dsets:
cls.collect_flist(
flist, collect_dir, fout, dset, max_workers=max_workers
flist,
collect_dir,
fout,
dset,
max_workers=max_workers,
)
95 changes: 68 additions & 27 deletions tests/cli/test_blend_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,64 @@ def test_blend_cli(runner):
os.mkdir(east_dir)
os.mkdir(west_dir)

east_fp = os.path.join(east_dir, 'nsrdb_conus_east_irradiance.h5')
west_fp = os.path.join(west_dir, 'nsrdb_conus_west_irradiance.h5')
out_fp = os.path.join(out_dir, 'nsrdb_conus_irradiance.h5')

dsets = ['dni', 'dhi', 'ghi']
attrs = {d: {'scale_factor': 1, 'units': 'unitless'} for d in dsets}
chunks = dict.fromkeys(dsets)
dtypes = dict.fromkeys(dsets, 'uint16')

Outputs.init_h5(
east_fp, dsets, attrs, chunks, dtypes, time_index, meta_east
)
Outputs.init_h5(
west_fp, dsets, attrs, chunks, dtypes, time_index, meta_west
)

with Outputs(east_fp, mode='a') as f:
for dset in dsets:
f[dset] = np.zeros((8760, len(meta_out)))

with Outputs(west_fp, mode='a') as f:
for dset in dsets:
f[dset] = np.ones((8760, len(meta_out)))
east_fps = [
os.path.join(east_dir, 'nsrdb_conus_east_irradiance.h5'),
os.path.join(east_dir, 'nsrdb_conus_east_clearsky.h5'),
]
west_fps = [
os.path.join(west_dir, 'nsrdb_conus_west_irradiance.h5'),
os.path.join(west_dir, 'nsrdb_conus_west_clearsky.h5'),
]

dsets_cld = ['dni', 'dhi', 'ghi']
dsets_clr = [f'clearsky_{dset}' for dset in dsets_cld]

for i, dsets in enumerate([dsets_cld, dsets_clr]):
attrs = {
d: {'scale_factor': 1, 'units': 'unitless'} for d in dsets
}
chunks = dict.fromkeys(dsets)
dtypes = dict.fromkeys(dsets, 'uint16')
Outputs.init_h5(
east_fps[i],
dsets,
attrs,
chunks,
dtypes,
time_index,
meta_east,
)
Outputs.init_h5(
west_fps[i],
dsets,
attrs,
chunks,
dtypes,
time_index,
meta_west,
)

with Outputs(east_fps[i], mode='a') as f:
for dset in dsets:
f[dset] = np.zeros((8760, len(meta_out)))

with Outputs(west_fps[i], mode='a') as f:
for dset in dsets:
f[dset] = np.ones((8760, len(meta_out)))

config = {
'blend': {
'meta': meta_path,
'out_dir': out_dir,
'east_dir': east_dir,
'west_dir': west_dir,
'file_tag': 'nsrdb_conus_',
'file_tag': ['irradiance', 'clearsky'],
'map_col': 'gid_full_map',
'lon_seam': lon_seam,
},
}

config_file = os.path.join(td, 'config.json')
config_file = os.path.join(td, 'config_blend.json')
with open(config_file, 'w') as f:
f.write(json.dumps(config))

Expand All @@ -82,10 +103,30 @@ def test_blend_cli(runner):
*result.exc_info
)

with Outputs(out_fp) as out:
fout = os.path.join(td, 'final_blend.h5')
config = {
'collect-blend': {
'collect_dir': out_dir,
'meta_final': meta_path,
'collect_tag': 'nsrdb_conus_',
'fout': fout,
'max_workers': 1,
},
}

config_file = os.path.join(td, 'config_collect_blend.json')
with open(config_file, 'w') as f:
f.write(json.dumps(config))

result = runner.invoke(cli.blend, ['-c', config_file, '--collect'])
assert result.exit_code == 0, traceback.print_exception(
*result.exc_info
)

with Outputs(fout) as out:
west_mask = out.meta.longitude < lon_seam
east_mask = out.meta.longitude >= lon_seam
for dset in dsets:
for dset in dsets_cld + dsets_clr:
data = out[dset]
assert (data[:, west_mask] == 1).all()
assert (data[:, east_mask] == 0).all()
Expand Down

0 comments on commit 7a0180e

Please sign in to comment.