1
0
Fork 0
mirror of synced 2024-06-01 10:39:30 +12:00
This commit is contained in:
nagadomi 2018-01-26 08:09:40 +09:00
commit fa6ee00624
14 changed files with 516 additions and 140 deletions

View file

@ -63,7 +63,24 @@ local function crop_if_large_pair(x, y, max_size)
return x, y return x, y
end end
end end
local function padding_x(x, pad)
if pad > 0 then
x = iproc.padding(x, pad, pad, pad, pad)
end
return x
end
local function padding_xy(x, y, pad, y_zero)
local scale = y:size(2) / x:size(2)
if pad > 0 then
x = iproc.padding(x, pad, pad, pad, pad)
if y_zero then
y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
else
y = iproc.padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
end
end
return x, y
end
local function load_images(list) local function load_images(list)
local MARGIN = 32 local MARGIN = 32
local csv = csvigo.load({path = list, verbose = false, mode = "raw"}) local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
@ -78,6 +95,7 @@ local function load_images(list)
if csv_meta and csv_meta.filters then if csv_meta and csv_meta.filters then
filters = csv_meta.filters filters = csv_meta.filters
end end
local basename_y = path.basename(filename)
local im, meta = image_loader.load_byte(filename) local im, meta = image_loader.load_byte(filename)
local skip = false local skip = false
local alpha_color = torch.random(0, 1) local alpha_color = torch.random(0, 1)
@ -100,25 +118,38 @@ local function load_images(list)
-- method == user -- method == user
local yy = im local yy = im
local xx, meta2 = image_loader.load_byte(csv_meta.x) local xx, meta2 = image_loader.load_byte(csv_meta.x)
if settings.invert_x then
xx = (-(xx:long()) + 255):byte()
end
if xx then if xx then
if meta2 and meta2.alpha then if meta2 and meta2.alpha then
xx = alpha_util.fill(xx, meta2.alpha, alpha_color) xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
end end
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size) xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
if settings.grayscale then
xx = iproc.rgb2y(xx)
yy = iproc.rgb2y(yy)
end
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)}, table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
{data = {filters = filters, has_x = true}}}) {data = {filters = filters, has_x = true, basename = basename_y}}})
else else
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x)) io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
end end
else else
im = crop_if_large(im, settings.max_training_image_size) im = crop_if_large(im, settings.max_training_image_size)
im = iproc.crop_mod4(im) im = iproc.crop_mod4(im)
im = padding_x(im, settings.padding)
local scale = 1.0 local scale = 1.0
if settings.random_half_rate > 0.0 then if settings.random_half_rate > 0.0 then
scale = 2.0 scale = 2.0
end end
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
table.insert(x, {compression.compress(im), {data = {filters = filters}}}) if settings.grayscale then
im = iproc.rgb2y(im)
end
table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
else else
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN)) io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
end end

42
lib/ShakeShakeTable.lua Normal file
View file

@ -0,0 +1,42 @@
local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')
function ShakeShakeTable:__init()
parent.__init(self)
self.alpha = torch.Tensor()
self.beta = torch.Tensor()
self.first = torch.Tensor()
self.second = torch.Tensor()
self.train = true
end
function ShakeShakeTable:updateOutput(input)
local batch_size = input[1]:size(1)
if self.train then
self.alpha:resize(batch_size):uniform()
self.beta:resize(batch_size):uniform()
self.second:resizeAs(input[1]):copy(input[2])
for i = 1, batch_size do
self.second[i]:mul(self.alpha[i])
end
self.output:resizeAs(input[1]):copy(input[1])
for i = 1, batch_size do
self.output[i]:mul(1.0 - self.alpha[i])
end
self.output:add(self.second):mul(2)
else
self.output:resizeAs(input[1]):copy(input[1]):add(input[2])
end
return self.output
end
function ShakeShakeTable:updateGradInput(input, gradOutput)
local batch_size = input[1]:size(1)
self.first:resizeAs(gradOutput):copy(gradOutput)
for i = 1, batch_size do
self.first[i]:mul(self.beta[i])
end
self.second:resizeAs(gradOutput):copy(gradOutput)
for i = 1, batch_size do
self.second[i]:mul(1.0 - self.beta[i])
end
self.gradOutput = {self.first, self.second}
return self.gradOutput
end

View file

