1
0
Fork 0
mirror of synced 2024-05-02 03:52:19 +12:00

sync from internal repo

- Memory compression by snappy (lua-csnappy)
- Use RGB-wise Weighted MSE(R*0.299, G*0.587, B*0.114) instead of MSE
- Aggressive cropping for edge region
and some change.
This commit is contained in:
nagadomi 2015-10-26 09:23:52 +09:00
parent 54580ba8c0
commit 8dea362bed
35 changed files with 882 additions and 278 deletions

4
.gitignore vendored
View file

@ -1,4 +1,8 @@
*~
/*.png
/*.mp4
/*.jpg
cache/*.png
models/*.png
models/*/*.png
waifu2x.log

280
benchmark.lua Normal file
View file

@ -0,0 +1,280 @@
require './lib/portable'
require './lib/mynn'
require 'xlua'
require 'pl'
local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader'
local gm = require 'graphicsmagick'
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x-benchmark")
cmd:text("Options:")
cmd:option("-seed", 11, 'fixed input seed')
cmd:option("-test_dir", "./test", 'test image directory')
cmd:option("-jpeg_quality", 50, 'jpeg quality')
cmd:option("-jpeg_times", 3, 'number of jpeg compression ')
cmd:option("-jpeg_quality_down", 5, 'reducing jpeg quality each times')
cmd:option("-core", 4, 'threads')
local opt = cmd:parse(arg)
torch.setnumthreads(opt.core)
torch.setdefaulttensortype('torch.FloatTensor')
local function MSE(x1, x2)
return (x1 - x2):pow(2):mean()
end
local function YMSE(x1, x2)
local x1_2 = x1:clone()
local x2_2 = x2:clone()
x1_2[1]:mul(0.299 * 3)
x1_2[2]:mul(0.587 * 3)
x1_2[3]:mul(0.114 * 3)
x2_2[1]:mul(0.299 * 3)
x2_2[2]:mul(0.587 * 3)
x2_2[3]:mul(0.114 * 3)
return (x1_2 - x2_2):pow(2):mean()
end
local function PSNR(x1, x2)
local mse = MSE(x1, x2)
return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10))
end
local function YPSNR(x1, x2)
local mse = YMSE(x1, x2)
return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10))
end
local function transform_jpeg(x)
for i = 1, opt.jpeg_times do
jpeg = gm.Image(x, "RGB", "DHW")
jpeg:format("jpeg")
jpeg:samplingFactors({1.0, 1.0, 1.0})
blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
jpeg:fromBlob(blob, len)
x = jpeg:toTensor("byte", "RGB", "DHW")
end
return x
end
local function noise_benchmark(x, v1_noise, v2_noise)
local v1_mse = 0
local v2_mse = 0
local jpeg_mse = 0
local v1_psnr = 0
local v2_psnr = 0
local jpeg_psnr = 0
local v1_time = 0
local v2_time = 0
for i = 1, #x do
local ground_truth = x[i]
local jpg, blob, len, input, v1_out, v2_out, t, mse
input = transform_jpeg(ground_truth)
input = input:float():div(255)
ground_truth = ground_truth:float():div(255)
jpeg_mse = jpeg_mse + MSE(ground_truth, input)
jpeg_psnr = jpeg_psnr + PSNR(ground_truth, input)
t = sys.clock()
v1_output = reconstruct.image(v1_noise, input)
v1_time = v1_time + (sys.clock() - t)
v1_mse = v1_mse + MSE(ground_truth, v1_output)
v1_psnr = v1_psnr + PSNR(ground_truth, v1_output)
t = sys.clock()
v2_output = reconstruct.image(v2_noise, input)
v2_time = v2_time + (sys.clock() - t)
v2_mse = v2_mse + MSE(ground_truth, v2_output)
v2_psnr = v2_psnr + PSNR(ground_truth, v2_output)
io.stdout:write(
string.format("%d/%d; v1_time=%f, v2_time=%f, jpeg_mse=%f, v1_mse=%f, v2_mse=%f, jpeg_psnr=%f, v1_psnr=%f, v2_psnr=%f \r",
i, #x,
v1_time / i, v2_time / i,
jpeg_mse / i,
v1_mse / i, v2_mse / i,
jpeg_psnr / i,
v1_psnr / i, v2_psnr / i
)
)
io.stdout:flush()
end
io.stdout:write("\n")
end
local function noise_scale_benchmark(x, params, v1_noise, v1_scale, v2_noise, v2_scale)
local v1_mse = 0
local v2_mse = 0
local jinc_mse = 0
local v1_time = 0
local v2_time = 0
for i = 1, #x do
local ground_truth = x[i]
local downscale = iproc.scale(ground_truth,
ground_truth:size(3) * 0.5,
ground_truth:size(2) * 0.5,
params[i].filter)
local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
jpeg = gm.Image(downscale, "RGB", "DHW")
jpeg:format("jpeg")
blob, len = jpeg:toBlob(params[i].quality)
jpeg:fromBlob(blob, len)
input = jpeg:toTensor("byte", "RGB", "DHW")
input = input:float():div(255)
ground_truth = ground_truth:float():div(255)
jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
jinc_mse = jinc_mse + (ground_truth - jinc_output):pow(2):mean()
t = sys.clock()
v1_output = reconstruct.image(v1_noise, input)
v1_output = reconstruct.scale(v1_scale, 2.0, v1_output)
v1_time = v1_time + (sys.clock() - t)
mse = (ground_truth - v1_output):pow(2):mean()
v1_mse = v1_mse + mse
t = sys.clock()
v2_output = reconstruct.image(v2_noise, input)
v2_output = reconstruct.scale(v2_scale, 2.0, v2_output)
v2_time = v2_time + (sys.clock() - t)
mse = (ground_truth - v2_output):pow(2):mean()
v2_mse = v2_mse + mse
io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
i, #x,
v1_time / i, v2_time / i,
(v1_time / i) / (v2_time / i),
jinc_mse / i,
v1_mse / i, (v1_mse/i) / (jinc_mse/i),
v2_mse / i, (v2_mse/i) / (jinc_mse/i),
(v1_mse / i) / (v2_mse / i)))
io.stdout:flush()
end
io.stdout:write("\n")
end
local function scale_benchmark(x, params, v1_scale, v2_scale)
local v1_mse = 0
local v2_mse = 0
local jinc_mse = 0
local v1_psnr = 0
local v2_psnr = 0
local jinc_psnr = 0
local v1_time = 0
local v2_time = 0
for i = 1, #x do
local ground_truth = x[i]
local downscale = iproc.scale(ground_truth,
ground_truth:size(3) * 0.5,
ground_truth:size(2) * 0.5,
params[i].filter)
local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
input = downscale
input = input:float():div(255)
ground_truth = ground_truth:float():div(255)
jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
mse = (ground_truth - jinc_output):pow(2):mean()
jinc_mse = jinc_mse + mse
jinc_psnr = jinc_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
t = sys.clock()
v1_output = reconstruct.scale(v1_scale, 2.0, input)
v1_time = v1_time + (sys.clock() - t)
mse = (ground_truth - v1_output):pow(2):mean()
v1_mse = v1_mse + mse
v1_psnr = v1_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
t = sys.clock()
v2_output = reconstruct.scale(v2_scale, 2.0, input)
v2_time = v2_time + (sys.clock() - t)
mse = (ground_truth - v2_output):pow(2):mean()
v2_mse = v2_mse + mse
v2_psnr = v2_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
i, #x,
v1_time / i, v2_time / i,
(v1_time / i) / (v2_time / i),
jinc_psnr / i,
v1_psnr / i, (v1_psnr/i) / (jinc_psnr/i),
v2_psnr / i, (v2_psnr/i) / (jinc_psnr/i),
(v1_psnr / i) / (v2_psnr / i)))
io.stdout:flush()
end
io.stdout:write("\n")
end
local function split_data(x, test_size)
local index = torch.randperm(#x)
local train_size = #x - test_size
local train_x = {}
local valid_x = {}
for i = 1, train_size do
train_x[i] = x[index[i]]
end
for i = 1, test_size do
valid_x[i] = x[index[train_size + i]]
end
return train_x, valid_x
end
local function crop_4x(x)
local w = x:size(3) % 4
local h = x:size(2) % 4
return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
end
local function load_data(valid_dir)
local valid_x = {}
local files = dir.getfiles(valid_dir, "*.png")
for i = 1, #files do
table.insert(valid_x, crop_4x(image_loader.load_byte(files[i])))
xlua.progress(i, #files)
end
return valid_x
end
local function noise_main(valid_dir, level)
local v1_noise = torch.load(path.join(V1_DIR, string.format("noise%d_model.t7", level)), "ascii")
local v2_noise = torch.load(path.join(V2_DIR, string.format("noise%d_model.t7", level)), "ascii")
local valid_x = load_data(valid_dir)
noise_benchmark(valid_x, v1_noise, v2_noise)
end
local function scale_main(valid_dir)
local v1 = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
local v2 = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
local valid_x = load_data(valid_dir)
local params = random_params(valid_x, 2)
scale_benchmark(valid_x, params, v1, v2)
end
local function noise_scale_main(valid_dir)
local v1_noise = torch.load(path.join(V1_DIR, "noise2_model.t7"), "ascii")
local v1_scale = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
local v2_noise = torch.load(path.join(V2_DIR, "noise2_model.t7"), "ascii")
local v2_scale = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
local valid_x = load_data(valid_dir)
local params = random_params(valid_x, 2)
noise_scale_benchmark(valid_x, params, v1_noise, v1_scale, v2_noise, v2_scale)
end
V1_DIR = "models/anime_style_art_rgb"
V2_DIR = "models/anime_style_art_rgb5"
torch.manualSeed(opt.seed)
cutorch.manualSeed(opt.seed)
noise_main("./test", 2)
--scale_main("./test")
--noise_scale_main("./test")

View file

@ -1,5 +1,5 @@
require './lib/portable'
require './lib/LeakyReLU'
require './lib/mynn'
torch.setdefaulttensortype("torch.FloatTensor")

View file

@ -1,8 +1,12 @@
local ffi = require 'ffi'
require './lib/portable'
require 'image'
require 'snappy'
local settings = require './lib/settings'
local image_loader = require './lib/image_loader'
local MAX_SIZE = 1440
local function count_lines(file)
local fp = io.open(file, "r")
local count = 0
@ -13,7 +17,17 @@ local function count_lines(file)
return count
end
local function crop_if_large(src, max_size)
if max_size > 0 and (src:size(2) > max_size or src:size(3) > max_size) then
local sx = torch.random(0, src:size(3) - math.min(max_size, src:size(3)))
local sy = torch.random(0, src:size(2) - math.min(max_size, src:size(2)))
return image.crop(src, sx, sy,
math.min(sx + max_size, src:size(3)),
math.min(sy + max_size, src:size(2)))
else
return src
end
end
local function crop_4x(x)
local w = x:size(3) % 4
local h = x:size(2) % 4
@ -27,13 +41,26 @@ local function load_images(list)
local x = {}
local c = 0
for line in fp:lines() do
local im = crop_4x(image_loader.load_byte(line))
if im then
if im:size(2) > (settings.crop_size * 2 + MARGIN) and im:size(3) > (settings.crop_size * 2 + MARGIN) then
table.insert(x, im)
end
local im, alpha = image_loader.load_byte(line)
im = crop_if_large(im, settings.max_size)
im = crop_4x(im)
if alpha then
io.stderr:write(string.format("%s: skip: reason: alpha channel.", line))
else
print("error:" .. line)
local scale = 1.0
if settings.random_half then
scale = 2.0
end
if im then
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
table.insert(x, {im:size(), torch.ByteStorage():string(snappy.compress(im:storage():string()))})
else
io.stderr:write(string.format("%s: skip: reason: too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
end
else
io.stderr:write(string.format("%s: skip: reason: load error.\n", line))
end
end
c = c + 1
xlua.progress(c, count)
@ -43,7 +70,7 @@ local function load_images(list)
end
return x
end
torch.manualSeed(settings.seed)
print(settings)
local x = load_images(settings.image_list)
torch.save(settings.images, x)

View file

View file

@ -1,8 +1,7 @@
#!/bin/sh
th waifu2x.lua -noise_level 1 -m noise_scale -i images/miku_small.png -o images/miku_small_waifu2x.png
th waifu2x.lua -m scale -i images/miku_small.png -o images/miku_small_waifu2x.png
th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_small_noisy.png -o images/miku_small_noisy_waifu2x.png
th waifu2x.lua -noise_level 2 -m noise -i images/miku_noisy.png -o images/miku_noisy_waifu2x.png
th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_CC_BY-NC_noisy.jpg -o images/miku_CC_BY-NC_noisy_waifu2x.png
th waifu2x.lua -noise_level 2 -m noise -i images/lena.png -o images/lena_waifu2x.png
th waifu2x.lua -m scale -model_dir models/ukbench -i images/lena.png -o images/lena_waifu2x_ukbench.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 315 KiB

After

Width:  |  Height:  |  Size: 355 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 605 KiB

After

Width:  |  Height:  |  Size: 651 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 154 KiB

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 177 KiB

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 138 KiB

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 136 KiB

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 499 KiB

After

Width:  |  Height:  |  Size: 498 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 380 KiB

After

Width:  |  Height:  |  Size: 377 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 378 KiB

After

Width:  |  Height:  |  Size: 356 KiB

75
lib/DepthExpand2x.lua Normal file
View file

@ -0,0 +1,75 @@
if mynn.DepthExpand2x then
return mynn.DepthExpand2x
end
local DepthExpand2x, parent = torch.class('mynn.DepthExpand2x','nn.Module')
function DepthExpand2x:__init()
parent:__init()
end
function DepthExpand2x:updateOutput(input)
local x = input
-- (batch_size, depth, height, width)
self.shape = x:size()
assert(self.shape:size() == 4, "input must be 4d tensor")
assert(self.shape[2] % 4 == 0, "depth must be depth % 4 = 0")
-- (batch_size, width, height, depth)
x = x:transpose(2, 4)
-- (batch_size, width, height * 2, depth / 2)
x = x:reshape(self.shape[1], self.shape[4], self.shape[3] * 2, self.shape[2] / 2)
-- (batch_size, height * 2, width, depth / 2)
x = x:transpose(2, 3)
-- (batch_size, height * 2, width * 2, depth / 4)
x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4] * 2, self.shape[2] / 4)
-- (batch_size, depth / 4, height * 2, width * 2)
x = x:transpose(2, 4)
x = x:transpose(3, 4)
self.output:resizeAs(x):copy(x) -- contiguous
return self.output
end
function DepthExpand2x:updateGradInput(input, gradOutput)
-- (batch_size, depth / 4, height * 2, width * 2)
local x = gradOutput
-- (batch_size, height * 2, width * 2, depth / 4)
x = x:transpose(2, 4)
x = x:transpose(2, 3)
-- (batch_size, height * 2, width, depth / 2)
x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4], self.shape[2] / 2)
-- (batch_size, width, height * 2, depth / 2)
x = x:transpose(2, 3)
-- (batch_size, width, height, depth)
x = x:reshape(self.shape[1], self.shape[4], self.shape[3], self.shape[2])
-- (batch_size, depth, height, width)
x = x:transpose(2, 4)
self.gradInput:resizeAs(x):copy(x)
return self.gradInput
end
function DepthExpand2x.test()
require 'image'
local function show(x)
local img = torch.Tensor(3, x:size(3), x:size(4))
img[1]:copy(x[1][1])
img[2]:copy(x[1][2])
img[3]:copy(x[1][3])
image.display(img)
end
local img = image.lena()
local x = torch.Tensor(1, img:size(1) * 4, img:size(2), img:size(3))
for i = 0, img:size(1) * 4 - 1 do
src_index = ((i % 3) + 1)
x[1][i + 1]:copy(img[src_index])
end
show(x)
local de2x = mynn.DepthExpand2x()
out = de2x:forward(x)
show(out)
out = de2x:updateGradInput(x, out)
show(out)
end

