Skip to content

Commit

Permalink
agg collect cli test
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jul 5, 2024
1 parent 074f0fa commit d83845a
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 16 deletions.
18 changes: 10 additions & 8 deletions nsrdb/aggregation/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,7 @@ def _agg_var_serial(self, var, method):

return arr

def _agg_var_parallel(self, var, method):
def _agg_var_parallel(self, var, method, max_workers=None):
"""Aggregate one var for all sites in this chunk in parallel.
Parameters
Expand All @@ -1545,7 +1545,7 @@ def _agg_var_parallel(self, var, method):
arr = self._init_arr(var)

loggers = ['farms', 'nsrdb']
with SpawnProcessPool(loggers=loggers) as exe:
with SpawnProcessPool(loggers=loggers, max_workers=max_workers) as exe:
logger.debug('Submitting futures...')
for i in range(len(self.meta_chunk)):
args = self._get_args(var, i)
Expand Down Expand Up @@ -1590,7 +1590,7 @@ def run_chunk(
n_chunks,
year=2018,
ignore_dsets=None,
parallel=True,
max_workers=None,
log_file='run_agg_chunk.log',
log_level='DEBUG',
):
Expand All @@ -1613,8 +1613,8 @@ def run_chunk(
Year being analyzed.
ignore_dsets : list | None
Source datasets to ignore (not aggregate). Optional.
parallel : bool
Flag to use parallel compute.
max_workers : int | None
Number of workers to user. Runs serially if max_workers == 1
log_file : str
File to use for logging
log_level : str | bool
Expand Down Expand Up @@ -1662,10 +1662,12 @@ def run_chunk(
var, i_var, n_var, method
)
)
if parallel:
arr = m._agg_var_parallel(var, method)
else:
if max_workers == 1:
arr = m._agg_var_serial(var, method)
else:
arr = m._agg_var_parallel(
var, method, max_workers=max_workers
)

m.write_output(arr, var)

Expand Down
9 changes: 7 additions & 2 deletions nsrdb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,17 @@ def main(ctx, config, verbose):
\b
{
"logging": {"log_level": "DEBUG"},
"<command name>": {kwargs},
"<command name>": {'run_name': ...,
**kwargs},
"direct": {more kwargs},
"execution_control": {"option": "kestrel", ...}
"another command": {...},
...
]
}
The "run_name" key will be prepended to each kicked off job. e.g.
<run_name>_0, <run_name>_1, ... for multiple jobs from the same cli module.
The "direct" key is used to provide arguments to multiple commands. This
removes the need for duplication in the case of multiple commands having
the same argument values. "execution_control" is used to provide arguments
Expand Down Expand Up @@ -764,7 +767,9 @@ def aggregate(ctx, config, verbose=False, pipeline_step=None, collect=False):
match resolution of low-resolution years (pre 2018)
"""
func = Collector.collect_dir if collect else Manager.run_chunk
mod_name = ModuleName.COLLECT_AGG if collect else ModuleName.AGGREGATE
mod_name = (
ModuleName.COLLECT_AGGREGATE if collect else ModuleName.AGGREGATE
)
kickoff_func = (
BaseCLI.kickoff_single if collect else BaseCLI.kickoff_multichunk
)
Expand Down
7 changes: 3 additions & 4 deletions nsrdb/file_handlers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,10 +882,9 @@ def collect_dir(
Desired log level, None will not initialize logging.
"""

if log_level is not None:
init_logger(
'nsrdb.file_handlers', log_file=log_file, log_level=log_level
)
init_logger(
'nsrdb.file_handlers', log_file=log_file, log_level=log_level
)

if isinstance(meta_final, str):
meta_final = pd.read_csv(meta_final, index_col=0)
Expand Down
2 changes: 1 addition & 1 deletion nsrdb/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ModuleName(str, Enum):
BLEND = 'blend'
COLLECT_BLEND = 'collect-blend'
AGGREGATE = 'aggregate'
COLLECT_AGG = 'collect-agg'
COLLECT_AGGREGATE = 'collect-aggregate'
COLLECT_DATA_MODEL = 'collect-data-model'
COLLECT_FINAL = 'collect-final'
TMY = 'tmy'
Expand Down
27 changes: 26 additions & 1 deletion tests/cli/test_agg_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def test_agg_cli(runner):
'data': TESTJOB3,
'data_dir': td,
'meta_dir': meta_dir,
'n_chunks': 2,
'year': 2018,
'ignore_dsets': IGNORE_DSETS,
'parallel': False,
'max_workers': 1,
},
}

Expand All @@ -100,6 +101,30 @@ def test_agg_cli(runner):
dsets = ('dni', 'aod', 'cloud_type', 'cld_opd_dcomp')
assert all(d in f for d in dsets)

fout = os.path.join(td, 'final_agg.h5')
config = {
'collect-aggregate': {
'collect_dir': os.path.join(td, 'agg_out'),
'meta_final': os.path.join(meta_dir, 'test_meta_agg.csv'),
'collect_tag': 'agg_out_',
'fout': fout,
'max_workers': 1,
},
}

config_file = os.path.join(td, 'config.json')
with open(config_file, 'w') as f:
f.write(json.dumps(config))

result = runner.invoke(cli.aggregate, ['-c', config_file, '--collect'])
assert result.exit_code == 0, traceback.print_exception(
*result.exc_info
)

with NSRDB(fout, mode='r') as f:
dsets = ('dni', 'aod', 'cloud_type', 'cld_opd_dcomp')
assert all(d in f for d in dsets)


if __name__ == '__main__':
execute_pytest(__file__)

0 comments on commit d83845a

Please sign in to comment.