From 22314230562dd534964d718fdadd5bc4039c58bc Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 17 May 2015 14:42:53 +0900 Subject: [PATCH] update training script --- .gitignore | 1 + convert_data.lua | 41 +++++++++++++++++++ lib/{minibatch_sgd.lua => minibatch_adam.lua} | 11 +++-- lib/pairwise_transform.lua | 2 +- lib/settings.lua | 2 +- train.lua | 16 +++----- train.sh | 10 +++++ 7 files changed, 65 insertions(+), 18 deletions(-) create mode 100644 convert_data.lua rename lib/{minibatch_sgd.lua => minibatch_adam.lua} (90%) create mode 100755 train.sh diff --git a/.gitignore b/.gitignore index 9c418f8..25859df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *~ cache/*.png +models/*.png waifu2x.log diff --git a/convert_data.lua b/convert_data.lua new file mode 100644 index 0000000..b81869a --- /dev/null +++ b/convert_data.lua @@ -0,0 +1,41 @@ +require 'torch' +local settings = require './lib/settings' +local image_loader = require './lib/image_loader' + +local function count_lines(file) + local fp = io.open(file, "r") + local count = 0 + for line in fp:lines() do + count = count + 1 + end + fp:close() + + return count +end + +local function load_images(list) + local count = count_lines(list) + local fp = io.open(list, "r") + local x = {} + local c = 0 + for line in fp:lines() do + local im = image_loader.load_byte(line) + if im then + if im:size(2) > settings.crop_size * 2 and im:size(3) > settings.crop_size * 2 then + table.insert(x, im) + end + else + print("error:" .. line) + end + c = c + 1 + xlua.progress(c, count) + if c % 10 == 0 then + collectgarbage() + end + end + return x +end +print(settings) +local x = load_images(settings.image_list) +torch.save(settings.images, x) + diff --git a/lib/minibatch_sgd.lua b/lib/minibatch_adam.lua similarity index 90% rename from lib/minibatch_sgd.lua rename to lib/minibatch_adam.lua index b5dad8f..49e16eb 100644 --- a/lib/minibatch_sgd.lua +++ b/lib/minibatch_adam.lua @@ -2,10 +2,10 @@ require 'optim' require 'cutorch' require 'xlua' -local function minibatch_sgd(model, criterion, - train_x, - config, transformer, - input_size, target_size) +local function minibatch_adam(model, criterion, + train_x, + config, transformer, + input_size, target_size) local parameters, gradParameters = model:getParameters() config = config or {} local sum_loss = 0 @@ -47,7 +47,6 @@ local function minibatch_sgd(model, criterion, model:backward(inputs, criterion:backward(output, targets)) return f, gradParameters end - -- must use Adam!! optim.adam(feval, parameters, config) c = c + 1 @@ -60,4 +59,4 @@ local function minibatch_sgd(model, criterion, return { mse = sum_loss / count_loss} end -return minibatch_sgd +return minibatch_adam diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index 24de37f..e3c41c3 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -6,7 +6,7 @@ local pairwise_transform = {} function pairwise_transform.scale(src, scale, size, offset, options) options = options or {} - local yi = torch.radom(0, src:size(2) - size - 1) + local yi = torch.random(0, src:size(2) - size - 1) local xi = torch.random(0, src:size(3) - size - 1) local down_scale = 1.0 / scale local y = image.crop(src, xi, yi, xi + size, yi + size) diff --git a/lib/settings.lua b/lib/settings.lua index 1678236..2d6522c 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -51,7 +51,7 @@ torch.setnumthreads(settings.core) settings.images = string.format("%s/images.t7", settings.data_dir) settings.image_list = string.format("%s/image_list.txt", settings.data_dir) -settings.validation_ratio = 01 +settings.validation_ratio = 0.1 settings.validation_crops = 40 settings.block_offset = 7 -- see srcnn.lua diff --git a/train.lua b/train.lua index 120119d..95f9f28 100644 --- a/train.lua +++ b/train.lua @@ -5,7 +5,7 @@ require 'xlua' require 'pl' local settings = require './lib/settings' -local minibatch_sgd = require './lib/minibatch_sgd' +local minibatch_adam = require './lib/minibatch_adam' local iproc = require './lib/iproc' local create_model = require './lib/srcnn' local reconstract, reconstract_ch = require './lib/reconstract' @@ -77,10 +77,6 @@ local function train() learningRate = settings.learning_rate, xBatchSize = settings.batch_size, } - local denoise_model = nil - if settings.method == "scale" and path.exists(settings.denoise_model_file) then - denoise_model = torch.load(settings.denoise_model_file) - end local transformer = function(x, is_validation) if is_validation == nil then is_validation = false end if settings.method == "scale" then @@ -109,11 +105,11 @@ local function train() for epoch = 1, settings.epoch do model:training() print("# " .. epoch) - print(minibatch_sgd(model, criterion, train_x, adam_config, - transformer, - {1, settings.crop_size, settings.crop_size}, - {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2} - )) + print(minibatch_adam(model, criterion, train_x, adam_config, + transformer, + {1, settings.crop_size, settings.crop_size}, + {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2} + )) if epoch % 1 == 0 then collectgarbage() model:evaluate() diff --git a/train.sh b/train.sh new file mode 100755 index 0000000..f6cf254 --- /dev/null +++ b/train.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +th train.lua -method noise -noise_level 1 -test images/miku_noise.png +th cleanup_model.lua -model models/noise1_model.t7 -oformat ascii + +th train.lua -method noise -noise_level 2 -test images/miku_noise.png +th cleanup_model.lua -model models/noise2_model.t7 -oformat ascii + +th train.lua -method scale -scale 2 -test images/miku_small.png +th cleanup_model.lua -model models/scale2.0x_model.t7 -oformat ascii