View file

@ -1,7 +1,8 @@
if nn.LeakyReLU then
return
if mynn.LeakyReLU then
return mynn.LeakyReLU
end
local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
local LeakyReLU, parent = torch.class('mynn.LeakyReLU','nn.Module')
function LeakyReLU:__init(negative_scale)
parent.__init(self)

View file

@ -0,0 +1,31 @@
if nn.LeakyReLU then
return nn.LeakyReLU
end
local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
function LeakyReLU:__init(negative_scale)
parent.__init(self)
self.negative_scale = negative_scale or 0.333
self.negative = torch.Tensor()
end
function LeakyReLU:updateOutput(input)
self.output:resizeAs(input):copy(input):abs():add(input):div(2)
self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale)
self.output:add(self.negative)
return self.output
end
function LeakyReLU:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(gradOutput)
-- filter positive
self.negative:sign():add(1)
torch.cmul(self.gradInput, gradOutput, self.negative)
-- filter negative
self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput)
self.gradInput:add(self.negative)
return self.gradInput
end

View file

@ -0,0 +1,25 @@
local RGBWeightedMSECriterion, parent = torch.class('mynn.RGBWeightedMSECriterion','nn.Criterion')
function RGBWeightedMSECriterion:__init(w)
parent.__init(self)
self.weight = w:clone()
self.diff = torch.Tensor()
self.loss = torch.Tensor()
end
function RGBWeightedMSECriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
for i = 1, input:size(1) do
self.diff[i]:add(-1, target[i]):cmul(self.weight)
end
self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
self.output = self.loss:mean()
return self.output
end
function RGBWeightedMSECriterion:updateGradInput(input, target)
self.gradInput:resizeAs(input):copy(self.diff)
return self.gradInput
end

