1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00
This commit is contained in:
nagadomi 2018-01-26 08:09:40 +09:00
commit fa6ee00624
14 changed files with 516 additions and 140 deletions

View file

@ -63,7 +63,24 @@ local function crop_if_large_pair(x, y, max_size)
return x, y
end
end
local function padding_x(x, pad)
if pad > 0 then
x = iproc.padding(x, pad, pad, pad, pad)
end
return x
end
local function padding_xy(x, y, pad, y_zero)
local scale = y:size(2) / x:size(2)
if pad > 0 then
x = iproc.padding(x, pad, pad, pad, pad)
if y_zero then
y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
else
y = iproc.padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
end
end
return x, y
end
local function load_images(list)
local MARGIN = 32
local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
@ -78,6 +95,7 @@ local function load_images(list)
if csv_meta and csv_meta.filters then
filters = csv_meta.filters
end
local basename_y = path.basename(filename)
local im, meta = image_loader.load_byte(filename)
local skip = false
local alpha_color = torch.random(0, 1)
@ -100,25 +118,38 @@ local function load_images(list)
-- method == user
local yy = im
local xx, meta2 = image_loader.load_byte(csv_meta.x)
if settings.invert_x then
xx = (-(xx:long()) + 255):byte()
end
if xx then
if meta2 and meta2.alpha then
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
end
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
if settings.grayscale then
xx = iproc.rgb2y(xx)
yy = iproc.rgb2y(yy)
end
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
{data = {filters = filters, has_x = true}}})
{data = {filters = filters, has_x = true, basename = basename_y}}})
else
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
end
else
im = crop_if_large(im, settings.max_training_image_size)
im = iproc.crop_mod4(im)
im = padding_x(im, settings.padding)
local scale = 1.0
if settings.random_half_rate > 0.0 then
scale = 2.0
end
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
if settings.grayscale then
im = iproc.rgb2y(im)
end
table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
else
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
end

42
lib/ShakeShakeTable.lua Normal file
View file

@ -0,0 +1,42 @@
local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')
function ShakeShakeTable:__init()
parent.__init(self)
self.alpha = torch.Tensor()
self.beta = torch.Tensor()
self.first = torch.Tensor()
self.second = torch.Tensor()
self.train = true
end
function ShakeShakeTable:updateOutput(input)
local batch_size = input[1]:size(1)
if self.train then
self.alpha:resize(batch_size):uniform()
self.beta:resize(batch_size):uniform()
self.second:resizeAs(input[1]):copy(input[2])
for i = 1, batch_size do
self.second[i]:mul(self.alpha[i])
end
self.output:resizeAs(input[1]):copy(input[1])
for i = 1, batch_size do
self.output[i]:mul(1.0 - self.alpha[i])
end
self.output:add(self.second):mul(2)
else
self.output:resizeAs(input[1]):copy(input[1]):add(input[2])
end
return self.output
end
function ShakeShakeTable:updateGradInput(input, gradOutput)
local batch_size = input[1]:size(1)
self.first:resizeAs(gradOutput):copy(gradOutput)
for i = 1, batch_size do
self.first[i]:mul(self.beta[i])
end
self.second:resizeAs(gradOutput):copy(gradOutput)
for i = 1, batch_size do
self.second[i]:mul(1.0 - self.beta[i])
end
self.gradOutput = {self.first, self.second}
return self.gradOutput
end

View file

@ -102,7 +102,9 @@ function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
local scale = torch.uniform(scale_min, scale_max)
local h = math.floor(x:size(2) * scale)
local w = math.floor(x:size(3) * scale)
x = iproc.scale(x, w, h, "Triangle")
local filters = {"Lanczos", "Catrom"}
local x_filter = filters[torch.random(1, 2)]
x = iproc.scale(x, w, h, x_filter)
y = iproc.scale(y, w, h, "Triangle")
return x, y
else
@ -139,6 +141,36 @@ function data_augmentation.pairwise_negate_x(x, y, p)
return x, y
end
end
function data_augmentation.pairwise_flip(x, y)
local flip = torch.random(1, 4)
local tr = torch.random(1, 2)
local x, conversion = iproc.byte2float(x)
y = iproc.byte2float(y)
x = x:contiguous()
y = y:contiguous()
if tr == 1 then
-- pass
elseif tr == 2 then
x = x:transpose(2, 3):contiguous()
y = y:transpose(2, 3):contiguous()
end
if flip == 1 then
x = iproc.hflip(x)
y = iproc.hflip(y)
elseif flip == 2 then
x = iproc.vflip(x)
y = iproc.vflip(y)
elseif flip == 3 then
x = iproc.hflip(iproc.vflip(x))
y = iproc.hflip(iproc.vflip(y))
elseif flip == 4 then
end
if conversion then
x = iproc.float2byte(x)
y = iproc.float2byte(y)
end
return x, y
end
function data_augmentation.shift_1px(src)
-- reducing the even/odd issue in nearest neighbor scaler.
local direction = torch.random(1, 4)

