refactor
This commit is contained in:
parent
cb81d21064
commit
3abc5a03e3
14
.gitignore
vendored
14
.gitignore
vendored
|
@ -1,8 +1,14 @@
|
|||
*~
|
||||
/*.png
|
||||
/*.mp4
|
||||
/*.jpg
|
||||
work/
|
||||
cache/*.png
|
||||
models/*.png
|
||||
data/
|
||||
!data/.gitkeep
|
||||
|
||||
models/
|
||||
!models/anime_style_art
|
||||
!models/anime_style_art_rgb
|
||||
!models/ukbench
|
||||
models/*/*.png
|
||||
|
||||
waifu2x.log
|
||||
|
||||
|
|
280
benchmark.lua
280
benchmark.lua
|
@ -1,280 +0,0 @@
|
|||
require './lib/portable'
|
||||
require './lib/mynn'
|
||||
require 'xlua'
|
||||
require 'pl'
|
||||
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
local gm = require 'graphicsmagick'
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("waifu2x-benchmark")
|
||||
cmd:text("Options:")
|
||||
|
||||
cmd:option("-seed", 11, 'fixed input seed')
|
||||
cmd:option("-test_dir", "./test", 'test image directory')
|
||||
cmd:option("-jpeg_quality", 50, 'jpeg quality')
|
||||
cmd:option("-jpeg_times", 3, 'number of jpeg compression ')
|
||||
cmd:option("-jpeg_quality_down", 5, 'reducing jpeg quality each times')
|
||||
cmd:option("-core", 4, 'threads')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
torch.setnumthreads(opt.core)
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
local function MSE(x1, x2)
|
||||
return (x1 - x2):pow(2):mean()
|
||||
end
|
||||
local function YMSE(x1, x2)
|
||||
local x1_2 = x1:clone()
|
||||
local x2_2 = x2:clone()
|
||||
|
||||
x1_2[1]:mul(0.299 * 3)
|
||||
x1_2[2]:mul(0.587 * 3)
|
||||
x1_2[3]:mul(0.114 * 3)
|
||||
|
||||
x2_2[1]:mul(0.299 * 3)
|
||||
x2_2[2]:mul(0.587 * 3)
|
||||
x2_2[3]:mul(0.114 * 3)
|
||||
|
||||
return (x1_2 - x2_2):pow(2):mean()
|
||||
end
|
||||
local function PSNR(x1, x2)
|
||||
local mse = MSE(x1, x2)
|
||||
return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10))
|
||||
end
|
||||
local function YPSNR(x1, x2)
|
||||
local mse = YMSE(x1, x2)
|
||||
return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10))
|
||||
end
|
||||
|
||||
local function transform_jpeg(x)
|
||||
for i = 1, opt.jpeg_times do
|
||||
jpeg = gm.Image(x, "RGB", "DHW")
|
||||
jpeg:format("jpeg")
|
||||
jpeg:samplingFactors({1.0, 1.0, 1.0})
|
||||
blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
|
||||
jpeg:fromBlob(blob, len)
|
||||
x = jpeg:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
return x
|
||||
end
|
||||
|
||||
local function noise_benchmark(x, v1_noise, v2_noise)
|
||||
local v1_mse = 0
|
||||
local v2_mse = 0
|
||||
local jpeg_mse = 0
|
||||
local v1_psnr = 0
|
||||
local v2_psnr = 0
|
||||
local jpeg_psnr = 0
|
||||
local v1_time = 0
|
||||
local v2_time = 0
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i]
|
||||
local jpg, blob, len, input, v1_out, v2_out, t, mse
|
||||
|
||||
input = transform_jpeg(ground_truth)
|
||||
input = input:float():div(255)
|
||||
ground_truth = ground_truth:float():div(255)
|
||||
|
||||
jpeg_mse = jpeg_mse + MSE(ground_truth, input)
|
||||
jpeg_psnr = jpeg_psnr + PSNR(ground_truth, input)
|
||||
|
||||
t = sys.clock()
|
||||
v1_output = reconstruct.image(v1_noise, input)
|
||||
v1_time = v1_time + (sys.clock() - t)
|
||||
v1_mse = v1_mse + MSE(ground_truth, v1_output)
|
||||
v1_psnr = v1_psnr + PSNR(ground_truth, v1_output)
|
||||
|
||||
t = sys.clock()
|
||||
v2_output = reconstruct.image(v2_noise, input)
|
||||
v2_time = v2_time + (sys.clock() - t)
|
||||
v2_mse = v2_mse + MSE(ground_truth, v2_output)
|
||||
v2_psnr = v2_psnr + PSNR(ground_truth, v2_output)
|
||||
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; v1_time=%f, v2_time=%f, jpeg_mse=%f, v1_mse=%f, v2_mse=%f, jpeg_psnr=%f, v1_psnr=%f, v2_psnr=%f \r",
|
||||
i, #x,
|
||||
v1_time / i, v2_time / i,
|
||||
jpeg_mse / i,
|
||||
v1_mse / i, v2_mse / i,
|
||||
jpeg_psnr / i,
|
||||
v1_psnr / i, v2_psnr / i
|
||||
)
|
||||
)
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
local function noise_scale_benchmark(x, params, v1_noise, v1_scale, v2_noise, v2_scale)
|
||||
local v1_mse = 0
|
||||
local v2_mse = 0
|
||||
local jinc_mse = 0
|
||||
local v1_time = 0
|
||||
local v2_time = 0
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i]
|
||||
local downscale = iproc.scale(ground_truth,
|
||||
ground_truth:size(3) * 0.5,
|
||||
ground_truth:size(2) * 0.5,
|
||||
params[i].filter)
|
||||
local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
|
||||
|
||||
jpeg = gm.Image(downscale, "RGB", "DHW")
|
||||
jpeg:format("jpeg")
|
||||
blob, len = jpeg:toBlob(params[i].quality)
|
||||
jpeg:fromBlob(blob, len)
|
||||
input = jpeg:toTensor("byte", "RGB", "DHW")
|
||||
|
||||
input = input:float():div(255)
|
||||
ground_truth = ground_truth:float():div(255)
|
||||
|
||||
jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
|
||||
jinc_mse = jinc_mse + (ground_truth - jinc_output):pow(2):mean()
|
||||
|
||||
t = sys.clock()
|
||||
v1_output = reconstruct.image(v1_noise, input)
|
||||
v1_output = reconstruct.scale(v1_scale, 2.0, v1_output)
|
||||
v1_time = v1_time + (sys.clock() - t)
|
||||
mse = (ground_truth - v1_output):pow(2):mean()
|
||||
v1_mse = v1_mse + mse
|
||||
|
||||
t = sys.clock()
|
||||
v2_output = reconstruct.image(v2_noise, input)
|
||||
v2_output = reconstruct.scale(v2_scale, 2.0, v2_output)
|
||||
v2_time = v2_time + (sys.clock() - t)
|
||||
mse = (ground_truth - v2_output):pow(2):mean()
|
||||
v2_mse = v2_mse + mse
|
||||
|
||||
io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
|
||||
i, #x,
|
||||
v1_time / i, v2_time / i,
|
||||
(v1_time / i) / (v2_time / i),
|
||||
jinc_mse / i,
|
||||
v1_mse / i, (v1_mse/i) / (jinc_mse/i),
|
||||
v2_mse / i, (v2_mse/i) / (jinc_mse/i),
|
||||
(v1_mse / i) / (v2_mse / i)))
|
||||
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
local function scale_benchmark(x, params, v1_scale, v2_scale)
|
||||
local v1_mse = 0
|
||||
local v2_mse = 0
|
||||
local jinc_mse = 0
|
||||
local v1_psnr = 0
|
||||
local v2_psnr = 0
|
||||
local jinc_psnr = 0
|
||||
|
||||
local v1_time = 0
|
||||
local v2_time = 0
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i]
|
||||
local downscale = iproc.scale(ground_truth,
|
||||
ground_truth:size(3) * 0.5,
|
||||
ground_truth:size(2) * 0.5,
|
||||
params[i].filter)
|
||||
local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
|
||||
input = downscale
|
||||
|
||||
input = input:float():div(255)
|
||||
ground_truth = ground_truth:float():div(255)
|
||||
|
||||
jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
|
||||
mse = (ground_truth - jinc_output):pow(2):mean()
|
||||
jinc_mse = jinc_mse + mse
|
||||
jinc_psnr = jinc_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
|
||||
|
||||
t = sys.clock()
|
||||
v1_output = reconstruct.scale(v1_scale, 2.0, input)
|
||||
v1_time = v1_time + (sys.clock() - t)
|
||||
mse = (ground_truth - v1_output):pow(2):mean()
|
||||
v1_mse = v1_mse + mse
|
||||
v1_psnr = v1_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
|
||||
|
||||
t = sys.clock()
|
||||
v2_output = reconstruct.scale(v2_scale, 2.0, input)
|
||||
v2_time = v2_time + (sys.clock() - t)
|
||||
mse = (ground_truth - v2_output):pow(2):mean()
|
||||
v2_mse = v2_mse + mse
|
||||
v2_psnr = v2_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
|
||||
|
||||
io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
|
||||
i, #x,
|
||||
v1_time / i, v2_time / i,
|
||||
(v1_time / i) / (v2_time / i),
|
||||
jinc_psnr / i,
|
||||
v1_psnr / i, (v1_psnr/i) / (jinc_psnr/i),
|
||||
v2_psnr / i, (v2_psnr/i) / (jinc_psnr/i),
|
||||
(v1_psnr / i) / (v2_psnr / i)))
|
||||
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
|
||||
local function split_data(x, test_size)
|
||||
local index = torch.randperm(#x)
|
||||
local train_size = #x - test_size
|
||||
local train_x = {}
|
||||
local valid_x = {}
|
||||
for i = 1, train_size do
|
||||
train_x[i] = x[index[i]]
|
||||
end
|
||||
for i = 1, test_size do
|
||||
valid_x[i] = x[index[train_size + i]]
|
||||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
local function crop_4x(x)
|
||||
local w = x:size(3) % 4
|
||||
local h = x:size(2) % 4
|
||||
return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
|
||||
end
|
||||
local function load_data(valid_dir)
|
||||
local valid_x = {}
|
||||
local files = dir.getfiles(valid_dir, "*.png")
|
||||
for i = 1, #files do
|
||||
table.insert(valid_x, crop_4x(image_loader.load_byte(files[i])))
|
||||
xlua.progress(i, #files)
|
||||
end
|
||||
return valid_x
|
||||
end
|
||||
|
||||
local function noise_main(valid_dir, level)
|
||||
local v1_noise = torch.load(path.join(V1_DIR, string.format("noise%d_model.t7", level)), "ascii")
|
||||
local v2_noise = torch.load(path.join(V2_DIR, string.format("noise%d_model.t7", level)), "ascii")
|
||||
local valid_x = load_data(valid_dir)
|
||||
noise_benchmark(valid_x, v1_noise, v2_noise)
|
||||
end
|
||||
local function scale_main(valid_dir)
|
||||
local v1 = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local v2 = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local valid_x = load_data(valid_dir)
|
||||
local params = random_params(valid_x, 2)
|
||||
scale_benchmark(valid_x, params, v1, v2)
|
||||
end
|
||||
local function noise_scale_main(valid_dir)
|
||||
local v1_noise = torch.load(path.join(V1_DIR, "noise2_model.t7"), "ascii")
|
||||
local v1_scale = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local v2_noise = torch.load(path.join(V2_DIR, "noise2_model.t7"), "ascii")
|
||||
local v2_scale = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local valid_x = load_data(valid_dir)
|
||||
local params = random_params(valid_x, 2)
|
||||
noise_scale_benchmark(valid_x, params, v1_noise, v1_scale, v2_noise, v2_scale)
|
||||
end
|
||||
|
||||
V1_DIR = "models/anime_style_art_rgb"
|
||||
V2_DIR = "models/anime_style_art_rgb5"
|
||||
|
||||
torch.manualSeed(opt.seed)
|
||||
cutorch.manualSeed(opt.seed)
|
||||
noise_main("./test", 2)
|
||||
--scale_main("./test")
|
||||
--noise_scale_main("./test")
|
|
@ -1,22 +1,14 @@
|
|||
local ffi = require 'ffi'
|
||||
require './lib/portable'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
|
||||
require 'pl'
|
||||
require 'image'
|
||||
require 'snappy'
|
||||
local settings = require './lib/settings'
|
||||
local image_loader = require './lib/image_loader'
|
||||
local compression = require 'compression'
|
||||
local settings = require 'settings'
|
||||
local image_loader = require 'image_loader'
|
||||
|
||||
local MAX_SIZE = 1440
|
||||
|
||||
local function count_lines(file)
|
||||
local fp = io.open(file, "r")
|
||||
local count = 0
|
||||
for line in fp:lines() do
|
||||
count = count + 1
|
||||
end
|
||||
fp:close()
|
||||
|
||||
return count
|
||||
end
|
||||
local function crop_if_large(src, max_size)
|
||||
if max_size > 0 and (src:size(2) > max_size or src:size(3) > max_size) then
|
||||
local sx = torch.random(0, src:size(3) - math.min(max_size, src:size(3)))
|
||||
|
@ -36,40 +28,38 @@ end
|
|||
|
||||
local function load_images(list)
|
||||
local MARGIN = 32
|
||||
local count = count_lines(list)
|
||||
local fp = io.open(list, "r")
|
||||
local lines = utils.split(file.read(list), "\n")
|
||||
local x = {}
|
||||
local c = 0
|
||||
for line in fp:lines() do
|
||||
for i = 1, #lines do
|
||||
local line = lines[i]
|
||||
local im, alpha = image_loader.load_byte(line)
|
||||
im = crop_if_large(im, settings.max_size)
|
||||
im = crop_4x(im)
|
||||
|
||||
if alpha then
|
||||
io.stderr:write(string.format("%s: skip: reason: alpha channel.", line))
|
||||
io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
|
||||
else
|
||||
im = crop_if_large(im, settings.max_size)
|
||||
im = crop_4x(im)
|
||||
local scale = 1.0
|
||||
if settings.random_half then
|
||||
scale = 2.0
|
||||
end
|
||||
if im then
|
||||
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
|
||||
table.insert(x, {im:size(), torch.ByteStorage():string(snappy.compress(im:storage():string()))})
|
||||
table.insert(x, compression.compress(im))
|
||||
else
|
||||
io.stderr:write(string.format("%s: skip: reason: too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
|
||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
|
||||
end
|
||||
else
|
||||
io.stderr:write(string.format("%s: skip: reason: load error.\n", line))
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", line))
|
||||
end
|
||||
end
|
||||
c = c + 1
|
||||
xlua.progress(c, count)
|
||||
if c % 10 == 0 then
|
||||
xlua.progress(i, #lines)
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
return x
|
||||
end
|
||||
|
||||
torch.manualSeed(settings.seed)
|
||||
print(settings)
|
||||
local x = load_images(settings.image_list)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
if mynn.DepthExpand2x then
|
||||
return mynn.DepthExpand2x
|
||||
if w2nn.DepthExpand2x then
|
||||
return w2nn.DepthExpand2x
|
||||
end
|
||||
local DepthExpand2x, parent = torch.class('mynn.DepthExpand2x','nn.Module')
|
||||
local DepthExpand2x, parent = torch.class('w2nn.DepthExpand2x','nn.Module')
|
||||
|
||||
function DepthExpand2x:__init()
|
||||
parent:__init()
|
||||
|
@ -67,9 +67,11 @@ function DepthExpand2x.test()
|
|||
end
|
||||
show(x)
|
||||
|
||||
local de2x = mynn.DepthExpand2x()
|
||||
local de2x = w2nn.DepthExpand2x()
|
||||
out = de2x:forward(x)
|
||||
show(out)
|
||||
out = de2x:updateGradInput(x, out)
|
||||
show(out)
|
||||
end
|
||||
|
||||
return DepthExpand2x
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
if mynn.LeakyReLU then
|
||||
return mynn.LeakyReLU
|
||||
if w2nn and w2nn.LeakyReLU then
|
||||
return w2nn.LeakyReLU
|
||||
end
|
||||
|
||||
local LeakyReLU, parent = torch.class('mynn.LeakyReLU','nn.Module')
|
||||
local LeakyReLU, parent = torch.class('w2nn.LeakyReLU','nn.Module')
|
||||
|
||||
function LeakyReLU:__init(negative_scale)
|
||||
parent.__init(self)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
local RGBWeightedMSECriterion, parent = torch.class('mynn.RGBWeightedMSECriterion','nn.Criterion')
|
||||
local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion')
|
||||
|
||||
function RGBWeightedMSECriterion:__init(w)
|
||||
function WeightedMSECriterion:__init(w)
|
||||
parent.__init(self)
|
||||
self.weight = w:clone()
|
||||
self.diff = torch.Tensor()
|
||||
self.loss = torch.Tensor()
|
||||
end
|
||||
|
||||
function RGBWeightedMSECriterion:updateOutput(input, target)
|
||||
function WeightedMSECriterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
for i = 1, input:size(1) do
|
||||
self.diff[i]:add(-1, target[i]):cmul(self.weight)
|
||||
|
@ -18,8 +18,7 @@ function RGBWeightedMSECriterion:updateOutput(input, target)
|
|||
return self.output
|
||||
end
|
||||
|
||||
function RGBWeightedMSECriterion:updateGradInput(input, target)
|
||||
function WeightedMSECriterion:updateGradInput(input, target)
|
||||
self.gradInput:resizeAs(input):copy(self.diff)
|
||||
return self.gradInput
|
||||
end
|
||||
|
17
lib/compression.lua
Normal file
17
lib/compression.lua
Normal file
|
@ -0,0 +1,17 @@
|
|||
-- snapply compression for ByteTensor
|
||||
require 'snappy'
|
||||
|
||||
local compression = {}
|
||||
compression.compress = function (bt)
|
||||
local enc = snappy.compress(bt:storage():string())
|
||||
return {bt:size(), torch.ByteStorage():string(enc)}
|
||||
end
|
||||
compression.decompress = function(data)
|
||||
local size = data[1]
|
||||
local dec = snappy.decompress(data[2]:string())
|
||||
local bt = torch.ByteTensor(unpack(torch.totable(size)))
|
||||
bt:storage():string(dec)
|
||||
return bt
|
||||
end
|
||||
|
||||
return compression
|
|
@ -17,7 +17,7 @@ function image_loader.encode_png(rgb, alpha)
|
|||
end
|
||||
if alpha then
|
||||
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
||||
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
|
||||
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
|
||||
end
|
||||
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
|
||||
rgba[1]:copy(rgb[1])
|
||||
|
@ -50,8 +50,8 @@ function image_loader.decode_byte(blob)
|
|||
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
||||
-- split alpha channel
|
||||
im = im:toTensor('float', 'RGBA', 'DHW')
|
||||
local sum_alpha = (im[4] - 1):sum()
|
||||
if sum_alpha > 0 or sum_alpha < 0 then
|
||||
local sum_alpha = (im[4] - 1.0):sum()
|
||||
if sum_alpha < 0 then
|
||||
alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
||||
end
|
||||
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
|
||||
|
|
|
@ -22,5 +22,4 @@ function iproc.padding(img, w1, w2, h1, h2)
|
|||
flow[2]:add(-w1)
|
||||
return image.warp(img, flow, "simple", false, "clamp")
|
||||
end
|
||||
|
||||
return iproc
|
||||
|
|
20
lib/mynn.lua
20
lib/mynn.lua
|
@ -1,20 +0,0 @@
|
|||
local function load_cunn()
|
||||
require 'nn'
|
||||
require 'cunn'
|
||||
end
|
||||
local function load_cudnn()
|
||||
require 'cudnn'
|
||||
cudnn.fastest = true
|
||||
end
|
||||
if mynn then
|
||||
return mynn
|
||||
else
|
||||
load_cunn()
|
||||
--load_cudnn()
|
||||
mynn = {}
|
||||
require './LeakyReLU'
|
||||
require './LeakyReLU_deprecated'
|
||||
require './DepthExpand2x'
|
||||
require './RGBWeightedMSECriterion'
|
||||
return mynn
|
||||
end
|
|
@ -1,7 +1,7 @@
|
|||
require 'image'
|
||||
local gm = require 'graphicsmagick'
|
||||
local iproc = require './iproc'
|
||||
local reconstruct = require './reconstruct'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local pairwise_transform = {}
|
||||
|
||||
local function random_half(src, p)
|
||||
|
@ -81,6 +81,11 @@ local function color_noise(src)
|
|||
|
||||
return x:mul(255):byte()
|
||||
end
|
||||
local function shift_1px(src)
|
||||
-- reducing the even/odd issue in nearest neighbor.
|
||||
local r = torch.random(1, 4)
|
||||
|
||||
end
|
||||
local function flip_augment(x, y)
|
||||
local flip = torch.random(1, 4)
|
||||
if y then
|
||||
|
@ -138,17 +143,16 @@ local function data_augment(y, options)
|
|||
return y
|
||||
end
|
||||
|
||||
|
||||
local INTERPOLATION_PADDING = 16
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
local filters = {
|
||||
"Box","Box","Box", -- 0.012756949974688
|
||||
"Box","Box", -- 0.012756949974688
|
||||
"Blackman", -- 0.013191924552285
|
||||
--"Cartom", -- 0.013753536746706
|
||||
--"Hanning", -- 0.013761314529647
|
||||
--"Hermite", -- 0.013850225205266
|
||||
"SincFast", -- 0.014095824314306
|
||||
"Jinc", -- 0.014244299255442
|
||||
--"Jinc", -- 0.014244299255442
|
||||
}
|
||||
if options.random_half then
|
||||
src = random_half(src)
|
||||
|
@ -176,26 +180,14 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
||||
if options.random_half then
|
||||
src = random_half(src)
|
||||
end
|
||||
src = crop_if_large(src, math.max(size * 4, 512))
|
||||
local y = src
|
||||
local x
|
||||
|
||||
if options.color_noise then
|
||||
y = color_noise(y)
|
||||
end
|
||||
if options.overlay then
|
||||
y = overlay_augment(y)
|
||||
end
|
||||
x = y
|
||||
local y = data_augment(crop_if_large(src, math.max(size * 4, 512)), options)
|
||||
local x = y
|
||||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg")
|
||||
if options.jpeg_sampling_factors == 444 then
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
else -- 422
|
||||
else -- 420
|
||||
x:samplingFactors({2.0, 1.0, 1.0})
|
||||
end
|
||||
local blob, len = x:toBlob(quality[i])
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
require 'torch'
|
||||
require 'nn'
|
||||
|
||||
local function load_cuda()
|
||||
require 'cutorch'
|
||||
require 'cunn'
|
||||
end
|
||||
local function load_cudnn()
|
||||
require 'cudnn'
|
||||
--cudnn.fastest = true
|
||||
end
|
||||
|
||||
if pcall(load_cuda) then
|
||||
else
|
||||
end
|
||||
if pcall(load_cudnn) then
|
||||
end
|
|
@ -1,5 +1,5 @@
|
|||
require 'image'
|
||||
local iproc = require './iproc'
|
||||
local iproc = require 'iproc'
|
||||
|
||||
local function reconstruct_y(model, x, offset, block_size)
|
||||
if x:dim() == 2 then
|
||||
|
|
|
@ -35,7 +35,7 @@ cmd:option("-crop_size", 128, 'crop size')
|
|||
cmd:option("-max_size", -1, 'crop if image size larger then this value.')
|
||||
cmd:option("-batch_size", 2, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'epoch')
|
||||
cmd:option("-core", 2, 'cpu core')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
|
||||
cmd:option("-validation_ratio", 0.1, 'validation ratio')
|
||||
cmd:option("-validation_crops", 40, 'number of crop region in validation')
|
||||
|
@ -84,7 +84,9 @@ else
|
|||
settings.overlay = false
|
||||
end
|
||||
|
||||
torch.setnumthreads(settings.core)
|
||||
if settings.thread > 0 then
|
||||
torch.setnumthreads(tonumber(settings.thread))
|
||||
end
|
||||
|
||||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
|
||||
require './mynn'
|
||||
require 'w2nn'
|
||||
|
||||
-- ref: http://arxiv.org/abs/1502.01852
|
||||
-- ref: http://arxiv.org/abs/1501.00092
|
||||
|
@ -7,17 +6,17 @@ local srcnn = {}
|
|||
function srcnn.waifu2x_cunn(ch)
|
||||
local model = nn.Sequential()
|
||||
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
--model:cuda()
|
||||
|
@ -28,17 +27,17 @@ end
|
|||
function srcnn.waifu2x_cudnn(ch)
|
||||
local model = nn.Sequential()
|
||||
model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
--model:cuda()
|
||||
|
|
24
lib/w2nn.lua
Normal file
24
lib/w2nn.lua
Normal file
|
@ -0,0 +1,24 @@
|
|||
local function load_nn()
|
||||
require 'torch'
|
||||
require 'nn'
|
||||
end
|
||||
local function load_cunn()
|
||||
require 'cutorch'
|
||||
require 'cunn'
|
||||
end
|
||||
local function load_cudnn()
|
||||
require 'cudnn'
|
||||
cudnn.fastest = true
|
||||
end
|
||||
if w2nn then
|
||||
return w2nn
|
||||
else
|
||||
pcall(load_cunn)
|
||||
pcall(load_cudnn)
|
||||
w2nn = {}
|
||||
require 'LeakyReLU'
|
||||
require 'LeakyReLU_deprecated'
|
||||
require 'DepthExpand2x'
|
||||
require 'WeightedMSECriterion'
|
||||
return w2nn
|
||||
end
|
148
tools/benchmark.lua
Normal file
148
tools/benchmark.lua
Normal file
|
@ -0,0 +1,148 @@
|
|||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'xlua'
|
||||
require 'pl'
|
||||
|
||||
require 'w2nn'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
local gm = require 'graphicsmagick'
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("waifu2x-benchmark")
|
||||
cmd:text("Options:")
|
||||
|
||||
cmd:option("-seed", 11, 'fixed input seed')
|
||||
cmd:option("-dir", "./data/test", 'test image directory')
|
||||
cmd:option("-model1_dir", "./models/anime_style_art", 'model1 directory')
|
||||
cmd:option("-model2_dir", "./models/anime_style_art_rgb", 'model2 directory')
|
||||
cmd:option("-method", "scale", '(scale|noise)')
|
||||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-color_weight", "y", '(y|rgb)')
|
||||
cmd:option("-jpeg_quality", 75, 'jpeg quality')
|
||||
cmd:option("-jpeg_times", 1, 'jpeg compression times')
|
||||
cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
local function MSE(x1, x2)
|
||||
return (x1 - x2):pow(2):mean()
|
||||
end
|
||||
local function YMSE(x1, x2)
|
||||
local x1_2 = x1:clone()
|
||||
local x2_2 = x2:clone()
|
||||
|
||||
x1_2[1]:mul(0.299 * 3)
|
||||
x1_2[2]:mul(0.587 * 3)
|
||||
x1_2[3]:mul(0.114 * 3)
|
||||
|
||||
x2_2[1]:mul(0.299 * 3)
|
||||
x2_2[2]:mul(0.587 * 3)
|
||||
x2_2[3]:mul(0.114 * 3)
|
||||
|
||||
return (x1_2 - x2_2):pow(2):mean()
|
||||
end
|
||||
local function PSNR(x1, x2)
|
||||
local mse = MSE(x1, x2)
|
||||
return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10))
|
||||
end
|
||||
local function YPSNR(x1, x2)
|
||||
local mse = YMSE(x1, x2)
|
||||
return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10))
|
||||
end
|
||||
|
||||
local function transform_jpeg(x)
|
||||
for i = 1, opt.jpeg_times do
|
||||
jpeg = gm.Image(x, "RGB", "DHW")
|
||||
jpeg:format("jpeg")
|
||||
jpeg:samplingFactors({1.0, 1.0, 1.0})
|
||||
blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
|
||||
jpeg:fromBlob(blob, len)
|
||||
x = jpeg:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
return x
|
||||
end
|
||||
local function transform_scale(x)
|
||||
return iproc.scale(x,
|
||||
x:size(3) * 0.5,
|
||||
x:size(2) * 0.5,
|
||||
"Box")
|
||||
end
|
||||
|
||||
local function benchmark(color_weight, x, input_func, v1_noise, v2_noise)
|
||||
local v1_mse = 0
|
||||
local v2_mse = 0
|
||||
local v1_psnr = 0
|
||||
local v2_psnr = 0
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i]
|
||||
local input, v1_output, v2_output
|
||||
|
||||
input = input_func(ground_truth)
|
||||
input = input:float():div(255)
|
||||
ground_truth = ground_truth:float():div(255)
|
||||
|
||||
t = sys.clock()
|
||||
if input:size(3) == ground_truth:size(3) then
|
||||
v1_output = reconstruct.image(v1_noise, input)
|
||||
v2_output = reconstruct.image(v2_noise, input)
|
||||
else
|
||||
v1_output = reconstruct.scale(v1_noise, 2.0, input)
|
||||
v2_output = reconstruct.scale(v2_noise, 2.0, input)
|
||||
end
|
||||
if color_weight == "y" then
|
||||
v1_mse = v1_mse + YMSE(ground_truth, v1_output)
|
||||
v1_psnr = v1_psnr + YPSNR(ground_truth, v1_output)
|
||||
v2_mse = v2_mse + YMSE(ground_truth, v2_output)
|
||||
v2_psnr = v2_psnr + YPSNR(ground_truth, v2_output)
|
||||
elseif color_weight == "rgb" then
|
||||
v1_mse = v1_mse + MSE(ground_truth, v1_output)
|
||||
v1_psnr = v1_psnr + PSNR(ground_truth, v1_output)
|
||||
v2_mse = v2_mse + MSE(ground_truth, v2_output)
|
||||
v2_psnr = v2_psnr + PSNR(ground_truth, v2_output)
|
||||
end
|
||||
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; v1_mse=%f, v2_mse=%f, v1_psnr=%f, v2_psnr=%f \r",
|
||||
i, #x,
|
||||
v1_mse / i, v2_mse / i,
|
||||
v1_psnr / i, v2_psnr / i
|
||||
)
|
||||
)
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
local function crop_4x(x)
|
||||
local w = x:size(3) % 4
|
||||
local h = x:size(2) % 4
|
||||
return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
|
||||
end
|
||||
local function load_data(test_dir)
|
||||
local test_x = {}
|
||||
local files = dir.getfiles(test_dir, "*.*")
|
||||
for i = 1, #files do
|
||||
table.insert(test_x, crop_4x(image_loader.load_byte(files[i])))
|
||||
xlua.progress(i, #files)
|
||||
end
|
||||
return test_x
|
||||
end
|
||||
|
||||
print(opt)
|
||||
torch.manualSeed(opt.seed)
|
||||
cutorch.manualSeed(opt.seed)
|
||||
if opt.method == "scale" then
|
||||
local v1 = torch.load(path.join(opt.model1_dir, "scale2.0x_model.t7"), "ascii")
|
||||
local v2 = torch.load(path.join(opt.model2_dir, "scale2.0x_model.t7"), "ascii")
|
||||
local test_x = load_data(opt.dir)
|
||||
benchmark(opt.color_weight, test_x, transform_scale, v1, v2)
|
||||
elseif opt.method == "noise" then
|
||||
local v1 = torch.load(path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
|
||||
local v2 = torch.load(path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
|
||||
local test_x = load_data(opt.dir)
|
||||
benchmark(opt.color_weight, test_x, transform_jpeg, v1, v2)
|
||||
end
|
|
@ -1,6 +1,7 @@
|
|||
require './lib/portable'
|
||||
require './lib/mynn'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
|
||||
require 'w2nn'
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
|
||||
-- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
|
||||
|
@ -27,7 +28,7 @@ local function cleanupModel(node)
|
|||
if node.finput ~= nil then
|
||||
node.finput = zeroDataSize(node.finput)
|
||||
end
|
||||
if tostring(node) == "nn.LeakyReLU" then
|
||||
if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
|
||||
if node.negative ~= nil then
|
||||
node.negative = zeroDataSize(node.negative)
|
||||
end
|
|
@ -1,6 +1,7 @@
|
|||
-- adapted from https://github.com/marcan/cl-waifu2x
|
||||
require './lib/portable'
|
||||
require './lib/LeakyReLU'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'w2nn'
|
||||
local cjson = require "cjson"
|
||||
|
||||
local model = torch.load(arg[1], "ascii")
|
29
train.lua
29
train.lua
|
@ -1,17 +1,18 @@
|
|||
require './lib/portable'
|
||||
require './lib/mynn'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
require 'optim'
|
||||
require 'xlua'
|
||||
require 'pl'
|
||||
require 'snappy'
|
||||
|
||||
local settings = require './lib/settings'
|
||||
local srcnn = require './lib/srcnn'
|
||||
local minibatch_adam = require './lib/minibatch_adam'
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local pairwise_transform = require './lib/pairwise_transform'
|
||||
local image_loader = require './lib/image_loader'
|
||||
require 'w2nn'
|
||||
local settings = require 'settings'
|
||||
local srcnn = require 'srcnn'
|
||||
local minibatch_adam = require 'minibatch_adam'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local compression = require 'compression'
|
||||
local pairwise_transform = require 'pairwise_transform'
|
||||
local image_loader = require 'image_loader'
|
||||
|
||||
local function save_test_scale(model, rgb, file)
|
||||
local up = reconstruct.scale(model, settings.scale, rgb)
|
||||
|
@ -73,17 +74,13 @@ local function create_criterion(model)
|
|||
weight[1]:fill(0.299 * 3) -- R
|
||||
weight[2]:fill(0.587 * 3) -- G
|
||||
weight[3]:fill(0.114 * 3) -- B
|
||||
return mynn.RGBWeightedMSECriterion(weight):cuda()
|
||||
return w2nn.WeightedMSECriterion(weight):cuda()
|
||||
else
|
||||
return nn.MSECriterion():cuda()
|
||||
end
|
||||
end
|
||||
local function transformer(x, is_validation, n, offset)
|
||||
local size = x[1]
|
||||
local dec = snappy.decompress(x[2]:string())
|
||||
x = torch.ByteTensor(size[1], size[2], size[3])
|
||||
x:storage():string(dec)
|
||||
|
||||
x = compression.decompress(x)
|
||||
n = n or settings.batch_size;
|
||||
if is_validation == nil then is_validation = false end
|
||||
local color_noise = nil
|
||||
|
|
18
waifu2x.lua
18
waifu2x.lua
|
@ -1,11 +1,11 @@
|
|||
require './lib/portable'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
require 'sys'
|
||||
require 'pl'
|
||||
require './lib/mynn'
|
||||
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
require 'w2nn'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
|
@ -111,8 +111,12 @@ local function waifu2x()
|
|||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-crop_size", 128, 'patch size per process')
|
||||
cmd:option("-resume", 0, "skip existing files (0|1)")
|
||||
|
||||
cmd:option("-thread", -1, "number of CPU threads")
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if opt.thread > 0 then
|
||||
torch.setnumthreads(opt.thread)
|
||||
end
|
||||
if string.len(opt.l) == 0 then
|
||||
convert_image(opt)
|
||||
else
|
||||
|
|
24
web.lua
24
web.lua
|
@ -1,11 +1,16 @@
|
|||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
_G.TURBO_SSL = true
|
||||
|
||||
require 'pl'
|
||||
require 'w2nn'
|
||||
local turbo = require 'turbo'
|
||||
local uuid = require 'uuid'
|
||||
local ffi = require 'ffi'
|
||||
local md5 = require 'md5'
|
||||
require 'pl'
|
||||
require 'lib.portable'
|
||||
require 'lib.mynn'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
|
@ -13,18 +18,15 @@ cmd:text("waifu2x-api")
|
|||
cmd:text("Options:")
|
||||
cmd:option("-port", 8812, 'listen port')
|
||||
cmd:option("-gpu", 1, 'Device ID')
|
||||
cmd:option("-core", 2, 'number of CPU cores')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
local opt = cmd:parse(arg)
|
||||
cutorch.setDevice(opt.gpu)
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
torch.setnumthreads(opt.core)
|
||||
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
|
||||
local MODEL_DIR = "./models/anime_style_art_rgb3"
|
||||
if opt.thread > 0 then
|
||||
torch.setnumthreads(opt.thread)
|
||||
end
|
||||
|
||||
local MODEL_DIR = "./models/anime_style_art_rgb"
|
||||
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
|
|
Loading…
Reference in a new issue