View file

@ -13,7 +13,7 @@ function image_loader.decode_float(blob)
end
function image_loader.encode_png(rgb, alpha)
if rgb:type() == "torch.ByteTensor" then
error("expect FloatTensor")
rgb = rgb:float():div(255)
end
if alpha then
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
@ -26,11 +26,11 @@ function image_loader.encode_png(rgb, alpha)
rgba[4]:copy(alpha)
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
im:format("png")
return im:toBlob()
return im:toBlob(9)
else
local im = gm.Image(rgb, "RGB", "DHW")
im:format("png")
return im:toBlob()
return im:toBlob(9)
end
end
function image_loader.save_png(filename, rgb, alpha)
@ -64,6 +64,7 @@ function image_loader.decode_byte(blob)
end
return {im, alpha}
end
load_image()
local state, ret = pcall(load_image)
if state then
return ret[1], ret[2]

View file

@ -22,15 +22,12 @@ local function minibatch_adam(model, criterion,
local targets_tmp = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3])
for t = 1, #train_x, batch_size do
if t + batch_size > #train_x then
break
end
for t = 1, #train_x do
xlua.progress(t, #train_x)
for i = 1, batch_size do
local x, y = transformer(train_x[shuffle[t + i - 1]])
inputs_tmp[i]:copy(x)
targets_tmp[i]:copy(y)
local xy = transformer(train_x[shuffle[t]], false, batch_size)
for i = 1, #xy do
inputs_tmp[i]:copy(xy[i][1])
targets_tmp[i]:copy(xy[i][2])
end
inputs:copy(inputs_tmp)
targets:copy(targets_tmp)

20
lib/mynn.lua Normal file
View file

@ -0,0 +1,20 @@
local function load_cunn()
require 'nn'
require 'cunn'
end
local function load_cudnn()
require 'cudnn'
cudnn.fastest = true
end
if mynn then
return mynn
else
load_cunn()
--load_cudnn()
mynn = {}
require './LeakyReLU'
require './LeakyReLU_deprecated'
require './DepthExpand2x'
require './RGBWeightedMSECriterion'
return mynn
end

View file

@ -4,10 +4,11 @@ local iproc = require './iproc'
local reconstruct = require './reconstruct'
local pairwise_transform = {}
local function random_half(src, p, min_size)
p = p or 0.5
local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
if p > torch.uniform() then
local function random_half(src, p)
p = p or 0.25
--local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
local filter = "Box"
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
@ -21,6 +22,48 @@ local function pcacov(x)
local ce, cv = torch.symeig(c, 'V')
return ce, cv
end
local function crop_if_large(src, max_size)
if src:size(2) > max_size and src:size(3) > max_size then
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
return image.crop(src, xi, yi, xi + max_size, yi + max_size)
else
return src
end
end
local function active_cropping(x, y, size, offset, p, tries)
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
local r = torch.uniform()
if p < r then
local xi = torch.random(offset, y:size(3) - (size + offset + 1))
local yi = torch.random(offset, y:size(2) - (size + offset + 1))
local xc = image.crop(x, xi, yi, xi + size, yi + size)
local yc = image.crop(y, xi, yi, xi + size, yi + size)
yc = yc:float():div(255)
xc = xc:float():div(255)
return xc, yc
else
local samples = {}
local sum_mse = 0
for i = 1, tries do
local xi = torch.random(offset, y:size(3) - (size + offset + 1))
local yi = torch.random(offset, y:size(2) - (size + offset + 1))
local xc = image.crop(x, xi, yi, xi + size, yi + size):float():div(255)
local yc = image.crop(y, xi, yi, xi + size, yi + size):float():div(255)
local mse = (xc - yc):pow(2):mean()
sum_mse = sum_mse + mse
table.insert(samples, {xc = xc, yc = yc, mse = mse})
end
if sum_mse > 0 then
table.sort(samples,
function (a, b)
return a.mse > b.mse
end)
end
return samples[1].xc, samples[1].yc
end
end
local function color_noise(src)
local p = 0.1
src = src:float():div(255)
@ -84,29 +127,7 @@ local function overlay_augment(src, p)
return src
end
end
local INTERPOLATION_PADDING = 16
function pairwise_transform.scale(src, scale, size, offset, options)
options = options or {color_noise = false, overlay = false, random_half = true, rgb = true}
if options.random_half then
src = random_half(src)
end
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
local down_scale = 1.0 / scale
local y = image.crop(src,
xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING,
xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING)
local filters = {
"Box", -- 0.012756949974688
"Blackman", -- 0.013191924552285
--"Cartom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266
"SincFast", -- 0.014095824314306
"Jinc", -- 0.014244299255442
}
local downscale_filter = filters[torch.random(1, #filters)]
local function data_augment(y, options)
y = flip_augment(y)
if options.color_noise then
y = color_noise(y)
@ -114,29 +135,51 @@ function pairwise_transform.scale(src, scale, size, offset, options)
if options.overlay then
y = overlay_augment(y)
end
local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
x = iproc.scale(x, y:size(3), y:size(2))
y = y:float():div(255)
x = x:float():div(255)
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
return x, y
return y
end
function pairwise_transform.jpeg_(src, quality, size, offset, options)
options = options or {color_noise = false, overlay = false, random_half = true, rgb = true}
local INTERPOLATION_PADDING = 16
function pairwise_transform.scale(src, scale, size, offset, n, options)
local filters = {
"Box","Box","Box", -- 0.012756949974688
"Blackman", -- 0.013191924552285
--"Cartom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266
"SincFast", -- 0.014095824314306
"Jinc", -- 0.014244299255442
}
if options.random_half then
src = random_half(src)
end
local yi = torch.random(0, src:size(2) - size - 1)
local xi = torch.random(0, src:size(3) - size - 1)
local downscale_filter = filters[torch.random(1, #filters)]
local y = data_augment(crop_if_large(src, math.max(size * 4, 512)), options)
local down_scale = 1.0 / scale
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
y:size(2) * down_scale, downscale_filter),
y:size(3), y:size(2))
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y,
size,
INTERPOLATION_PADDING,
options.active_cropping_rate,
options.active_cropping_tries)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
table.insert(batch, {xc, image.crop(yc, offset, offset, size - offset, size - offset)})
end
return batch
end
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
if options.random_half then
src = random_half(src)
end
src = crop_if_large(src, math.max(size * 4, 512))
local y = src
local x
@ -150,63 +193,64 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg")
x:samplingFactors({1.0, 1.0, 1.0})
if options.jpeg_sampling_factors == 444 then
x:samplingFactors({1.0, 1.0, 1.0})
else -- 422
x:samplingFactors({2.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
y = image.crop(y, xi, yi, xi + size, yi + size)
x = image.crop(x, xi, yi, xi + size, yi + size)
y = y:float():div(255)
x = x:float():div(255)
x, y = flip_augment(x, y)
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y, size, 0,
options.active_cropping_rate,
options.active_cropping_tries)
xc, yc = flip_augment(xc, yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
table.insert(batch, {xc, image.crop(yc, offset, offset, size - offset, size - offset)})
end
return x, image.crop(y, offset, offset, size - offset, size - offset)
return batch
end
function pairwise_transform.jpeg(src, category, level, size, offset, options)
function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
if category == "anime_style_art" then
if level == 1 then
if torch.uniform() > 0.7 then
if torch.uniform() > 0.8 then
return pairwise_transform.jpeg_(src, {},
size, offset,
options)
size, offset, n, options)
else
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset,
options)
size, offset, n, options)
end
elseif level == 2 then
if torch.uniform() > 0.7 then
local r = torch.uniform()
if torch.uniform() > 0.8 then
return pairwise_transform.jpeg_(src, {},
size, offset,
options)
size, offset, n, options)
else
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
size, offset,
options)
size, offset, n, options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset,
options)
size, offset, n, options)
else
local quality1 = torch.random(52, 70)
return pairwise_transform.jpeg_(src,
{quality1,
quality1 - torch.random(5, 15),
quality1 - torch.random(15, 25)},
size, offset,
options)
local quality2 = quality1 - torch.random(5, 15)
local quality3 = quality1 - torch.random(15, 25)
return pairwise_transform.jpeg_(src,
{quality1, quality2, quality3},
size, offset, n, options)
end
end
else
@ -216,23 +260,25 @@ function pairwise_transform.jpeg(src, category, level, size, offset, options)
if level == 1 then
if torch.uniform() > 0.7 then
return pairwise_transform.jpeg_(src, {},
size, offset,
size, offset, n,
options)
else
return pairwise_transform.jpeg_(src, {torch.random(80, 95)},
size, offset,
size, offset, n,
options)
end
elseif level == 2 then
if torch.uniform() > 0.7 then
return pairwise_transform.jpeg_(src, {},
size, offset,
size, offset, n,
options)
else
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset,
size, offset, n,
options)
end
else
error("unknown noise level: " .. level)
end
else
error("unknown category: " .. category)
@ -242,6 +288,7 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
if options.random_half then
src = random_half(src)
end
src = crop_if_large(src, math.max(size * 4, 512))
local down_scale = 1.0 / scale
local filters = {
"Box", -- 0.012756949974688
@ -270,7 +317,11 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg")
x:samplingFactors({1.0, 1.0, 1.0})
if options.jpeg_sampling_factors == 444 then
x:samplingFactors({1.0, 1.0, 1.0})
else -- 422
x:samplingFactors({2.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
@ -321,10 +372,11 @@ function pairwise_transform.jpeg_scale(src, scale, category, level, size, offset
size, offset, options)
else
local quality1 = torch.random(52, 70)
local quality2 = quality1 - torch.random(5, 15)
local quality3 = quality1 - torch.random(15, 25)
return pairwise_transform.jpeg_scale_(src, scale,
{quality1,
quality1 - torch.random(5, 15),
quality1 - torch.random(15, 25)},
{quality1, quality2, quality3 },
size, offset, options)
end
end
@ -354,14 +406,13 @@ end
local function test_jpeg()
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, {})
image.display({image = y, legend = "y:0"})
image.display({image = x, legend = "x:0"})
for i = 2, 9 do
local y, x = pairwise_transform.jpeg_(random_half(src),
{i * 10}, 128, 0, {color_noise = false, random_half = true, overlay = true, rgb = true})
image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
local xy = pairwise_transform.jpeg_(random_half(src),
{i * 10}, 128, 0, 2, {color_noise = false, random_half = true, overlay = true, rgb = true})
for i = 1, #xy do
image.display({image = xy[i][1], legend = "y:" .. (i * 10), max=1,min=0})
image.display({image = xy[i][2], legend = "x:" .. (i * 10),max=1,min=0})
end
--print(x:mean(), y:mean())
end
end
@ -370,27 +421,40 @@ local function test_scale()
torch.setdefaulttensortype('torch.FloatTensor')
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
local options = {color_noise = true,
random_half = true,
overlay = false,
active_cropping_rate = 1.5,
active_cropping_tries = 10,
rgb = true
}
for i = 1, 9 do
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_noise = true, random_half = true, rgb = true, overlay = true})
image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
print(xy[1][1]:size(), xy[1][2]:size())
--print(x:mean(), y:mean())
end
end
local function test_jpeg_scale()
torch.setdefaulttensortype('torch.FloatTensor')
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
local options = {color_noise = true,
random_half = true,
overlay = true,
active_cropping_ratio = 0.5,
active_cropping_times = 10
}
for i = 1, 9 do
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_noise = true, random_half = true, overlay = true})
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, options)
image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())
--print(x:mean(), y:mean())
end
for i = 1, 9 do
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_noise = true, random_half = true, overlay = true})
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, options)
image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())

