-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
66 lines (46 loc) · 2.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import sys
import argparse
import random
import logging
from pathlib import Path
from typing import (
Any,
List,
)
import torch
import numpy as np
from core.algos import BaseAlgo
from core.utils.configs.base import BaseConfig
from core.utils.exceptions import GPUNotAvailable
from core.utils.typing import ArgsObject
from core.utils.recorders import TensorboardRecorder
from experiments.tuner.basic import BasicTuner
logger = logging.getLogger(__name__)
def _parse_args(args: List[Any]) -> ArgsObject:
parser = argparse.ArgumentParser(description="The main file for running experiments")
parser.add_argument('--config-file', dest='config_file', type=Path, required=True, help='Expect a json file describing the fixed and sweeping parameters')
parser.add_argument('--gpu', action='store_true', help="Use GPU: if not specified, use CPU")
parser.add_argument('--verbose', dest='verbose', type=_to_logging_level, required=True, help="Logging level: info or debug")
parser.add_argument('--workers', dest='workers', default=-1, type=int, help="Number of workers used to run the experiments. -1 means that the number of runs are going to be automatically determined")
parser.add_argument('--run', dest='run', default=1, type=int, help="Number of times that each algorithm needs to be evaluated")
args = parser.parse_args(args)
return args
def _to_logging_level(level: str) -> int:
if level == 'info':
return logging.INFO
elif level == 'debug':
return logging.DEBUG
else:
raise ValueError("Unknown logging level: {}".format(level))
def _check_device_availability(flag: bool) -> None:
if flag and not torch.cuda.is_available():
raise ValueError("GPU is not available!")
def main(args: List[str]) -> None:
args = _parse_args(args)
_check_device_availability(args.gpu)
config_file = args.config_file.absolute()
tuner = BasicTuner.create_by_config_file(config_file, verbose=args.verbose, workers=args.workers, gpu=args.gpu, run=args.run)
tuner.tune()
if __name__ == '__main__':
main(sys.argv[1:])