update training script
This commit is contained in:
parent
3b1a71883d
commit
2231423056
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
|||
*~
|
||||
cache/*.png
|
||||
models/*.png
|
||||
waifu2x.log
|
||||
|
|
41
convert_data.lua
Normal file
41
convert_data.lua
Normal file
|
@ -0,0 +1,41 @@
|
|||
require 'torch'
|
||||
local settings = require './lib/settings'
|
||||
local image_loader = require './lib/image_loader'
|
||||
|
||||
local function count_lines(file)
|
||||
local fp = io.open(file, "r")
|
||||
local count = 0
|
||||
for line in fp:lines() do
|
||||
count = count + 1
|
||||
end
|
||||
fp:close()
|
||||
|
||||
return count
|
||||
end
|
||||
|
||||
local function load_images(list)
|
||||
local count = count_lines(list)
|
||||
local fp = io.open(list, "r")
|
||||
local x = {}
|
||||
local c = 0
|
||||
for line in fp:lines() do
|
||||
local im = image_loader.load_byte(line)
|
||||
if im then
|
||||
if im:size(2) > settings.crop_size * 2 and im:size(3) > settings.crop_size * 2 then
|
||||
table.insert(x, im)
|
||||
end
|
||||
else
|
||||
print("error:" .. line)
|
||||
end
|
||||
c = c + 1
|
||||
xlua.progress(c, count)
|
||||
if c % 10 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
return x
|
||||
end
|
||||
print(settings)
|
||||
local x = load_images(settings.image_list)
|
||||
torch.save(settings.images, x)
|
||||
|
|
@ -2,10 +2,10 @@ require 'optim'
|
|||
require 'cutorch'
|
||||
require 'xlua'
|
||||
|
||||
local function minibatch_sgd(model, criterion,
|
||||
train_x,
|
||||
config, transformer,
|
||||
input_size, target_size)
|
||||
local function minibatch_adam(model, criterion,
|
||||
train_x,
|
||||
config, transformer,
|
||||
input_size, target_size)
|
||||
local parameters, gradParameters = model:getParameters()
|
||||
config = config or {}
|
||||
local sum_loss = 0
|
||||
|
@ -47,7 +47,6 @@ local function minibatch_sgd(model, criterion,
|
|||
model:backward(inputs, criterion:backward(output, targets))
|
||||
return f, gradParameters
|
||||
end
|
||||
-- must use Adam!!
|
||||
optim.adam(feval, parameters, config)
|
||||
|
||||
c = c + 1
|
||||
|
@ -60,4 +59,4 @@ local function minibatch_sgd(model, criterion,
|
|||
return { mse = sum_loss / count_loss}
|
||||
end
|
||||
|
||||
return minibatch_sgd
|
||||
return minibatch_adam
|
|
@ -6,7 +6,7 @@ local pairwise_transform = {}
|
|||
|
||||
function pairwise_transform.scale(src, scale, size, offset, options)
|
||||
options = options or {}
|
||||
local yi = torch.radom(0, src:size(2) - size - 1)
|
||||
local yi = torch.random(0, src:size(2) - size - 1)
|
||||
local xi = torch.random(0, src:size(3) - size - 1)
|
||||
local down_scale = 1.0 / scale
|
||||
local y = image.crop(src, xi, yi, xi + size, yi + size)
|
||||
|
|
|
@ -51,7 +51,7 @@ torch.setnumthreads(settings.core)
|
|||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||
|
||||
settings.validation_ratio = 01
|
||||
settings.validation_ratio = 0.1
|
||||
settings.validation_crops = 40
|
||||
settings.block_offset = 7 -- see srcnn.lua
|
||||
|
||||
|
|
16
train.lua
16
train.lua
|
@ -5,7 +5,7 @@ require 'xlua'
|
|||
require 'pl'
|
||||
|
||||
local settings = require './lib/settings'
|
||||
local minibatch_sgd = require './lib/minibatch_sgd'
|
||||
local minibatch_adam = require './lib/minibatch_adam'
|
||||
local iproc = require './lib/iproc'
|
||||
local create_model = require './lib/srcnn'
|
||||
local reconstract, reconstract_ch = require './lib/reconstract'
|
||||
|
@ -77,10 +77,6 @@ local function train()
|
|||
learningRate = settings.learning_rate,
|
||||
xBatchSize = settings.batch_size,
|
||||
}
|
||||
local denoise_model = nil
|
||||
if settings.method == "scale" and path.exists(settings.denoise_model_file) then
|
||||
denoise_model = torch.load(settings.denoise_model_file)
|
||||
end
|
||||
local transformer = function(x, is_validation)
|
||||
if is_validation == nil then is_validation = false end
|
||||
if settings.method == "scale" then
|
||||
|
@ -109,11 +105,11 @@ local function train()
|
|||
for epoch = 1, settings.epoch do
|
||||
model:training()
|
||||
print("# " .. epoch)
|
||||
print(minibatch_sgd(model, criterion, train_x, adam_config,
|
||||
transformer,
|
||||
{1, settings.crop_size, settings.crop_size},
|
||||
{1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
||||
))
|
||||
print(minibatch_adam(model, criterion, train_x, adam_config,
|
||||
transformer,
|
||||
{1, settings.crop_size, settings.crop_size},
|
||||
{1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
||||
))
|
||||
if epoch % 1 == 0 then
|
||||
collectgarbage()
|
||||
model:evaluate()
|
||||
|
|
10
train.sh
Executable file
10
train.sh
Executable file
|
@ -0,0 +1,10 @@
|
|||
#!/bin/sh
|
||||
|
||||
th train.lua -method noise -noise_level 1 -test images/miku_noise.png
|
||||
th cleanup_model.lua -model models/noise1_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -method noise -noise_level 2 -test images/miku_noise.png
|
||||
th cleanup_model.lua -model models/noise2_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -method scale -scale 2 -test images/miku_small.png
|
||||
th cleanup_model.lua -model models/scale2.0x_model.t7 -oformat ascii
|
Loading…
Reference in a new issue