From aaac6ed6e570e85c0c6bfeed30c4a3e662942d32 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Mon, 30 Nov 2015 17:18:52 +0900 Subject: [PATCH] Refactor training loop more shuffle --- lib/data_augmentation.lua | 4 +- lib/minibatch_adam.lua | 39 ++++++++------- lib/settings.lua | 8 +-- train.lua | 101 +++++++++++++++++++++++--------------- 4 files changed, 89 insertions(+), 63 deletions(-) diff --git a/lib/data_augmentation.lua b/lib/data_augmentation.lua index e9895fd..d9c0143 100644 --- a/lib/data_augmentation.lua +++ b/lib/data_augmentation.lua @@ -54,8 +54,8 @@ end function data_augmentation.unsharp_mask(src, p) if torch.uniform() < p then local radius = 0 -- auto - local sigma = torch.uniform(0.7, 1.4) - local amount = torch.uniform(0.5, 1.0) + local sigma = torch.uniform(0.7, 3.0) + local amount = torch.uniform(0.25, 0.75) local threshold = torch.uniform(0.0, 0.05) local unsharp = gm.Image(src, "RGB", "DHW"): unsharpMask(radius, sigma, amount, threshold): diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index 915f1b4..221d82c 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -3,30 +3,32 @@ require 'cutorch' require 'xlua' local function minibatch_adam(model, criterion, - train_x, - config, transformer, - input_size, target_size) + train_x, train_y, + config) local parameters, gradParameters = model:getParameters() config = config or {} local sum_loss = 0 local count_loss = 0 local batch_size = config.xBatchSize or 32 - local shuffle = torch.randperm(#train_x) + local shuffle = torch.randperm(train_x:size(1)) local c = 1 - local inputs = torch.Tensor(batch_size, - input_size[1], input_size[2], input_size[3]):cuda() - local targets = torch.Tensor(batch_size, - target_size[1] * target_size[2] * target_size[3]):cuda() local inputs_tmp = torch.Tensor(batch_size, - input_size[1], input_size[2], input_size[3]) + train_x:size(2), train_x:size(3), train_x:size(4)):zero() local targets_tmp = torch.Tensor(batch_size, - target_size[1] * target_size[2] * target_size[3]) - for t = 1, #train_x do - xlua.progress(t, #train_x) - local xy = transformer(train_x[shuffle[t]], false, batch_size) - for i = 1, #xy do - inputs_tmp[i]:copy(xy[i][1]) - targets_tmp[i]:copy(xy[i][2]) + train_y:size(2)):zero() + local inputs = inputs_tmp:clone():cuda() + local targets = targets_tmp:clone():cuda() + + print("## update") + for t = 1, train_x:size(1), batch_size do + if t + batch_size -1 > train_x:size(1) then + break + end + xlua.progress(t, train_x:size(1)) + + for i = 1, batch_size do + inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]]) + targets_tmp[i]:copy(train_y[shuffle[t + i - 1]]) end inputs:copy(inputs_tmp) targets:copy(targets_tmp) @@ -43,13 +45,12 @@ local function minibatch_adam(model, criterion, return f, gradParameters end optim.adam(feval, parameters, config) - c = c + 1 - if c % 20 == 0 then + if c % 50 == 0 then collectgarbage() end end - xlua.progress(#train_x, #train_x) + xlua.progress(train_x:size(1), train_x:size(1)) return { loss = sum_loss / count_loss} end diff --git a/lib/settings.lua b/lib/settings.lua index c086a0a..5d5cbde 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -32,11 +32,13 @@ cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image o cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)') cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)') cmd:option("-scale", 2.0, 'scale factor (2)') -cmd:option("-learning_rate", 0.00025, 'learning rate for adam') +cmd:option("-learning_rate", 0.001, 'learning rate for adam') cmd:option("-crop_size", 46, 'crop size') cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly') -cmd:option("-batch_size", 8, 'mini batch size') -cmd:option("-epoch", 200, 'number of total epochs to run') +cmd:option("-batch_size", 32, 'mini batch size') +cmd:option("-patches", 16, 'number of patch samples') +cmd:option("-inner_epoch", 4, 'number of inner epochs') +cmd:option("-epoch", 30, 'number of epochs to run') cmd:option("-thread", -1, 'number of CPU threads') cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)') cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)') diff --git a/train.lua b/train.lua index ee15f5a..16a9dee 100644 --- a/train.lua +++ b/train.lua @@ -35,14 +35,14 @@ local function split_data(x, test_size) end return train_x, valid_x end -local function make_validation_set(x, transformer, n, batch_size) +local function make_validation_set(x, transformer, n, patches) n = n or 4 local data = {} for i = 1, #x do - for k = 1, math.max(n / batch_size, 1) do - local xy = transformer(x[i], true, batch_size) - local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3)) - local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3)) + for k = 1, math.max(n / patches, 1) do + local xy = transformer(x[i], true, patches) + local tx = torch.Tensor(patches, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3)) + local ty = torch.Tensor(patches, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3)) for j = 1, #xy do tx[j]:copy(xy[j][1]) ty[j]:copy(xy[j][2]) @@ -83,14 +83,15 @@ local function create_criterion(model) end local function transformer(x, is_validation, n, offset) x = compression.decompress(x) - n = n or settings.batch_size; + n = n or settings.patches + if is_validation == nil then is_validation = false end local random_color_noise_rate = nil local random_overlay_rate = nil local active_cropping_rate = nil local active_cropping_tries = nil if is_validation then - active_cropping_rate = 0.0 + active_cropping_rate = 0 active_cropping_tries = 0 random_color_noise_rate = 0.0 random_overlay_rate = 0.0 @@ -137,7 +138,24 @@ local function transformer(x, is_validation, n, offset) end end +local function resampling(x, y, train_x, transformer, input_size, target_size) + print("## resampling") + for t = 1, #train_x do + xlua.progress(t, #train_x) + local xy = transformer(train_x[t], false, settings.patches) + for i = 1, #xy do + local index = (t - 1) * settings.patches + i + x[index]:copy(xy[i][1]) + y[index]:copy(xy[i][2]) + end + if t % 50 == 0 then + collectgarbage() + end + end +end + local function train() + local LR_MIN = 1.0e-5 local model = srcnn.create(settings.method, settings.backend, settings.color) local offset = reconstruct.offset_size(model) local pairwise_func = function(x, is_validation, n) @@ -145,12 +163,12 @@ local function train() end local criterion = create_criterion(model) local x = torch.load(settings.images) - local lrd_count = 0 local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x)) local adam_config = { learningRate = settings.learning_rate, xBatchSize = settings.batch_size, } + local lrd_count = 0 local ch = nil if settings.color == "y" then ch = 1 @@ -161,48 +179,53 @@ local function train() print("# make validation-set") local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops, - settings.batch_size) + settings.patches) valid_x = nil collectgarbage() model:cuda() print("load .. " .. #train_x) + + local x = torch.Tensor(settings.patches * #train_x, + ch, settings.crop_size, settings.crop_size) + local y = torch.Tensor(settings.patches * #train_x, + ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero() + for epoch = 1, settings.epoch do model:training() print("# " .. epoch) - print(minibatch_adam(model, criterion, train_x, adam_config, - pairwise_func, - {ch, settings.crop_size, settings.crop_size}, - {ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2} - )) - model:evaluate() - print("# validation") - local score = validate(model, criterion, valid_xy) - if score < best_score then - local test_image = image_loader.load_float(settings.test) -- reload - lrd_count = 0 - best_score = score - print("* update best model") - torch.save(settings.model_file, model) - if settings.method == "noise" then - local log = path.join(settings.model_dir, - ("noise%d_best.png"):format(settings.noise_level)) - save_test_jpeg(model, test_image, log) - elseif settings.method == "scale" then - local log = path.join(settings.model_dir, - ("scale%.1f_best.png"):format(settings.scale)) - save_test_scale(model, test_image, log) - end - else - lrd_count = lrd_count + 1 - if lrd_count > 5 then + resampling(x, y, train_x, pairwise_func) + for i = 1, settings.inner_epoch do + print(minibatch_adam(model, criterion, x, y, adam_config)) + model:evaluate() + print("# validation") + local score = validate(model, criterion, valid_xy) + if score < best_score then + local test_image = image_loader.load_float(settings.test) -- reload lrd_count = 0 - adam_config.learningRate = adam_config.learningRate * 0.9 - print("* learning rate decay: " .. adam_config.learningRate) + best_score = score + print("* update best model") + torch.save(settings.model_file, model) + if settings.method == "noise" then + local log = path.join(settings.model_dir, + ("noise%d_best.png"):format(settings.noise_level)) + save_test_jpeg(model, test_image, log) + elseif settings.method == "scale" then + local log = path.join(settings.model_dir, + ("scale%.1f_best.png"):format(settings.scale)) + save_test_scale(model, test_image, log) + end + else + lrd_count = lrd_count + 1 + if lrd_count > 2 and adam_config.learningRate > LR_MIN then + adam_config.learningRate = adam_config.learningRate * 0.8 + print("* learning rate decay: " .. adam_config.learningRate) + lrd_count = 0 + end end + print("current: " .. score .. ", best: " .. best_score) + collectgarbage() end - print("current: " .. score .. ", best: " .. best_score) - collectgarbage() end end if settings.gpu > 0 then