1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

update training script

This commit is contained in:
nagadomi 2015-05-17 14:42:53 +09:00
parent 3b1a71883d
commit 2231423056
7 changed files with 65 additions and 18 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
*~
cache/*.png
models/*.png
waifu2x.log

41
convert_data.lua Normal file
View file

@ -0,0 +1,41 @@
require 'torch'
local settings = require './lib/settings'
local image_loader = require './lib/image_loader'
local function count_lines(file)
local fp = io.open(file, "r")
local count = 0
for line in fp:lines() do
count = count + 1
end
fp:close()
return count
end
local function load_images(list)
local count = count_lines(list)
local fp = io.open(list, "r")
local x = {}
local c = 0
for line in fp:lines() do
local im = image_loader.load_byte(line)
if im then
if im:size(2) > settings.crop_size * 2 and im:size(3) > settings.crop_size * 2 then
table.insert(x, im)
end
else
print("error:" .. line)
end
c = c + 1
xlua.progress(c, count)
if c % 10 == 0 then
collectgarbage()
end
end
return x
end
print(settings)
local x = load_images(settings.image_list)
torch.save(settings.images, x)

View file

@ -2,10 +2,10 @@ require 'optim'
require 'cutorch'
require 'xlua'
local function minibatch_sgd(model, criterion,
train_x,
config, transformer,
input_size, target_size)
local function minibatch_adam(model, criterion,
train_x,
config, transformer,
input_size, target_size)
local parameters, gradParameters = model:getParameters()
config = config or {}
local sum_loss = 0
@ -47,7 +47,6 @@ local function minibatch_sgd(model, criterion,
model:backward(inputs, criterion:backward(output, targets))
return f, gradParameters
end
-- must use Adam!!
optim.adam(feval, parameters, config)
c = c + 1
@ -60,4 +59,4 @@ local function minibatch_sgd(model, criterion,
return { mse = sum_loss / count_loss}
end
return minibatch_sgd
return minibatch_adam

View file

@ -6,7 +6,7 @@ local pairwise_transform = {}
function pairwise_transform.scale(src, scale, size, offset, options)
options = options or {}
local yi = torch.radom(0, src:size(2) - size - 1)
local yi = torch.random(0, src:size(2) - size - 1)
local xi = torch.random(0, src:size(3) - size - 1)
local down_scale = 1.0 / scale
local y = image.crop(src, xi, yi, xi + size, yi + size)

View file

@ -51,7 +51,7 @@ torch.setnumthreads(settings.core)
settings.images = string.format("%s/images.t7", settings.data_dir)
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
settings.validation_ratio = 01
settings.validation_ratio = 0.1
settings.validation_crops = 40
settings.block_offset = 7 -- see srcnn.lua

View file

@ -5,7 +5,7 @@ require 'xlua'
require 'pl'
local settings = require './lib/settings'
local minibatch_sgd = require './lib/minibatch_sgd'
local minibatch_adam = require './lib/minibatch_adam'
local iproc = require './lib/iproc'
local create_model = require './lib/srcnn'
local reconstract, reconstract_ch = require './lib/reconstract'
@ -77,10 +77,6 @@ local function train()
learningRate = settings.learning_rate,
xBatchSize = settings.batch_size,
}
local denoise_model = nil
if settings.method == "scale" and path.exists(settings.denoise_model_file) then
denoise_model = torch.load(settings.denoise_model_file)
end
local transformer = function(x, is_validation)
if is_validation == nil then is_validation = false end
if settings.method == "scale" then
@ -109,11 +105,11 @@ local function train()
for epoch = 1, settings.epoch do
model:training()
print("# " .. epoch)
print(minibatch_sgd(model, criterion, train_x, adam_config,
transformer,
{1, settings.crop_size, settings.crop_size},
{1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
))
print(minibatch_adam(model, criterion, train_x, adam_config,
transformer,
{1, settings.crop_size, settings.crop_size},
{1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
))
if epoch % 1 == 0 then
collectgarbage()
model:evaluate()

10
train.sh Executable file
View file

@ -0,0 +1,10 @@
#!/bin/sh
th train.lua -method noise -noise_level 1 -test images/miku_noise.png
th cleanup_model.lua -model models/noise1_model.t7 -oformat ascii
th train.lua -method noise -noise_level 2 -test images/miku_noise.png
th cleanup_model.lua -model models/noise2_model.t7 -oformat ascii
th train.lua -method scale -scale 2 -test images/miku_small.png
th cleanup_model.lua -model models/scale2.0x_model.t7 -oformat ascii