merge
This commit is contained in:
commit
fa6ee00624
|
@ -63,7 +63,24 @@ local function crop_if_large_pair(x, y, max_size)
|
|||
return x, y
|
||||
end
|
||||
end
|
||||
|
||||
local function padding_x(x, pad)
|
||||
if pad > 0 then
|
||||
x = iproc.padding(x, pad, pad, pad, pad)
|
||||
end
|
||||
return x
|
||||
end
|
||||
local function padding_xy(x, y, pad, y_zero)
|
||||
local scale = y:size(2) / x:size(2)
|
||||
if pad > 0 then
|
||||
x = iproc.padding(x, pad, pad, pad, pad)
|
||||
if y_zero then
|
||||
y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
|
||||
else
|
||||
y = iproc.padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
|
||||
end
|
||||
end
|
||||
return x, y
|
||||
end
|
||||
local function load_images(list)
|
||||
local MARGIN = 32
|
||||
local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
|
||||
|
@ -78,6 +95,7 @@ local function load_images(list)
|
|||
if csv_meta and csv_meta.filters then
|
||||
filters = csv_meta.filters
|
||||
end
|
||||
local basename_y = path.basename(filename)
|
||||
local im, meta = image_loader.load_byte(filename)
|
||||
local skip = false
|
||||
local alpha_color = torch.random(0, 1)
|
||||
|
@ -100,25 +118,38 @@ local function load_images(list)
|
|||
-- method == user
|
||||
local yy = im
|
||||
local xx, meta2 = image_loader.load_byte(csv_meta.x)
|
||||
if settings.invert_x then
|
||||
xx = (-(xx:long()) + 255):byte()
|
||||
end
|
||||
|
||||
if xx then
|
||||
if meta2 and meta2.alpha then
|
||||
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
|
||||
end
|
||||
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
|
||||
xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
|
||||
if settings.grayscale then
|
||||
xx = iproc.rgb2y(xx)
|
||||
yy = iproc.rgb2y(yy)
|
||||
end
|
||||
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
||||
{data = {filters = filters, has_x = true}}})
|
||||
{data = {filters = filters, has_x = true, basename = basename_y}}})
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
|
||||
end
|
||||
else
|
||||
im = crop_if_large(im, settings.max_training_image_size)
|
||||
im = iproc.crop_mod4(im)
|
||||
im = padding_x(im, settings.padding)
|
||||
local scale = 1.0
|
||||
if settings.random_half_rate > 0.0 then
|
||||
scale = 2.0
|
||||
end
|
||||
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
|
||||
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
|
||||
if settings.grayscale then
|
||||
im = iproc.rgb2y(im)
|
||||
end
|
||||
table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
|
||||
end
|
||||
|
|
42
lib/ShakeShakeTable.lua
Normal file
42
lib/ShakeShakeTable.lua
Normal file
|
@ -0,0 +1,42 @@
|
|||
local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')
|
||||
|
||||
function ShakeShakeTable:__init()
|
||||
parent.__init(self)
|
||||
self.alpha = torch.Tensor()
|
||||
self.beta = torch.Tensor()
|
||||
self.first = torch.Tensor()
|
||||
self.second = torch.Tensor()
|
||||
self.train = true
|
||||
end
|
||||
function ShakeShakeTable:updateOutput(input)
|
||||
local batch_size = input[1]:size(1)
|
||||
if self.train then
|
||||
self.alpha:resize(batch_size):uniform()
|
||||
self.beta:resize(batch_size):uniform()
|
||||
self.second:resizeAs(input[1]):copy(input[2])
|
||||
for i = 1, batch_size do
|
||||
self.second[i]:mul(self.alpha[i])
|
||||
end
|
||||
self.output:resizeAs(input[1]):copy(input[1])
|
||||
for i = 1, batch_size do
|
||||
self.output[i]:mul(1.0 - self.alpha[i])
|
||||
end
|
||||
self.output:add(self.second):mul(2)
|
||||
else
|
||||
self.output:resizeAs(input[1]):copy(input[1]):add(input[2])
|
||||
end
|
||||
return self.output
|
||||
end
|
||||
function ShakeShakeTable:updateGradInput(input, gradOutput)
|
||||
local batch_size = input[1]:size(1)
|
||||
self.first:resizeAs(gradOutput):copy(gradOutput)
|
||||
for i = 1, batch_size do
|
||||
self.first[i]:mul(self.beta[i])
|
||||
end
|
||||
self.second:resizeAs(gradOutput):copy(gradOutput)
|
||||
for i = 1, batch_size do
|
||||
self.second[i]:mul(1.0 - self.beta[i])
|
||||
end
|
||||
self.gradOutput = {self.first, self.second}
|
||||
return self.gradOutput
|
||||
end
|
|
@ -102,7 +102,9 @@ function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
|
|||
local scale = torch.uniform(scale_min, scale_max)
|
||||
local h = math.floor(x:size(2) * scale)
|
||||
local w = math.floor(x:size(3) * scale)
|
||||
x = iproc.scale(x, w, h, "Triangle")
|
||||
local filters = {"Lanczos", "Catrom"}
|
||||
local x_filter = filters[torch.random(1, 2)]
|
||||
x = iproc.scale(x, w, h, x_filter)
|
||||
y = iproc.scale(y, w, h, "Triangle")
|
||||
return x, y
|
||||
else
|
||||
|
@ -139,6 +141,36 @@ function data_augmentation.pairwise_negate_x(x, y, p)
|
|||
return x, y
|
||||
end
|
||||
end
|
||||
function data_augmentation.pairwise_flip(x, y)
|
||||
local flip = torch.random(1, 4)
|
||||
local tr = torch.random(1, 2)
|
||||
local x, conversion = iproc.byte2float(x)
|
||||
y = iproc.byte2float(y)
|
||||
x = x:contiguous()
|
||||
y = y:contiguous()
|
||||
if tr == 1 then
|
||||
-- pass
|
||||
elseif tr == 2 then
|
||||
x = x:transpose(2, 3):contiguous()
|
||||
y = y:transpose(2, 3):contiguous()
|
||||
end
|
||||
if flip == 1 then
|
||||
x = iproc.hflip(x)
|
||||
y = iproc.hflip(y)
|
||||
elseif flip == 2 then
|
||||
x = iproc.vflip(x)
|
||||
y = iproc.vflip(y)
|
||||
elseif flip == 3 then
|
||||
x = iproc.hflip(iproc.vflip(x))
|
||||
y = iproc.hflip(iproc.vflip(y))
|
||||
elseif flip == 4 then
|
||||
end
|
||||
if conversion then
|
||||
x = iproc.float2byte(x)
|
||||
y = iproc.float2byte(y)
|
||||
end
|
||||
return x, y
|
||||
end
|
||||
function data_augmentation.shift_1px(src)
|
||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||
local direction = torch.random(1, 4)
|
||||
|
|
|
@ -80,6 +80,8 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
|
|||
return dest
|
||||
end
|
||||
function iproc.padding(img, w1, w2, h1, h2)
|
||||
local conversion
|
||||
img, conversion = iproc.byte2float(img)
|
||||
image = image or require 'image'
|
||||
local dst_height = img:size(2) + h1 + h2
|
||||
local dst_width = img:size(3) + w1 + w2
|
||||
|
@ -88,9 +90,15 @@ function iproc.padding(img, w1, w2, h1, h2)
|
|||
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
||||
flow[1]:add(-h1)
|
||||
flow[2]:add(-w1)
|
||||
return image.warp(img, flow, "simple", false, "clamp")
|
||||
local dest = image.warp(img, flow, "simple", false, "clamp")
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.zero_padding(img, w1, w2, h1, h2)
|
||||
local conversion
|
||||
img, conversion = iproc.byte2float(img)
|
||||
image = image or require 'image'
|
||||
local dst_height = img:size(2) + h1 + h2
|
||||
local dst_width = img:size(3) + w1 + w2
|
||||
|
@ -99,7 +107,11 @@ function iproc.zero_padding(img, w1, w2, h1, h2)
|
|||
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
||||
flow[1]:add(-h1)
|
||||
flow[2]:add(-w1)
|
||||
return image.warp(img, flow, "simple", false, "pad", 0)
|
||||
local dest = image.warp(img, flow, "simple", false, "pad", 0)
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.white_noise(src, std, rgb_weights, gamma)
|
||||
gamma = gamma or 0.454545
|
||||
|
@ -217,6 +229,7 @@ function iproc.rgb2y(src)
|
|||
src, conversion = iproc.byte2float(src)
|
||||
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
|
||||
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
|
||||
dest:clamp(0, 1)
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
|
|
|
@ -43,8 +43,10 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
|||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
if xc:size(1) > 1 then
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
end
|
||||
if torch.uniform() < options.nr_rate then
|
||||
-- reducing noise
|
||||
|
|
|
@ -51,8 +51,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
if xc:size(1) > 1 then
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
local pairwise_utils = require 'pairwise_transform_utils'
|
||||
local data_augmentation = require 'data_augmentation'
|
||||
local iproc = require 'iproc'
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
|
@ -21,12 +22,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
|||
if options.active_cropping_rate > 0 then
|
||||
lowres_y = pairwise_utils.low_resolution(y)
|
||||
end
|
||||
if options.pairwise_flip then
|
||||
if options.pairwise_flip and n == 1 then
|
||||
xs[1], ys[1] = data_augmentation.pairwise_flip(xs[1], ys[1])
|
||||
elseif options.pairwise_flip then
|
||||
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
||||
end
|
||||
assert(#xs == #ys)
|
||||
local perm = torch.randperm(#xs)
|
||||
for i = 1, n do
|
||||
local t = (i % #xs) + 1
|
||||
local t = perm[(i % #xs) + 1]
|
||||
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
|
@ -34,8 +38,10 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
|||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
if xc:size(1) > 1 then
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
end
|
||||
if options.gcn then
|
||||
local mean = xc:mean()
|
||||
|
@ -46,7 +52,12 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
|||
xc:add(-mean)
|
||||
end
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
yc = iproc.crop(yc, offset, offset, size - offset, size - offset)
|
||||
if options.pairwise_y_binary then
|
||||
yc[torch.lt(yc, 0.5)] = 0
|
||||
yc[torch.gt(yc, 0)] = 1
|
||||
end
|
||||
table.insert(batch, {xc, yc})
|
||||
end
|
||||
|
||||
return batch
|
||||
|
|
|
@ -108,12 +108,6 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
|
|||
|
||||
x = iproc.crop_mod4(x)
|
||||
y = iproc.crop_mod4(y)
|
||||
|
||||
if options.pairwise_y_binary then
|
||||
y[torch.lt(y, 128)] = 0
|
||||
y[torch.gt(y, 0)] = 255
|
||||
end
|
||||
|
||||
return x, y
|
||||
end
|
||||
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
||||
|
@ -125,8 +119,14 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
|
|||
t = "byte"
|
||||
end
|
||||
if p < r then
|
||||
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
|
||||
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
|
||||
local xi = 0
|
||||
local yi = 0
|
||||
if x:size(3) > size + 1 then
|
||||
xi = torch.random(0, x:size(3) - (size + 1)) * scale
|
||||
end
|
||||
if x:size(2) > size + 1 then
|
||||
yi = torch.random(0, x:size(2) - (size + 1)) * scale
|
||||
end
|
||||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
|
||||
return xc, yc
|
||||
|
@ -273,10 +273,17 @@ function pairwise_transform_utils.low_resolution(src)
|
|||
toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
--]]
|
||||
return gm.Image(src, "RGB", "DHW"):
|
||||
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
||||
size(src:size(3), src:size(2), "Box"):
|
||||
toTensor("byte", "RGB", "DHW")
|
||||
if src:size(1) == 1 then
|
||||
return gm.Image(src, "I", "DHW"):
|
||||
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
||||
size(src:size(3), src:size(2), "Box"):
|
||||
toTensor("byte", "I", "DHW")
|
||||
else
|
||||
return gm.Image(src, "RGB", "DHW"):
|
||||
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
||||
size(src:size(3), src:size(2), "Box"):
|
||||
toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
end
|
||||
|
||||
return pairwise_transform_utils
|
||||
|
|
|
@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
|
|||
cmd:text()
|
||||
cmd:text("waifu2x-training")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-gpu", -1, 'GPU Device ID')
|
||||
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
|
||||
cmd:option("-data_dir", "./data", 'path to data directory')
|
||||
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
||||
|
@ -74,9 +73,14 @@ cmd:option("-oracle_drop_rate", 0.5, '')
|
|||
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
||||
cmd:option("-resume", "", 'resume model file')
|
||||
cmd:option("-name", "user", 'model name for user method')
|
||||
cmd:option("-gpu", 1, 'Device ID')
|
||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
|
||||
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
|
||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
||||
cmd:option("-update_criterion", "mse", 'mse|loss')
|
||||
cmd:option("-padding", 0, 'replication padding size')
|
||||
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
|
||||
cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
|
||||
cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
|
||||
cmd:option("-invert_x", 0, 'invert x image in convert_lua')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
@ -95,6 +99,10 @@ to_bool(settings, "save_history")
|
|||
to_bool(settings, "use_transparent_png")
|
||||
to_bool(settings, "pairwise_y_binary")
|
||||
to_bool(settings, "pairwise_flip")
|
||||
to_bool(settings, "padding_y_zero")
|
||||
to_bool(settings, "grayscale")
|
||||
to_bool(settings, "validation_filename_split")
|
||||
to_bool(settings, "invert_x")
|
||||
|
||||
if settings.plot then
|
||||
require 'gnuplot'
|
||||
|
@ -168,10 +176,20 @@ end
|
|||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||
|
||||
cutorch.setDevice(opt.gpu)
|
||||
-- patch for lua52
|
||||
if not math.log10 then
|
||||
math.log10 = function(x) return math.log(x, 10) end
|
||||
end
|
||||
if settings.gpu:len() > 0 then
|
||||
local gpus = {}
|
||||
local gpu_string = utils.split(settings.gpu, ",")
|
||||
for i = 1, #gpu_string do
|
||||
table.insert(gpus, tonumber(gpu_string[i]))
|
||||
end
|
||||
settings.gpu = gpus
|
||||
else
|
||||
settings.gpu = {1}
|
||||
end
|
||||
cutorch.setDevice(settings.gpu[1])
|
||||
|
||||
return settings
|
||||
|
|
|
@ -4,34 +4,52 @@ require 'w2nn'
|
|||
-- ref: http://arxiv.org/abs/1501.00092
|
||||
local srcnn = {}
|
||||
|
||||
function nn.SpatialConvolutionMM:reset(stdv)
|
||||
local fin = self.kW * self.kH * self.nInputPlane
|
||||
local fout = self.kW * self.kH * self.nOutputPlane
|
||||
local function msra_filler(mod)
|
||||
local fin = mod.kW * mod.kH * mod.nInputPlane
|
||||
local fout = mod.kW * mod.kH * mod.nOutputPlane
|
||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
mod.weight:normal(0, stdv)
|
||||
mod.bias:zero()
|
||||
end
|
||||
local function identity_filler(mod)
|
||||
assert(mod.nInputPlane <= mod.nOutputPlane)
|
||||
mod.weight:normal(0, 0.01)
|
||||
mod.bias:zero()
|
||||
local num_groups = mod.nInputPlane -- fixed
|
||||
local filler_value = num_groups / mod.nOutputPlane
|
||||
local in_group_size = math.floor(mod.nInputPlane / num_groups)
|
||||
local out_group_size = math.floor(mod.nOutputPlane / num_groups)
|
||||
local x = math.floor(mod.kW / 2)
|
||||
local y = math.floor(mod.kH / 2)
|
||||
for i = 0, num_groups - 1 do
|
||||
for j = i * out_group_size, (i + 1) * out_group_size - 1 do
|
||||
for k = i * in_group_size, (i + 1) * in_group_size - 1 do
|
||||
mod.weight[j+1][k+1][y+1][x+1] = filler_value
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
function nn.SpatialConvolutionMM:reset(stdv)
|
||||
msra_filler(self)
|
||||
end
|
||||
function nn.SpatialFullConvolution:reset(stdv)
|
||||
local fin = self.kW * self.kH * self.nInputPlane
|
||||
local fout = self.kW * self.kH * self.nOutputPlane
|
||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
msra_filler(self)
|
||||
end
|
||||
function nn.SpatialDilatedConvolution:reset(stdv)
|
||||
identity_filler(self)
|
||||
end
|
||||
|
||||
if cudnn and cudnn.SpatialConvolution then
|
||||
function cudnn.SpatialConvolution:reset(stdv)
|
||||
local fin = self.kW * self.kH * self.nInputPlane
|
||||
local fout = self.kW * self.kH * self.nOutputPlane
|
||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
msra_filler(self)
|
||||
end
|
||||
function cudnn.SpatialFullConvolution:reset(stdv)
|
||||
local fin = self.kW * self.kH * self.nInputPlane
|
||||
local fout = self.kW * self.kH * self.nOutputPlane
|
||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
msra_filler(self)
|
||||
end
|
||||
if cudnn.SpatialDilatedConvolution then
|
||||
function cudnn.SpatialDilatedConvolution:reset(stdv)
|
||||
identity_filler(self)
|
||||
end
|
||||
end
|
||||
end
|
||||
function nn.SpatialConvolutionMM:clearState()
|
||||
|
@ -127,6 +145,8 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialConvolution = SpatialConvolution
|
||||
|
||||
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
|
||||
|
@ -136,6 +156,8 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialFullConvolution = SpatialFullConvolution
|
||||
|
||||
local function ReLU(backend)
|
||||
if backend == "cunn" then
|
||||
return nn.ReLU(true)
|
||||
|
@ -145,6 +167,8 @@ local function ReLU(backend)
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.ReLU = ReLU
|
||||
|
||||
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
||||
|
@ -154,6 +178,35 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialMaxPooling = SpatialMaxPooling
|
||||
|
||||
local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
||||
elseif backend == "cudnn" then
|
||||
return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
||||
else
|
||||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialAveragePooling = SpatialAveragePooling
|
||||
|
||||
local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||
elseif backend == "cudnn" then
|
||||
if cudnn.SpatialDilatedConvolution then
|
||||
-- cudnn v 6
|
||||
return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||
else
|
||||
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||
end
|
||||
else
|
||||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
|
||||
|
||||
|
||||
-- VGG style net(7 layers)
|
||||
function srcnn.vgg_7(backend, ch)
|
||||
|
@ -548,6 +601,7 @@ function srcnn.create(model_name, backend, color)
|
|||
error("unsupported model_name: " .. model_name)
|
||||
end
|
||||
end
|
||||
|
||||
--[[
|
||||
local model = srcnn.fcn_v1("cunn", 3):cuda()
|
||||
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
|
||||
|
|
42
lib/w2nn.lua
42
lib/w2nn.lua
|
@ -9,6 +9,40 @@ end
|
|||
local function load_cudnn()
|
||||
cudnn = require('cudnn')
|
||||
end
|
||||
local function make_data_parallel_table(model, gpus)
|
||||
if cudnn then
|
||||
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
|
||||
local dpt = nn.DataParallelTable(1, true, true)
|
||||
:add(model, gpus)
|
||||
:threads(function()
|
||||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'torch'
|
||||
require 'cunn'
|
||||
require 'w2nn'
|
||||
local cudnn = require 'cudnn'
|
||||
cudnn.fastest, cudnn.benchmark = fastest, benchmark
|
||||
end)
|
||||
dpt.gradInput = nil
|
||||
model = dpt:cuda()
|
||||
else
|
||||
local dpt = nn.DataParallelTable(1, true, true)
|
||||
:add(model, gpus)
|
||||
:threads(function()
|
||||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'torch'
|
||||
require 'cunn'
|
||||
require 'w2nn'
|
||||
end)
|
||||
dpt.gradInput = nil
|
||||
model = dpt:cuda()
|
||||
end
|
||||
return model
|
||||
end
|
||||
|
||||
if w2nn then
|
||||
return w2nn
|
||||
else
|
||||
|
@ -27,11 +61,19 @@ else
|
|||
model:cuda():evaluate()
|
||||
return model
|
||||
end
|
||||
function w2nn.data_parallel(model, gpus)
|
||||
if #gpus > 1 then
|
||||
return make_data_parallel_table(model, gpus)
|
||||
else
|
||||
return model
|
||||
end
|
||||
end
|
||||
require 'LeakyReLU'
|
||||
require 'ClippedWeightedHuberCriterion'
|
||||
require 'ClippedMSECriterion'
|
||||
require 'SSIMCriterion'
|
||||
require 'InplaceClip01'
|
||||
require 'L1Criterion'
|
||||
require 'ShakeShakeTable'
|
||||
return w2nn
|
||||
end
|
||||
|
|
|
@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
|
|||
cmd:option("-x_file", "", 'input image for user method')
|
||||
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
||||
cmd:option("-border", 0, 'border px that will removed')
|
||||
cmd:option("-metric", "", '(jaccard)')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
@ -203,8 +204,34 @@ local function remove_border(x, border)
|
|||
x:size(3) - border,
|
||||
x:size(2) - border)
|
||||
end
|
||||
local function create_metric(metric)
|
||||
if metric and metric:len() > 0 then
|
||||
if metric == "jaccard" then
|
||||
return {
|
||||
name = "jaccard",
|
||||
func = function (a, b)
|
||||
local ga = iproc.rgb2y(a)
|
||||
local gb = iproc.rgb2y(b)
|
||||
local ba = torch.Tensor():resizeAs(ga)
|
||||
local bb = torch.Tensor():resizeAs(gb)
|
||||
ba:zero()
|
||||
bb:zero()
|
||||
ba[torch.gt(ga, 0.5)] = 1.0
|
||||
bb[torch.gt(gb, 0.5)] = 1.0
|
||||
local num_a = ba:sum()
|
||||
local num_b = bb:sum()
|
||||
local a_and_b = ba:cmul(bb):sum()
|
||||
return (a_and_b / (num_a + num_b - a_and_b))
|
||||
end}
|
||||
else
|
||||
error("unknown metric: " .. metric)
|
||||
end
|
||||
else
|
||||
return nil
|
||||
end
|
||||
end
|
||||
local function benchmark(opt, x, model1, model2)
|
||||
local mse1, mse2
|
||||
local mse1, mse2, am1, am2
|
||||
local won = {0, 0}
|
||||
local model1_mse = 0
|
||||
local model2_mse = 0
|
||||
|
@ -217,6 +244,13 @@ local function benchmark(opt, x, model1, model2)
|
|||
local scale_f = reconstruct.scale
|
||||
local image_f = reconstruct.image
|
||||
local detail_fp = nil
|
||||
local am = nil
|
||||
local model1_am = 0
|
||||
local model2_am = 0
|
||||
|
||||
if opt.method == "user" or opt.method == "diff" then
|
||||
am = create_metric(opt.metric)
|
||||
end
|
||||
if opt.save_info then
|
||||
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
||||
end
|
||||
|
@ -406,32 +440,57 @@ local function benchmark(opt, x, model1, model2)
|
|||
ground_truth = remove_border(ground_truth, opt.border)
|
||||
model1_output = remove_border(model1_output, opt.border)
|
||||
end
|
||||
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||
model1_mse = model1_mse + mse1
|
||||
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||
|
||||
if am then
|
||||
am1 = am.func(ground_truth, model1_output)
|
||||
model1_am = model1_am + am1
|
||||
else
|
||||
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||
model1_mse = model1_mse + mse1
|
||||
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||
end
|
||||
local won_model = 1
|
||||
if model2 then
|
||||
if opt.border > 0 then
|
||||
model2_output = remove_border(model2_output, opt.border)
|
||||
end
|
||||
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||
model2_mse = model2_mse + mse2
|
||||
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
||||
|
||||
if mse1 < mse2 then
|
||||
won[1] = won[1] + 1
|
||||
elseif mse1 > mse2 then
|
||||
won[2] = won[2] + 1
|
||||
won_model = 2
|
||||
if am then
|
||||
am2 = am.func(ground_truth, model2_output)
|
||||
model2_am = model2_am + am2
|
||||
else
|
||||
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||
model2_mse = model2_mse + mse2
|
||||
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
||||
end
|
||||
if am then
|
||||
if am1 < am2 then
|
||||
won[1] = won[1] + 1
|
||||
elseif am1 > am2 then
|
||||
won[2] = won[2] + 1
|
||||
won_model = 2
|
||||
end
|
||||
else
|
||||
if mse1 < mse2 then
|
||||
won[1] = won[1] + 1
|
||||
elseif mse1 > mse2 then
|
||||
won[2] = won[2] + 1
|
||||
won_model = 2
|
||||
end
|
||||
end
|
||||
if detail_fp then
|
||||
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
||||
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
||||
if am then
|
||||
detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
|
||||
else
|
||||
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
||||
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
||||
end
|
||||
end
|
||||
else
|
||||
if detail_fp then
|
||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
||||
if am then
|
||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
|
||||
else
|
||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
||||
end
|
||||
end
|
||||
end
|
||||
if baseline_output then
|
||||
|
@ -455,46 +514,65 @@ local function benchmark(opt, x, model1, model2)
|
|||
end
|
||||
end
|
||||
if opt.show_progress or i == #x then
|
||||
if model2 then
|
||||
if baseline_output then
|
||||
if am then
|
||||
if model2 then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(baseline_mse / i),
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
baseline_psnr / i,
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
am.name, model1_am / i, am.name, model2_am / i
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
||||
string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
am.name, model1_am / i
|
||||
))
|
||||
end
|
||||
else
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
||||
baseline_psnr / i, model1_psnr / i
|
||||
if model2 then
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(baseline_mse / i),
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
baseline_psnr / i,
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
end
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(model1_mse / i), model1_psnr / i
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
||||
baseline_psnr / i, model1_psnr / i
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(model1_mse / i), model1_psnr / i
|
||||
))
|
||||
end
|
||||
end
|
||||
end
|
||||
io.stdout:flush()
|
||||
|
@ -515,6 +593,14 @@ local function benchmark(opt, x, model1, model2)
|
|||
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
|
||||
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
||||
end
|
||||
if model1_am > 0 then
|
||||
fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n",
|
||||
math.sqrt(model1_am / #x), model1_time))
|
||||
end
|
||||
if model2_am > 0 then
|
||||
fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n",
|
||||
math.sqrt(model2_am / #x), model2_time))
|
||||
end
|
||||
fp:close()
|
||||
if detail_fp then
|
||||
detail_fp:close()
|
||||
|
|
124
train.lua
124
train.lua
|
@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
|
|||
end
|
||||
end
|
||||
local function split_data(x, test_size)
|
||||
local index = torch.randperm(#x)
|
||||
local train_size = #x - test_size
|
||||
local train_x = {}
|
||||
local valid_x = {}
|
||||
for i = 1, train_size do
|
||||
train_x[i] = x[index[i]]
|
||||
if settings.validation_filename_split then
|
||||
if not (x[1][2].data and x[1][2].data.basename) then
|
||||
error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
|
||||
end
|
||||
local basename_db = {}
|
||||
for i = 1, #x do
|
||||
local meta = x[i][2].data
|
||||
if basename_db[meta.basename] then
|
||||
table.insert(basename_db[meta.basename], x[i])
|
||||
else
|
||||
basename_db[meta.basename] = {x[i]}
|
||||
end
|
||||
end
|
||||
local basename_list = {}
|
||||
for k, v in pairs(basename_db) do
|
||||
table.insert(basename_list, v)
|
||||
end
|
||||
local index = torch.randperm(#basename_list)
|
||||
local train_x = {}
|
||||
local valid_x = {}
|
||||
local pos = 1
|
||||
for i = 1, #basename_list do
|
||||
if #valid_x >= test_size then
|
||||
break
|
||||
end
|
||||
local xs = basename_list[index[pos]]
|
||||
for j = 1, #xs do
|
||||
table.insert(valid_x, xs[j])
|
||||
end
|
||||
pos = pos + 1
|
||||
end
|
||||
for i = pos, #basename_list do
|
||||
local xs = basename_list[index[i]]
|
||||
for j = 1, #xs do
|
||||
table.insert(train_x, xs[j])
|
||||
end
|
||||
end
|
||||
return train_x, valid_x
|
||||
else
|
||||
local index = torch.randperm(#x)
|
||||
local train_size = #x - test_size
|
||||
local train_x = {}
|
||||
local valid_x = {}
|
||||
for i = 1, train_size do
|
||||
train_x[i] = x[index[i]]
|
||||
end
|
||||
for i = 1, test_size do
|
||||
valid_x[i] = x[index[train_size + i]]
|
||||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
for i = 1, test_size do
|
||||
valid_x[i] = x[index[train_size + i]]
|
||||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
|
||||
local g_transform_pool = nil
|
||||
|
@ -175,35 +215,19 @@ local function transform_pool_init(has_resize, offset)
|
|||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
elseif settings.method == "user" then
|
||||
if is_validation == nil then is_validation = false end
|
||||
local rotate_rate = nil
|
||||
local scale_rate = nil
|
||||
local negate_rate = nil
|
||||
local negate_x_rate = nil
|
||||
if is_validation then
|
||||
rotate_rate = 0
|
||||
scale_rate = 0
|
||||
negate_rate = 0
|
||||
negate_x_rate = 0
|
||||
else
|
||||
rotate_rate = settings.random_pairwise_rotate_rate
|
||||
scale_rate = settings.random_pairwise_scale_rate
|
||||
negate_rate = settings.random_pairwise_negate_rate
|
||||
negate_x_rate = settings.random_pairwise_negate_x_rate
|
||||
end
|
||||
local conf = tablex.update({
|
||||
gcn = settings.gcn,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
random_pairwise_rotate_rate = rotate_rate,
|
||||
random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
|
||||
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
|
||||
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
|
||||
random_pairwise_scale_rate = scale_rate,
|
||||
random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
|
||||
random_pairwise_scale_min = settings.random_pairwise_scale_min,
|
||||
random_pairwise_scale_max = settings.random_pairwise_scale_max,
|
||||
random_pairwise_negate_rate = negate_rate,
|
||||
random_pairwise_negate_x_rate = negate_x_rate,
|
||||
random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
|
||||
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
|
||||
pairwise_y_binary = settings.pairwise_y_binary,
|
||||
pairwise_flip = settings.pairwise_flip,
|
||||
rgb = (settings.color == "rgb")}, meta)
|
||||
|
@ -290,7 +314,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|||
local batch_mse = eval_metric:forward(z, targets)
|
||||
loss = loss + criterion:forward(z, targets)
|
||||
mse = mse + batch_mse
|
||||
psnr = psnr + (10 * math.log10(1 / batch_mse))
|
||||
psnr = psnr + (10 * math.log10(1 / (batch_mse + 1.0e-6)))
|
||||
loss_count = loss_count + 1
|
||||
if loss_count % 10 == 0 then
|
||||
xlua.progress(t, #data)
|
||||
|
@ -322,6 +346,10 @@ local function create_criterion(model)
|
|||
return w2nn.L1Criterion():cuda()
|
||||
elseif settings.loss == "mse" then
|
||||
return w2nn.ClippedMSECriterion(0, 1.0):cuda()
|
||||
elseif settings.loss == "bce" then
|
||||
local bce = nn.BCECriterion()
|
||||
bce.sizeAverage = true
|
||||
return bce:cuda()
|
||||
else
|
||||
error("unsupported loss .." .. settings.loss)
|
||||
end
|
||||
|
@ -421,7 +449,10 @@ local function plot(train, valid)
|
|||
{'validation', torch.Tensor(valid), '-'}})
|
||||
end
|
||||
local function train()
|
||||
local x = remove_small_image(torch.load(settings.images))
|
||||
local x = torch.load(settings.images)
|
||||
if settings.method ~= "user" then
|
||||
x = remove_small_image(x)
|
||||
end
|
||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||
local hist_train = {}
|
||||
local hist_valid = {}
|
||||
|
@ -429,7 +460,12 @@ local function train()
|
|||
if settings.resume:len() > 0 then
|
||||
model = torch.load(settings.resume, "ascii")
|
||||
else
|
||||
model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||
if stringx.endswith(settings.model, ".lua") then
|
||||
local create_model = dofile(settings.model)
|
||||
model = create_model(srcnn, settings)
|
||||
else
|
||||
model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||
end
|
||||
end
|
||||
if model.w2nn_input_size then
|
||||
if settings.crop_size ~= model.w2nn_input_size then
|
||||
|
@ -484,8 +520,9 @@ local function train()
|
|||
ch, settings.crop_size, settings.crop_size)
|
||||
end
|
||||
local instance_loss = nil
|
||||
local pmodel = w2nn.data_parallel(model, settings.gpu)
|
||||
for epoch = 1, settings.epoch do
|
||||
model:training()
|
||||
pmodel:training()
|
||||
print("# " .. epoch)
|
||||
if adam_config.learningRate then
|
||||
print("learning rate: " .. adam_config.learningRate)
|
||||
|
@ -523,13 +560,13 @@ local function train()
|
|||
instance_loss = torch.Tensor(x:size(1)):zero()
|
||||
|
||||
for i = 1, settings.inner_epoch do
|
||||
model:training()
|
||||
local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
||||
pmodel:training()
|
||||
local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config)
|
||||
instance_loss:copy(il)
|
||||
print(train_score)
|
||||
model:evaluate()
|
||||
pmodel:evaluate()
|
||||
print("# validation")
|
||||
local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
|
||||
local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
|
||||
table.insert(hist_train, train_score.loss)
|
||||
table.insert(hist_valid, score.loss)
|
||||
if settings.plot then
|
||||
|
@ -546,8 +583,9 @@ local function train()
|
|||
best_score = score_for_update
|
||||
print("* model has updated")
|
||||
if settings.save_history then
|
||||
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
||||
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
|
||||
pmodel:clearState()
|
||||
torch.save(settings.model_file_best, model, "ascii")
|
||||
torch.save(string.format(settings.model_file, epoch, i), model, "ascii")
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.%d-%d.png"):format(settings.noise_level,
|
||||
|
@ -571,7 +609,8 @@ local function train()
|
|||
save_test_user(model, test_image, log)
|
||||
end
|
||||
else
|
||||
torch.save(settings.model_file, model:clearState(), "ascii")
|
||||
pmodel:clearState()
|
||||
torch.save(settings.model_file, model, "ascii")
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.png"):format(settings.noise_level))
|
||||
|
@ -597,9 +636,6 @@ local function train()
|
|||
end
|
||||
end
|
||||
end
|
||||
if settings.gpu > 0 then
|
||||
cutorch.setDevice(settings.gpu)
|
||||
end
|
||||
torch.manualSeed(settings.seed)
|
||||
cutorch.manualSeed(settings.seed)
|
||||
print(settings)
|
||||
|
|
|
@ -276,6 +276,7 @@ local function waifu2x()
|
|||
if opt.thread > 0 then
|
||||
torch.setnumthreads(opt.thread)
|
||||
end
|
||||
cutorch.setDevice(opt.gpu)
|
||||
if cudnn then
|
||||
cudnn.fastest = true
|
||||
if opt.l:len() > 0 then
|
||||
|
@ -293,6 +294,5 @@ local function waifu2x()
|
|||
else
|
||||
convert_frames(opt)
|
||||
end
|
||||
cutorch.setDevice(opt.gpu)
|
||||
end
|
||||
waifu2x()
|
||||
|
|
Loading…
Reference in a new issue