View file

@ -1,9 +1,13 @@
local function load_cuda()
require 'nn'
require 'cunn'
end
local function load_cudnn()
require 'cudnn'
--cudnn.fastest = true
end
if pcall(load_cuda) then
require 'cunn'
else
--[[ TODO: fakecuda does not work.
@ -13,3 +17,5 @@ else
require('fakecuda').init(true)
--]]
end
if pcall(load_cudnn) then
end

View file

@ -48,7 +48,8 @@ local function reconstruct_rgb(model, x, offset, block_size)
end
return new_x
end
function model_is_rgb(model)
local reconstruct = {}
function reconstruct.is_rgb(model)
if model:get(model:size() - 1).weight:size(1) == 3 then
-- 3ch RGB
return true
@ -57,8 +58,23 @@ function model_is_rgb(model)
return false
end
end
local reconstruct = {}
function reconstruct.offset_size(model)
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv > 0 then
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
else
conv = model:findModules("cudnn.SpatialConvolution")
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
end
end
function reconstruct.image_y(model, x, offset, block_size)
block_size = block_size or 128
local output_size = block_size - offset * 2
@ -172,18 +188,22 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
return output
end
function reconstruct.image(model, x, offset, block_size)
if model_is_rgb(model) then
return reconstruct.image_rgb(model, x, offset, block_size)
function reconstruct.image(model, x, block_size)
if reconstruct.is_rgb(model) then
return reconstruct.image_rgb(model, x,
reconstruct.offset_size(model), block_size)
else
return reconstruct.image_y(model, x, offset, block_size)
return reconstruct.image_y(model, x,
reconstruct.offset_size(model), block_size)
end
end
function reconstruct.scale(model, scale, x, offset, block_size)
if model_is_rgb(model) then
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
function reconstruct.scale(model, scale, x, block_size)
if reconstruct.is_rgb(model) then
return reconstruct.scale_rgb(model, scale, x,
reconstruct.offset_size(model), block_size)
else
return reconstruct.scale_y(model, scale, x, offset, block_size)
return reconstruct.scale_y(model, scale, x,
reconstruct.offset_size(model), block_size)
end
end

