require 'cutorch' require 'cunn' require 'optim' require 'xlua' require 'pl' local settings = require './lib/settings' local minibatch_sgd = require './lib/minibatch_sgd' local iproc = require './lib/iproc' local create_model = require './lib/srcnn' local reconstract, reconstract_ch = require './lib/reconstract' local pairwise_transform = require './lib/pairwise_transform' local image_loader = require './lib/image_loader' local function save_test_scale(model, rgb, file) local input = iproc.scale(rgb, rgb:size(3) * settings.scale, rgb:size(2) * settings.scale) local up = reconstract(model, input, settings.block_offset) image.save(file, up) end local function save_test_jpeg(model, rgb, file) local im, count = reconstract(model, rgb, settings.block_offset) image.save(file, im) end local function split_data(x, test_size) local index = torch.randperm(#x) local train_size = #x - test_size local train_x = {} local valid_x = {} for i = 1, train_size do train_x[i] = x[index[i]] end for i = 1, test_size do valid_x[i] = x[index[train_size + i]] end return train_x, valid_x end local function make_validation_set(x, transformer, n) n = n or 4 local data = {} for i = 1, #x do for k = 1, n do local x, y = transformer(x[i], true) table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)), y = y:reshape(1, y:size(1), y:size(2), y:size(3))}) end xlua.progress(i, #x) collectgarbage() end return data end local function validate(model, criterion, data) local loss = 0 for i = 1, #data do local z = model:forward(data[i].x:cuda()) loss = loss + criterion:forward(z, data[i].y:cuda()) xlua.progress(i, #data) if i % 10 == 0 then collectgarbage() end end return loss / #data end local function train() local model, offset = create_model() assert(offset == settings.block_offset) local criterion = nn.MSECriterion():cuda() local x = torch.load(settings.images) local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x), settings.validation_crops) local test = image_loader.load_float(settings.test) local adam_config = { learningRate = settings.learning_rate, xBatchSize = settings.batch_size, } local denoise_model = nil if settings.method == "scale" and path.exists(settings.denoise_model_file) then denoise_model = torch.load(settings.denoise_model_file) end local transformer = function(x, is_validation) if is_validation == nil then is_validation = false end if settings.method == "scale" then return pairwise_transform.scale(x, settings.scale, settings.crop_size, offset, {color_augment = not is_validation, noise = false, denoise_model = nil }) elseif settings.method == "noise" then return pairwise_transform.jpeg(x, settings.noise_level, settings.crop_size, offset, not is_validation) end end local best_score = 100000.0 print("# make validation-set") local valid_xy = make_validation_set(valid_x, transformer, 20) valid_x = nil collectgarbage() model:cuda() print("load .. " .. #train_x) for epoch = 1, settings.epoch do model:training() print("# " .. epoch) print(minibatch_sgd(model, criterion, train_x, adam_config, transformer, {1, settings.crop_size, settings.crop_size}, {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2} )) if epoch % 1 == 0 then collectgarbage() model:evaluate() print("# validation") local score = validate(model, criterion, valid_xy) if score < best_score then best_score = score print("* update best model") torch.save(settings.model_file, model) if settings.method == "noise" then local log = path.join(settings.model_dir, ("noise%d_best.png"):format(settings.noise_level)) save_test_jpeg(model, test, log) elseif settings.method == "scale" then local log = path.join(settings.model_dir, ("scale%.1f_best.png"):format(settings.scale)) save_test_scale(model, test, log) end end print("current: " .. score .. ", best: " .. best_score) end end end torch.manualSeed(settings.seed) cutorch.manualSeed(settings.seed) print(settings) train()