Add -resume_epoch option
This commit is contained in:
parent
6efd7f890e
commit
bdafab9e10
|
@ -76,6 +76,7 @@ cmd:option("-oracle_rate", 0.1, '')
|
||||||
cmd:option("-oracle_drop_rate", 0.5, '')
|
cmd:option("-oracle_drop_rate", 0.5, '')
|
||||||
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
||||||
cmd:option("-resume", "", 'resume model file')
|
cmd:option("-resume", "", 'resume model file')
|
||||||
|
cmd:option("-resume_epoch", 1, 'resume epoch')
|
||||||
cmd:option("-name", "user", 'model name for user method')
|
cmd:option("-name", "user", 'model name for user method')
|
||||||
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
|
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
|
||||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
||||||
|
|
17
train.lua
17
train.lua
|
@ -506,9 +506,24 @@ local function train()
|
||||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||||
local hist_train = {}
|
local hist_train = {}
|
||||||
local hist_valid = {}
|
local hist_valid = {}
|
||||||
|
local adam_config = {
|
||||||
|
xLearningRate = settings.learning_rate,
|
||||||
|
xBatchSize = settings.batch_size,
|
||||||
|
xLearningRateDecay = settings.learning_rate_decay,
|
||||||
|
xInstanceLoss = (settings.oracle_rate > 0)
|
||||||
|
}
|
||||||
local model
|
local model
|
||||||
if settings.resume:len() > 0 then
|
if settings.resume:len() > 0 then
|
||||||
model = torch.load(settings.resume, "ascii")
|
model = torch.load(settings.resume, "ascii")
|
||||||
|
adam_config.xEvalCount = math.floor((#train_x * settings.patches) / settings.batch_size) * settings.batch_size * settings.inner_epoch * (settings.resume_epoch - 1)
|
||||||
|
print(string.format("set eval count = %d", adam_config.xEvalCount))
|
||||||
|
if adam_config.xEvalCount > 0 then
|
||||||
|
adam_config.learningRate = adam_config.xLearningRate / (1 + adam_config.xEvalCount * adam_config.xLearningRateDecay)
|
||||||
|
print(string.format("set learning rate = %E", adam_config.learningRate))
|
||||||
|
else
|
||||||
|
adam_config.xEvalCount = 0
|
||||||
|
adam_config.learningRate = adam_config.xLearningRate
|
||||||
|
end
|
||||||
else
|
else
|
||||||
if stringx.endswith(settings.model, ".lua") then
|
if stringx.endswith(settings.model, ".lua") then
|
||||||
local create_model = dofile(settings.model)
|
local create_model = dofile(settings.model)
|
||||||
|
@ -576,7 +591,7 @@ local function train()
|
||||||
end
|
end
|
||||||
local instance_loss = nil
|
local instance_loss = nil
|
||||||
local pmodel = w2nn.data_parallel(model, settings.gpu)
|
local pmodel = w2nn.data_parallel(model, settings.gpu)
|
||||||
for epoch = 1, settings.epoch do
|
for epoch = settings.resume_epoch, settings.epoch do
|
||||||
pmodel:training()
|
pmodel:training()
|
||||||
print("# " .. epoch)
|
print("# " .. epoch)
|
||||||
if adam_config.learningRate then
|
if adam_config.learningRate then
|
||||||
|
|
Loading…
Reference in a new issue