View file

@ -1,5 +1,6 @@
require 'xlua'
require 'pl'
require 'trepl'
-- global settings
@ -18,6 +19,7 @@ cmd:text("waifu2x")
cmd:text("Options:")
cmd:option("-seed", 11, 'fixed input seed')
cmd:option("-data_dir", "./data", 'data directory')
-- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slow than cunn
cmd:option("-test", "images/miku_small.png", 'test image file')
cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
@ -30,9 +32,15 @@ cmd:option("-scale", 2.0, 'scale')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
cmd:option("-crop_size", 128, 'crop size')
cmd:option("-max_size", -1, 'crop if image size larger then this value.')
cmd:option("-batch_size", 2, 'mini batch size')
cmd:option("-epoch", 200, 'epoch')
cmd:option("-core", 2, 'cpu core')
cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
cmd:option("-validation_ratio", 0.1, 'validation ratio')
cmd:option("-validation_crops", 40, 'number of crop region in validation')
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
cmd:option("-active_cropping_tries", 20, 'active cropping tries')
local opt = cmd:parse(arg)
for k, v in pairs(opt) do
@ -81,16 +89,6 @@ torch.setnumthreads(settings.core)
settings.images = string.format("%s/images.t7", settings.data_dir)
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
settings.validation_ratio = 0.1
settings.validation_crops = 30
local srcnn = require './srcnn'
if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
settings.create_model = srcnn.waifu4x
settings.block_offset = 13
else
settings.create_model = srcnn.waifu2x
settings.block_offset = 7
end
settings.backend = "cunn"
return settings