@ -102,7 +102,9 @@ function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
local scale = torch.uniform(scale_min, scale_max) local scale = torch.uniform(scale_min, scale_max)
local h = math.floor(x:size(2) * scale) local h = math.floor(x:size(2) * scale)
local w = math.floor(x:size(3) * scale) local w = math.floor(x:size(3) * scale)
x = iproc.scale(x, w, h, "Triangle") local filters = {"Lanczos", "Catrom"}
local x_filter = filters[torch.random(1, 2)]
x = iproc.scale(x, w, h, x_filter)
y = iproc.scale(y, w, h, "Triangle") y = iproc.scale(y, w, h, "Triangle")
return x, y return x, y
else else
@ -139,6 +141,36 @@ function data_augmentation.pairwise_negate_x(x, y, p)
return x, y return x, y
end end
end end
function data_augmentation.pairwise_flip(x, y)
local flip = torch.random(1, 4)
local tr = torch.random(1, 2)
local x, conversion = iproc.byte2float(x)
y = iproc.byte2float(y)
x = x:contiguous()
y = y:contiguous()
if tr == 1 then
-- pass
elseif tr == 2 then
x = x:transpose(2, 3):contiguous()
y = y:transpose(2, 3):contiguous()
end
if flip == 1 then
x = iproc.hflip(x)
y = iproc.hflip(y)
elseif flip == 2 then
x = iproc.vflip(x)
y = iproc.vflip(y)
elseif flip == 3 then
x = iproc.hflip(iproc.vflip(x))
y = iproc.hflip(iproc.vflip(y))
elseif flip == 4 then
end
if conversion then
x = iproc.float2byte(x)
y = iproc.float2byte(y)
end
return x, y
end
function data_augmentation.shift_1px(src) function data_augmentation.shift_1px(src)
-- reducing the even/odd issue in nearest neighbor scaler. -- reducing the even/odd issue in nearest neighbor scaler.
local direction = torch.random(1, 4) local direction = torch.random(1, 4)

View file

@ -80,6 +80,8 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
return dest return dest
end end
function iproc.padding(img, w1, w2, h1, h2) function iproc.padding(img, w1, w2, h1, h2)
local conversion
img, conversion = iproc.byte2float(img)
image = image or require 'image' image = image or require 'image'
local dst_height = img:size(2) + h1 + h2 local dst_height = img:size(2) + h1 + h2
local dst_width = img:size(3) + w1 + w2 local dst_width = img:size(3) + w1 + w2
@ -88,9 +90,15 @@ function iproc.padding(img, w1, w2, h1, h2)
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width)) flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
flow[1]:add(-h1) flow[1]:add(-h1)
flow[2]:add(-w1) flow[2]:add(-w1)
return image.warp(img, flow, "simple", false, "clamp") local dest = image.warp(img, flow, "simple", false, "clamp")
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end end
function iproc.zero_padding(img, w1, w2, h1, h2) function iproc.zero_padding(img, w1, w2, h1, h2)
local conversion
img, conversion = iproc.byte2float(img)
image = image or require 'image' image = image or require 'image'
local dst_height = img:size(2) + h1 + h2 local dst_height = img:size(2) + h1 + h2
local dst_width = img:size(3) + w1 + w2 local dst_width = img:size(3) + w1 + w2
@ -99,7 +107,11 @@ function iproc.zero_padding(img, w1, w2, h1, h2)
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width)) flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
flow[1]:add(-h1) flow[1]:add(-h1)
flow[2]:add(-w1) flow[2]:add(-w1)
return image.warp(img, flow, "simple", false, "pad", 0) local dest = image.warp(img, flow, "simple", false, "pad", 0)
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end end
function iproc.white_noise(src, std, rgb_weights, gamma) function iproc.white_noise(src, std, rgb_weights, gamma)
gamma = gamma or 0.454545 gamma = gamma or 0.454545
@ -217,6 +229,7 @@ function iproc.rgb2y(src)
src, conversion = iproc.byte2float(src) src, conversion = iproc.byte2float(src)
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero() local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3]) dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
dest:clamp(0, 1)
if conversion then if conversion then
dest = iproc.float2byte(dest) dest = iproc.float2byte(dest)
end end

View file

@ -43,8 +43,10 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
yc = iproc.byte2float(yc) yc = iproc.byte2float(yc)
if options.rgb then if options.rgb then
else else
yc = iproc.rgb2y(yc) if xc:size(1) > 1 then
xc = iproc.rgb2y(xc) yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
end end
if torch.uniform() < options.nr_rate then if torch.uniform() < options.nr_rate then
-- reducing noise -- reducing noise

View file

@ -51,8 +51,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
yc = iproc.byte2float(yc) yc = iproc.byte2float(yc)
if options.rgb then if options.rgb then
else else
yc = iproc.rgb2y(yc) if xc:size(1) > 1 then
xc = iproc.rgb2y(xc) yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
end end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end end

View file

