merge
This commit is contained in:
commit
fa6ee00624
|
@ -63,7 +63,24 @@ local function crop_if_large_pair(x, y, max_size)
|
||||||
return x, y
|
return x, y
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
local function padding_x(x, pad)
|
||||||
|
if pad > 0 then
|
||||||
|
x = iproc.padding(x, pad, pad, pad, pad)
|
||||||
|
end
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
local function padding_xy(x, y, pad, y_zero)
|
||||||
|
local scale = y:size(2) / x:size(2)
|
||||||
|
if pad > 0 then
|
||||||
|
x = iproc.padding(x, pad, pad, pad, pad)
|
||||||
|
if y_zero then
|
||||||
|
y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
|
||||||
|
else
|
||||||
|
y = iproc.padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
local function load_images(list)
|
local function load_images(list)
|
||||||
local MARGIN = 32
|
local MARGIN = 32
|
||||||
local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
|
local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
|
||||||
|
@ -78,6 +95,7 @@ local function load_images(list)
|
||||||
if csv_meta and csv_meta.filters then
|
if csv_meta and csv_meta.filters then
|
||||||
filters = csv_meta.filters
|
filters = csv_meta.filters
|
||||||
end
|
end
|
||||||
|
local basename_y = path.basename(filename)
|
||||||
local im, meta = image_loader.load_byte(filename)
|
local im, meta = image_loader.load_byte(filename)
|
||||||
local skip = false
|
local skip = false
|
||||||
local alpha_color = torch.random(0, 1)
|
local alpha_color = torch.random(0, 1)
|
||||||
|
@ -100,25 +118,38 @@ local function load_images(list)
|
||||||
-- method == user
|
-- method == user
|
||||||
local yy = im
|
local yy = im
|
||||||
local xx, meta2 = image_loader.load_byte(csv_meta.x)
|
local xx, meta2 = image_loader.load_byte(csv_meta.x)
|
||||||
|
if settings.invert_x then
|
||||||
|
xx = (-(xx:long()) + 255):byte()
|
||||||
|
end
|
||||||
|
|
||||||
if xx then
|
if xx then
|
||||||
if meta2 and meta2.alpha then
|
if meta2 and meta2.alpha then
|
||||||
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
|
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
|
||||||
end
|
end
|
||||||
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
|
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
|
||||||
|
xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
|
||||||
|
if settings.grayscale then
|
||||||
|
xx = iproc.rgb2y(xx)
|
||||||
|
yy = iproc.rgb2y(yy)
|
||||||
|
end
|
||||||
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
||||||
{data = {filters = filters, has_x = true}}})
|
{data = {filters = filters, has_x = true, basename = basename_y}}})
|
||||||
else
|
else
|
||||||
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
|
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
im = crop_if_large(im, settings.max_training_image_size)
|
im = crop_if_large(im, settings.max_training_image_size)
|
||||||
im = iproc.crop_mod4(im)
|
im = iproc.crop_mod4(im)
|
||||||
|
im = padding_x(im, settings.padding)
|
||||||
local scale = 1.0
|
local scale = 1.0
|
||||||
if settings.random_half_rate > 0.0 then
|
if settings.random_half_rate > 0.0 then
|
||||||
scale = 2.0
|
scale = 2.0
|
||||||
end
|
end
|
||||||
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
|
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
|
||||||
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
|
if settings.grayscale then
|
||||||
|
im = iproc.rgb2y(im)
|
||||||
|
end
|
||||||
|
table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
|
||||||
else
|
else
|
||||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
|
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
|
||||||
end
|
end
|
||||||
|
|
42
lib/ShakeShakeTable.lua
Normal file
42
lib/ShakeShakeTable.lua
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')
|
||||||
|
|
||||||
|
function ShakeShakeTable:__init()
|
||||||
|
parent.__init(self)
|
||||||
|
self.alpha = torch.Tensor()
|
||||||
|
self.beta = torch.Tensor()
|
||||||
|
self.first = torch.Tensor()
|
||||||
|
self.second = torch.Tensor()
|
||||||
|
self.train = true
|
||||||
|
end
|
||||||
|
function ShakeShakeTable:updateOutput(input)
|
||||||
|
local batch_size = input[1]:size(1)
|
||||||
|
if self.train then
|
||||||
|
self.alpha:resize(batch_size):uniform()
|
||||||
|
self.beta:resize(batch_size):uniform()
|
||||||
|
self.second:resizeAs(input[1]):copy(input[2])
|
||||||
|
for i = 1, batch_size do
|
||||||
|
self.second[i]:mul(self.alpha[i])
|
||||||
|
end
|
||||||
|
self.output:resizeAs(input[1]):copy(input[1])
|
||||||
|
for i = 1, batch_size do
|
||||||
|
self.output[i]:mul(1.0 - self.alpha[i])
|
||||||
|
end
|
||||||
|
self.output:add(self.second):mul(2)
|
||||||
|
else
|
||||||
|
self.output:resizeAs(input[1]):copy(input[1]):add(input[2])
|
||||||
|
end
|
||||||
|
return self.output
|
||||||
|
end
|
||||||
|
function ShakeShakeTable:updateGradInput(input, gradOutput)
|
||||||
|
local batch_size = input[1]:size(1)
|
||||||
|
self.first:resizeAs(gradOutput):copy(gradOutput)
|
||||||
|
for i = 1, batch_size do
|
||||||
|
self.first[i]:mul(self.beta[i])
|
||||||
|
end
|
||||||
|
self.second:resizeAs(gradOutput):copy(gradOutput)
|
||||||
|
for i = 1, batch_size do
|
||||||
|
self.second[i]:mul(1.0 - self.beta[i])
|
||||||
|
end
|
||||||
|
self.gradOutput = {self.first, self.second}
|
||||||
|
return self.gradOutput
|
||||||
|
end
|
|
@ -102,7 +102,9 @@ function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
|
||||||
local scale = torch.uniform(scale_min, scale_max)
|
local scale = torch.uniform(scale_min, scale_max)
|
||||||
local h = math.floor(x:size(2) * scale)
|
local h = math.floor(x:size(2) * scale)
|
||||||
local w = math.floor(x:size(3) * scale)
|
local w = math.floor(x:size(3) * scale)
|
||||||
x = iproc.scale(x, w, h, "Triangle")
|
local filters = {"Lanczos", "Catrom"}
|
||||||
|
local x_filter = filters[torch.random(1, 2)]
|
||||||
|
x = iproc.scale(x, w, h, x_filter)
|
||||||
y = iproc.scale(y, w, h, "Triangle")
|
y = iproc.scale(y, w, h, "Triangle")
|
||||||
return x, y
|
return x, y
|
||||||
else
|
else
|
||||||
|
@ -139,6 +141,36 @@ function data_augmentation.pairwise_negate_x(x, y, p)
|
||||||
return x, y
|
return x, y
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
function data_augmentation.pairwise_flip(x, y)
|
||||||
|
local flip = torch.random(1, 4)
|
||||||
|
local tr = torch.random(1, 2)
|
||||||
|
local x, conversion = iproc.byte2float(x)
|
||||||
|
y = iproc.byte2float(y)
|
||||||
|
x = x:contiguous()
|
||||||
|
y = y:contiguous()
|
||||||
|
if tr == 1 then
|
||||||
|
-- pass
|
||||||
|
elseif tr == 2 then
|
||||||
|
x = x:transpose(2, 3):contiguous()
|
||||||
|
y = y:transpose(2, 3):contiguous()
|
||||||
|
end
|
||||||
|
if flip == 1 then
|
||||||
|
x = iproc.hflip(x)
|
||||||
|
y = iproc.hflip(y)
|
||||||
|
elseif flip == 2 then
|
||||||
|
x = iproc.vflip(x)
|
||||||
|
y = iproc.vflip(y)
|
||||||
|
elseif flip == 3 then
|
||||||
|
x = iproc.hflip(iproc.vflip(x))
|
||||||
|
y = iproc.hflip(iproc.vflip(y))
|
||||||
|
elseif flip == 4 then
|
||||||
|
end
|
||||||
|
if conversion then
|
||||||
|
x = iproc.float2byte(x)
|
||||||
|
y = iproc.float2byte(y)
|
||||||
|
end
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
function data_augmentation.shift_1px(src)
|
function data_augmentation.shift_1px(src)
|
||||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||||
local direction = torch.random(1, 4)
|
local direction = torch.random(1, 4)
|
||||||
|
|
|
@ -80,6 +80,8 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
|
||||||
return dest
|
return dest
|
||||||
end
|
end
|
||||||
function iproc.padding(img, w1, w2, h1, h2)
|
function iproc.padding(img, w1, w2, h1, h2)
|
||||||
|
local conversion
|
||||||
|
img, conversion = iproc.byte2float(img)
|
||||||
image = image or require 'image'
|
image = image or require 'image'
|
||||||
local dst_height = img:size(2) + h1 + h2
|
local dst_height = img:size(2) + h1 + h2
|
||||||
local dst_width = img:size(3) + w1 + w2
|
local dst_width = img:size(3) + w1 + w2
|
||||||
|
@ -88,9 +90,15 @@ function iproc.padding(img, w1, w2, h1, h2)
|
||||||
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
||||||
flow[1]:add(-h1)
|
flow[1]:add(-h1)
|
||||||
flow[2]:add(-w1)
|
flow[2]:add(-w1)
|
||||||
return image.warp(img, flow, "simple", false, "clamp")
|
local dest = image.warp(img, flow, "simple", false, "clamp")
|
||||||
|
if conversion then
|
||||||
|
dest = iproc.float2byte(dest)
|
||||||
|
end
|
||||||
|
return dest
|
||||||
end
|
end
|
||||||
function iproc.zero_padding(img, w1, w2, h1, h2)
|
function iproc.zero_padding(img, w1, w2, h1, h2)
|
||||||
|
local conversion
|
||||||
|
img, conversion = iproc.byte2float(img)
|
||||||
image = image or require 'image'
|
image = image or require 'image'
|
||||||
local dst_height = img:size(2) + h1 + h2
|
local dst_height = img:size(2) + h1 + h2
|
||||||
local dst_width = img:size(3) + w1 + w2
|
local dst_width = img:size(3) + w1 + w2
|
||||||
|
@ -99,7 +107,11 @@ function iproc.zero_padding(img, w1, w2, h1, h2)
|
||||||
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
||||||
flow[1]:add(-h1)
|
flow[1]:add(-h1)
|
||||||
flow[2]:add(-w1)
|
flow[2]:add(-w1)
|
||||||
return image.warp(img, flow, "simple", false, "pad", 0)
|
local dest = image.warp(img, flow, "simple", false, "pad", 0)
|
||||||
|
if conversion then
|
||||||
|
dest = iproc.float2byte(dest)
|
||||||
|
end
|
||||||
|
return dest
|
||||||
end
|
end
|
||||||
function iproc.white_noise(src, std, rgb_weights, gamma)
|
function iproc.white_noise(src, std, rgb_weights, gamma)
|
||||||
gamma = gamma or 0.454545
|
gamma = gamma or 0.454545
|
||||||
|
@ -217,6 +229,7 @@ function iproc.rgb2y(src)
|
||||||
src, conversion = iproc.byte2float(src)
|
src, conversion = iproc.byte2float(src)
|
||||||
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
|
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
|
||||||
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
|
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
|
||||||
|
dest:clamp(0, 1)
|
||||||
if conversion then
|
if conversion then
|
||||||
dest = iproc.float2byte(dest)
|
dest = iproc.float2byte(dest)
|
||||||
end
|
end
|
||||||
|
|
|
@ -43,8 +43,10 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
||||||
yc = iproc.byte2float(yc)
|
yc = iproc.byte2float(yc)
|
||||||
if options.rgb then
|
if options.rgb then
|
||||||
else
|
else
|
||||||
yc = iproc.rgb2y(yc)
|
if xc:size(1) > 1 then
|
||||||
xc = iproc.rgb2y(xc)
|
yc = iproc.rgb2y(yc)
|
||||||
|
xc = iproc.rgb2y(xc)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
if torch.uniform() < options.nr_rate then
|
if torch.uniform() < options.nr_rate then
|
||||||
-- reducing noise
|
-- reducing noise
|
||||||
|
|
|
@ -51,8 +51,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||||
yc = iproc.byte2float(yc)
|
yc = iproc.byte2float(yc)
|
||||||
if options.rgb then
|
if options.rgb then
|
||||||
else
|
else
|
||||||
yc = iproc.rgb2y(yc)
|
if xc:size(1) > 1 then
|
||||||
xc = iproc.rgb2y(xc)
|
yc = iproc.rgb2y(yc)
|
||||||
|
xc = iproc.rgb2y(xc)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
local pairwise_utils = require 'pairwise_transform_utils'
|
local pairwise_utils = require 'pairwise_transform_utils'
|
||||||
|
local data_augmentation = require 'data_augmentation'
|
||||||
local iproc = require 'iproc'
|
local iproc = require 'iproc'
|
||||||
local gm = {}
|
local gm = {}
|
||||||
gm.Image = require 'graphicsmagick.Image'
|
gm.Image = require 'graphicsmagick.Image'
|
||||||
|
@ -21,12 +22,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
||||||
if options.active_cropping_rate > 0 then
|
if options.active_cropping_rate > 0 then
|
||||||
lowres_y = pairwise_utils.low_resolution(y)
|
lowres_y = pairwise_utils.low_resolution(y)
|
||||||
end
|
end
|
||||||
if options.pairwise_flip then
|
if options.pairwise_flip and n == 1 then
|
||||||
|
xs[1], ys[1] = data_augmentation.pairwise_flip(xs[1], ys[1])
|
||||||
|
elseif options.pairwise_flip then
|
||||||
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
||||||
end
|
end
|
||||||
assert(#xs == #ys)
|
assert(#xs == #ys)
|
||||||
|
local perm = torch.randperm(#xs)
|
||||||
for i = 1, n do
|
for i = 1, n do
|
||||||
local t = (i % #xs) + 1
|
local t = perm[(i % #xs) + 1]
|
||||||
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
|
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
|
||||||
options.active_cropping_rate,
|
options.active_cropping_rate,
|
||||||
options.active_cropping_tries)
|
options.active_cropping_tries)
|
||||||
|
@ -34,8 +38,10 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
||||||
yc = iproc.byte2float(yc)
|
yc = iproc.byte2float(yc)
|
||||||
if options.rgb then
|
if options.rgb then
|
||||||
else
|
else
|
||||||
yc = iproc.rgb2y(yc)
|
if xc:size(1) > 1 then
|
||||||
xc = iproc.rgb2y(xc)
|
yc = iproc.rgb2y(yc)
|
||||||
|
xc = iproc.rgb2y(xc)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
if options.gcn then
|
if options.gcn then
|
||||||
local mean = xc:mean()
|
local mean = xc:mean()
|
||||||
|
@ -46,7 +52,12 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
||||||
xc:add(-mean)
|
xc:add(-mean)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
yc = iproc.crop(yc, offset, offset, size - offset, size - offset)
|
||||||
|
if options.pairwise_y_binary then
|
||||||
|
yc[torch.lt(yc, 0.5)] = 0
|
||||||
|
yc[torch.gt(yc, 0)] = 1
|
||||||
|
end
|
||||||
|
table.insert(batch, {xc, yc})
|
||||||
end
|
end
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
|
@ -108,12 +108,6 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
|
||||||
|
|
||||||
x = iproc.crop_mod4(x)
|
x = iproc.crop_mod4(x)
|
||||||
y = iproc.crop_mod4(y)
|
y = iproc.crop_mod4(y)
|
||||||
|
|
||||||
if options.pairwise_y_binary then
|
|
||||||
y[torch.lt(y, 128)] = 0
|
|
||||||
y[torch.gt(y, 0)] = 255
|
|
||||||
end
|
|
||||||
|
|
||||||
return x, y
|
return x, y
|
||||||
end
|
end
|
||||||
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
||||||
|
@ -125,8 +119,14 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
|
||||||
t = "byte"
|
t = "byte"
|
||||||
end
|
end
|
||||||
if p < r then
|
if p < r then
|
||||||
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
|
local xi = 0
|
||||||
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
|
local yi = 0
|
||||||
|
if x:size(3) > size + 1 then
|
||||||
|
xi = torch.random(0, x:size(3) - (size + 1)) * scale
|
||||||
|
end
|
||||||
|
if x:size(2) > size + 1 then
|
||||||
|
yi = torch.random(0, x:size(2) - (size + 1)) * scale
|
||||||
|
end
|
||||||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||||
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
|
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
|
||||||
return xc, yc
|
return xc, yc
|
||||||
|
@ -273,10 +273,17 @@ function pairwise_transform_utils.low_resolution(src)
|
||||||
toTensor("byte", "RGB", "DHW")
|
toTensor("byte", "RGB", "DHW")
|
||||||
end
|
end
|
||||||
--]]
|
--]]
|
||||||
return gm.Image(src, "RGB", "DHW"):
|
if src:size(1) == 1 then
|
||||||
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
return gm.Image(src, "I", "DHW"):
|
||||||
size(src:size(3), src:size(2), "Box"):
|
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
||||||
toTensor("byte", "RGB", "DHW")
|
size(src:size(3), src:size(2), "Box"):
|
||||||
|
toTensor("byte", "I", "DHW")
|
||||||
|
else
|
||||||
|
return gm.Image(src, "RGB", "DHW"):
|
||||||
|
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
||||||
|
size(src:size(3), src:size(2), "Box"):
|
||||||
|
toTensor("byte", "RGB", "DHW")
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return pairwise_transform_utils
|
return pairwise_transform_utils
|
||||||
|
|
|
@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
|
||||||
cmd:text()
|
cmd:text()
|
||||||
cmd:text("waifu2x-training")
|
cmd:text("waifu2x-training")
|
||||||
cmd:text("Options:")
|
cmd:text("Options:")
|
||||||
cmd:option("-gpu", -1, 'GPU Device ID')
|
|
||||||
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
|
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
|
||||||
cmd:option("-data_dir", "./data", 'path to data directory')
|
cmd:option("-data_dir", "./data", 'path to data directory')
|
||||||
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
||||||
|
@ -74,9 +73,14 @@ cmd:option("-oracle_drop_rate", 0.5, '')
|
||||||
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
||||||
cmd:option("-resume", "", 'resume model file')
|
cmd:option("-resume", "", 'resume model file')
|
||||||
cmd:option("-name", "user", 'model name for user method')
|
cmd:option("-name", "user", 'model name for user method')
|
||||||
cmd:option("-gpu", 1, 'Device ID')
|
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
|
||||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
|
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
||||||
cmd:option("-update_criterion", "mse", 'mse|loss')
|
cmd:option("-update_criterion", "mse", 'mse|loss')
|
||||||
|
cmd:option("-padding", 0, 'replication padding size')
|
||||||
|
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
|
||||||
|
cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
|
||||||
|
cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
|
||||||
|
cmd:option("-invert_x", 0, 'invert x image in convert_lua')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
if settings[name] == 1 then
|
if settings[name] == 1 then
|
||||||
|
@ -95,6 +99,10 @@ to_bool(settings, "save_history")
|
||||||
to_bool(settings, "use_transparent_png")
|
to_bool(settings, "use_transparent_png")
|
||||||
to_bool(settings, "pairwise_y_binary")
|
to_bool(settings, "pairwise_y_binary")
|
||||||
to_bool(settings, "pairwise_flip")
|
to_bool(settings, "pairwise_flip")
|
||||||
|
to_bool(settings, "padding_y_zero")
|
||||||
|
to_bool(settings, "grayscale")
|
||||||
|
to_bool(settings, "validation_filename_split")
|
||||||
|
to_bool(settings, "invert_x")
|
||||||
|
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
require 'gnuplot'
|
require 'gnuplot'
|
||||||
|
@ -168,10 +176,20 @@ end
|
||||||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||||
|
|
||||||
cutorch.setDevice(opt.gpu)
|
|
||||||
-- patch for lua52
|
-- patch for lua52
|
||||||
if not math.log10 then
|
if not math.log10 then
|
||||||
math.log10 = function(x) return math.log(x, 10) end
|
math.log10 = function(x) return math.log(x, 10) end
|
||||||
end
|
end
|
||||||
|
if settings.gpu:len() > 0 then
|
||||||
|
local gpus = {}
|
||||||
|
local gpu_string = utils.split(settings.gpu, ",")
|
||||||
|
for i = 1, #gpu_string do
|
||||||
|
table.insert(gpus, tonumber(gpu_string[i]))
|
||||||
|
end
|
||||||
|
settings.gpu = gpus
|
||||||
|
else
|
||||||
|
settings.gpu = {1}
|
||||||
|
end
|
||||||
|
cutorch.setDevice(settings.gpu[1])
|
||||||
|
|
||||||
return settings
|
return settings
|
||||||
|
|
|
@ -4,34 +4,52 @@ require 'w2nn'
|
||||||
-- ref: http://arxiv.org/abs/1501.00092
|
-- ref: http://arxiv.org/abs/1501.00092
|
||||||
local srcnn = {}
|
local srcnn = {}
|
||||||
|
|
||||||
function nn.SpatialConvolutionMM:reset(stdv)
|
local function msra_filler(mod)
|
||||||
local fin = self.kW * self.kH * self.nInputPlane
|
local fin = mod.kW * mod.kH * mod.nInputPlane
|
||||||
local fout = self.kW * self.kH * self.nOutputPlane
|
local fout = mod.kW * mod.kH * mod.nOutputPlane
|
||||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||||
self.weight:normal(0, stdv)
|
mod.weight:normal(0, stdv)
|
||||||
self.bias:zero()
|
mod.bias:zero()
|
||||||
|
end
|
||||||
|
local function identity_filler(mod)
|
||||||
|
assert(mod.nInputPlane <= mod.nOutputPlane)
|
||||||
|
mod.weight:normal(0, 0.01)
|
||||||
|
mod.bias:zero()
|
||||||
|
local num_groups = mod.nInputPlane -- fixed
|
||||||
|
local filler_value = num_groups / mod.nOutputPlane
|
||||||
|
local in_group_size = math.floor(mod.nInputPlane / num_groups)
|
||||||
|
local out_group_size = math.floor(mod.nOutputPlane / num_groups)
|
||||||
|
local x = math.floor(mod.kW / 2)
|
||||||
|
local y = math.floor(mod.kH / 2)
|
||||||
|
for i = 0, num_groups - 1 do
|
||||||
|
for j = i * out_group_size, (i + 1) * out_group_size - 1 do
|
||||||
|
for k = i * in_group_size, (i + 1) * in_group_size - 1 do
|
||||||
|
mod.weight[j+1][k+1][y+1][x+1] = filler_value
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function nn.SpatialConvolutionMM:reset(stdv)
|
||||||
|
msra_filler(self)
|
||||||
end
|
end
|
||||||
function nn.SpatialFullConvolution:reset(stdv)
|
function nn.SpatialFullConvolution:reset(stdv)
|
||||||
local fin = self.kW * self.kH * self.nInputPlane
|
msra_filler(self)
|
||||||
local fout = self.kW * self.kH * self.nOutputPlane
|
|
||||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
|
||||||
self.weight:normal(0, stdv)
|
|
||||||
self.bias:zero()
|
|
||||||
end
|
end
|
||||||
|
function nn.SpatialDilatedConvolution:reset(stdv)
|
||||||
|
identity_filler(self)
|
||||||
|
end
|
||||||
|
|
||||||
if cudnn and cudnn.SpatialConvolution then
|
if cudnn and cudnn.SpatialConvolution then
|
||||||
function cudnn.SpatialConvolution:reset(stdv)
|
function cudnn.SpatialConvolution:reset(stdv)
|
||||||
local fin = self.kW * self.kH * self.nInputPlane
|
msra_filler(self)
|
||||||
local fout = self.kW * self.kH * self.nOutputPlane
|
|
||||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
|
||||||
self.weight:normal(0, stdv)
|
|
||||||
self.bias:zero()
|
|
||||||
end
|
end
|
||||||
function cudnn.SpatialFullConvolution:reset(stdv)
|
function cudnn.SpatialFullConvolution:reset(stdv)
|
||||||
local fin = self.kW * self.kH * self.nInputPlane
|
msra_filler(self)
|
||||||
local fout = self.kW * self.kH * self.nOutputPlane
|
end
|
||||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
if cudnn.SpatialDilatedConvolution then
|
||||||
self.weight:normal(0, stdv)
|
function cudnn.SpatialDilatedConvolution:reset(stdv)
|
||||||
self.bias:zero()
|
identity_filler(self)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
function nn.SpatialConvolutionMM:clearState()
|
function nn.SpatialConvolutionMM:clearState()
|
||||||
|
@ -127,6 +145,8 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
|
||||||
error("unsupported backend:" .. backend)
|
error("unsupported backend:" .. backend)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
srcnn.SpatialConvolution = SpatialConvolution
|
||||||
|
|
||||||
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
|
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
|
||||||
if backend == "cunn" then
|
if backend == "cunn" then
|
||||||
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
|
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
|
||||||
|
@ -136,6 +156,8 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
|
||||||
error("unsupported backend:" .. backend)
|
error("unsupported backend:" .. backend)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
srcnn.SpatialFullConvolution = SpatialFullConvolution
|
||||||
|
|
||||||
local function ReLU(backend)
|
local function ReLU(backend)
|
||||||
if backend == "cunn" then
|
if backend == "cunn" then
|
||||||
return nn.ReLU(true)
|
return nn.ReLU(true)
|
||||||
|
@ -145,6 +167,8 @@ local function ReLU(backend)
|
||||||
error("unsupported backend:" .. backend)
|
error("unsupported backend:" .. backend)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
srcnn.ReLU = ReLU
|
||||||
|
|
||||||
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
||||||
if backend == "cunn" then
|
if backend == "cunn" then
|
||||||
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
||||||
|
@ -154,6 +178,35 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
||||||
error("unsupported backend:" .. backend)
|
error("unsupported backend:" .. backend)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
srcnn.SpatialMaxPooling = SpatialMaxPooling
|
||||||
|
|
||||||
|
local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
|
||||||
|
if backend == "cunn" then
|
||||||
|
return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
||||||
|
elseif backend == "cudnn" then
|
||||||
|
return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
||||||
|
else
|
||||||
|
error("unsupported backend:" .. backend)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
srcnn.SpatialAveragePooling = SpatialAveragePooling
|
||||||
|
|
||||||
|
local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||||
|
if backend == "cunn" then
|
||||||
|
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||||
|
elseif backend == "cudnn" then
|
||||||
|
if cudnn.SpatialDilatedConvolution then
|
||||||
|
-- cudnn v 6
|
||||||
|
return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||||
|
else
|
||||||
|
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
error("unsupported backend:" .. backend)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
|
||||||
|
|
||||||
|
|
||||||
-- VGG style net(7 layers)
|
-- VGG style net(7 layers)
|
||||||
function srcnn.vgg_7(backend, ch)
|
function srcnn.vgg_7(backend, ch)
|
||||||
|
@ -548,6 +601,7 @@ function srcnn.create(model_name, backend, color)
|
||||||
error("unsupported model_name: " .. model_name)
|
error("unsupported model_name: " .. model_name)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
--[[
|
--[[
|
||||||
local model = srcnn.fcn_v1("cunn", 3):cuda()
|
local model = srcnn.fcn_v1("cunn", 3):cuda()
|
||||||
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
|
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
|
||||||
|
|
42
lib/w2nn.lua
42
lib/w2nn.lua
|
@ -9,6 +9,40 @@ end
|
||||||
local function load_cudnn()
|
local function load_cudnn()
|
||||||
cudnn = require('cudnn')
|
cudnn = require('cudnn')
|
||||||
end
|
end
|
||||||
|
local function make_data_parallel_table(model, gpus)
|
||||||
|
if cudnn then
|
||||||
|
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
|
||||||
|
local dpt = nn.DataParallelTable(1, true, true)
|
||||||
|
:add(model, gpus)
|
||||||
|
:threads(function()
|
||||||
|
require 'pl'
|
||||||
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||||
|
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||||
|
require 'torch'
|
||||||
|
require 'cunn'
|
||||||
|
require 'w2nn'
|
||||||
|
local cudnn = require 'cudnn'
|
||||||
|
cudnn.fastest, cudnn.benchmark = fastest, benchmark
|
||||||
|
end)
|
||||||
|
dpt.gradInput = nil
|
||||||
|
model = dpt:cuda()
|
||||||
|
else
|
||||||
|
local dpt = nn.DataParallelTable(1, true, true)
|
||||||
|
:add(model, gpus)
|
||||||
|
:threads(function()
|
||||||
|
require 'pl'
|
||||||
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||||
|
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||||
|
require 'torch'
|
||||||
|
require 'cunn'
|
||||||
|
require 'w2nn'
|
||||||
|
end)
|
||||||
|
dpt.gradInput = nil
|
||||||
|
model = dpt:cuda()
|
||||||
|
end
|
||||||
|
return model
|
||||||
|
end
|
||||||
|
|
||||||
if w2nn then
|
if w2nn then
|
||||||
return w2nn
|
return w2nn
|
||||||
else
|
else
|
||||||
|
@ -27,11 +61,19 @@ else
|
||||||
model:cuda():evaluate()
|
model:cuda():evaluate()
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
function w2nn.data_parallel(model, gpus)
|
||||||
|
if #gpus > 1 then
|
||||||
|
return make_data_parallel_table(model, gpus)
|
||||||
|
else
|
||||||
|
return model
|
||||||
|
end
|
||||||
|
end
|
||||||
require 'LeakyReLU'
|
require 'LeakyReLU'
|
||||||
require 'ClippedWeightedHuberCriterion'
|
require 'ClippedWeightedHuberCriterion'
|
||||||
require 'ClippedMSECriterion'
|
require 'ClippedMSECriterion'
|
||||||
require 'SSIMCriterion'
|
require 'SSIMCriterion'
|
||||||
require 'InplaceClip01'
|
require 'InplaceClip01'
|
||||||
require 'L1Criterion'
|
require 'L1Criterion'
|
||||||
|
require 'ShakeShakeTable'
|
||||||
return w2nn
|
return w2nn
|
||||||
end
|
end
|
||||||
|
|
|
@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
|
||||||
cmd:option("-x_file", "", 'input image for user method')
|
cmd:option("-x_file", "", 'input image for user method')
|
||||||
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
||||||
cmd:option("-border", 0, 'border px that will removed')
|
cmd:option("-border", 0, 'border px that will removed')
|
||||||
|
cmd:option("-metric", "", '(jaccard)')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
if settings[name] == 1 then
|
if settings[name] == 1 then
|
||||||
|
@ -203,8 +204,34 @@ local function remove_border(x, border)
|
||||||
x:size(3) - border,
|
x:size(3) - border,
|
||||||
x:size(2) - border)
|
x:size(2) - border)
|
||||||
end
|
end
|
||||||
|
local function create_metric(metric)
|
||||||
|
if metric and metric:len() > 0 then
|
||||||
|
if metric == "jaccard" then
|
||||||
|
return {
|
||||||
|
name = "jaccard",
|
||||||
|
func = function (a, b)
|
||||||
|
local ga = iproc.rgb2y(a)
|
||||||
|
local gb = iproc.rgb2y(b)
|
||||||
|
local ba = torch.Tensor():resizeAs(ga)
|
||||||
|
local bb = torch.Tensor():resizeAs(gb)
|
||||||
|
ba:zero()
|
||||||
|
bb:zero()
|
||||||
|
ba[torch.gt(ga, 0.5)] = 1.0
|
||||||
|
bb[torch.gt(gb, 0.5)] = 1.0
|
||||||
|
local num_a = ba:sum()
|
||||||
|
local num_b = bb:sum()
|
||||||
|
local a_and_b = ba:cmul(bb):sum()
|
||||||
|
return (a_and_b / (num_a + num_b - a_and_b))
|
||||||
|
end}
|
||||||
|
else
|
||||||
|
error("unknown metric: " .. metric)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
end
|
||||||
local function benchmark(opt, x, model1, model2)
|
local function benchmark(opt, x, model1, model2)
|
||||||
local mse1, mse2
|
local mse1, mse2, am1, am2
|
||||||
local won = {0, 0}
|
local won = {0, 0}
|
||||||
local model1_mse = 0
|
local model1_mse = 0
|
||||||
local model2_mse = 0
|
local model2_mse = 0
|
||||||
|
@ -217,6 +244,13 @@ local function benchmark(opt, x, model1, model2)
|
||||||
local scale_f = reconstruct.scale
|
local scale_f = reconstruct.scale
|
||||||
local image_f = reconstruct.image
|
local image_f = reconstruct.image
|
||||||
local detail_fp = nil
|
local detail_fp = nil
|
||||||
|
local am = nil
|
||||||
|
local model1_am = 0
|
||||||
|
local model2_am = 0
|
||||||
|
|
||||||
|
if opt.method == "user" or opt.method == "diff" then
|
||||||
|
am = create_metric(opt.metric)
|
||||||
|
end
|
||||||
if opt.save_info then
|
if opt.save_info then
|
||||||
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
||||||
end
|
end
|
||||||
|
@ -406,32 +440,57 @@ local function benchmark(opt, x, model1, model2)
|
||||||
ground_truth = remove_border(ground_truth, opt.border)
|
ground_truth = remove_border(ground_truth, opt.border)
|
||||||
model1_output = remove_border(model1_output, opt.border)
|
model1_output = remove_border(model1_output, opt.border)
|
||||||
end
|
end
|
||||||
mse1 = MSE(ground_truth, model1_output, opt.color)
|
if am then
|
||||||
model1_mse = model1_mse + mse1
|
am1 = am.func(ground_truth, model1_output)
|
||||||
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
model1_am = model1_am + am1
|
||||||
|
else
|
||||||
|
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||||
|
model1_mse = model1_mse + mse1
|
||||||
|
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||||
|
end
|
||||||
local won_model = 1
|
local won_model = 1
|
||||||
if model2 then
|
if model2 then
|
||||||
if opt.border > 0 then
|
if opt.border > 0 then
|
||||||
model2_output = remove_border(model2_output, opt.border)
|
model2_output = remove_border(model2_output, opt.border)
|
||||||
end
|
end
|
||||||
mse2 = MSE(ground_truth, model2_output, opt.color)
|
if am then
|
||||||
model2_mse = model2_mse + mse2
|
am2 = am.func(ground_truth, model2_output)
|
||||||
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
model2_am = model2_am + am2
|
||||||
|
else
|
||||||
if mse1 < mse2 then
|
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||||
won[1] = won[1] + 1
|
model2_mse = model2_mse + mse2
|
||||||
elseif mse1 > mse2 then
|
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
||||||
won[2] = won[2] + 1
|
end
|
||||||
won_model = 2
|
if am then
|
||||||
|
if am1 < am2 then
|
||||||
|
won[1] = won[1] + 1
|
||||||
|
elseif am1 > am2 then
|
||||||
|
won[2] = won[2] + 1
|
||||||
|
won_model = 2
|
||||||
|
end
|
||||||
|
else
|
||||||
|
if mse1 < mse2 then
|
||||||
|
won[1] = won[1] + 1
|
||||||
|
elseif mse1 > mse2 then
|
||||||
|
won[2] = won[2] + 1
|
||||||
|
won_model = 2
|
||||||
|
end
|
||||||
end
|
end
|
||||||
if detail_fp then
|
if detail_fp then
|
||||||
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
if am then
|
||||||
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
|
||||||
|
else
|
||||||
|
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
||||||
|
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
if detail_fp then
|
if detail_fp then
|
||||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
if am then
|
||||||
|
detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
|
||||||
|
else
|
||||||
|
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if baseline_output then
|
if baseline_output then
|
||||||
|
@ -455,46 +514,65 @@ local function benchmark(opt, x, model1, model2)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if opt.show_progress or i == #x then
|
if opt.show_progress or i == #x then
|
||||||
if model2 then
|
if am then
|
||||||
if baseline_output then
|
if model2 then
|
||||||
io.stdout:write(
|
io.stdout:write(
|
||||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
|
||||||
i, #x,
|
i, #x,
|
||||||
model1_time,
|
model1_time,
|
||||||
model2_time,
|
model2_time,
|
||||||
math.sqrt(baseline_mse / i),
|
am.name, model1_am / i, am.name, model2_am / i
|
||||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
))
|
||||||
baseline_psnr / i,
|
|
||||||
model1_psnr / i, model2_psnr / i,
|
|
||||||
won[1], won[2]
|
|
||||||
))
|
|
||||||
else
|
else
|
||||||
io.stdout:write(
|
io.stdout:write(
|
||||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
|
||||||
i, #x,
|
i, #x,
|
||||||
model1_time,
|
model1_time,
|
||||||
model2_time,
|
am.name, model1_am / i
|
||||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
))
|
||||||
model1_psnr / i, model2_psnr / i,
|
|
||||||
won[1], won[2]
|
|
||||||
))
|
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
if baseline_output then
|
if model2 then
|
||||||
io.stdout:write(
|
if baseline_output then
|
||||||
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
io.stdout:write(
|
||||||
i, #x,
|
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
||||||
model1_time,
|
i, #x,
|
||||||
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
model1_time,
|
||||||
baseline_psnr / i, model1_psnr / i
|
model2_time,
|
||||||
|
math.sqrt(baseline_mse / i),
|
||||||
|
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||||
|
baseline_psnr / i,
|
||||||
|
model1_psnr / i, model2_psnr / i,
|
||||||
|
won[1], won[2]
|
||||||
))
|
))
|
||||||
|
else
|
||||||
|
io.stdout:write(
|
||||||
|
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
||||||
|
i, #x,
|
||||||
|
model1_time,
|
||||||
|
model2_time,
|
||||||
|
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||||
|
model1_psnr / i, model2_psnr / i,
|
||||||
|
won[1], won[2]
|
||||||
|
))
|
||||||
|
end
|
||||||
else
|
else
|
||||||
io.stdout:write(
|
if baseline_output then
|
||||||
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
io.stdout:write(
|
||||||
i, #x,
|
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
||||||
model1_time,
|
i, #x,
|
||||||
math.sqrt(model1_mse / i), model1_psnr / i
|
model1_time,
|
||||||
|
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
||||||
|
baseline_psnr / i, model1_psnr / i
|
||||||
))
|
))
|
||||||
|
else
|
||||||
|
io.stdout:write(
|
||||||
|
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
||||||
|
i, #x,
|
||||||
|
model1_time,
|
||||||
|
math.sqrt(model1_mse / i), model1_psnr / i
|
||||||
|
))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
io.stdout:flush()
|
io.stdout:flush()
|
||||||
|
@ -515,6 +593,14 @@ local function benchmark(opt, x, model1, model2)
|
||||||
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
|
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
|
||||||
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
||||||
end
|
end
|
||||||
|
if model1_am > 0 then
|
||||||
|
fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n",
|
||||||
|
math.sqrt(model1_am / #x), model1_time))
|
||||||
|
end
|
||||||
|
if model2_am > 0 then
|
||||||
|
fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n",
|
||||||
|
math.sqrt(model2_am / #x), model2_time))
|
||||||
|
end
|
||||||
fp:close()
|
fp:close()
|
||||||
if detail_fp then
|
if detail_fp then
|
||||||
detail_fp:close()
|
detail_fp:close()
|
||||||
|
|
124
train.lua
124
train.lua
|
@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local function split_data(x, test_size)
|
local function split_data(x, test_size)
|
||||||
local index = torch.randperm(#x)
|
if settings.validation_filename_split then
|
||||||
local train_size = #x - test_size
|
if not (x[1][2].data and x[1][2].data.basename) then
|
||||||
local train_x = {}
|
error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
|
||||||
local valid_x = {}
|
end
|
||||||
for i = 1, train_size do
|
local basename_db = {}
|
||||||
train_x[i] = x[index[i]]
|
for i = 1, #x do
|
||||||
|
local meta = x[i][2].data
|
||||||
|
if basename_db[meta.basename] then
|
||||||
|
table.insert(basename_db[meta.basename], x[i])
|
||||||
|
else
|
||||||
|
basename_db[meta.basename] = {x[i]}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
local basename_list = {}
|
||||||
|
for k, v in pairs(basename_db) do
|
||||||
|
table.insert(basename_list, v)
|
||||||
|
end
|
||||||
|
local index = torch.randperm(#basename_list)
|
||||||
|
local train_x = {}
|
||||||
|
local valid_x = {}
|
||||||
|
local pos = 1
|
||||||
|
for i = 1, #basename_list do
|
||||||
|
if #valid_x >= test_size then
|
||||||
|
break
|
||||||
|
end
|
||||||
|
local xs = basename_list[index[pos]]
|
||||||
|
for j = 1, #xs do
|
||||||
|
table.insert(valid_x, xs[j])
|
||||||
|
end
|
||||||
|
pos = pos + 1
|
||||||
|
end
|
||||||
|
for i = pos, #basename_list do
|
||||||
|
local xs = basename_list[index[i]]
|
||||||
|
for j = 1, #xs do
|
||||||
|
table.insert(train_x, xs[j])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return train_x, valid_x
|
||||||
|
else
|
||||||
|
local index = torch.randperm(#x)
|
||||||
|
local train_size = #x - test_size
|
||||||
|
local train_x = {}
|
||||||
|
local valid_x = {}
|
||||||
|
for i = 1, train_size do
|
||||||
|
train_x[i] = x[index[i]]
|
||||||
|
end
|
||||||
|
for i = 1, test_size do
|
||||||
|
valid_x[i] = x[index[train_size + i]]
|
||||||
|
end
|
||||||
|
return train_x, valid_x
|
||||||
end
|
end
|
||||||
for i = 1, test_size do
|
|
||||||
valid_x[i] = x[index[train_size + i]]
|
|
||||||
end
|
|
||||||
return train_x, valid_x
|
|
||||||
end
|
end
|
||||||
|
|
||||||
local g_transform_pool = nil
|
local g_transform_pool = nil
|
||||||
|
@ -175,35 +215,19 @@ local function transform_pool_init(has_resize, offset)
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
n, conf)
|
n, conf)
|
||||||
elseif settings.method == "user" then
|
elseif settings.method == "user" then
|
||||||
if is_validation == nil then is_validation = false end
|
|
||||||
local rotate_rate = nil
|
|
||||||
local scale_rate = nil
|
|
||||||
local negate_rate = nil
|
|
||||||
local negate_x_rate = nil
|
|
||||||
if is_validation then
|
|
||||||
rotate_rate = 0
|
|
||||||
scale_rate = 0
|
|
||||||
negate_rate = 0
|
|
||||||
negate_x_rate = 0
|
|
||||||
else
|
|
||||||
rotate_rate = settings.random_pairwise_rotate_rate
|
|
||||||
scale_rate = settings.random_pairwise_scale_rate
|
|
||||||
negate_rate = settings.random_pairwise_negate_rate
|
|
||||||
negate_x_rate = settings.random_pairwise_negate_x_rate
|
|
||||||
end
|
|
||||||
local conf = tablex.update({
|
local conf = tablex.update({
|
||||||
gcn = settings.gcn,
|
gcn = settings.gcn,
|
||||||
max_size = settings.max_size,
|
max_size = settings.max_size,
|
||||||
active_cropping_rate = active_cropping_rate,
|
active_cropping_rate = active_cropping_rate,
|
||||||
active_cropping_tries = active_cropping_tries,
|
active_cropping_tries = active_cropping_tries,
|
||||||
random_pairwise_rotate_rate = rotate_rate,
|
random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
|
||||||
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
|
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
|
||||||
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
|
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
|
||||||
random_pairwise_scale_rate = scale_rate,
|
random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
|
||||||
random_pairwise_scale_min = settings.random_pairwise_scale_min,
|
random_pairwise_scale_min = settings.random_pairwise_scale_min,
|
||||||
random_pairwise_scale_max = settings.random_pairwise_scale_max,
|
random_pairwise_scale_max = settings.random_pairwise_scale_max,
|
||||||
random_pairwise_negate_rate = negate_rate,
|
random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
|
||||||
random_pairwise_negate_x_rate = negate_x_rate,
|
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
|
||||||
pairwise_y_binary = settings.pairwise_y_binary,
|
pairwise_y_binary = settings.pairwise_y_binary,
|
||||||
pairwise_flip = settings.pairwise_flip,
|
pairwise_flip = settings.pairwise_flip,
|
||||||
rgb = (settings.color == "rgb")}, meta)
|
rgb = (settings.color == "rgb")}, meta)
|
||||||
|
@ -290,7 +314,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
||||||
local batch_mse = eval_metric:forward(z, targets)
|
local batch_mse = eval_metric:forward(z, targets)
|
||||||
loss = loss + criterion:forward(z, targets)
|
loss = loss + criterion:forward(z, targets)
|
||||||
mse = mse + batch_mse
|
mse = mse + batch_mse
|
||||||
psnr = psnr + (10 * math.log10(1 / batch_mse))
|
psnr = psnr + (10 * math.log10(1 / (batch_mse + 1.0e-6)))
|
||||||
loss_count = loss_count + 1
|
loss_count = loss_count + 1
|
||||||
if loss_count % 10 == 0 then
|
if loss_count % 10 == 0 then
|
||||||
xlua.progress(t, #data)
|
xlua.progress(t, #data)
|
||||||
|
@ -322,6 +346,10 @@ local function create_criterion(model)
|
||||||
return w2nn.L1Criterion():cuda()
|
return w2nn.L1Criterion():cuda()
|
||||||
elseif settings.loss == "mse" then
|
elseif settings.loss == "mse" then
|
||||||
return w2nn.ClippedMSECriterion(0, 1.0):cuda()
|
return w2nn.ClippedMSECriterion(0, 1.0):cuda()
|
||||||
|
elseif settings.loss == "bce" then
|
||||||
|
local bce = nn.BCECriterion()
|
||||||
|
bce.sizeAverage = true
|
||||||
|
return bce:cuda()
|
||||||
else
|
else
|
||||||
error("unsupported loss .." .. settings.loss)
|
error("unsupported loss .." .. settings.loss)
|
||||||
end
|
end
|
||||||
|
@ -421,7 +449,10 @@ local function plot(train, valid)
|
||||||
{'validation', torch.Tensor(valid), '-'}})
|
{'validation', torch.Tensor(valid), '-'}})
|
||||||
end
|
end
|
||||||
local function train()
|
local function train()
|
||||||
local x = remove_small_image(torch.load(settings.images))
|
local x = torch.load(settings.images)
|
||||||
|
if settings.method ~= "user" then
|
||||||
|
x = remove_small_image(x)
|
||||||
|
end
|
||||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||||
local hist_train = {}
|
local hist_train = {}
|
||||||
local hist_valid = {}
|
local hist_valid = {}
|
||||||
|
@ -429,7 +460,12 @@ local function train()
|
||||||
if settings.resume:len() > 0 then
|
if settings.resume:len() > 0 then
|
||||||
model = torch.load(settings.resume, "ascii")
|
model = torch.load(settings.resume, "ascii")
|
||||||
else
|
else
|
||||||
model = srcnn.create(settings.model, settings.backend, settings.color)
|
if stringx.endswith(settings.model, ".lua") then
|
||||||
|
local create_model = dofile(settings.model)
|
||||||
|
model = create_model(srcnn, settings)
|
||||||
|
else
|
||||||
|
model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
if model.w2nn_input_size then
|
if model.w2nn_input_size then
|
||||||
if settings.crop_size ~= model.w2nn_input_size then
|
if settings.crop_size ~= model.w2nn_input_size then
|
||||||
|
@ -484,8 +520,9 @@ local function train()
|
||||||
ch, settings.crop_size, settings.crop_size)
|
ch, settings.crop_size, settings.crop_size)
|
||||||
end
|
end
|
||||||
local instance_loss = nil
|
local instance_loss = nil
|
||||||
|
local pmodel = w2nn.data_parallel(model, settings.gpu)
|
||||||
for epoch = 1, settings.epoch do
|
for epoch = 1, settings.epoch do
|
||||||
model:training()
|
pmodel:training()
|
||||||
print("# " .. epoch)
|
print("# " .. epoch)
|
||||||
if adam_config.learningRate then
|
if adam_config.learningRate then
|
||||||
print("learning rate: " .. adam_config.learningRate)
|
print("learning rate: " .. adam_config.learningRate)
|
||||||
|
@ -523,13 +560,13 @@ local function train()
|
||||||
instance_loss = torch.Tensor(x:size(1)):zero()
|
instance_loss = torch.Tensor(x:size(1)):zero()
|
||||||
|
|
||||||
for i = 1, settings.inner_epoch do
|
for i = 1, settings.inner_epoch do
|
||||||
model:training()
|
pmodel:training()
|
||||||
local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config)
|
||||||
instance_loss:copy(il)
|
instance_loss:copy(il)
|
||||||
print(train_score)
|
print(train_score)
|
||||||
model:evaluate()
|
pmodel:evaluate()
|
||||||
print("# validation")
|
print("# validation")
|
||||||
local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
|
local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
|
||||||
table.insert(hist_train, train_score.loss)
|
table.insert(hist_train, train_score.loss)
|
||||||
table.insert(hist_valid, score.loss)
|
table.insert(hist_valid, score.loss)
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
|
@ -546,8 +583,9 @@ local function train()
|
||||||
best_score = score_for_update
|
best_score = score_for_update
|
||||||
print("* model has updated")
|
print("* model has updated")
|
||||||
if settings.save_history then
|
if settings.save_history then
|
||||||
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
pmodel:clearState()
|
||||||
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
|
torch.save(settings.model_file_best, model, "ascii")
|
||||||
|
torch.save(string.format(settings.model_file, epoch, i), model, "ascii")
|
||||||
if settings.method == "noise" then
|
if settings.method == "noise" then
|
||||||
local log = path.join(settings.model_dir,
|
local log = path.join(settings.model_dir,
|
||||||
("noise%d_best.%d-%d.png"):format(settings.noise_level,
|
("noise%d_best.%d-%d.png"):format(settings.noise_level,
|
||||||
|
@ -571,7 +609,8 @@ local function train()
|
||||||
save_test_user(model, test_image, log)
|
save_test_user(model, test_image, log)
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
torch.save(settings.model_file, model:clearState(), "ascii")
|
pmodel:clearState()
|
||||||
|
torch.save(settings.model_file, model, "ascii")
|
||||||
if settings.method == "noise" then
|
if settings.method == "noise" then
|
||||||
local log = path.join(settings.model_dir,
|
local log = path.join(settings.model_dir,
|
||||||
("noise%d_best.png"):format(settings.noise_level))
|
("noise%d_best.png"):format(settings.noise_level))
|
||||||
|
@ -597,9 +636,6 @@ local function train()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if settings.gpu > 0 then
|
|
||||||
cutorch.setDevice(settings.gpu)
|
|
||||||
end
|
|
||||||
torch.manualSeed(settings.seed)
|
torch.manualSeed(settings.seed)
|
||||||
cutorch.manualSeed(settings.seed)
|
cutorch.manualSeed(settings.seed)
|
||||||
print(settings)
|
print(settings)
|
||||||
|
|
|
@ -276,6 +276,7 @@ local function waifu2x()
|
||||||
if opt.thread > 0 then
|
if opt.thread > 0 then
|
||||||
torch.setnumthreads(opt.thread)
|
torch.setnumthreads(opt.thread)
|
||||||
end
|
end
|
||||||
|
cutorch.setDevice(opt.gpu)
|
||||||
if cudnn then
|
if cudnn then
|
||||||
cudnn.fastest = true
|
cudnn.fastest = true
|
||||||
if opt.l:len() > 0 then
|
if opt.l:len() > 0 then
|
||||||
|
@ -293,6 +294,5 @@ local function waifu2x()
|
||||||
else
|
else
|
||||||
convert_frames(opt)
|
convert_frames(opt)
|
||||||
end
|
end
|
||||||
cutorch.setDevice(opt.gpu)
|
|
||||||
end
|
end
|
||||||
waifu2x()
|
waifu2x()
|
||||||
|
|
Loading…
Reference in a new issue