View file

@ -1,77 +1,66 @@
require './LeakyReLU'
require './mynn'
-- ref: http://arxiv.org/abs/1502.01852
function nn.SpatialConvolutionMM:reset(stdv)
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
self.weight:normal(0, stdv)
self.bias:zero()
end
-- ref: http://arxiv.org/abs/1501.00092
local srcnn = {}
function srcnn.waifu2x(color)
function srcnn.waifu2x_cunn(ch)
local model = nn.Sequential()
local ch = nil
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
if color then
error("unknown color: " .. color)
else
error("unknown color: nil")
end
end
-- very deep model
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(mynn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(mynn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(mynn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(mynn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(mynn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(mynn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
--model:cuda()
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model, 7
return model
end
-- current 4x is worse then 2x * 2
function srcnn.waifu4x(color)
function srcnn.waifu2x_cudnn(ch)
local model = nn.Sequential()
local ch = nil
model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(mynn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
model:add(mynn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
model:add(mynn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
model:add(mynn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
model:add(mynn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
model:add(mynn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
function srcnn.create(model_name, backend, color)
local ch = 3
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
error("unknown color: " .. color)
error("unsupported color: " + color)
end
if backend == "cunn" then
return srcnn.waifu2x_cunn(ch)
elseif backend == "cudnn" then
return srcnn.waifu2x_cudnn(ch)
else
error("unsupported backend: " + backend)
end
model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
return model, 13
end
return srcnn

159
train.lua
View file

@ -1,9 +1,12 @@
require './lib/portable'
require './lib/mynn'
require 'optim'
require 'xlua'
require 'pl'
require 'snappy'
local settings = require './lib/settings'
local srcnn = require './lib/srcnn'
local minibatch_adam = require './lib/minibatch_adam'
local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
@ -11,11 +14,11 @@ local pairwise_transform = require './lib/pairwise_transform'
local image_loader = require './lib/image_loader'
local function save_test_scale(model, rgb, file)
local up = reconstruct.scale(model, settings.scale, rgb, settings.block_offset)
local up = reconstruct.scale(model, settings.scale, rgb)
image.save(file, up)
end
local function save_test_jpeg(model, rgb, file)
local im, count = reconstruct.image(model, rgb, settings.block_offset)
local im, count = reconstruct.image(model, rgb)
image.save(file, im)
end
local function split_data(x, test_size)
@ -35,10 +38,14 @@ local function make_validation_set(x, transformer, n)
n = n or 4
local data = {}
for i = 1, #x do
for k = 1, n do
local x, y = transformer(x[i], true)
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
for k = 1, math.max(n / 8, 1) do
local xy = transformer(x[i], true, 8)
for j = 1, #xy do
local x = xy[j][1]
local y = xy[j][2]
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
end
end
xlua.progress(i, #x)
collectgarbage()
@ -58,15 +65,96 @@ local function validate(model, criterion, data)
return loss / #data
end
local function create_criterion(model)
if reconstruct.is_rgb(model) then
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.299 * 3) -- R
weight[2]:fill(0.587 * 3) -- G
weight[3]:fill(0.114 * 3) -- B
return mynn.RGBWeightedMSECriterion(weight):cuda()
else
return nn.MSECriterion():cuda()
end
end
local function transformer(x, is_validation, n, offset)
local size = x[1]
local dec = snappy.decompress(x[2]:string())
x = torch.ByteTensor(size[1], size[2], size[3])
x:storage():string(dec)
n = n or settings.batch_size;
if is_validation == nil then is_validation = false end
local color_noise = nil
local overlay = nil
local active_cropping_ratio = nil
local active_cropping_tries = nil
if is_validation then
active_cropping_rate = 0.0
active_cropping_tries = 0
color_noise = false
overlay = false
else
active_cropping_rate = settings.active_cropping_rate
active_cropping_tries = settings.active_cropping_tries
color_noise = settings.color_noise
overlay = settings.overlay
end
if settings.method == "scale" then
return pairwise_transform.scale(x,
settings.scale,
settings.crop_size, offset,
n,
{ color_noise = color_noise,
overlay = overlay,
random_half = settings.random_half,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise" then
return pairwise_transform.jpeg(x,
settings.category,
settings.noise_level,
settings.crop_size, offset,
n,
{ color_noise = color_noise,
overlay = overlay,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
random_half = settings.random_half,
jpeg_sampling_factors = settings.jpeg_sampling_factors,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise_scale" then
return pairwise_transform.jpeg_scale(x,
settings.scale,
settings.category,
settings.noise_level,
settings.crop_size, offset,
n,
{ color_noise = color_noise,
overlay = overlay,
jpeg_sampling_factors = settings.jpeg_sampling_factors,
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
end
end
local function train()
local model, offset = settings.create_model(settings.color)
assert(offset == settings.block_offset)
local criterion = nn.MSECriterion():cuda()
local model = srcnn.create(settings.method, settings.backend, settings.color)
local offset = reconstruct.offset_size(model)
local pairwise_func = function(x, is_validation, n)
return transformer(x, is_validation, n, offset)
end
local criterion = create_criterion(model)
local x = torch.load(settings.images)
local lrd_count = 0
local train_x, valid_x = split_data(x,
math.floor(settings.validation_ratio * #x))
local test = image_loader.load_float(settings.test)
local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x))
local adam_config = {
learningRate = settings.learning_rate,
xBatchSize = settings.batch_size,
@ -77,45 +165,9 @@ local function train()
elseif settings.color == "rgb" then
ch = 3
end
local transformer = function(x, is_validation)
if is_validation == nil then is_validation = false end
local color_noise = (not is_validation) and settings.color_noise
local overlay = (not is_validation) and settings.overlay
if settings.method == "scale" then
return pairwise_transform.scale(x,
settings.scale,
settings.crop_size, offset,
{ color_noise = color_noise,
overlay = overlay,
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise" then
return pairwise_transform.jpeg(x,
settings.category,
settings.noise_level,
settings.crop_size, offset,
{ color_noise = color_noise,
overlay = overlay,
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise_scale" then
return pairwise_transform.jpeg_scale(x,
settings.scale,
settings.category,
settings.noise_level,
settings.crop_size, offset,
{ color_noise = color_noise,
overlay = overlay,
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
end
end
local best_score = 100000.0
print("# make validation-set")
local valid_xy = make_validation_set(valid_x, transformer, settings.validation_crops)
local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
valid_x = nil
collectgarbage()
@ -125,7 +177,7 @@ local function train()
model:training()
print("# " .. epoch)
print(minibatch_adam(model, criterion, train_x, adam_config,
transformer,
pairwise_func,
{ch, settings.crop_size, settings.crop_size},
{ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
))
@ -133,6 +185,7 @@ local function train()
print("# validation")
local score = validate(model, criterion, valid_xy)
if score < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0
best_score = score
print("* update best model")
@ -140,16 +193,16 @@ local function train()
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.png"):format(settings.noise_level))
save_test_jpeg(model, test, log)
save_test_jpeg(model, test_image, log)
elseif settings.method == "scale" then
local log = path.join(settings.model_dir,
("scale%.1f_best.png"):format(settings.scale))
save_test_scale(model, test, log)
save_test_scale(model, test_image, log)
elseif settings.method == "noise_scale" then
local log = path.join(settings.model_dir,
("noise%d_scale%.1f_best.png"):format(settings.noise_level,
settings.scale))
save_test_scale(model, test, log)
save_test_scale(model, test_image, log)
end
else
lrd_count = lrd_count + 1

View file

@ -1,10 +1,13 @@
#!/bin/sh
th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
th convert_data.lua
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 1 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 2 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method scale -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_small_noisy.jpg -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_ratio 0.1 -validation_crops 80
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii

11
train_ukbench.sh Executable file
View file

@ -0,0 +1,11 @@
#!/bin/sh
th train.lua -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 1 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
th cleanup_model.lua -model models/ukbench2/noise1_model.t7 -oformat ascii
th train.lua -core 1 -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 2 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
th cleanup_model.lua -model models/ukbench2/noise2_model.t7 -oformat ascii
th train.lua -category photo -color rgb -random_half 0 -epoch 400 -batch_size 1 -method scale -scale 2 -model_dir models/ukbench2 -data_dir ukbench -test photo2-noise.png
th cleanup_model.lua -model models/ukbench2/scale2.0x_model.t7 -oformat ascii

View file

@ -1,12 +1,11 @@
require './lib/portable'
require 'sys'
require 'pl'
require './lib/LeakyReLU'
require './lib/mynn'
local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader'
local BLOCK_OFFSET = 7
torch.setdefaulttensortype('torch.FloatTensor')
@ -22,19 +21,21 @@ local function convert_image(opt)
end
if opt.m == "noise" then
local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
--local srcnn = require 'lib/srcnn'
--local model = srcnn.waifu2x("rgb"):cuda()
model:evaluate()
new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size)
new_x = reconstruct.image(model, x, opt.crop_size)
elseif opt.m == "scale" then
local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
model:evaluate()
new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" then
local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
noise_model:evaluate()
scale_model:evaluate()
x = reconstruct.image(noise_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
x = reconstruct.image(noise_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
@ -62,17 +63,17 @@ local function convert_frames(opt)
local x, alpha = image_loader.load_float(lines[i])
local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
new_x = reconstruct.image(noise1_model, x, opt.crop_size)
elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
new_x = reconstruct.image(noise2_model, x)
elseif opt.m == "scale" then
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
x = reconstruct.image(noise1_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
x = reconstruct.image(noise2_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
else
error("undefined method:" .. opt.method)
end

13
web.lua
View file

@ -4,8 +4,8 @@ local uuid = require 'uuid'
local ffi = require 'ffi'
local md5 = require 'md5'
require 'pl'
require './lib/portable'
require './lib/LeakyReLU'
require 'lib.portable'
require 'lib.mynn'
local cmd = torch.CmdLine()
cmd:text()
@ -23,7 +23,7 @@ local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader'
local MODEL_DIR = "./models/anime_style_art_rgb"
local MODEL_DIR = "./models/anime_style_art_rgb3"
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
@ -40,7 +40,6 @@ local CURL_OPTIONS = {
max_redirects = 2
}
local CURL_MAX_SIZE = 2 * 1024 * 1024
local BLOCK_OFFSET = 7 -- see srcnn.lua
local function valid_size(x, scale)
if scale == 0 then
@ -80,13 +79,13 @@ local function get_image(req)
end
local function apply_denoise1(x)
return reconstruct.image(noise1_model, x, BLOCK_OFFSET)
return reconstruct.image(noise1_model, x)
end
local function apply_denoise2(x)
return reconstruct.image(noise2_model, x, BLOCK_OFFSET)
return reconstruct.image(noise2_model, x)
end
local function apply_scale2x(x)
return reconstruct.scale(scale20_model, 2.0, x, BLOCK_OFFSET)
return reconstruct.scale(scale20_model, 2.0, x)
end
local function cache_do(cache, x, func)
if path.exists(cache) then