@ -1,4 +1,5 @@
local pairwise_utils = require 'pairwise_transform_utils' local pairwise_utils = require 'pairwise_transform_utils'
local data_augmentation = require 'data_augmentation'
local iproc = require 'iproc' local iproc = require 'iproc'
local gm = {} local gm = {}
gm.Image = require 'graphicsmagick.Image' gm.Image = require 'graphicsmagick.Image'
@ -21,12 +22,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
if options.active_cropping_rate > 0 then if options.active_cropping_rate > 0 then
lowres_y = pairwise_utils.low_resolution(y) lowres_y = pairwise_utils.low_resolution(y)
end end
if options.pairwise_flip then if options.pairwise_flip and n == 1 then
xs[1], ys[1] = data_augmentation.pairwise_flip(xs[1], ys[1])
elseif options.pairwise_flip then
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y) xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
end end
assert(#xs == #ys) assert(#xs == #ys)
local perm = torch.randperm(#xs)
for i = 1, n do for i = 1, n do
local t = (i % #xs) + 1 local t = perm[(i % #xs) + 1]
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y, local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
options.active_cropping_rate, options.active_cropping_rate,
options.active_cropping_tries) options.active_cropping_tries)
@ -34,8 +38,10 @@ function pairwise_transform.user(x, y, size, offset, n, options)
yc = iproc.byte2float(yc) yc = iproc.byte2float(yc)
if options.rgb then if options.rgb then
else else
yc = iproc.rgb2y(yc) if xc:size(1) > 1 then
xc = iproc.rgb2y(xc) yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
end end
if options.gcn then if options.gcn then
local mean = xc:mean() local mean = xc:mean()
@ -46,7 +52,12 @@ function pairwise_transform.user(x, y, size, offset, n, options)
xc:add(-mean) xc:add(-mean)
end end
end end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) yc = iproc.crop(yc, offset, offset, size - offset, size - offset)
if options.pairwise_y_binary then
yc[torch.lt(yc, 0.5)] = 0
yc[torch.gt(yc, 0)] = 1
end
table.insert(batch, {xc, yc})
end end
return batch return batch

View file

@ -108,12 +108,6 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
x = iproc.crop_mod4(x) x = iproc.crop_mod4(x)
y = iproc.crop_mod4(y) y = iproc.crop_mod4(y)
if options.pairwise_y_binary then
y[torch.lt(y, 128)] = 0
y[torch.gt(y, 0)] = 255
end
return x, y return x, y
end end
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries) function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
@ -125,8 +119,14 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
t = "byte" t = "byte"
end end
if p < r then if p < r then
local xi = torch.random(1, x:size(3) - (size + 1)) * scale local xi = 0
local yi = torch.random(1, x:size(2) - (size + 1)) * scale local yi = 0
if x:size(3) > size + 1 then
xi = torch.random(0, x:size(3) - (size + 1)) * scale
end
if x:size(2) > size + 1 then
yi = torch.random(0, x:size(2) - (size + 1)) * scale
end
local yc = iproc.crop(y, xi, yi, xi + size, yi + size) local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale) local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
return xc, yc return xc, yc
@ -273,10 +273,17 @@ function pairwise_transform_utils.low_resolution(src)
toTensor("byte", "RGB", "DHW") toTensor("byte", "RGB", "DHW")
end end
--]] --]]
return gm.Image(src, "RGB", "DHW"): if src:size(1) == 1 then
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"): return gm.Image(src, "I", "DHW"):
size(src:size(3), src:size(2), "Box"): size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
toTensor("byte", "RGB", "DHW") size(src:size(3), src:size(2), "Box"):
toTensor("byte", "I", "DHW")
else
return gm.Image(src, "RGB", "DHW"):
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "RGB", "DHW")
end
end end
return pairwise_transform_utils return pairwise_transform_utils

View file

