-
Notifications
You must be signed in to change notification settings - Fork 69
/
train.lua
121 lines (101 loc) · 4.24 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
require 'nn'
require 'nngraph'
require 'rnn'
------------------------------------------------------------------------
-- Input arguments and options
------------------------------------------------------------------------
local opt = require 'opts';
print(opt)
-- seed for reproducibility
torch.manualSeed(1234);
-- set default tensor based on gpu usage
if opt.gpuid >= 0 then
require 'cutorch'
require 'cunn'
if opt.backend == 'cudnn' then require 'cudnn' end
cutorch.setDevice(opt.gpuid+1)
cutorch.manualSeed(1234)
torch.setdefaulttensortype('torch.CudaTensor');
else
torch.setdefaulttensortype('torch.FloatTensor');
end
-- transfer all options to model
local modelParams = opt;
------------------------------------------------------------------------
-- Read saved model and parameters
------------------------------------------------------------------------
local savedModel = false;
if opt.loadPath ~= '' then
savedModel = torch.load(opt.loadPath);
modelParams = savedModel.modelParams;
opt.imgNorm = modelParams.imgNorm;
opt.encoder = modelParams.encoder;
opt.decoder = modelParams.decoder;
modelParams.gpuid = opt.gpuid;
modelParams.batchSize = opt.batchSize;
end
------------------------------------------------------------------------
-- Loading dataset
------------------------------------------------------------------------
local dataloader = dofile('dataloader.lua');
dataloader:initialize(opt, {'train'});
collectgarbage();
------------------------------------------------------------------------
-- Setting model parameters
------------------------------------------------------------------------
-- transfer parameters from dataloader to model
paramNames = {'numTrainThreads', 'numTestThreads', 'numValThreads',
'vocabSize', 'maxQuesCount', 'maxQuesLen', 'maxAnsLen'};
for _, value in pairs(paramNames) do
modelParams[value] = dataloader[value];
end
-- path to save the model
local modelPath = opt.savePath
-- creating the directory to save the model
paths.mkdir(modelPath);
-- Iterations per epoch
modelParams.numIterPerEpoch = math.ceil(modelParams.numTrainThreads /
modelParams.batchSize);
print(string.format('\n%d iter per epoch.', modelParams.numIterPerEpoch));
------------------------------------------------------------------------
-- Setup the model
------------------------------------------------------------------------
require 'model'
local model = Model(modelParams);
if opt.loadPath ~= '' then
model.wrapperW:copy(savedModel.modelW);
model.optims.learningRate = savedModel.optims.learningRate;
end
------------------------------------------------------------------------
-- Training
------------------------------------------------------------------------
print('Training..')
collectgarbage()
runningLoss = 0;
for iter = 1, modelParams.numEpochs * modelParams.numIterPerEpoch do
-- forward and backward propagation
model:trainIteration(dataloader);
-- evaluate on val and save model
if iter % (modelParams.saveIter * modelParams.numIterPerEpoch) == 0 then
local currentEpoch = iter / modelParams.numIterPerEpoch
-- save model and optimization parameters
torch.save(string.format(modelPath .. 'model_epoch_%d.t7', currentEpoch),
{modelW = model.wrapperW,
optims = model.optims,
modelParams = modelParams})
-- validation accuracy
-- model:retrieve(dataloader, 'val');
end
-- print after every few iterations
if iter % 100 == 0 then
local currentEpoch = iter / modelParams.numIterPerEpoch;
-- print current time, running average, learning rate, iteration, epoch
print(string.format('[%s][Epoch:%.02f][Iter:%d][Loss:%.05f][lr:%f]',
os.date(), currentEpoch, iter, runningLoss,
model.optims.learningRate))
end
if iter % 10 == 0 then collectgarbage(); end
end
-- Saving the final model
torch.save(modelPath .. 'model_final.t7', {modelW = model.wrapperW:float(),
modelParams = modelParams});