-
Notifications
You must be signed in to change notification settings - Fork 69
/
opts.lua
69 lines (59 loc) · 3.15 KB
/
opts.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
cmd = torch.CmdLine()
cmd:text('Train the Visual Dialog model')
cmd:text()
cmd:text('Options')
-- Data input settings
cmd:option('-inputImg', 'data/data_img.h5', 'HDF5 file with image features')
cmd:option('-inputQues', 'data/visdial_data.h5', 'HDF5 file with preprocessed questions')
cmd:option('-inputJson', 'data/visdial_params.json', 'JSON file with info and vocab')
cmd:option('-savePath', 'checkpoints/', 'Path to save checkpoints')
cmd:option('-saveIter', 2, 'Save model checkpoint after every saveIter epochs')
-- specify encoder/decoder
cmd:option('-encoder', 'lf-ques-hist', 'Name of the encoder to use')
cmd:option('-decoder', 'gen', 'Name of the decoder to use (gen/disc)')
cmd:option('-imgNorm', 1, 'normalize the image feature. 1=yes, 0=no')
-- model params
cmd:option('-imgEmbedSize', 300, 'Size of the multimodal embedding')
cmd:option('-imgFeatureSize', 4096, 'Channel size of the image feature')
cmd:option('-imgSpatialSize', 14, 'Spatial size of image features (for attention-based encoders).')
cmd:option('-embedSize', 300, 'Size of input word embeddings')
cmd:option('-rnnHiddenSize', 512, 'Size of the LSTM state')
cmd:option('-maxHistoryLen', 60, 'Maximum history to consider when using concatenated QA pairs')
cmd:option('-numLayers', 2, 'Number of layers in LSTM')
cmd:option('-commonEmbeddingSize', 512, 'Common embedding size in MN-ATT-QIH')
cmd:option('-numAttentionLayers', 1, 'No. of attention hops in MN-ATT-QIH')
cmd:option('-loadPath', '', 'Checkpoint path to load from')
-- optimization params
cmd:option('-batchSize', 40, 'Batch size (number of threads) (Adjust base on GPU memory)')
cmd:option('-learningRate', 1e-3, 'Learning rate')
cmd:option('-weightInit', 'xavier', 'Weight initialization strategy: xavier|heuristic|kaiming')
cmd:option('-dropout', 0.5, 'Dropout')
cmd:option('-numEpochs', 100, 'Epochs')
cmd:option('-LRateDecay', 10, 'After lr_decay epochs lr reduces to 0.1*lr')
cmd:option('-lrDecayRate', 0.9997592083, 'Decay for learning rate')
cmd:option('-minLRate', 5e-5, 'Minimum learning rate')
cmd:option('-gpuid', 0, 'GPU id to use')
cmd:option('-backend', 'cudnn', 'nn|cudnn')
local opts = cmd:parse(arg);
-- if save path is not given, use default — time
-- get the current time
local curTime = os.date('*t', os.time());
-- create another folder to avoid clutter
local modelPath = string.format('checkpoints/model-%d-%d-%d-%d:%d:%d-%s-%s/',
curTime.month, curTime.day, curTime.year,
curTime.hour, curTime.min, curTime.sec,
opts.encoder, opts.decoder)
if opts.savePath == 'checkpoints/' then opts.savePath = modelPath end;
-- check for inputs required
if string.match(opts.encoder, 'hist') then opts.useHistory = true end
if string.match(opts.encoder, 'im') then opts.useIm = true end
-- check if history is to be concatenated (only for late fusion encoder)
if string.match(opts.encoder, 'lf') then opts.concatHistory = true end
-- attention is always on conv features, not fc7
if string.match(opts.encoder, 'att') then
if opts.inputImg == 'data/data_img.h5' then
opts.inputImg = 'data/data_img_pool5.h5'
end
opts.imgNorm = 0
end
return opts;