@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
cmd:text() cmd:text()
cmd:text("waifu2x-training") cmd:text("waifu2x-training")
cmd:text("Options:") cmd:text("Options:")
cmd:option("-gpu", -1, 'GPU Device ID')
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)') cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
cmd:option("-data_dir", "./data", 'path to data directory') cmd:option("-data_dir", "./data", 'path to data directory')
cmd:option("-backend", "cunn", '(cunn|cudnn)') cmd:option("-backend", "cunn", '(cunn|cudnn)')
@ -74,9 +73,14 @@ cmd:option("-oracle_drop_rate", 0.5, '')
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))') cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
cmd:option("-resume", "", 'resume model file') cmd:option("-resume", "", 'resume model file')
cmd:option("-name", "user", 'model name for user method') cmd:option("-name", "user", 'model name for user method')
cmd:option("-gpu", 1, 'Device ID') cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)') cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
cmd:option("-update_criterion", "mse", 'mse|loss') cmd:option("-update_criterion", "mse", 'mse|loss')
cmd:option("-padding", 0, 'replication padding size')
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
cmd:option("-invert_x", 0, 'invert x image in convert_lua')
local function to_bool(settings, name) local function to_bool(settings, name)
if settings[name] == 1 then if settings[name] == 1 then
@ -95,6 +99,10 @@ to_bool(settings, "save_history")
to_bool(settings, "use_transparent_png") to_bool(settings, "use_transparent_png")
to_bool(settings, "pairwise_y_binary") to_bool(settings, "pairwise_y_binary")
to_bool(settings, "pairwise_flip") to_bool(settings, "pairwise_flip")
to_bool(settings, "padding_y_zero")
to_bool(settings, "grayscale")
to_bool(settings, "validation_filename_split")
to_bool(settings, "invert_x")
if settings.plot then if settings.plot then
require 'gnuplot' require 'gnuplot'
@ -168,10 +176,20 @@ end
settings.images = string.format("%s/images.t7", settings.data_dir) settings.images = string.format("%s/images.t7", settings.data_dir)
settings.image_list = string.format("%s/image_list.txt", settings.data_dir) settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
cutorch.setDevice(opt.gpu)
-- patch for lua52 -- patch for lua52
if not math.log10 then if not math.log10 then
math.log10 = function(x) return math.log(x, 10) end math.log10 = function(x) return math.log(x, 10) end
end end
if settings.gpu:len() > 0 then
local gpus = {}
local gpu_string = utils.split(settings.gpu, ",")
for i = 1, #gpu_string do
table.insert(gpus, tonumber(gpu_string[i]))
end
settings.gpu = gpus
else
settings.gpu = {1}
end
cutorch.setDevice(settings.gpu[1])
return settings return settings

View file