View file

@ -80,6 +80,8 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
return dest
end
function iproc.padding(img, w1, w2, h1, h2)
local conversion
img, conversion = iproc.byte2float(img)
image = image or require 'image'
local dst_height = img:size(2) + h1 + h2
local dst_width = img:size(3) + w1 + w2
@ -88,9 +90,15 @@ function iproc.padding(img, w1, w2, h1, h2)
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
flow[1]:add(-h1)
flow[2]:add(-w1)
return image.warp(img, flow, "simple", false, "clamp")
local dest = image.warp(img, flow, "simple", false, "clamp")
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
function iproc.zero_padding(img, w1, w2, h1, h2)
local conversion
img, conversion = iproc.byte2float(img)
image = image or require 'image'
local dst_height = img:size(2) + h1 + h2
local dst_width = img:size(3) + w1 + w2
@ -99,7 +107,11 @@ function iproc.zero_padding(img, w1, w2, h1, h2)
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
flow[1]:add(-h1)
flow[2]:add(-w1)
return image.warp(img, flow, "simple", false, "pad", 0)
local dest = image.warp(img, flow, "simple", false, "pad", 0)
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
function iproc.white_noise(src, std, rgb_weights, gamma)
gamma = gamma or 0.454545
@ -217,6 +229,7 @@ function iproc.rgb2y(src)
src, conversion = iproc.byte2float(src)
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
dest:clamp(0, 1)
if conversion then
dest = iproc.float2byte(dest)
end

View file

@ -43,8 +43,10 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
if xc:size(1) > 1 then
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
end
if torch.uniform() < options.nr_rate then
-- reducing noise

View file

@ -51,8 +51,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
if xc:size(1) > 1 then
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end

View file

@ -1,4 +1,5 @@
local pairwise_utils = require 'pairwise_transform_utils'
local data_augmentation = require 'data_augmentation'
local iproc = require 'iproc'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
@ -21,12 +22,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
if options.active_cropping_rate > 0 then
lowres_y = pairwise_utils.low_resolution(y)
end
if options.pairwise_flip then
if options.pairwise_flip and n == 1 then
xs[1], ys[1] = data_augmentation.pairwise_flip(xs[1], ys[1])
elseif options.pairwise_flip then
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
end
assert(#xs == #ys)
local perm = torch.randperm(#xs)
for i = 1, n do
local t = (i % #xs) + 1
local t = perm[(i % #xs) + 1]
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
options.active_cropping_rate,
options.active_cropping_tries)
@ -34,8 +38,10 @@ function pairwise_transform.user(x, y, size, offset, n, options)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
if xc:size(1) > 1 then
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
end
if options.gcn then
local mean = xc:mean()
@ -46,7 +52,12 @@ function pairwise_transform.user(x, y, size, offset, n, options)
xc:add(-mean)
end
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
yc = iproc.crop(yc, offset, offset, size - offset, size - offset)
if options.pairwise_y_binary then
yc[torch.lt(yc, 0.5)] = 0
yc[torch.gt(yc, 0)] = 1
end
table.insert(batch, {xc, yc})
end
return batch

View file

@ -108,12 +108,6 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
x = iproc.crop_mod4(x)
y = iproc.crop_mod4(y)
if options.pairwise_y_binary then
y[torch.lt(y, 128)] = 0
y[torch.gt(y, 0)] = 255
end
return x, y
end
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
@ -125,8 +119,14 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
t = "byte"
end
if p < r then
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
local xi = 0
local yi = 0
if x:size(3) > size + 1 then
xi = torch.random(0, x:size(3) - (size + 1)) * scale
end
if x:size(2) > size + 1 then
yi = torch.random(0, x:size(2) - (size + 1)) * scale
end
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
return xc, yc
@ -273,10 +273,17 @@ function pairwise_transform_utils.low_resolution(src)
toTensor("byte", "RGB", "DHW")
end
--]]
return gm.Image(src, "RGB", "DHW"):
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "RGB", "DHW")
if src:size(1) == 1 then
return gm.Image(src, "I", "DHW"):
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "I", "DHW")
else
return gm.Image(src, "RGB", "DHW"):
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "RGB", "DHW")
end
end
return pairwise_transform_utils

