Add -resume_epoch option
This commit is contained in:
parent
6efd7f890e
commit
bdafab9e10
|
@ -76,6 +76,7 @@ cmd:option("-oracle_rate", 0.1, '')
|
|||
cmd:option("-oracle_drop_rate", 0.5, '')
|
||||
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
||||
cmd:option("-resume", "", 'resume model file')
|
||||
cmd:option("-resume_epoch", 1, 'resume epoch')
|
||||
cmd:option("-name", "user", 'model name for user method')
|
||||
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
|
||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
||||
|
|
17
train.lua
17
train.lua
|
@ -506,9 +506,24 @@ local function train()
|
|||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||
local hist_train = {}
|
||||
local hist_valid = {}
|
||||
local adam_config = {
|
||||
xLearningRate = settings.learning_rate,
|
||||
xBatchSize = settings.batch_size,
|
||||
xLearningRateDecay = settings.learning_rate_decay,
|
||||
xInstanceLoss = (settings.oracle_rate > 0)
|
||||
}
|
||||
local model
|
||||
if settings.resume:len() > 0 then
|
||||
model = torch.load(settings.resume, "ascii")
|
||||
adam_config.xEvalCount = math.floor((#train_x * settings.patches) / settings.batch_size) * settings.batch_size * settings.inner_epoch * (settings.resume_epoch - 1)
|
||||
print(string.format("set eval count = %d", adam_config.xEvalCount))
|
||||
if adam_config.xEvalCount > 0 then
|
||||
adam_config.learningRate = adam_config.xLearningRate / (1 + adam_config.xEvalCount * adam_config.xLearningRateDecay)
|
||||
print(string.format("set learning rate = %E", adam_config.learningRate))
|
||||
else
|
||||
adam_config.xEvalCount = 0
|
||||
adam_config.learningRate = adam_config.xLearningRate
|
||||
end
|
||||
else
|
||||
if stringx.endswith(settings.model, ".lua") then
|
||||
local create_model = dofile(settings.model)
|
||||
|
@ -576,7 +591,7 @@ local function train()
|
|||
end
|
||||
local instance_loss = nil
|
||||
local pmodel = w2nn.data_parallel(model, settings.gpu)
|
||||
for epoch = 1, settings.epoch do
|
||||
for epoch = settings.resume_epoch, settings.epoch do
|
||||
pmodel:training()
|
||||
print("# " .. epoch)
|
||||
if adam_config.learningRate then
|
||||
|
|
Loading…
Reference in a new issue