@ -4,34 +4,52 @@ require 'w2nn'
-- ref: http://arxiv.org/abs/1501.00092 -- ref: http://arxiv.org/abs/1501.00092
local srcnn = {} local srcnn = {}
function nn.SpatialConvolutionMM:reset(stdv) local function msra_filler(mod)
local fin = self.kW * self.kH * self.nInputPlane local fin = mod.kW * mod.kH * mod.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane local fout = mod.kW * mod.kH * mod.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout))) stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv) mod.weight:normal(0, stdv)
self.bias:zero() mod.bias:zero()
end
local function identity_filler(mod)
assert(mod.nInputPlane <= mod.nOutputPlane)
mod.weight:normal(0, 0.01)
mod.bias:zero()
local num_groups = mod.nInputPlane -- fixed
local filler_value = num_groups / mod.nOutputPlane
local in_group_size = math.floor(mod.nInputPlane / num_groups)
local out_group_size = math.floor(mod.nOutputPlane / num_groups)
local x = math.floor(mod.kW / 2)
local y = math.floor(mod.kH / 2)
for i = 0, num_groups - 1 do
for j = i * out_group_size, (i + 1) * out_group_size - 1 do
for k = i * in_group_size, (i + 1) * in_group_size - 1 do
mod.weight[j+1][k+1][y+1][x+1] = filler_value
end
end
end
end
function nn.SpatialConvolutionMM:reset(stdv)
msra_filler(self)
end end
function nn.SpatialFullConvolution:reset(stdv) function nn.SpatialFullConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane msra_filler(self)
local fout = self.kW * self.kH * self.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
end end
function nn.SpatialDilatedConvolution:reset(stdv)
identity_filler(self)
end
if cudnn and cudnn.SpatialConvolution then if cudnn and cudnn.SpatialConvolution then
function cudnn.SpatialConvolution:reset(stdv) function cudnn.SpatialConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane msra_filler(self)
local fout = self.kW * self.kH * self.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
end end
function cudnn.SpatialFullConvolution:reset(stdv) function cudnn.SpatialFullConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane msra_filler(self)
local fout = self.kW * self.kH * self.nOutputPlane end
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout))) if cudnn.SpatialDilatedConvolution then
self.weight:normal(0, stdv) function cudnn.SpatialDilatedConvolution:reset(stdv)
self.bias:zero() identity_filler(self)
end
end end
end end
function nn.SpatialConvolutionMM:clearState() function nn.SpatialConvolutionMM:clearState()
@ -127,6 +145,8 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
error("unsupported backend:" .. backend) error("unsupported backend:" .. backend)
end end
end end
srcnn.SpatialConvolution = SpatialConvolution
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH) local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
if backend == "cunn" then if backend == "cunn" then
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH) return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
@ -136,6 +156,8 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
error("unsupported backend:" .. backend) error("unsupported backend:" .. backend)
end end
end end
srcnn.SpatialFullConvolution = SpatialFullConvolution
local function ReLU(backend) local function ReLU(backend)
if backend == "cunn" then if backend == "cunn" then
return nn.ReLU(true) return nn.ReLU(true)
@ -145,6 +167,8 @@ local function ReLU(backend)
error("unsupported backend:" .. backend) error("unsupported backend:" .. backend)
end end
end end
srcnn.ReLU = ReLU
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH) local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then if backend == "cunn" then
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH) return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
@ -154,6 +178,35 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
error("unsupported backend:" .. backend) error("unsupported backend:" .. backend)
end end
end end
srcnn.SpatialMaxPooling = SpatialMaxPooling
local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
elseif backend == "cudnn" then
return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
else
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialAveragePooling = SpatialAveragePooling
local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
if backend == "cunn" then
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
elseif backend == "cudnn" then
if cudnn.SpatialDilatedConvolution then
-- cudnn v 6
return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
else
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
end
else
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
-- VGG style net(7 layers) -- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch) function srcnn.vgg_7(backend, ch)
@ -548,6 +601,7 @@ function srcnn.create(model_name, backend, color)
error("unsupported model_name: " .. model_name) error("unsupported model_name: " .. model_name)
end end
end end
--[[ --[[
local model = srcnn.fcn_v1("cunn", 3):cuda() local model = srcnn.fcn_v1("cunn", 3):cuda()
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size()) print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())

View file

@ -9,6 +9,40 @@ end
local function load_cudnn() local function load_cudnn()
cudnn = require('cudnn') cudnn = require('cudnn')
end end
local function make_data_parallel_table(model, gpus)
if cudnn then
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end)
dpt.gradInput = nil
model = dpt:cuda()
else
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
end)
dpt.gradInput = nil
model = dpt:cuda()
end
return model
end
if w2nn then if w2nn then
return w2nn return w2nn
else else
@ -27,11 +61,19 @@ else
model:cuda():evaluate() model:cuda():evaluate()
return model return model
end end
function w2nn.data_parallel(model, gpus)
if #gpus > 1 then
return make_data_parallel_table(model, gpus)
else
return model
end
end
require 'LeakyReLU' require 'LeakyReLU'
require 'ClippedWeightedHuberCriterion' require 'ClippedWeightedHuberCriterion'
require 'ClippedMSECriterion' require 'ClippedMSECriterion'
require 'SSIMCriterion' require 'SSIMCriterion'
require 'InplaceClip01' require 'InplaceClip01'
require 'L1Criterion' require 'L1Criterion'
require 'ShakeShakeTable'
return w2nn return w2nn
end end

View file