View file

@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x-training")
cmd:text("Options:")
cmd:option("-gpu", -1, 'GPU Device ID')
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
cmd:option("-data_dir", "./data", 'path to data directory')
cmd:option("-backend", "cunn", '(cunn|cudnn)')
@ -74,9 +73,14 @@ cmd:option("-oracle_drop_rate", 0.5, '')
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
cmd:option("-resume", "", 'resume model file')
cmd:option("-name", "user", 'model name for user method')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
cmd:option("-update_criterion", "mse", 'mse|loss')
cmd:option("-padding", 0, 'replication padding size')
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
cmd:option("-invert_x", 0, 'invert x image in convert_lua')
local function to_bool(settings, name)
if settings[name] == 1 then
@ -95,6 +99,10 @@ to_bool(settings, "save_history")
to_bool(settings, "use_transparent_png")
to_bool(settings, "pairwise_y_binary")
to_bool(settings, "pairwise_flip")
to_bool(settings, "padding_y_zero")
to_bool(settings, "grayscale")
to_bool(settings, "validation_filename_split")
to_bool(settings, "invert_x")
if settings.plot then
require 'gnuplot'
@ -168,10 +176,20 @@ end
settings.images = string.format("%s/images.t7", settings.data_dir)
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
cutorch.setDevice(opt.gpu)
-- patch for lua52
if not math.log10 then
math.log10 = function(x) return math.log(x, 10) end
end
if settings.gpu:len() > 0 then
local gpus = {}
local gpu_string = utils.split(settings.gpu, ",")
for i = 1, #gpu_string do
table.insert(gpus, tonumber(gpu_string[i]))
end
settings.gpu = gpus
else
settings.gpu = {1}
end
cutorch.setDevice(settings.gpu[1])
return settings

View file

@ -4,34 +4,52 @@ require 'w2nn'
-- ref: http://arxiv.org/abs/1501.00092
local srcnn = {}
function nn.SpatialConvolutionMM:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane
local function msra_filler(mod)
local fin = mod.kW * mod.kH * mod.nInputPlane
local fout = mod.kW * mod.kH * mod.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
mod.weight:normal(0, stdv)
mod.bias:zero()
end
local function identity_filler(mod)
assert(mod.nInputPlane <= mod.nOutputPlane)
mod.weight:normal(0, 0.01)
mod.bias:zero()
local num_groups = mod.nInputPlane -- fixed
local filler_value = num_groups / mod.nOutputPlane
local in_group_size = math.floor(mod.nInputPlane / num_groups)
local out_group_size = math.floor(mod.nOutputPlane / num_groups)
local x = math.floor(mod.kW / 2)
local y = math.floor(mod.kH / 2)
for i = 0, num_groups - 1 do
for j = i * out_group_size, (i + 1) * out_group_size - 1 do
for k = i * in_group_size, (i + 1) * in_group_size - 1 do
mod.weight[j+1][k+1][y+1][x+1] = filler_value
end
end
end
end
function nn.SpatialConvolutionMM:reset(stdv)
msra_filler(self)
end
function nn.SpatialFullConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
msra_filler(self)
end
function nn.SpatialDilatedConvolution:reset(stdv)
identity_filler(self)
end
if cudnn and cudnn.SpatialConvolution then
function cudnn.SpatialConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
msra_filler(self)
end
function cudnn.SpatialFullConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
msra_filler(self)
end
if cudnn.SpatialDilatedConvolution then
function cudnn.SpatialDilatedConvolution:reset(stdv)
identity_filler(self)
end
end
end
function nn.SpatialConvolutionMM:clearState()
@ -127,6 +145,8 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialConvolution = SpatialConvolution
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
if backend == "cunn" then
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
@ -136,6 +156,8 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialFullConvolution = SpatialFullConvolution
local function ReLU(backend)
if backend == "cunn" then
return nn.ReLU(true)
@ -145,6 +167,8 @@ local function ReLU(backend)
error("unsupported backend:" .. backend)
end
end
srcnn.ReLU = ReLU
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
@ -154,6 +178,35 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialMaxPooling = SpatialMaxPooling
local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
elseif backend == "cudnn" then
return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
else
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialAveragePooling = SpatialAveragePooling
local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
if backend == "cunn" then
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
elseif backend == "cudnn" then
if cudnn.SpatialDilatedConvolution then
-- cudnn v 6
return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
else
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
end
else
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
-- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch)
@ -548,6 +601,7 @@ function srcnn.create(model_name, backend, color)
error("unsupported model_name: " .. model_name)
end
end
--[[
local model = srcnn.fcn_v1("cunn", 3):cuda()
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())

View file

@ -9,6 +9,40 @@ end
local function load_cudnn()
cudnn = require('cudnn')
end
local function make_data_parallel_table(model, gpus)
if cudnn then
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end)
dpt.gradInput = nil
model = dpt:cuda()
else
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
end)
dpt.gradInput = nil
model = dpt:cuda()
end
return model
end
if w2nn then
return w2nn
else
@ -27,11 +61,19 @@ else
model:cuda():evaluate()
return model
end
function w2nn.data_parallel(model, gpus)
if #gpus > 1 then
return make_data_parallel_table(model, gpus)
else
return model
end
end
require 'LeakyReLU'
require 'ClippedWeightedHuberCriterion'
require 'ClippedMSECriterion'
require 'SSIMCriterion'
require 'InplaceClip01'
require 'L1Criterion'
require 'ShakeShakeTable'
return w2nn
end

