From c89fd7249ab3d87020978a608b0a490c2c2d7889 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 2 Jun 2016 10:11:15 +0900 Subject: [PATCH] Add learning_rate_decay --- lib/minibatch_adam.lua | 8 ++++++++ lib/settings.lua | 1 + train.lua | 21 +++++++-------------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index 66daa2e..dbe70f4 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -7,6 +7,11 @@ local function minibatch_adam(model, criterion, eval_metric, config) local parameters, gradParameters = model:getParameters() config = config or {} + if config.xEvalCount == nil then + config.xEvalCount = 0 + config.learningRate = config.xLearningRate + end + local sum_loss = 0 local sum_eval = 0 local count_loss = 0 @@ -52,11 +57,14 @@ local function minibatch_adam(model, criterion, eval_metric, return f, gradParameters end optim.adam(feval, parameters, config) + config.xEvalCount = config.xEvalCount + batch_size + config.learningRate = config.xLearningRate / (1 + config.xEvalCount * config.xLearningRateDecay) c = c + 1 if c % 50 == 0 then collectgarbage() xlua.progress(t, train_x:size(1)) end + end xlua.progress(train_x:size(1), train_x:size(1)) return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss diff --git a/lib/settings.lua b/lib/settings.lua index f7a37ba..1ed8552 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -58,6 +58,7 @@ cmd:option("-resize_blur_min", 0.85, 'min blur parameter for ResizeImage') cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage') cmd:option("-oracle_rate", 0.0, '') cmd:option("-oracle_drop_rate", 0.5, '') +cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))') local function to_bool(settings, name) if settings[name] == 1 then diff --git a/train.lua b/train.lua index 36a69b8..b1e8d28 100644 --- a/train.lua +++ b/train.lua @@ -100,7 +100,6 @@ local function create_criterion(model) local offset = reconstruct.offset_size(model) local output_w = settings.crop_size - offset * 2 local weight = torch.Tensor(3, output_w * output_w) - weight[1]:fill(0.29891 * 3) -- R weight[2]:fill(0.58661 * 3) -- G weight[3]:fill(0.11448 * 3) -- B @@ -223,8 +222,8 @@ local function remove_small_image(x) local new_x = {} for i = 1, #x do local x_s = compression.size(x[i]) - if x_s[2] / settings.scale > settings.crop_size + 16 and - x_s[3] / settings.scale > settings.crop_size + 16 then + if x_s[2] / settings.scale > settings.crop_size + 32 and + x_s[3] / settings.scale > settings.crop_size + 32 then table.insert(new_x, x[i]) end if i % 100 == 0 then @@ -253,10 +252,10 @@ local function train() local x = remove_small_image(torch.load(settings.images)) local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local adam_config = { - learningRate = settings.learning_rate, + xLearningRate = settings.learning_rate, xBatchSize = settings.batch_size, + xLearningRateDecay = settings.learning_rate_decay } - local lrd_count = 0 local ch = nil if settings.color == "y" then ch = 1 @@ -285,10 +284,12 @@ local function train() ch, settings.crop_size, settings.crop_size) end local instance_loss = nil - for epoch = 1, settings.epoch do model:training() print("# " .. epoch) + if adam_config.learningRate then + print("learning rate: " .. adam_config.learningRate) + end print("## resampling") if instance_loss then -- active learning @@ -323,7 +324,6 @@ local function train() end if score.loss < best_score then local test_image = image_loader.load_float(settings.test) -- reload - lrd_count = 0 best_score = score.loss print("* update best model") if settings.save_history then @@ -351,13 +351,6 @@ local function train() save_test_scale(model, test_image, log) end end - else - lrd_count = lrd_count + 1 - if lrd_count > 2 then - adam_config.learningRate = adam_config.learningRate * 0.874 - print("* learning rate decay: " .. adam_config.learningRate) - lrd_count = 0 - end end print("PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score) collectgarbage()