@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
cmd:option("-x_file", "", 'input image for user method') cmd:option("-x_file", "", 'input image for user method')
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file') cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
cmd:option("-border", 0, 'border px that will removed') cmd:option("-border", 0, 'border px that will removed')
cmd:option("-metric", "", '(jaccard)')
local function to_bool(settings, name) local function to_bool(settings, name)
if settings[name] == 1 then if settings[name] == 1 then
@ -203,8 +204,34 @@ local function remove_border(x, border)
x:size(3) - border, x:size(3) - border,
x:size(2) - border) x:size(2) - border)
end end
local function create_metric(metric)
if metric and metric:len() > 0 then
if metric == "jaccard" then
return {
name = "jaccard",
func = function (a, b)
local ga = iproc.rgb2y(a)
local gb = iproc.rgb2y(b)
local ba = torch.Tensor():resizeAs(ga)
local bb = torch.Tensor():resizeAs(gb)
ba:zero()
bb:zero()
ba[torch.gt(ga, 0.5)] = 1.0
bb[torch.gt(gb, 0.5)] = 1.0
local num_a = ba:sum()
local num_b = bb:sum()
local a_and_b = ba:cmul(bb):sum()
return (a_and_b / (num_a + num_b - a_and_b))
end}
else
error("unknown metric: " .. metric)
end
else
return nil
end
end
local function benchmark(opt, x, model1, model2) local function benchmark(opt, x, model1, model2)
local mse1, mse2 local mse1, mse2, am1, am2
local won = {0, 0} local won = {0, 0}
local model1_mse = 0 local model1_mse = 0
local model2_mse = 0 local model2_mse = 0
@ -217,6 +244,13 @@ local function benchmark(opt, x, model1, model2)
local scale_f = reconstruct.scale local scale_f = reconstruct.scale
local image_f = reconstruct.image local image_f = reconstruct.image
local detail_fp = nil local detail_fp = nil
local am = nil
local model1_am = 0
local model2_am = 0
if opt.method == "user" or opt.method == "diff" then
am = create_metric(opt.metric)
end
if opt.save_info then if opt.save_info then
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w") detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
end end
@ -406,32 +440,57 @@ local function benchmark(opt, x, model1, model2)
ground_truth = remove_border(ground_truth, opt.border) ground_truth = remove_border(ground_truth, opt.border)
model1_output = remove_border(model1_output, opt.border) model1_output = remove_border(model1_output, opt.border)
end end
mse1 = MSE(ground_truth, model1_output, opt.color) if am then
model1_mse = model1_mse + mse1 am1 = am.func(ground_truth, model1_output)
model1_psnr = model1_psnr + MSE2PSNR(mse1) model1_am = model1_am + am1
else
mse1 = MSE(ground_truth, model1_output, opt.color)
model1_mse = model1_mse + mse1
model1_psnr = model1_psnr + MSE2PSNR(mse1)
end
local won_model = 1 local won_model = 1
if model2 then if model2 then
if opt.border > 0 then if opt.border > 0 then
model2_output = remove_border(model2_output, opt.border) model2_output = remove_border(model2_output, opt.border)
end end
mse2 = MSE(ground_truth, model2_output, opt.color) if am then
model2_mse = model2_mse + mse2 am2 = am.func(ground_truth, model2_output)
model2_psnr = model2_psnr + MSE2PSNR(mse2) model2_am = model2_am + am2
else
if mse1 < mse2 then mse2 = MSE(ground_truth, model2_output, opt.color)
won[1] = won[1] + 1 model2_mse = model2_mse + mse2
elseif mse1 > mse2 then model2_psnr = model2_psnr + MSE2PSNR(mse2)
won[2] = won[2] + 1 end
won_model = 2 if am then
if am1 < am2 then
won[1] = won[1] + 1
elseif am1 > am2 then
won[2] = won[2] + 1
won_model = 2
end
else
if mse1 < mse2 then
won[1] = won[1] + 1
elseif mse1 > mse2 then
won[2] = won[2] + 1
won_model = 2
end
end end
if detail_fp then if detail_fp then
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename, if am then
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model)) detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
else
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
end
end end
else else
if detail_fp then if detail_fp then
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1))) if am then
detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
else
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
end
end end
end end
if baseline_output then if baseline_output then
@ -455,46 +514,65 @@ local function benchmark(opt, x, model1, model2)
end end
end end
if opt.show_progress or i == #x then if opt.show_progress or i == #x then
if model2 then if am then
if baseline_output then if model2 then
io.stdout:write( io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r", string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
i, #x, i, #x,
model1_time, model1_time,
model2_time, model2_time,
math.sqrt(baseline_mse / i), am.name, model1_am / i, am.name, model2_am / i
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), ))
baseline_psnr / i,
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
else else
io.stdout:write( io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r", string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
i, #x, i, #x,
model1_time, model1_time,
model2_time, am.name, model1_am / i
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), ))
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
end end
else else
if baseline_output then if model2 then
io.stdout:write( if baseline_output then
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r", io.stdout:write(
i, #x, string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
model1_time, i, #x,
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i), model1_time,
baseline_psnr / i, model1_psnr / i model2_time,
math.sqrt(baseline_mse / i),
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
baseline_psnr / i,
model1_psnr / i, model2_psnr / i,
won[1], won[2]
)) ))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
i, #x,
model1_time,
model2_time,
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
end
else else
io.stdout:write( if baseline_output then
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r", io.stdout:write(
i, #x, string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
model1_time, i, #x,
math.sqrt(model1_mse / i), model1_psnr / i model1_time,
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
baseline_psnr / i, model1_psnr / i
)) ))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(model1_mse / i), model1_psnr / i
))
end
end end
end end
io.stdout:flush() io.stdout:flush()
@ -515,6 +593,14 @@ local function benchmark(opt, x, model1, model2)
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n", fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time)) math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
end end
if model1_am > 0 then
fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n",
math.sqrt(model1_am / #x), model1_time))
end
if model2_am > 0 then
fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n",
math.sqrt(model2_am / #x), model2_time))
end
fp:close() fp:close()
if detail_fp then if detail_fp then
detail_fp:close() detail_fp:close()

124
train.lua
View file

@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
end end
end end
local function split_data(x, test_size) local function split_data(x, test_size)
local index = torch.randperm(#x) if settings.validation_filename_split then
local train_size = #x - test_size if not (x[1][2].data and x[1][2].data.basename) then
local train_x = {} error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
local valid_x = {} end
for i = 1, train_size do local basename_db = {}
train_x[i] = x[index[i]] for i = 1, #x do
local meta = x[i][2].data
if basename_db[meta.basename] then
table.insert(basename_db[meta.basename], x[i])
else
basename_db[meta.basename] = {x[i]}
end
end
local basename_list = {}
for k, v in pairs(basename_db) do
table.insert(basename_list, v)
end
local index = torch.randperm(#basename_list)
local train_x = {}
local valid_x = {}
local pos = 1
for i = 1, #basename_list do
if #valid_x >= test_size then
break
end
local xs = basename_list[index[pos]]
for j = 1, #xs do
table.insert(valid_x, xs[j])
end
pos = pos + 1
end
for i = pos, #basename_list do
local xs = basename_list[index[i]]
for j = 1, #xs do
table.insert(train_x, xs[j])
end
end
return train_x, valid_x
else
local index = torch.randperm(#x)
local train_size = #x - test_size
local train_x = {}
local valid_x = {}
for i = 1, train_size do
train_x[i] = x[index[i]]
end
for i = 1, test_size do
valid_x[i] = x[index[train_size + i]]
end
return train_x, valid_x
end end
for i = 1, test_size do
valid_x[i] = x[index[train_size + i]]
end
return train_x, valid_x
end end
local g_transform_pool = nil local g_transform_pool = nil
@ -175,35 +215,19 @@ local function transform_pool_init(has_resize, offset)
settings.crop_size, offset, settings.crop_size, offset,
n, conf) n, conf)
elseif settings.method == "user" then elseif settings.method == "user" then
if is_validation == nil then is_validation = false end
local rotate_rate = nil
local scale_rate = nil
local negate_rate = nil
local negate_x_rate = nil
if is_validation then
rotate_rate = 0
scale_rate = 0
negate_rate = 0
negate_x_rate = 0
else
rotate_rate = settings.random_pairwise_rotate_rate
scale_rate = settings.random_pairwise_scale_rate
negate_rate = settings.random_pairwise_negate_rate
negate_x_rate = settings.random_pairwise_negate_x_rate
end
local conf = tablex.update({ local conf = tablex.update({
gcn = settings.gcn, gcn = settings.gcn,
max_size = settings.max_size, max_size = settings.max_size,
active_cropping_rate = active_cropping_rate, active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries, active_cropping_tries = active_cropping_tries,
random_pairwise_rotate_rate = rotate_rate, random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
random_pairwise_rotate_min = settings.random_pairwise_rotate_min, random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
random_pairwise_rotate_max = settings.random_pairwise_rotate_max, random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
random_pairwise_scale_rate = scale_rate, random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
random_pairwise_scale_min = settings.random_pairwise_scale_min, random_pairwise_scale_min = settings.random_pairwise_scale_min,
random_pairwise_scale_max = settings.random_pairwise_scale_max, random_pairwise_scale_max = settings.random_pairwise_scale_max,
random_pairwise_negate_rate = negate_rate, random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
random_pairwise_negate_x_rate = negate_x_rate, random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
pairwise_y_binary = settings.pairwise_y_binary, pairwise_y_binary = settings.pairwise_y_binary,
pairwise_flip = settings.pairwise_flip, pairwise_flip = settings.pairwise_flip,
rgb = (settings.color == "rgb")}, meta) rgb = (settings.color == "rgb")}, meta)
@ -290,7 +314,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
local batch_mse = eval_metric:forward(z, targets) local batch_mse = eval_metric:forward(z, targets)
loss = loss + criterion:forward(z, targets) loss = loss + criterion:forward(z, targets)
mse = mse + batch_mse mse = mse + batch_mse
psnr = psnr + (10 * math.log10(1 / batch_mse)) psnr = psnr + (10 * math.log10(1 / (batch_mse + 1.0e-6)))
loss_count = loss_count + 1 loss_count = loss_count + 1
if loss_count % 10 == 0 then if loss_count % 10 == 0 then
xlua.progress(t, #data) xlua.progress(t, #data)
@ -322,6 +346,10 @@ local function create_criterion(model)
return w2nn.L1Criterion():cuda() return w2nn.L1Criterion():cuda()
elseif settings.loss == "mse" then elseif settings.loss == "mse" then
return w2nn.ClippedMSECriterion(0, 1.0):cuda() return w2nn.ClippedMSECriterion(0, 1.0):cuda()
elseif settings.loss == "bce" then
local bce = nn.BCECriterion()
bce.sizeAverage = true
return bce:cuda()
else else
error("unsupported loss .." .. settings.loss) error("unsupported loss .." .. settings.loss)
end end
@ -421,7 +449,10 @@ local function plot(train, valid)
{'validation', torch.Tensor(valid), '-'}}) {'validation', torch.Tensor(valid), '-'}})
end end
local function train() local function train()
local x = remove_small_image(torch.load(settings.images)) local x = torch.load(settings.images)
if settings.method ~= "user" then
x = remove_small_image(x)
end
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local hist_train = {} local hist_train = {}
local hist_valid = {} local hist_valid = {}
@ -429,7 +460,12 @@ local function train()
if settings.resume:len() > 0 then if settings.resume:len() > 0 then
model = torch.load(settings.resume, "ascii") model = torch.load(settings.resume, "ascii")
else else
model = srcnn.create(settings.model, settings.backend, settings.color) if stringx.endswith(settings.model, ".lua") then
local create_model = dofile(settings.model)
model = create_model(srcnn, settings)
else
model = srcnn.create(settings.model, settings.backend, settings.color)
end
end end
if model.w2nn_input_size then if model.w2nn_input_size then
if settings.crop_size ~= model.w2nn_input_size then if settings.crop_size ~= model.w2nn_input_size then
@ -484,8 +520,9 @@ local function train()
ch, settings.crop_size, settings.crop_size) ch, settings.crop_size, settings.crop_size)
end end
local instance_loss = nil local instance_loss = nil
local pmodel = w2nn.data_parallel(model, settings.gpu)
for epoch = 1, settings.epoch do for epoch = 1, settings.epoch do
model:training() pmodel:training()
print("# " .. epoch) print("# " .. epoch)
if adam_config.learningRate then if adam_config.learningRate then
print("learning rate: " .. adam_config.learningRate) print("learning rate: " .. adam_config.learningRate)
@ -523,13 +560,13 @@ local function train()
instance_loss = torch.Tensor(x:size(1)):zero() instance_loss = torch.Tensor(x:size(1)):zero()
for i = 1, settings.inner_epoch do for i = 1, settings.inner_epoch do
model:training() pmodel:training()
local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config) local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config)
instance_loss:copy(il) instance_loss:copy(il)
print(train_score) print(train_score)
model:evaluate() pmodel:evaluate()
print("# validation") print("# validation")
local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize) local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
table.insert(hist_train, train_score.loss) table.insert(hist_train, train_score.loss)
table.insert(hist_valid, score.loss) table.insert(hist_valid, score.loss)
if settings.plot then if settings.plot then
@ -546,8 +583,9 @@ local function train()
best_score = score_for_update best_score = score_for_update
print("* model has updated") print("* model has updated")
if settings.save_history then if settings.save_history then
torch.save(settings.model_file_best, model:clearState(), "ascii") pmodel:clearState()
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii") torch.save(settings.model_file_best, model, "ascii")
torch.save(string.format(settings.model_file, epoch, i), model, "ascii")
if settings.method == "noise" then if settings.method == "noise" then
local log = path.join(settings.model_dir, local log = path.join(settings.model_dir,
("noise%d_best.%d-%d.png"):format(settings.noise_level, ("noise%d_best.%d-%d.png"):format(settings.noise_level,
@ -571,7 +609,8 @@ local function train()
save_test_user(model, test_image, log) save_test_user(model, test_image, log)
end end
else else
torch.save(settings.model_file, model:clearState(), "ascii") pmodel:clearState()
torch.save(settings.model_file, model, "ascii")
if settings.method == "noise" then if settings.method == "noise" then
local log = path.join(settings.model_dir, local log = path.join(settings.model_dir,
("noise%d_best.png"):format(settings.noise_level)) ("noise%d_best.png"):format(settings.noise_level))
@ -597,9 +636,6 @@ local function train()
end end
end end
end end
if settings.gpu > 0 then
cutorch.setDevice(settings.gpu)
end
torch.manualSeed(settings.seed) torch.manualSeed(settings.seed)
cutorch.manualSeed(settings.seed) cutorch.manualSeed(settings.seed)
print(settings) print(settings)

View file

@ -276,6 +276,7 @@ local function waifu2x()
if opt.thread > 0 then if opt.thread > 0 then
torch.setnumthreads(opt.thread) torch.setnumthreads(opt.thread)
end end
cutorch.setDevice(opt.gpu)
if cudnn then if cudnn then
cudnn.fastest = true cudnn.fastest = true
if opt.l:len() > 0 then if opt.l:len() > 0 then
@ -293,6 +294,5 @@ local function waifu2x()
else else
convert_frames(opt) convert_frames(opt)
end end
cutorch.setDevice(opt.gpu)
end end
waifu2x() waifu2x()