1
0
Fork 0
mirror of synced 2024-05-16 10:52:20 +12:00

Add -resume_epoch option

This commit is contained in:
nagadomi 2018-10-29 17:25:34 +00:00
parent 6efd7f890e
commit bdafab9e10
2 changed files with 17 additions and 1 deletions

View file

@ -76,6 +76,7 @@ cmd:option("-oracle_rate", 0.1, '')
cmd:option("-oracle_drop_rate", 0.5, '')
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
cmd:option("-resume", "", 'resume model file')
cmd:option("-resume_epoch", 1, 'resume epoch')
cmd:option("-name", "user", 'model name for user method')
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')

View file

@ -506,9 +506,24 @@ local function train()
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local hist_train = {}
local hist_valid = {}
local adam_config = {
xLearningRate = settings.learning_rate,
xBatchSize = settings.batch_size,
xLearningRateDecay = settings.learning_rate_decay,
xInstanceLoss = (settings.oracle_rate > 0)
}
local model
if settings.resume:len() > 0 then
model = torch.load(settings.resume, "ascii")
adam_config.xEvalCount = math.floor((#train_x * settings.patches) / settings.batch_size) * settings.batch_size * settings.inner_epoch * (settings.resume_epoch - 1)
print(string.format("set eval count = %d", adam_config.xEvalCount))
if adam_config.xEvalCount > 0 then
adam_config.learningRate = adam_config.xLearningRate / (1 + adam_config.xEvalCount * adam_config.xLearningRateDecay)
print(string.format("set learning rate = %E", adam_config.learningRate))
else
adam_config.xEvalCount = 0
adam_config.learningRate = adam_config.xLearningRate
end
else
if stringx.endswith(settings.model, ".lua") then
local create_model = dofile(settings.model)
@ -576,7 +591,7 @@ local function train()
end
local instance_loss = nil
local pmodel = w2nn.data_parallel(model, settings.gpu)
for epoch = 1, settings.epoch do
for epoch = settings.resume_epoch, settings.epoch do
pmodel:training()
print("# " .. epoch)
if adam_config.learningRate then