sync from internal repo
- Memory compression by snappy (lua-csnappy) - Use RGB-wise Weighted MSE(R*0.299, G*0.587, B*0.114) instead of MSE - Aggressive cropping for edge region and some change.
4
.gitignore
vendored
|
@ -1,4 +1,8 @@
|
|||
*~
|
||||
/*.png
|
||||
/*.mp4
|
||||
/*.jpg
|
||||
cache/*.png
|
||||
models/*.png
|
||||
models/*/*.png
|
||||
waifu2x.log
|
||||
|
|
280
benchmark.lua
Normal file
|
@ -0,0 +1,280 @@
|
|||
require './lib/portable'
|
||||
require './lib/mynn'
|
||||
require 'xlua'
|
||||
require 'pl'
|
||||
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
local gm = require 'graphicsmagick'
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("waifu2x-benchmark")
|
||||
cmd:text("Options:")
|
||||
|
||||
cmd:option("-seed", 11, 'fixed input seed')
|
||||
cmd:option("-test_dir", "./test", 'test image directory')
|
||||
cmd:option("-jpeg_quality", 50, 'jpeg quality')
|
||||
cmd:option("-jpeg_times", 3, 'number of jpeg compression ')
|
||||
cmd:option("-jpeg_quality_down", 5, 'reducing jpeg quality each times')
|
||||
cmd:option("-core", 4, 'threads')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
torch.setnumthreads(opt.core)
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
local function MSE(x1, x2)
|
||||
return (x1 - x2):pow(2):mean()
|
||||
end
|
||||
local function YMSE(x1, x2)
|
||||
local x1_2 = x1:clone()
|
||||
local x2_2 = x2:clone()
|
||||
|
||||
x1_2[1]:mul(0.299 * 3)
|
||||
x1_2[2]:mul(0.587 * 3)
|
||||
x1_2[3]:mul(0.114 * 3)
|
||||
|
||||
x2_2[1]:mul(0.299 * 3)
|
||||
x2_2[2]:mul(0.587 * 3)
|
||||
x2_2[3]:mul(0.114 * 3)
|
||||
|
||||
return (x1_2 - x2_2):pow(2):mean()
|
||||
end
|
||||
local function PSNR(x1, x2)
|
||||
local mse = MSE(x1, x2)
|
||||
return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10))
|
||||
end
|
||||
local function YPSNR(x1, x2)
|
||||
local mse = YMSE(x1, x2)
|
||||
return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10))
|
||||
end
|
||||
|
||||
local function transform_jpeg(x)
|
||||
for i = 1, opt.jpeg_times do
|
||||
jpeg = gm.Image(x, "RGB", "DHW")
|
||||
jpeg:format("jpeg")
|
||||
jpeg:samplingFactors({1.0, 1.0, 1.0})
|
||||
blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
|
||||
jpeg:fromBlob(blob, len)
|
||||
x = jpeg:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
return x
|
||||
end
|
||||
|
||||
local function noise_benchmark(x, v1_noise, v2_noise)
|
||||
local v1_mse = 0
|
||||
local v2_mse = 0
|
||||
local jpeg_mse = 0
|
||||
local v1_psnr = 0
|
||||
local v2_psnr = 0
|
||||
local jpeg_psnr = 0
|
||||
local v1_time = 0
|
||||
local v2_time = 0
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i]
|
||||
local jpg, blob, len, input, v1_out, v2_out, t, mse
|
||||
|
||||
input = transform_jpeg(ground_truth)
|
||||
input = input:float():div(255)
|
||||
ground_truth = ground_truth:float():div(255)
|
||||
|
||||
jpeg_mse = jpeg_mse + MSE(ground_truth, input)
|
||||
jpeg_psnr = jpeg_psnr + PSNR(ground_truth, input)
|
||||
|
||||
t = sys.clock()
|
||||
v1_output = reconstruct.image(v1_noise, input)
|
||||
v1_time = v1_time + (sys.clock() - t)
|
||||
v1_mse = v1_mse + MSE(ground_truth, v1_output)
|
||||
v1_psnr = v1_psnr + PSNR(ground_truth, v1_output)
|
||||
|
||||
t = sys.clock()
|
||||
v2_output = reconstruct.image(v2_noise, input)
|
||||
v2_time = v2_time + (sys.clock() - t)
|
||||
v2_mse = v2_mse + MSE(ground_truth, v2_output)
|
||||
v2_psnr = v2_psnr + PSNR(ground_truth, v2_output)
|
||||
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; v1_time=%f, v2_time=%f, jpeg_mse=%f, v1_mse=%f, v2_mse=%f, jpeg_psnr=%f, v1_psnr=%f, v2_psnr=%f \r",
|
||||
i, #x,
|
||||
v1_time / i, v2_time / i,
|
||||
jpeg_mse / i,
|
||||
v1_mse / i, v2_mse / i,
|
||||
jpeg_psnr / i,
|
||||
v1_psnr / i, v2_psnr / i
|
||||
)
|
||||
)
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
local function noise_scale_benchmark(x, params, v1_noise, v1_scale, v2_noise, v2_scale)
|
||||
local v1_mse = 0
|
||||
local v2_mse = 0
|
||||
local jinc_mse = 0
|
||||
local v1_time = 0
|
||||
local v2_time = 0
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i]
|
||||
local downscale = iproc.scale(ground_truth,
|
||||
ground_truth:size(3) * 0.5,
|
||||
ground_truth:size(2) * 0.5,
|
||||
params[i].filter)
|
||||
local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
|
||||
|
||||
jpeg = gm.Image(downscale, "RGB", "DHW")
|
||||
jpeg:format("jpeg")
|
||||
blob, len = jpeg:toBlob(params[i].quality)
|
||||
jpeg:fromBlob(blob, len)
|
||||
input = jpeg:toTensor("byte", "RGB", "DHW")
|
||||
|
||||
input = input:float():div(255)
|
||||
ground_truth = ground_truth:float():div(255)
|
||||
|
||||
jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
|
||||
jinc_mse = jinc_mse + (ground_truth - jinc_output):pow(2):mean()
|
||||
|
||||
t = sys.clock()
|
||||
v1_output = reconstruct.image(v1_noise, input)
|
||||
v1_output = reconstruct.scale(v1_scale, 2.0, v1_output)
|
||||
v1_time = v1_time + (sys.clock() - t)
|
||||
mse = (ground_truth - v1_output):pow(2):mean()
|
||||
v1_mse = v1_mse + mse
|
||||
|
||||
t = sys.clock()
|
||||
v2_output = reconstruct.image(v2_noise, input)
|
||||
v2_output = reconstruct.scale(v2_scale, 2.0, v2_output)
|
||||
v2_time = v2_time + (sys.clock() - t)
|
||||
mse = (ground_truth - v2_output):pow(2):mean()
|
||||
v2_mse = v2_mse + mse
|
||||
|
||||
io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
|
||||
i, #x,
|
||||
v1_time / i, v2_time / i,
|
||||
(v1_time / i) / (v2_time / i),
|
||||
jinc_mse / i,
|
||||
v1_mse / i, (v1_mse/i) / (jinc_mse/i),
|
||||
v2_mse / i, (v2_mse/i) / (jinc_mse/i),
|
||||
(v1_mse / i) / (v2_mse / i)))
|
||||
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
local function scale_benchmark(x, params, v1_scale, v2_scale)
|
||||
local v1_mse = 0
|
||||
local v2_mse = 0
|
||||
local jinc_mse = 0
|
||||
local v1_psnr = 0
|
||||
local v2_psnr = 0
|
||||
local jinc_psnr = 0
|
||||
|
||||
local v1_time = 0
|
||||
local v2_time = 0
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i]
|
||||
local downscale = iproc.scale(ground_truth,
|
||||
ground_truth:size(3) * 0.5,
|
||||
ground_truth:size(2) * 0.5,
|
||||
params[i].filter)
|
||||
local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
|
||||
input = downscale
|
||||
|
||||
input = input:float():div(255)
|
||||
ground_truth = ground_truth:float():div(255)
|
||||
|
||||
jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
|
||||
mse = (ground_truth - jinc_output):pow(2):mean()
|
||||
jinc_mse = jinc_mse + mse
|
||||
jinc_psnr = jinc_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
|
||||
|
||||
t = sys.clock()
|
||||
v1_output = reconstruct.scale(v1_scale, 2.0, input)
|
||||
v1_time = v1_time + (sys.clock() - t)
|
||||
mse = (ground_truth - v1_output):pow(2):mean()
|
||||
v1_mse = v1_mse + mse
|
||||
v1_psnr = v1_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
|
||||
|
||||
t = sys.clock()
|
||||
v2_output = reconstruct.scale(v2_scale, 2.0, input)
|
||||
v2_time = v2_time + (sys.clock() - t)
|
||||
mse = (ground_truth - v2_output):pow(2):mean()
|
||||
v2_mse = v2_mse + mse
|
||||
v2_psnr = v2_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
|
||||
|
||||
io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
|
||||
i, #x,
|
||||
v1_time / i, v2_time / i,
|
||||
(v1_time / i) / (v2_time / i),
|
||||
jinc_psnr / i,
|
||||
v1_psnr / i, (v1_psnr/i) / (jinc_psnr/i),
|
||||
v2_psnr / i, (v2_psnr/i) / (jinc_psnr/i),
|
||||
(v1_psnr / i) / (v2_psnr / i)))
|
||||
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
|
||||
local function split_data(x, test_size)
|
||||
local index = torch.randperm(#x)
|
||||
local train_size = #x - test_size
|
||||
local train_x = {}
|
||||
local valid_x = {}
|
||||
for i = 1, train_size do
|
||||
train_x[i] = x[index[i]]
|
||||
end
|
||||
for i = 1, test_size do
|
||||
valid_x[i] = x[index[train_size + i]]
|
||||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
local function crop_4x(x)
|
||||
local w = x:size(3) % 4
|
||||
local h = x:size(2) % 4
|
||||
return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
|
||||
end
|
||||
local function load_data(valid_dir)
|
||||
local valid_x = {}
|
||||
local files = dir.getfiles(valid_dir, "*.png")
|
||||
for i = 1, #files do
|
||||
table.insert(valid_x, crop_4x(image_loader.load_byte(files[i])))
|
||||
xlua.progress(i, #files)
|
||||
end
|
||||
return valid_x
|
||||
end
|
||||
|
||||
local function noise_main(valid_dir, level)
|
||||
local v1_noise = torch.load(path.join(V1_DIR, string.format("noise%d_model.t7", level)), "ascii")
|
||||
local v2_noise = torch.load(path.join(V2_DIR, string.format("noise%d_model.t7", level)), "ascii")
|
||||
local valid_x = load_data(valid_dir)
|
||||
noise_benchmark(valid_x, v1_noise, v2_noise)
|
||||
end
|
||||
local function scale_main(valid_dir)
|
||||
local v1 = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local v2 = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local valid_x = load_data(valid_dir)
|
||||
local params = random_params(valid_x, 2)
|
||||
scale_benchmark(valid_x, params, v1, v2)
|
||||
end
|
||||
local function noise_scale_main(valid_dir)
|
||||
local v1_noise = torch.load(path.join(V1_DIR, "noise2_model.t7"), "ascii")
|
||||
local v1_scale = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local v2_noise = torch.load(path.join(V2_DIR, "noise2_model.t7"), "ascii")
|
||||
local v2_scale = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local valid_x = load_data(valid_dir)
|
||||
local params = random_params(valid_x, 2)
|
||||
noise_scale_benchmark(valid_x, params, v1_noise, v1_scale, v2_noise, v2_scale)
|
||||
end
|
||||
|
||||
V1_DIR = "models/anime_style_art_rgb"
|
||||
V2_DIR = "models/anime_style_art_rgb5"
|
||||
|
||||
torch.manualSeed(opt.seed)
|
||||
cutorch.manualSeed(opt.seed)
|
||||
noise_main("./test", 2)
|
||||
--scale_main("./test")
|
||||
--noise_scale_main("./test")
|
|
@ -1,5 +1,5 @@
|
|||
require './lib/portable'
|
||||
require './lib/LeakyReLU'
|
||||
require './lib/mynn'
|
||||
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
local ffi = require 'ffi'
|
||||
require './lib/portable'
|
||||
require 'image'
|
||||
require 'snappy'
|
||||
local settings = require './lib/settings'
|
||||
local image_loader = require './lib/image_loader'
|
||||
|
||||
local MAX_SIZE = 1440
|
||||
|
||||
local function count_lines(file)
|
||||
local fp = io.open(file, "r")
|
||||
local count = 0
|
||||
|
@ -13,7 +17,17 @@ local function count_lines(file)
|
|||
|
||||
return count
|
||||
end
|
||||
|
||||
local function crop_if_large(src, max_size)
|
||||
if max_size > 0 and (src:size(2) > max_size or src:size(3) > max_size) then
|
||||
local sx = torch.random(0, src:size(3) - math.min(max_size, src:size(3)))
|
||||
local sy = torch.random(0, src:size(2) - math.min(max_size, src:size(2)))
|
||||
return image.crop(src, sx, sy,
|
||||
math.min(sx + max_size, src:size(3)),
|
||||
math.min(sy + max_size, src:size(2)))
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
local function crop_4x(x)
|
||||
local w = x:size(3) % 4
|
||||
local h = x:size(2) % 4
|
||||
|
@ -27,13 +41,26 @@ local function load_images(list)
|
|||
local x = {}
|
||||
local c = 0
|
||||
for line in fp:lines() do
|
||||
local im = crop_4x(image_loader.load_byte(line))
|
||||
if im then
|
||||
if im:size(2) > (settings.crop_size * 2 + MARGIN) and im:size(3) > (settings.crop_size * 2 + MARGIN) then
|
||||
table.insert(x, im)
|
||||
end
|
||||
local im, alpha = image_loader.load_byte(line)
|
||||
im = crop_if_large(im, settings.max_size)
|
||||
im = crop_4x(im)
|
||||
|
||||
if alpha then
|
||||
io.stderr:write(string.format("%s: skip: reason: alpha channel.", line))
|
||||
else
|
||||
print("error:" .. line)
|
||||
local scale = 1.0
|
||||
if settings.random_half then
|
||||
scale = 2.0
|
||||
end
|
||||
if im then
|
||||
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
|
||||
table.insert(x, {im:size(), torch.ByteStorage():string(snappy.compress(im:storage():string()))})
|
||||
else
|
||||
io.stderr:write(string.format("%s: skip: reason: too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
|
||||
end
|
||||
else
|
||||
io.stderr:write(string.format("%s: skip: reason: load error.\n", line))
|
||||
end
|
||||
end
|
||||
c = c + 1
|
||||
xlua.progress(c, count)
|
||||
|
@ -43,7 +70,7 @@ local function load_images(list)
|
|||
end
|
||||
return x
|
||||
end
|
||||
torch.manualSeed(settings.seed)
|
||||
print(settings)
|
||||
local x = load_images(settings.image_list)
|
||||
torch.save(settings.images, x)
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
#!/bin/sh
|
||||
|
||||
th waifu2x.lua -noise_level 1 -m noise_scale -i images/miku_small.png -o images/miku_small_waifu2x.png
|
||||
th waifu2x.lua -m scale -i images/miku_small.png -o images/miku_small_waifu2x.png
|
||||
th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_small_noisy.png -o images/miku_small_noisy_waifu2x.png
|
||||
th waifu2x.lua -noise_level 2 -m noise -i images/miku_noisy.png -o images/miku_noisy_waifu2x.png
|
||||
th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_CC_BY-NC_noisy.jpg -o images/miku_CC_BY-NC_noisy_waifu2x.png
|
||||
th waifu2x.lua -noise_level 2 -m noise -i images/lena.png -o images/lena_waifu2x.png
|
||||
th waifu2x.lua -m scale -model_dir models/ukbench -i images/lena.png -o images/lena_waifu2x_ukbench.png
|
||||
|
|
Before Width: | Height: | Size: 315 KiB After Width: | Height: | Size: 355 KiB |
Before Width: | Height: | Size: 605 KiB After Width: | Height: | Size: 651 KiB |
Before Width: | Height: | Size: 154 KiB After Width: | Height: | Size: 156 KiB |
Before Width: | Height: | Size: 53 KiB After Width: | Height: | Size: 45 KiB |
Before Width: | Height: | Size: 177 KiB After Width: | Height: | Size: 156 KiB |
Before Width: | Height: | Size: 138 KiB After Width: | Height: | Size: 140 KiB |
Before Width: | Height: | Size: 136 KiB After Width: | Height: | Size: 140 KiB |
BIN
images/slide.odp
BIN
images/slide.png
Before Width: | Height: | Size: 1.2 MiB After Width: | Height: | Size: 1.2 MiB |
Before Width: | Height: | Size: 499 KiB After Width: | Height: | Size: 498 KiB |
Before Width: | Height: | Size: 380 KiB After Width: | Height: | Size: 377 KiB |
Before Width: | Height: | Size: 378 KiB After Width: | Height: | Size: 356 KiB |
75
lib/DepthExpand2x.lua
Normal file
|
@ -0,0 +1,75 @@
|
|||
if mynn.DepthExpand2x then
|
||||
return mynn.DepthExpand2x
|
||||
end
|
||||
local DepthExpand2x, parent = torch.class('mynn.DepthExpand2x','nn.Module')
|
||||
|
||||
function DepthExpand2x:__init()
|
||||
parent:__init()
|
||||
end
|
||||
|
||||
function DepthExpand2x:updateOutput(input)
|
||||
local x = input
|
||||
-- (batch_size, depth, height, width)
|
||||
self.shape = x:size()
|
||||
|
||||
assert(self.shape:size() == 4, "input must be 4d tensor")
|
||||
assert(self.shape[2] % 4 == 0, "depth must be depth % 4 = 0")
|
||||
-- (batch_size, width, height, depth)
|
||||
x = x:transpose(2, 4)
|
||||
-- (batch_size, width, height * 2, depth / 2)
|
||||
x = x:reshape(self.shape[1], self.shape[4], self.shape[3] * 2, self.shape[2] / 2)
|
||||
-- (batch_size, height * 2, width, depth / 2)
|
||||
x = x:transpose(2, 3)
|
||||
-- (batch_size, height * 2, width * 2, depth / 4)
|
||||
x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4] * 2, self.shape[2] / 4)
|
||||
-- (batch_size, depth / 4, height * 2, width * 2)
|
||||
x = x:transpose(2, 4)
|
||||
x = x:transpose(3, 4)
|
||||
self.output:resizeAs(x):copy(x) -- contiguous
|
||||
|
||||
return self.output
|
||||
end
|
||||
|
||||
function DepthExpand2x:updateGradInput(input, gradOutput)
|
||||
-- (batch_size, depth / 4, height * 2, width * 2)
|
||||
local x = gradOutput
|
||||
-- (batch_size, height * 2, width * 2, depth / 4)
|
||||
x = x:transpose(2, 4)
|
||||
x = x:transpose(2, 3)
|
||||
-- (batch_size, height * 2, width, depth / 2)
|
||||
x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4], self.shape[2] / 2)
|
||||
-- (batch_size, width, height * 2, depth / 2)
|
||||
x = x:transpose(2, 3)
|
||||
-- (batch_size, width, height, depth)
|
||||
x = x:reshape(self.shape[1], self.shape[4], self.shape[3], self.shape[2])
|
||||
-- (batch_size, depth, height, width)
|
||||
x = x:transpose(2, 4)
|
||||
|
||||
self.gradInput:resizeAs(x):copy(x)
|
||||
|
||||
return self.gradInput
|
||||
end
|
||||
|
||||
function DepthExpand2x.test()
|
||||
require 'image'
|
||||
local function show(x)
|
||||
local img = torch.Tensor(3, x:size(3), x:size(4))
|
||||
img[1]:copy(x[1][1])
|
||||
img[2]:copy(x[1][2])
|
||||
img[3]:copy(x[1][3])
|
||||
image.display(img)
|
||||
end
|
||||
local img = image.lena()
|
||||
local x = torch.Tensor(1, img:size(1) * 4, img:size(2), img:size(3))
|
||||
for i = 0, img:size(1) * 4 - 1 do
|
||||
src_index = ((i % 3) + 1)
|
||||
x[1][i + 1]:copy(img[src_index])
|
||||
end
|
||||
show(x)
|
||||
|
||||
local de2x = mynn.DepthExpand2x()
|
||||
out = de2x:forward(x)
|
||||
show(out)
|
||||
out = de2x:updateGradInput(x, out)
|
||||
show(out)
|
||||
end
|
|
@ -1,7 +1,8 @@
|
|||
if nn.LeakyReLU then
|
||||
return
|
||||
if mynn.LeakyReLU then
|
||||
return mynn.LeakyReLU
|
||||
end
|
||||
local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
|
||||
|
||||
local LeakyReLU, parent = torch.class('mynn.LeakyReLU','nn.Module')
|
||||
|
||||
function LeakyReLU:__init(negative_scale)
|
||||
parent.__init(self)
|
||||
|
|
31
lib/LeakyReLU_deprecated.lua
Normal file
|
@ -0,0 +1,31 @@
|
|||
if nn.LeakyReLU then
|
||||
return nn.LeakyReLU
|
||||
end
|
||||
|
||||
local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
|
||||
|
||||
function LeakyReLU:__init(negative_scale)
|
||||
parent.__init(self)
|
||||
self.negative_scale = negative_scale or 0.333
|
||||
self.negative = torch.Tensor()
|
||||
end
|
||||
|
||||
function LeakyReLU:updateOutput(input)
|
||||
self.output:resizeAs(input):copy(input):abs():add(input):div(2)
|
||||
self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale)
|
||||
self.output:add(self.negative)
|
||||
|
||||
return self.output
|
||||
end
|
||||
|
||||
function LeakyReLU:updateGradInput(input, gradOutput)
|
||||
self.gradInput:resizeAs(gradOutput)
|
||||
-- filter positive
|
||||
self.negative:sign():add(1)
|
||||
torch.cmul(self.gradInput, gradOutput, self.negative)
|
||||
-- filter negative
|
||||
self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput)
|
||||
self.gradInput:add(self.negative)
|
||||
|
||||
return self.gradInput
|
||||
end
|
25
lib/RGBWeightedMSECriterion.lua
Normal file
|
@ -0,0 +1,25 @@
|
|||
local RGBWeightedMSECriterion, parent = torch.class('mynn.RGBWeightedMSECriterion','nn.Criterion')
|
||||
|
||||
function RGBWeightedMSECriterion:__init(w)
|
||||
parent.__init(self)
|
||||
self.weight = w:clone()
|
||||
self.diff = torch.Tensor()
|
||||
self.loss = torch.Tensor()
|
||||
end
|
||||
|
||||
function RGBWeightedMSECriterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
for i = 1, input:size(1) do
|
||||
self.diff[i]:add(-1, target[i]):cmul(self.weight)
|
||||
end
|
||||
self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
|
||||
self.output = self.loss:mean()
|
||||
|
||||
return self.output
|
||||
end
|
||||
|
||||
function RGBWeightedMSECriterion:updateGradInput(input, target)
|
||||
self.gradInput:resizeAs(input):copy(self.diff)
|
||||
return self.gradInput
|
||||
end
|
||||
|
|
@ -13,7 +13,7 @@ function image_loader.decode_float(blob)
|
|||
end
|
||||
function image_loader.encode_png(rgb, alpha)
|
||||
if rgb:type() == "torch.ByteTensor" then
|
||||
error("expect FloatTensor")
|
||||
rgb = rgb:float():div(255)
|
||||
end
|
||||
if alpha then
|
||||
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
||||
|
@ -26,11 +26,11 @@ function image_loader.encode_png(rgb, alpha)
|
|||
rgba[4]:copy(alpha)
|
||||
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
||||
im:format("png")
|
||||
return im:toBlob()
|
||||
return im:toBlob(9)
|
||||
else
|
||||
local im = gm.Image(rgb, "RGB", "DHW")
|
||||
im:format("png")
|
||||
return im:toBlob()
|
||||
return im:toBlob(9)
|
||||
end
|
||||
end
|
||||
function image_loader.save_png(filename, rgb, alpha)
|
||||
|
@ -64,6 +64,7 @@ function image_loader.decode_byte(blob)
|
|||
end
|
||||
return {im, alpha}
|
||||
end
|
||||
load_image()
|
||||
local state, ret = pcall(load_image)
|
||||
if state then
|
||||
return ret[1], ret[2]
|
||||
|
|
|
@ -22,15 +22,12 @@ local function minibatch_adam(model, criterion,
|
|||
local targets_tmp = torch.Tensor(batch_size,
|
||||
target_size[1] * target_size[2] * target_size[3])
|
||||
|
||||
for t = 1, #train_x, batch_size do
|
||||
if t + batch_size > #train_x then
|
||||
break
|
||||
end
|
||||
for t = 1, #train_x do
|
||||
xlua.progress(t, #train_x)
|
||||
for i = 1, batch_size do
|
||||
local x, y = transformer(train_x[shuffle[t + i - 1]])
|
||||
inputs_tmp[i]:copy(x)
|
||||
targets_tmp[i]:copy(y)
|
||||
local xy = transformer(train_x[shuffle[t]], false, batch_size)
|
||||
for i = 1, #xy do
|
||||
inputs_tmp[i]:copy(xy[i][1])
|
||||
targets_tmp[i]:copy(xy[i][2])
|
||||
end
|
||||
inputs:copy(inputs_tmp)
|
||||
targets:copy(targets_tmp)
|
||||
|
|
20
lib/mynn.lua
Normal file
|
@ -0,0 +1,20 @@
|
|||
local function load_cunn()
|
||||
require 'nn'
|
||||
require 'cunn'
|
||||
end
|
||||
local function load_cudnn()
|
||||
require 'cudnn'
|
||||
cudnn.fastest = true
|
||||
end
|
||||
if mynn then
|
||||
return mynn
|
||||
else
|
||||
load_cunn()
|
||||
--load_cudnn()
|
||||
mynn = {}
|
||||
require './LeakyReLU'
|
||||
require './LeakyReLU_deprecated'
|
||||
require './DepthExpand2x'
|
||||
require './RGBWeightedMSECriterion'
|
||||
return mynn
|
||||
end
|
|
@ -4,10 +4,11 @@ local iproc = require './iproc'
|
|||
local reconstruct = require './reconstruct'
|
||||
local pairwise_transform = {}
|
||||
|
||||
local function random_half(src, p, min_size)
|
||||
p = p or 0.5
|
||||
local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
|
||||
if p > torch.uniform() then
|
||||
local function random_half(src, p)
|
||||
p = p or 0.25
|
||||
--local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
|
||||
local filter = "Box"
|
||||
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
|
||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||
else
|
||||
return src
|
||||
|
@ -21,6 +22,48 @@ local function pcacov(x)
|
|||
local ce, cv = torch.symeig(c, 'V')
|
||||
return ce, cv
|
||||
end
|
||||
local function crop_if_large(src, max_size)
|
||||
if src:size(2) > max_size and src:size(3) > max_size then
|
||||
local yi = torch.random(0, src:size(2) - max_size)
|
||||
local xi = torch.random(0, src:size(3) - max_size)
|
||||
return image.crop(src, xi, yi, xi + max_size, yi + max_size)
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
local function active_cropping(x, y, size, offset, p, tries)
|
||||
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
local r = torch.uniform()
|
||||
if p < r then
|
||||
local xi = torch.random(offset, y:size(3) - (size + offset + 1))
|
||||
local yi = torch.random(offset, y:size(2) - (size + offset + 1))
|
||||
local xc = image.crop(x, xi, yi, xi + size, yi + size)
|
||||
local yc = image.crop(y, xi, yi, xi + size, yi + size)
|
||||
yc = yc:float():div(255)
|
||||
xc = xc:float():div(255)
|
||||
return xc, yc
|
||||
else
|
||||
local samples = {}
|
||||
local sum_mse = 0
|
||||
for i = 1, tries do
|
||||
local xi = torch.random(offset, y:size(3) - (size + offset + 1))
|
||||
local yi = torch.random(offset, y:size(2) - (size + offset + 1))
|
||||
local xc = image.crop(x, xi, yi, xi + size, yi + size):float():div(255)
|
||||
local yc = image.crop(y, xi, yi, xi + size, yi + size):float():div(255)
|
||||
local mse = (xc - yc):pow(2):mean()
|
||||
sum_mse = sum_mse + mse
|
||||
table.insert(samples, {xc = xc, yc = yc, mse = mse})
|
||||
end
|
||||
if sum_mse > 0 then
|
||||
table.sort(samples,
|
||||
function (a, b)
|
||||
return a.mse > b.mse
|
||||
end)
|
||||
end
|
||||
return samples[1].xc, samples[1].yc
|
||||
end
|
||||
end
|
||||
|
||||
local function color_noise(src)
|
||||
local p = 0.1
|
||||
src = src:float():div(255)
|
||||
|
@ -84,29 +127,7 @@ local function overlay_augment(src, p)
|
|||
return src
|
||||
end
|
||||
end
|
||||
local INTERPOLATION_PADDING = 16
|
||||
function pairwise_transform.scale(src, scale, size, offset, options)
|
||||
options = options or {color_noise = false, overlay = false, random_half = true, rgb = true}
|
||||
if options.random_half then
|
||||
src = random_half(src)
|
||||
end
|
||||
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
|
||||
local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
|
||||
local down_scale = 1.0 / scale
|
||||
local y = image.crop(src,
|
||||
xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING,
|
||||
xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING)
|
||||
local filters = {
|
||||
"Box", -- 0.012756949974688
|
||||
"Blackman", -- 0.013191924552285
|
||||
--"Cartom", -- 0.013753536746706
|
||||
--"Hanning", -- 0.013761314529647
|
||||
--"Hermite", -- 0.013850225205266
|
||||
"SincFast", -- 0.014095824314306
|
||||
"Jinc", -- 0.014244299255442
|
||||
}
|
||||
local downscale_filter = filters[torch.random(1, #filters)]
|
||||
|
||||
local function data_augment(y, options)
|
||||
y = flip_augment(y)
|
||||
if options.color_noise then
|
||||
y = color_noise(y)
|
||||
|
@ -114,29 +135,51 @@ function pairwise_transform.scale(src, scale, size, offset, options)
|
|||
if options.overlay then
|
||||
y = overlay_augment(y)
|
||||
end
|
||||
local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
|
||||
x = iproc.scale(x, y:size(3), y:size(2))
|
||||
y = y:float():div(255)
|
||||
x = x:float():div(255)
|
||||
|
||||
if options.rgb then
|
||||
else
|
||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
||||
end
|
||||
|
||||
y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
|
||||
x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
|
||||
|
||||
return x, y
|
||||
return y
|
||||
end
|
||||
function pairwise_transform.jpeg_(src, quality, size, offset, options)
|
||||
options = options or {color_noise = false, overlay = false, random_half = true, rgb = true}
|
||||
|
||||
|
||||
local INTERPOLATION_PADDING = 16
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
local filters = {
|
||||
"Box","Box","Box", -- 0.012756949974688
|
||||
"Blackman", -- 0.013191924552285
|
||||
--"Cartom", -- 0.013753536746706
|
||||
--"Hanning", -- 0.013761314529647
|
||||
--"Hermite", -- 0.013850225205266
|
||||
"SincFast", -- 0.014095824314306
|
||||
"Jinc", -- 0.014244299255442
|
||||
}
|
||||
if options.random_half then
|
||||
src = random_half(src)
|
||||
end
|
||||
local yi = torch.random(0, src:size(2) - size - 1)
|
||||
local xi = torch.random(0, src:size(3) - size - 1)
|
||||
local downscale_filter = filters[torch.random(1, #filters)]
|
||||
local y = data_augment(crop_if_large(src, math.max(size * 4, 512)), options)
|
||||
local down_scale = 1.0 / scale
|
||||
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
|
||||
y:size(2) * down_scale, downscale_filter),
|
||||
y:size(3), y:size(2))
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
local xc, yc = active_cropping(x, y,
|
||||
size,
|
||||
INTERPOLATION_PADDING,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
end
|
||||
table.insert(batch, {xc, image.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
||||
if options.random_half then
|
||||
src = random_half(src)
|
||||
end
|
||||
src = crop_if_large(src, math.max(size * 4, 512))
|
||||
local y = src
|
||||
local x
|
||||
|
||||
|
@ -150,63 +193,64 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
|
|||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg")
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
if options.jpeg_sampling_factors == 444 then
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
else -- 422
|
||||
x:samplingFactors({2.0, 1.0, 1.0})
|
||||
end
|
||||
local blob, len = x:toBlob(quality[i])
|
||||
x:fromBlob(blob, len)
|
||||
x = x:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
|
||||
y = image.crop(y, xi, yi, xi + size, yi + size)
|
||||
x = image.crop(x, xi, yi, xi + size, yi + size)
|
||||
y = y:float():div(255)
|
||||
x = x:float():div(255)
|
||||
x, y = flip_augment(x, y)
|
||||
|
||||
if options.rgb then
|
||||
else
|
||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
local xc, yc = active_cropping(x, y, size, 0,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
xc, yc = flip_augment(xc, yc)
|
||||
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
end
|
||||
table.insert(batch, {xc, image.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
|
||||
return x, image.crop(y, offset, offset, size - offset, size - offset)
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg(src, category, level, size, offset, options)
|
||||
function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
||||
if category == "anime_style_art" then
|
||||
if level == 1 then
|
||||
if torch.uniform() > 0.7 then
|
||||
if torch.uniform() > 0.8 then
|
||||
return pairwise_transform.jpeg_(src, {},
|
||||
size, offset,
|
||||
options)
|
||||
size, offset, n, options)
|
||||
else
|
||||
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
|
||||
size, offset,
|
||||
options)
|
||||
size, offset, n, options)
|
||||
end
|
||||
elseif level == 2 then
|
||||
if torch.uniform() > 0.7 then
|
||||
local r = torch.uniform()
|
||||
if torch.uniform() > 0.8 then
|
||||
return pairwise_transform.jpeg_(src, {},
|
||||
size, offset,
|
||||
options)
|
||||
size, offset, n, options)
|
||||
else
|
||||
local r = torch.uniform()
|
||||
if r > 0.6 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
|
||||
size, offset,
|
||||
options)
|
||||
size, offset, n, options)
|
||||
elseif r > 0.3 then
|
||||
local quality1 = torch.random(37, 70)
|
||||
local quality2 = quality1 - torch.random(5, 10)
|
||||
return pairwise_transform.jpeg_(src, {quality1, quality2},
|
||||
size, offset,
|
||||
options)
|
||||
size, offset, n, options)
|
||||
else
|
||||
local quality1 = torch.random(52, 70)
|
||||
return pairwise_transform.jpeg_(src,
|
||||
{quality1,
|
||||
quality1 - torch.random(5, 15),
|
||||
quality1 - torch.random(15, 25)},
|
||||
size, offset,
|
||||
options)
|
||||
local quality2 = quality1 - torch.random(5, 15)
|
||||
local quality3 = quality1 - torch.random(15, 25)
|
||||
|
||||
return pairwise_transform.jpeg_(src,
|
||||
{quality1, quality2, quality3},
|
||||
size, offset, n, options)
|
||||
end
|
||||
end
|
||||
else
|
||||
|
@ -216,23 +260,25 @@ function pairwise_transform.jpeg(src, category, level, size, offset, options)
|
|||
if level == 1 then
|
||||
if torch.uniform() > 0.7 then
|
||||
return pairwise_transform.jpeg_(src, {},
|
||||
size, offset,
|
||||
size, offset, n,
|
||||
options)
|
||||
else
|
||||
return pairwise_transform.jpeg_(src, {torch.random(80, 95)},
|
||||
size, offset,
|
||||
size, offset, n,
|
||||
options)
|
||||
end
|
||||
elseif level == 2 then
|
||||
if torch.uniform() > 0.7 then
|
||||
return pairwise_transform.jpeg_(src, {},
|
||||
size, offset,
|
||||
size, offset, n,
|
||||
options)
|
||||
else
|
||||
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
|
||||
size, offset,
|
||||
size, offset, n,
|
||||
options)
|
||||
end
|
||||
else
|
||||
error("unknown noise level: " .. level)
|
||||
end
|
||||
else
|
||||
error("unknown category: " .. category)
|
||||
|
@ -242,6 +288,7 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
|
|||
if options.random_half then
|
||||
src = random_half(src)
|
||||
end
|
||||
src = crop_if_large(src, math.max(size * 4, 512))
|
||||
local down_scale = 1.0 / scale
|
||||
local filters = {
|
||||
"Box", -- 0.012756949974688
|
||||
|
@ -270,7 +317,11 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
|
|||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg")
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
if options.jpeg_sampling_factors == 444 then
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
else -- 422
|
||||
x:samplingFactors({2.0, 1.0, 1.0})
|
||||
end
|
||||
local blob, len = x:toBlob(quality[i])
|
||||
x:fromBlob(blob, len)
|
||||
x = x:toTensor("byte", "RGB", "DHW")
|
||||
|
@ -321,10 +372,11 @@ function pairwise_transform.jpeg_scale(src, scale, category, level, size, offset
|
|||
size, offset, options)
|
||||
else
|
||||
local quality1 = torch.random(52, 70)
|
||||
local quality2 = quality1 - torch.random(5, 15)
|
||||
local quality3 = quality1 - torch.random(15, 25)
|
||||
|
||||
return pairwise_transform.jpeg_scale_(src, scale,
|
||||
{quality1,
|
||||
quality1 - torch.random(5, 15),
|
||||
quality1 - torch.random(15, 25)},
|
||||
{quality1, quality2, quality3 },
|
||||
size, offset, options)
|
||||
end
|
||||
end
|
||||
|
@ -354,14 +406,13 @@ end
|
|||
local function test_jpeg()
|
||||
local loader = require './image_loader'
|
||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||
local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, {})
|
||||
image.display({image = y, legend = "y:0"})
|
||||
image.display({image = x, legend = "x:0"})
|
||||
for i = 2, 9 do
|
||||
local y, x = pairwise_transform.jpeg_(random_half(src),
|
||||
{i * 10}, 128, 0, {color_noise = false, random_half = true, overlay = true, rgb = true})
|
||||
image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
|
||||
image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
|
||||
local xy = pairwise_transform.jpeg_(random_half(src),
|
||||
{i * 10}, 128, 0, 2, {color_noise = false, random_half = true, overlay = true, rgb = true})
|
||||
for i = 1, #xy do
|
||||
image.display({image = xy[i][1], legend = "y:" .. (i * 10), max=1,min=0})
|
||||
image.display({image = xy[i][2], legend = "x:" .. (i * 10),max=1,min=0})
|
||||
end
|
||||
--print(x:mean(), y:mean())
|
||||
end
|
||||
end
|
||||
|
@ -370,27 +421,40 @@ local function test_scale()
|
|||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
local loader = require './image_loader'
|
||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||
local options = {color_noise = true,
|
||||
random_half = true,
|
||||
overlay = false,
|
||||
active_cropping_rate = 1.5,
|
||||
active_cropping_tries = 10,
|
||||
rgb = true
|
||||
}
|
||||
for i = 1, 9 do
|
||||
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_noise = true, random_half = true, rgb = true, overlay = true})
|
||||
image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
|
||||
print(y:size(), x:size())
|
||||
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
|
||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
|
||||
print(xy[1][1]:size(), xy[1][2]:size())
|
||||
--print(x:mean(), y:mean())
|
||||
end
|
||||
end
|
||||
local function test_jpeg_scale()
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
local loader = require './image_loader'
|
||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||
local options = {color_noise = true,
|
||||
random_half = true,
|
||||
overlay = true,
|
||||
active_cropping_ratio = 0.5,
|
||||
active_cropping_times = 10
|
||||
}
|
||||
for i = 1, 9 do
|
||||
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_noise = true, random_half = true, overlay = true})
|
||||
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, options)
|
||||
image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
|
||||
print(y:size(), x:size())
|
||||
--print(x:mean(), y:mean())
|
||||
end
|
||||
for i = 1, 9 do
|
||||
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_noise = true, random_half = true, overlay = true})
|
||||
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, options)
|
||||
image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
|
||||
print(y:size(), x:size())
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
local function load_cuda()
|
||||
require 'nn'
|
||||
require 'cunn'
|
||||
end
|
||||
local function load_cudnn()
|
||||
require 'cudnn'
|
||||
--cudnn.fastest = true
|
||||
end
|
||||
|
||||
if pcall(load_cuda) then
|
||||
require 'cunn'
|
||||
else
|
||||
--[[ TODO: fakecuda does not work.
|
||||
|
||||
|
@ -13,3 +17,5 @@ else
|
|||
require('fakecuda').init(true)
|
||||
--]]
|
||||
end
|
||||
if pcall(load_cudnn) then
|
||||
end
|
||||
|
|
|
@ -48,7 +48,8 @@ local function reconstruct_rgb(model, x, offset, block_size)
|
|||
end
|
||||
return new_x
|
||||
end
|
||||
function model_is_rgb(model)
|
||||
local reconstruct = {}
|
||||
function reconstruct.is_rgb(model)
|
||||
if model:get(model:size() - 1).weight:size(1) == 3 then
|
||||
-- 3ch RGB
|
||||
return true
|
||||
|
@ -57,8 +58,23 @@ function model_is_rgb(model)
|
|||
return false
|
||||
end
|
||||
end
|
||||
|
||||
local reconstruct = {}
|
||||
function reconstruct.offset_size(model)
|
||||
local conv = model:findModules("nn.SpatialConvolutionMM")
|
||||
if #conv > 0 then
|
||||
local offset = 0
|
||||
for i = 1, #conv do
|
||||
offset = offset + (conv[i].kW - 1) / 2
|
||||
end
|
||||
return math.floor(offset)
|
||||
else
|
||||
conv = model:findModules("cudnn.SpatialConvolution")
|
||||
local offset = 0
|
||||
for i = 1, #conv do
|
||||
offset = offset + (conv[i].kW - 1) / 2
|
||||
end
|
||||
return math.floor(offset)
|
||||
end
|
||||
end
|
||||
function reconstruct.image_y(model, x, offset, block_size)
|
||||
block_size = block_size or 128
|
||||
local output_size = block_size - offset * 2
|
||||
|
@ -172,18 +188,22 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
|||
return output
|
||||
end
|
||||
|
||||
function reconstruct.image(model, x, offset, block_size)
|
||||
if model_is_rgb(model) then
|
||||
return reconstruct.image_rgb(model, x, offset, block_size)
|
||||
function reconstruct.image(model, x, block_size)
|
||||
if reconstruct.is_rgb(model) then
|
||||
return reconstruct.image_rgb(model, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
else
|
||||
return reconstruct.image_y(model, x, offset, block_size)
|
||||
return reconstruct.image_y(model, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
end
|
||||
end
|
||||
function reconstruct.scale(model, scale, x, offset, block_size)
|
||||
if model_is_rgb(model) then
|
||||
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
||||
function reconstruct.scale(model, scale, x, block_size)
|
||||
if reconstruct.is_rgb(model) then
|
||||
return reconstruct.scale_rgb(model, scale, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
else
|
||||
return reconstruct.scale_y(model, scale, x, offset, block_size)
|
||||
return reconstruct.scale_y(model, scale, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
require 'xlua'
|
||||
require 'pl'
|
||||
require 'trepl'
|
||||
|
||||
-- global settings
|
||||
|
||||
|
@ -18,6 +19,7 @@ cmd:text("waifu2x")
|
|||
cmd:text("Options:")
|
||||
cmd:option("-seed", 11, 'fixed input seed')
|
||||
cmd:option("-data_dir", "./data", 'data directory')
|
||||
-- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slow than cunn
|
||||
cmd:option("-test", "images/miku_small.png", 'test image file')
|
||||
cmd:option("-model_dir", "./models", 'model directory')
|
||||
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
|
||||
|
@ -30,9 +32,15 @@ cmd:option("-scale", 2.0, 'scale')
|
|||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
|
||||
cmd:option("-crop_size", 128, 'crop size')
|
||||
cmd:option("-max_size", -1, 'crop if image size larger then this value.')
|
||||
cmd:option("-batch_size", 2, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'epoch')
|
||||
cmd:option("-core", 2, 'cpu core')
|
||||
cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
|
||||
cmd:option("-validation_ratio", 0.1, 'validation ratio')
|
||||
cmd:option("-validation_crops", 40, 'number of crop region in validation')
|
||||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||
cmd:option("-active_cropping_tries", 20, 'active cropping tries')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
for k, v in pairs(opt) do
|
||||
|
@ -81,16 +89,6 @@ torch.setnumthreads(settings.core)
|
|||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||
|
||||
settings.validation_ratio = 0.1
|
||||
settings.validation_crops = 30
|
||||
|
||||
local srcnn = require './srcnn'
|
||||
if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
|
||||
settings.create_model = srcnn.waifu4x
|
||||
settings.block_offset = 13
|
||||
else
|
||||
settings.create_model = srcnn.waifu2x
|
||||
settings.block_offset = 7
|
||||
end
|
||||
settings.backend = "cunn"
|
||||
|
||||
return settings
|
||||
|
|
|
@ -1,77 +1,66 @@
|
|||
require './LeakyReLU'
|
||||
|
||||
require './mynn'
|
||||
|
||||
-- ref: http://arxiv.org/abs/1502.01852
|
||||
function nn.SpatialConvolutionMM:reset(stdv)
|
||||
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
end
|
||||
|
||||
-- ref: http://arxiv.org/abs/1501.00092
|
||||
local srcnn = {}
|
||||
function srcnn.waifu2x(color)
|
||||
function srcnn.waifu2x_cunn(ch)
|
||||
local model = nn.Sequential()
|
||||
local ch = nil
|
||||
if color == "rgb" then
|
||||
ch = 3
|
||||
elseif color == "y" then
|
||||
ch = 1
|
||||
else
|
||||
if color then
|
||||
error("unknown color: " .. color)
|
||||
else
|
||||
error("unknown color: nil")
|
||||
end
|
||||
end
|
||||
-- very deep model
|
||||
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model, 7
|
||||
return model
|
||||
end
|
||||
|
||||
-- current 4x is worse then 2x * 2
|
||||
function srcnn.waifu4x(color)
|
||||
function srcnn.waifu2x_cudnn(ch)
|
||||
local model = nn.Sequential()
|
||||
|
||||
local ch = nil
|
||||
model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(mynn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
function srcnn.create(model_name, backend, color)
|
||||
local ch = 3
|
||||
if color == "rgb" then
|
||||
ch = 3
|
||||
elseif color == "y" then
|
||||
ch = 1
|
||||
else
|
||||
error("unknown color: " .. color)
|
||||
error("unsupported color: " + color)
|
||||
end
|
||||
if backend == "cunn" then
|
||||
return srcnn.waifu2x_cunn(ch)
|
||||
elseif backend == "cudnn" then
|
||||
return srcnn.waifu2x_cudnn(ch)
|
||||
else
|
||||
error("unsupported backend: " + backend)
|
||||
end
|
||||
|
||||
model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
return model, 13
|
||||
end
|
||||
return srcnn
|
||||
|
|
159
train.lua
|
@ -1,9 +1,12 @@
|
|||
require './lib/portable'
|
||||
require './lib/mynn'
|
||||
require 'optim'
|
||||
require 'xlua'
|
||||
require 'pl'
|
||||
require 'snappy'
|
||||
|
||||
local settings = require './lib/settings'
|
||||
local srcnn = require './lib/srcnn'
|
||||
local minibatch_adam = require './lib/minibatch_adam'
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
|
@ -11,11 +14,11 @@ local pairwise_transform = require './lib/pairwise_transform'
|
|||
local image_loader = require './lib/image_loader'
|
||||
|
||||
local function save_test_scale(model, rgb, file)
|
||||
local up = reconstruct.scale(model, settings.scale, rgb, settings.block_offset)
|
||||
local up = reconstruct.scale(model, settings.scale, rgb)
|
||||
image.save(file, up)
|
||||
end
|
||||
local function save_test_jpeg(model, rgb, file)
|
||||
local im, count = reconstruct.image(model, rgb, settings.block_offset)
|
||||
local im, count = reconstruct.image(model, rgb)
|
||||
image.save(file, im)
|
||||
end
|
||||
local function split_data(x, test_size)
|
||||
|
@ -35,10 +38,14 @@ local function make_validation_set(x, transformer, n)
|
|||
n = n or 4
|
||||
local data = {}
|
||||
for i = 1, #x do
|
||||
for k = 1, n do
|
||||
local x, y = transformer(x[i], true)
|
||||
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
|
||||
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
|
||||
for k = 1, math.max(n / 8, 1) do
|
||||
local xy = transformer(x[i], true, 8)
|
||||
for j = 1, #xy do
|
||||
local x = xy[j][1]
|
||||
local y = xy[j][2]
|
||||
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
|
||||
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
|
||||
end
|
||||
end
|
||||
xlua.progress(i, #x)
|
||||
collectgarbage()
|
||||
|
@ -58,15 +65,96 @@ local function validate(model, criterion, data)
|
|||
return loss / #data
|
||||
end
|
||||
|
||||
local function create_criterion(model)
|
||||
if reconstruct.is_rgb(model) then
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(3, output_w * output_w)
|
||||
weight[1]:fill(0.299 * 3) -- R
|
||||
weight[2]:fill(0.587 * 3) -- G
|
||||
weight[3]:fill(0.114 * 3) -- B
|
||||
return mynn.RGBWeightedMSECriterion(weight):cuda()
|
||||
else
|
||||
return nn.MSECriterion():cuda()
|
||||
end
|
||||
end
|
||||
local function transformer(x, is_validation, n, offset)
|
||||
local size = x[1]
|
||||
local dec = snappy.decompress(x[2]:string())
|
||||
x = torch.ByteTensor(size[1], size[2], size[3])
|
||||
x:storage():string(dec)
|
||||
|
||||
n = n or settings.batch_size;
|
||||
if is_validation == nil then is_validation = false end
|
||||
local color_noise = nil
|
||||
local overlay = nil
|
||||
local active_cropping_ratio = nil
|
||||
local active_cropping_tries = nil
|
||||
|
||||
if is_validation then
|
||||
active_cropping_rate = 0.0
|
||||
active_cropping_tries = 0
|
||||
color_noise = false
|
||||
overlay = false
|
||||
else
|
||||
active_cropping_rate = settings.active_cropping_rate
|
||||
active_cropping_tries = settings.active_cropping_tries
|
||||
color_noise = settings.color_noise
|
||||
overlay = settings.overlay
|
||||
end
|
||||
|
||||
if settings.method == "scale" then
|
||||
return pairwise_transform.scale(x,
|
||||
settings.scale,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
random_half = settings.random_half,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
elseif settings.method == "noise" then
|
||||
return pairwise_transform.jpeg(x,
|
||||
settings.category,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
random_half = settings.random_half,
|
||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
elseif settings.method == "noise_scale" then
|
||||
return pairwise_transform.jpeg_scale(x,
|
||||
settings.scale,
|
||||
settings.category,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||
random_half = settings.random_half,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
local function train()
|
||||
local model, offset = settings.create_model(settings.color)
|
||||
assert(offset == settings.block_offset)
|
||||
local criterion = nn.MSECriterion():cuda()
|
||||
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local pairwise_func = function(x, is_validation, n)
|
||||
return transformer(x, is_validation, n, offset)
|
||||
end
|
||||
local criterion = create_criterion(model)
|
||||
local x = torch.load(settings.images)
|
||||
local lrd_count = 0
|
||||
local train_x, valid_x = split_data(x,
|
||||
math.floor(settings.validation_ratio * #x))
|
||||
local test = image_loader.load_float(settings.test)
|
||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x))
|
||||
local adam_config = {
|
||||
learningRate = settings.learning_rate,
|
||||
xBatchSize = settings.batch_size,
|
||||
|
@ -77,45 +165,9 @@ local function train()
|
|||
elseif settings.color == "rgb" then
|
||||
ch = 3
|
||||
end
|
||||
local transformer = function(x, is_validation)
|
||||
if is_validation == nil then is_validation = false end
|
||||
local color_noise = (not is_validation) and settings.color_noise
|
||||
local overlay = (not is_validation) and settings.overlay
|
||||
if settings.method == "scale" then
|
||||
return pairwise_transform.scale(x,
|
||||
settings.scale,
|
||||
settings.crop_size, offset,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
random_half = settings.random_half,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
elseif settings.method == "noise" then
|
||||
return pairwise_transform.jpeg(x,
|
||||
settings.category,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
random_half = settings.random_half,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
elseif settings.method == "noise_scale" then
|
||||
return pairwise_transform.jpeg_scale(x,
|
||||
settings.scale,
|
||||
settings.category,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
random_half = settings.random_half,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
end
|
||||
end
|
||||
local best_score = 100000.0
|
||||
print("# make validation-set")
|
||||
local valid_xy = make_validation_set(valid_x, transformer, settings.validation_crops)
|
||||
local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
|
||||
valid_x = nil
|
||||
|
||||
collectgarbage()
|
||||
|
@ -125,7 +177,7 @@ local function train()
|
|||
model:training()
|
||||
print("# " .. epoch)
|
||||
print(minibatch_adam(model, criterion, train_x, adam_config,
|
||||
transformer,
|
||||
pairwise_func,
|
||||
{ch, settings.crop_size, settings.crop_size},
|
||||
{ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
||||
))
|
||||
|
@ -133,6 +185,7 @@ local function train()
|
|||
print("# validation")
|
||||
local score = validate(model, criterion, valid_xy)
|
||||
if score < best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
lrd_count = 0
|
||||
best_score = score
|
||||
print("* update best model")
|
||||
|
@ -140,16 +193,16 @@ local function train()
|
|||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.png"):format(settings.noise_level))
|
||||
save_test_jpeg(model, test, log)
|
||||
save_test_jpeg(model, test_image, log)
|
||||
elseif settings.method == "scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("scale%.1f_best.png"):format(settings.scale))
|
||||
save_test_scale(model, test, log)
|
||||
save_test_scale(model, test_image, log)
|
||||
elseif settings.method == "noise_scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_scale%.1f_best.png"):format(settings.noise_level,
|
||||
settings.scale))
|
||||
save_test_scale(model, test, log)
|
||||
save_test_scale(model, test_image, log)
|
||||
end
|
||||
else
|
||||
lrd_count = lrd_count + 1
|
||||
|
|
9
train.sh
|
@ -1,10 +1,13 @@
|
|||
#!/bin/sh
|
||||
|
||||
th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
|
||||
th convert_data.lua
|
||||
|
||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 1 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
|
||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 2 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
|
||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method scale -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_small_noisy.jpg -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_ratio 0.1 -validation_crops 80
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
||||
|
||||
|
|
11
train_ukbench.sh
Executable file
|
@ -0,0 +1,11 @@
|
|||
#!/bin/sh
|
||||
|
||||
th train.lua -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 1 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
|
||||
th cleanup_model.lua -model models/ukbench2/noise1_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -core 1 -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 2 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
|
||||
th cleanup_model.lua -model models/ukbench2/noise2_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -category photo -color rgb -random_half 0 -epoch 400 -batch_size 1 -method scale -scale 2 -model_dir models/ukbench2 -data_dir ukbench -test photo2-noise.png
|
||||
th cleanup_model.lua -model models/ukbench2/scale2.0x_model.t7 -oformat ascii
|
||||
|
27
waifu2x.lua
|
@ -1,12 +1,11 @@
|
|||
require './lib/portable'
|
||||
require 'sys'
|
||||
require 'pl'
|
||||
require './lib/LeakyReLU'
|
||||
require './lib/mynn'
|
||||
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
local BLOCK_OFFSET = 7
|
||||
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
|
@ -22,19 +21,21 @@ local function convert_image(opt)
|
|||
end
|
||||
if opt.m == "noise" then
|
||||
local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
|
||||
--local srcnn = require 'lib/srcnn'
|
||||
--local model = srcnn.waifu2x("rgb"):cuda()
|
||||
model:evaluate()
|
||||
new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size)
|
||||
new_x = reconstruct.image(model, x, opt.crop_size)
|
||||
elseif opt.m == "scale" then
|
||||
local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
|
||||
model:evaluate()
|
||||
new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" then
|
||||
local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
|
||||
local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
|
||||
noise_model:evaluate()
|
||||
scale_model:evaluate()
|
||||
x = reconstruct.image(noise_model, x, BLOCK_OFFSET)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
x = reconstruct.image(noise_model, x)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
|
@ -62,17 +63,17 @@ local function convert_frames(opt)
|
|||
local x, alpha = image_loader.load_float(lines[i])
|
||||
local new_x = nil
|
||||
if opt.m == "noise" and opt.noise_level == 1 then
|
||||
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
|
||||
new_x = reconstruct.image(noise1_model, x, opt.crop_size)
|
||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
||||
new_x = reconstruct.image(noise2_model, x)
|
||||
elseif opt.m == "scale" then
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
|
||||
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
x = reconstruct.image(noise1_model, x)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
|
||||
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
x = reconstruct.image(noise2_model, x)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
|
|
13
web.lua
|
@ -4,8 +4,8 @@ local uuid = require 'uuid'
|
|||
local ffi = require 'ffi'
|
||||
local md5 = require 'md5'
|
||||
require 'pl'
|
||||
require './lib/portable'
|
||||
require './lib/LeakyReLU'
|
||||
require 'lib.portable'
|
||||
require 'lib.mynn'
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
|
@ -23,7 +23,7 @@ local iproc = require './lib/iproc'
|
|||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
|
||||
local MODEL_DIR = "./models/anime_style_art_rgb"
|
||||
local MODEL_DIR = "./models/anime_style_art_rgb3"
|
||||
|
||||
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
|
@ -40,7 +40,6 @@ local CURL_OPTIONS = {
|
|||
max_redirects = 2
|
||||
}
|
||||
local CURL_MAX_SIZE = 2 * 1024 * 1024
|
||||
local BLOCK_OFFSET = 7 -- see srcnn.lua
|
||||
|
||||
local function valid_size(x, scale)
|
||||
if scale == 0 then
|
||||
|
@ -80,13 +79,13 @@ local function get_image(req)
|
|||
end
|
||||
|
||||
local function apply_denoise1(x)
|
||||
return reconstruct.image(noise1_model, x, BLOCK_OFFSET)
|
||||
return reconstruct.image(noise1_model, x)
|
||||
end
|
||||
local function apply_denoise2(x)
|
||||
return reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
||||
return reconstruct.image(noise2_model, x)
|
||||
end
|
||||
local function apply_scale2x(x)
|
||||
return reconstruct.scale(scale20_model, 2.0, x, BLOCK_OFFSET)
|
||||
return reconstruct.scale(scale20_model, 2.0, x)
|
||||
end
|
||||
local function cache_do(cache, x, func)
|
||||
if path.exists(cache) then
|
||||
|
|