-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_task_list.py
64 lines (62 loc) · 2.43 KB
/
train_task_list.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
'''The model and optimizee used when training.'''
from torch import optim
import nn_optimizer
import optimizee
# Only `MNIST attack` task support yet
tasks = {
# train ZO optimizer (UpdateRNN only) for MNIST attack
'ZOL2L-Attack': {
'nn_optimizer': nn_optimizer.zoopt.ZOOptimizer,
'optimizee': optimizee.mnist.MnistAttack,
'batch_size': 1,
'test_batch_size': 1,
'lr': 1e-3,
"max_epoch": 20,
'optimizer_steps': 200,
'test_optimizer_steps': 200,
'attack_model': optimizee.mnist.MnistConvModel,
'attack_model_ckpt': "./ckpt/attack_model/mnist_cnn.pt",
'tests': {
'optimizee': optimizee.mnist.MnistAttack,
'test_indexes': list(range(1, 11)), # test image indexes
'test_num': 10, # number of independent attacks
'n_steps': 200,
'test_batch_size': 1,
'nn_opt': nn_optimizer.zoopt.ZOOptimizer,
'base_opt': nn_optimizer.basezoopt.BaseZOOptimizer,
'base_lr': 4,
}
},
# train ZO optimizer (both UpdateRNN and QueryRNN) for MNIST attack
'VarReducedZOL2L-Attack': {
'nn_optimizer': nn_optimizer.zoopt.VarReducedZOOptimizer,
'optimizee': optimizee.mnist.MnistAttack,
'batch_size': 1,
'test_batch_size': 1,
'lr': 0.005,
"max_epoch": 40,
'optimizer_steps': 200,
'test_optimizer_steps': 200,
'attack_model': optimizee.mnist.MnistConvModel,
'attack_model_ckpt': "./ckpt/attack_model/mnist_cnn.pt",
'tests': {
'optimizee': optimizee.mnist.MnistAttack,
'test_indexes': list(range(1, 11)), # test image indexes
'test_num': 10, # number of independent attacks
'n_steps': 200,
'test_batch_size': 1,
'nn_opt': nn_optimizer.zoopt.VarReducedZOOptimizer,
'base_opt': nn_optimizer.basezoopt.BaseZOOptimizer,
'base_lr': 4,
'sign_opt': nn_optimizer.basezoopt.SignZOOptimizer,
'sign_lr': 8,
'adam_opt': nn_optimizer.basezoopt.AdamZOOptimizer,
'adam_lr': 8,
'adam_beta_1': 0.9,
'adam_beta_2': 0.996,
# 'nn_opt_no_query': nn_optimizer.zoopt.VarReducedZOOptimizer,
# 'nn_opt_no_update': nn_optimizer.zoopt.VarReducedZOOptimizer,
# 'nn_opt_guided': nn_optimizer.zoopt.VarReducedZOOptimizer,
}
},
}