View file

@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
cmd:option("-x_file", "", 'input image for user method')
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
cmd:option("-border", 0, 'border px that will removed')
cmd:option("-metric", "", '(jaccard)')
local function to_bool(settings, name)
if settings[name] == 1 then
@ -203,8 +204,34 @@ local function remove_border(x, border)
x:size(3) - border,
x:size(2) - border)
end
local function create_metric(metric)
if metric and metric:len() > 0 then
if metric == "jaccard" then
return {
name = "jaccard",
func = function (a, b)
local ga = iproc.rgb2y(a)
local gb = iproc.rgb2y(b)
local ba = torch.Tensor():resizeAs(ga)
local bb = torch.Tensor():resizeAs(gb)
ba:zero()
bb:zero()
ba[torch.gt(ga, 0.5)] = 1.0
bb[torch.gt(gb, 0.5)] = 1.0
local num_a = ba:sum()
local num_b = bb:sum()
local a_and_b = ba:cmul(bb):sum()
return (a_and_b / (num_a + num_b - a_and_b))
end}
else
error("unknown metric: " .. metric)
end
else
return nil
end
end
local function benchmark(opt, x, model1, model2)
local mse1, mse2
local mse1, mse2, am1, am2
local won = {0, 0}
local model1_mse = 0
local model2_mse = 0
@ -217,6 +244,13 @@ local function benchmark(opt, x, model1, model2)
local scale_f = reconstruct.scale
local image_f = reconstruct.image
local detail_fp = nil
local am = nil
local model1_am = 0
local model2_am = 0
if opt.method == "user" or opt.method == "diff" then
am = create_metric(opt.metric)
end
if opt.save_info then
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
end
@ -406,32 +440,57 @@ local function benchmark(opt, x, model1, model2)
ground_truth = remove_border(ground_truth, opt.border)
model1_output = remove_border(model1_output, opt.border)
end
mse1 = MSE(ground_truth, model1_output, opt.color)
model1_mse = model1_mse + mse1
model1_psnr = model1_psnr + MSE2PSNR(mse1)
if am then
am1 = am.func(ground_truth, model1_output)
model1_am = model1_am + am1
else
mse1 = MSE(ground_truth, model1_output, opt.color)
model1_mse = model1_mse + mse1
model1_psnr = model1_psnr + MSE2PSNR(mse1)
end
local won_model = 1
if model2 then
if opt.border > 0 then
model2_output = remove_border(model2_output, opt.border)
end
mse2 = MSE(ground_truth, model2_output, opt.color)
model2_mse = model2_mse + mse2
model2_psnr = model2_psnr + MSE2PSNR(mse2)
if mse1 < mse2 then
won[1] = won[1] + 1
elseif mse1 > mse2 then
won[2] = won[2] + 1
won_model = 2
if am then
am2 = am.func(ground_truth, model2_output)
model2_am = model2_am + am2
else
mse2 = MSE(ground_truth, model2_output, opt.color)
model2_mse = model2_mse + mse2
model2_psnr = model2_psnr + MSE2PSNR(mse2)
end
if am then
if am1 < am2 then
won[1] = won[1] + 1
elseif am1 > am2 then
won[2] = won[2] + 1
won_model = 2
end
else
if mse1 < mse2 then
won[1] = won[1] + 1
elseif mse1 > mse2 then
won[2] = won[2] + 1
won_model = 2
end
end
if detail_fp then
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
if am then
detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
else
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
end
end
else
if detail_fp then
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
if am then
detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
else
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
end
end
end
if baseline_output then
@ -455,46 +514,65 @@ local function benchmark(opt, x, model1, model2)
end
end
if opt.show_progress or i == #x then
if model2 then
if baseline_output then
if am then
if model2 then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
i, #x,
model1_time,
model2_time,
math.sqrt(baseline_mse / i),
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
baseline_psnr / i,
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
am.name, model1_am / i, am.name, model2_am / i
))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
i, #x,
model1_time,
model2_time,
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
am.name, model1_am / i
))
end
else
if baseline_output then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
baseline_psnr / i, model1_psnr / i
if model2 then
if baseline_output then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
i, #x,
model1_time,
model2_time,
math.sqrt(baseline_mse / i),
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
baseline_psnr / i,
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
i, #x,
model1_time,
model2_time,
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
end
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(model1_mse / i), model1_psnr / i
if baseline_output then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
baseline_psnr / i, model1_psnr / i
))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(model1_mse / i), model1_psnr / i
))
end
end
end
io.stdout:flush()
@ -515,6 +593,14 @@ local function benchmark(opt, x, model1, model2)
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
end
if model1_am > 0 then
fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n",
math.sqrt(model1_am / #x), model1_time))
end
if model2_am > 0 then
fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n",
math.sqrt(model2_am / #x), model2_time))
end
fp:close()
if detail_fp then
detail_fp:close()

124
train.lua
View file

@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
end
end
local function split_data(x, test_size)
local index = torch.randperm(#x)
local train_size = #x - test_size
local train_x = {}
local valid_x = {}
for i = 1, train_size do
train_x[i] = x[index[i]]
if settings.validation_filename_split then
if not (x[1][2].data and x[1][2].data.basename) then
error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
end
local basename_db = {}
for i = 1, #x do
local meta = x[i][2].data
if basename_db[meta.basename] then
table.insert(basename_db[meta.basename], x[i])
else
basename_db[meta.basename] = {x[i]}
end
end
local basename_list = {}
for k, v in pairs(basename_db) do
table.insert(basename_list, v)
end
local index = torch.randperm(#basename_list)
local train_x = {}
local valid_x = {}
local pos = 1
for i = 1, #basename_list do
if #valid_x >= test_size then
break
end
local xs = basename_list[index[pos]]
for j = 1, #xs do
table.insert(valid_x, xs[j])
end
pos = pos + 1
end
for i = pos, #basename_list do
local xs = basename_list[index[i]]
for j = 1, #xs do
table.insert(train_x, xs[j])
end
end
return train_x, valid_x
else
local index = torch.randperm(#x)
local train_size = #x - test_size
local train_x = {}
local valid_x = {}
for i = 1, train_size do
train_x[i] = x[index[i]]
end
for i = 1, test_size do
valid_x[i] = x[index[train_size + i]]
end
return train_x, valid_x
end
for i = 1, test_size do
valid_x[i] = x[index[train_size + i]]
end
return train_x, valid_x
end
local g_transform_pool = nil
@ -175,35 +215,19 @@ local function transform_pool_init(has_resize, offset)
settings.crop_size, offset,
n, conf)
elseif settings.method == "user" then
if is_validation == nil then is_validation = false end
local rotate_rate = nil
local scale_rate = nil
local negate_rate = nil
local negate_x_rate = nil
if is_validation then
rotate_rate = 0
scale_rate = 0
negate_rate = 0
negate_x_rate = 0
else
rotate_rate = settings.random_pairwise_rotate_rate
scale_rate = settings.random_pairwise_scale_rate
negate_rate = settings.random_pairwise_negate_rate
negate_x_rate = settings.random_pairwise_negate_x_rate
end
local conf = tablex.update({
gcn = settings.gcn,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
random_pairwise_rotate_rate = rotate_rate,
random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
random_pairwise_scale_rate = scale_rate,
random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
random_pairwise_scale_min = settings.random_pairwise_scale_min,
random_pairwise_scale_max = settings.random_pairwise_scale_max,
random_pairwise_negate_rate = negate_rate,
random_pairwise_negate_x_rate = negate_x_rate,
random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
pairwise_y_binary = settings.pairwise_y_binary,
pairwise_flip = settings.pairwise_flip,
rgb = (settings.color == "rgb")}, meta)
@ -290,7 +314,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
local batch_mse = eval_metric:forward(z, targets)
loss = loss + criterion:forward(z, targets)
mse = mse + batch_mse
psnr = psnr + (10 * math.log10(1 / batch_mse))
psnr = psnr + (10 * math.log10(1 / (batch_mse + 1.0e-6)))
loss_count = loss_count + 1
if loss_count % 10 == 0 then
xlua.progress(t, #data)
@ -322,6 +346,10 @@ local function create_criterion(model)
return w2nn.L1Criterion():cuda()
elseif settings.loss == "mse" then
return w2nn.ClippedMSECriterion(0, 1.0):cuda()
elseif settings.loss == "bce" then
local bce = nn.BCECriterion()
bce.sizeAverage = true
return bce:cuda()
else
error("unsupported loss .." .. settings.loss)
end
@ -421,7 +449,10 @@ local function plot(train, valid)
{'validation', torch.Tensor(valid), '-'}})
end
local function train()
local x = remove_small_image(torch.load(settings.images))
local x = torch.load(settings.images)
if settings.method ~= "user" then
x = remove_small_image(x)
end
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local hist_train = {}
local hist_valid = {}
@ -429,7 +460,12 @@ local function train()
if settings.resume:len() > 0 then
model = torch.load(settings.resume, "ascii")
else
model = srcnn.create(settings.model, settings.backend, settings.color)
if stringx.endswith(settings.model, ".lua") then
local create_model = dofile(settings.model)
model = create_model(srcnn, settings)
else
model = srcnn.create(settings.model, settings.backend, settings.color)
end
end
if model.w2nn_input_size then
if settings.crop_size ~= model.w2nn_input_size then
@ -484,8 +520,9 @@ local function train()
ch, settings.crop_size, settings.crop_size)
end
local instance_loss = nil
local pmodel = w2nn.data_parallel(model, settings.gpu)
for epoch = 1, settings.epoch do
model:training()
pmodel:training()
print("# " .. epoch)
if adam_config.learningRate then
print("learning rate: " .. adam_config.learningRate)
@ -523,13 +560,13 @@ local function train()
instance_loss = torch.Tensor(x:size(1)):zero()
for i = 1, settings.inner_epoch do
model:training()
local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
pmodel:training()
local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config)
instance_loss:copy(il)
print(train_score)
model:evaluate()
pmodel:evaluate()
print("# validation")
local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
table.insert(hist_train, train_score.loss)
table.insert(hist_valid, score.loss)
if settings.plot then
@ -546,8 +583,9 @@ local function train()
best_score = score_for_update
print("* model has updated")
if settings.save_history then
torch.save(settings.model_file_best, model:clearState(), "ascii")
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
pmodel:clearState()
torch.save(settings.model_file_best, model, "ascii")
torch.save(string.format(settings.model_file, epoch, i), model, "ascii")
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.%d-%d.png"):format(settings.noise_level,
@ -571,7 +609,8 @@ local function train()
save_test_user(model, test_image, log)
end
else
torch.save(settings.model_file, model:clearState(), "ascii")
pmodel:clearState()
torch.save(settings.model_file, model, "ascii")
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.png"):format(settings.noise_level))
@ -597,9 +636,6 @@ local function train()
end
end
end
if settings.gpu > 0 then
cutorch.setDevice(settings.gpu)
end
torch.manualSeed(settings.seed)
cutorch.manualSeed(settings.seed)
print(settings)

View file

@ -276,6 +276,7 @@ local function waifu2x()
if opt.thread > 0 then
torch.setnumthreads(opt.thread)
end
cutorch.setDevice(opt.gpu)
if cudnn then
cudnn.fastest = true
if opt.l:len() > 0 then
@ -293,6 +294,5 @@ local function waifu2x()
else
convert_frames(opt)
end
cutorch.setDevice(opt.gpu)
end
waifu2x()