diff --git a/.gitignore b/.gitignore index 43374ac..c490ae4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,14 @@ *~ -/*.png -/*.mp4 -/*.jpg +work/ cache/*.png -models/*.png +data/ +!data/.gitkeep + +models/ +!models/anime_style_art +!models/anime_style_art_rgb +!models/ukbench models/*/*.png + waifu2x.log + diff --git a/benchmark.lua b/benchmark.lua deleted file mode 100644 index dc9f5a2..0000000 --- a/benchmark.lua +++ /dev/null @@ -1,280 +0,0 @@ -require './lib/portable' -require './lib/mynn' -require 'xlua' -require 'pl' - -local iproc = require './lib/iproc' -local reconstruct = require './lib/reconstruct' -local image_loader = require './lib/image_loader' -local gm = require 'graphicsmagick' - -local cmd = torch.CmdLine() -cmd:text() -cmd:text("waifu2x-benchmark") -cmd:text("Options:") - -cmd:option("-seed", 11, 'fixed input seed') -cmd:option("-test_dir", "./test", 'test image directory') -cmd:option("-jpeg_quality", 50, 'jpeg quality') -cmd:option("-jpeg_times", 3, 'number of jpeg compression ') -cmd:option("-jpeg_quality_down", 5, 'reducing jpeg quality each times') -cmd:option("-core", 4, 'threads') - -local opt = cmd:parse(arg) -torch.setnumthreads(opt.core) -torch.setdefaulttensortype('torch.FloatTensor') - -local function MSE(x1, x2) - return (x1 - x2):pow(2):mean() -end -local function YMSE(x1, x2) - local x1_2 = x1:clone() - local x2_2 = x2:clone() - - x1_2[1]:mul(0.299 * 3) - x1_2[2]:mul(0.587 * 3) - x1_2[3]:mul(0.114 * 3) - - x2_2[1]:mul(0.299 * 3) - x2_2[2]:mul(0.587 * 3) - x2_2[3]:mul(0.114 * 3) - - return (x1_2 - x2_2):pow(2):mean() -end -local function PSNR(x1, x2) - local mse = MSE(x1, x2) - return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10)) -end -local function YPSNR(x1, x2) - local mse = YMSE(x1, x2) - return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10)) -end - -local function transform_jpeg(x) - for i = 1, opt.jpeg_times do - jpeg = gm.Image(x, "RGB", "DHW") - jpeg:format("jpeg") - jpeg:samplingFactors({1.0, 1.0, 1.0}) - blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down) - jpeg:fromBlob(blob, len) - x = jpeg:toTensor("byte", "RGB", "DHW") - end - return x -end - -local function noise_benchmark(x, v1_noise, v2_noise) - local v1_mse = 0 - local v2_mse = 0 - local jpeg_mse = 0 - local v1_psnr = 0 - local v2_psnr = 0 - local jpeg_psnr = 0 - local v1_time = 0 - local v2_time = 0 - - for i = 1, #x do - local ground_truth = x[i] - local jpg, blob, len, input, v1_out, v2_out, t, mse - - input = transform_jpeg(ground_truth) - input = input:float():div(255) - ground_truth = ground_truth:float():div(255) - - jpeg_mse = jpeg_mse + MSE(ground_truth, input) - jpeg_psnr = jpeg_psnr + PSNR(ground_truth, input) - - t = sys.clock() - v1_output = reconstruct.image(v1_noise, input) - v1_time = v1_time + (sys.clock() - t) - v1_mse = v1_mse + MSE(ground_truth, v1_output) - v1_psnr = v1_psnr + PSNR(ground_truth, v1_output) - - t = sys.clock() - v2_output = reconstruct.image(v2_noise, input) - v2_time = v2_time + (sys.clock() - t) - v2_mse = v2_mse + MSE(ground_truth, v2_output) - v2_psnr = v2_psnr + PSNR(ground_truth, v2_output) - - io.stdout:write( - string.format("%d/%d; v1_time=%f, v2_time=%f, jpeg_mse=%f, v1_mse=%f, v2_mse=%f, jpeg_psnr=%f, v1_psnr=%f, v2_psnr=%f \r", - i, #x, - v1_time / i, v2_time / i, - jpeg_mse / i, - v1_mse / i, v2_mse / i, - jpeg_psnr / i, - v1_psnr / i, v2_psnr / i - ) - ) - io.stdout:flush() - end - io.stdout:write("\n") -end -local function noise_scale_benchmark(x, params, v1_noise, v1_scale, v2_noise, v2_scale) - local v1_mse = 0 - local v2_mse = 0 - local jinc_mse = 0 - local v1_time = 0 - local v2_time = 0 - - for i = 1, #x do - local ground_truth = x[i] - local downscale = iproc.scale(ground_truth, - ground_truth:size(3) * 0.5, - ground_truth:size(2) * 0.5, - params[i].filter) - local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse - - jpeg = gm.Image(downscale, "RGB", "DHW") - jpeg:format("jpeg") - blob, len = jpeg:toBlob(params[i].quality) - jpeg:fromBlob(blob, len) - input = jpeg:toTensor("byte", "RGB", "DHW") - - input = input:float():div(255) - ground_truth = ground_truth:float():div(255) - - jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc") - jinc_mse = jinc_mse + (ground_truth - jinc_output):pow(2):mean() - - t = sys.clock() - v1_output = reconstruct.image(v1_noise, input) - v1_output = reconstruct.scale(v1_scale, 2.0, v1_output) - v1_time = v1_time + (sys.clock() - t) - mse = (ground_truth - v1_output):pow(2):mean() - v1_mse = v1_mse + mse - - t = sys.clock() - v2_output = reconstruct.image(v2_noise, input) - v2_output = reconstruct.scale(v2_scale, 2.0, v2_output) - v2_time = v2_time + (sys.clock() - t) - mse = (ground_truth - v2_output):pow(2):mean() - v2_mse = v2_mse + mse - - io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r", - i, #x, - v1_time / i, v2_time / i, - (v1_time / i) / (v2_time / i), - jinc_mse / i, - v1_mse / i, (v1_mse/i) / (jinc_mse/i), - v2_mse / i, (v2_mse/i) / (jinc_mse/i), - (v1_mse / i) / (v2_mse / i))) - - io.stdout:flush() - end - io.stdout:write("\n") -end -local function scale_benchmark(x, params, v1_scale, v2_scale) - local v1_mse = 0 - local v2_mse = 0 - local jinc_mse = 0 - local v1_psnr = 0 - local v2_psnr = 0 - local jinc_psnr = 0 - - local v1_time = 0 - local v2_time = 0 - - for i = 1, #x do - local ground_truth = x[i] - local downscale = iproc.scale(ground_truth, - ground_truth:size(3) * 0.5, - ground_truth:size(2) * 0.5, - params[i].filter) - local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse - input = downscale - - input = input:float():div(255) - ground_truth = ground_truth:float():div(255) - - jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc") - mse = (ground_truth - jinc_output):pow(2):mean() - jinc_mse = jinc_mse + mse - jinc_psnr = jinc_psnr + (10 * (math.log(1.0 / mse) / math.log(10))) - - t = sys.clock() - v1_output = reconstruct.scale(v1_scale, 2.0, input) - v1_time = v1_time + (sys.clock() - t) - mse = (ground_truth - v1_output):pow(2):mean() - v1_mse = v1_mse + mse - v1_psnr = v1_psnr + (10 * (math.log(1.0 / mse) / math.log(10))) - - t = sys.clock() - v2_output = reconstruct.scale(v2_scale, 2.0, input) - v2_time = v2_time + (sys.clock() - t) - mse = (ground_truth - v2_output):pow(2):mean() - v2_mse = v2_mse + mse - v2_psnr = v2_psnr + (10 * (math.log(1.0 / mse) / math.log(10))) - - io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r", - i, #x, - v1_time / i, v2_time / i, - (v1_time / i) / (v2_time / i), - jinc_psnr / i, - v1_psnr / i, (v1_psnr/i) / (jinc_psnr/i), - v2_psnr / i, (v2_psnr/i) / (jinc_psnr/i), - (v1_psnr / i) / (v2_psnr / i))) - - io.stdout:flush() - end - io.stdout:write("\n") -end - -local function split_data(x, test_size) - local index = torch.randperm(#x) - local train_size = #x - test_size - local train_x = {} - local valid_x = {} - for i = 1, train_size do - train_x[i] = x[index[i]] - end - for i = 1, test_size do - valid_x[i] = x[index[train_size + i]] - end - return train_x, valid_x -end -local function crop_4x(x) - local w = x:size(3) % 4 - local h = x:size(2) % 4 - return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h) -end -local function load_data(valid_dir) - local valid_x = {} - local files = dir.getfiles(valid_dir, "*.png") - for i = 1, #files do - table.insert(valid_x, crop_4x(image_loader.load_byte(files[i]))) - xlua.progress(i, #files) - end - return valid_x -end - -local function noise_main(valid_dir, level) - local v1_noise = torch.load(path.join(V1_DIR, string.format("noise%d_model.t7", level)), "ascii") - local v2_noise = torch.load(path.join(V2_DIR, string.format("noise%d_model.t7", level)), "ascii") - local valid_x = load_data(valid_dir) - noise_benchmark(valid_x, v1_noise, v2_noise) -end -local function scale_main(valid_dir) - local v1 = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii") - local v2 = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii") - local valid_x = load_data(valid_dir) - local params = random_params(valid_x, 2) - scale_benchmark(valid_x, params, v1, v2) -end -local function noise_scale_main(valid_dir) - local v1_noise = torch.load(path.join(V1_DIR, "noise2_model.t7"), "ascii") - local v1_scale = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii") - local v2_noise = torch.load(path.join(V2_DIR, "noise2_model.t7"), "ascii") - local v2_scale = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii") - local valid_x = load_data(valid_dir) - local params = random_params(valid_x, 2) - noise_scale_benchmark(valid_x, params, v1_noise, v1_scale, v2_noise, v2_scale) -end - -V1_DIR = "models/anime_style_art_rgb" -V2_DIR = "models/anime_style_art_rgb5" - -torch.manualSeed(opt.seed) -cutorch.manualSeed(opt.seed) -noise_main("./test", 2) ---scale_main("./test") ---noise_scale_main("./test") diff --git a/convert_data.lua b/convert_data.lua index ceb7323..0f8f368 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -1,22 +1,14 @@ -local ffi = require 'ffi' -require './lib/portable' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path + +require 'pl' require 'image' -require 'snappy' -local settings = require './lib/settings' -local image_loader = require './lib/image_loader' +local compression = require 'compression' +local settings = require 'settings' +local image_loader = require 'image_loader' local MAX_SIZE = 1440 -local function count_lines(file) - local fp = io.open(file, "r") - local count = 0 - for line in fp:lines() do - count = count + 1 - end - fp:close() - - return count -end local function crop_if_large(src, max_size) if max_size > 0 and (src:size(2) > max_size or src:size(3) > max_size) then local sx = torch.random(0, src:size(3) - math.min(max_size, src:size(3))) @@ -36,40 +28,38 @@ end local function load_images(list) local MARGIN = 32 - local count = count_lines(list) - local fp = io.open(list, "r") + local lines = utils.split(file.read(list), "\n") local x = {} - local c = 0 - for line in fp:lines() do + for i = 1, #lines do + local line = lines[i] local im, alpha = image_loader.load_byte(line) - im = crop_if_large(im, settings.max_size) - im = crop_4x(im) - if alpha then - io.stderr:write(string.format("%s: skip: reason: alpha channel.", line)) + io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line)) else + im = crop_if_large(im, settings.max_size) + im = crop_4x(im) local scale = 1.0 if settings.random_half then scale = 2.0 end if im then if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then - table.insert(x, {im:size(), torch.ByteStorage():string(snappy.compress(im:storage():string()))}) + table.insert(x, compression.compress(im)) else - io.stderr:write(string.format("%s: skip: reason: too small (%d > size).\n", line, settings.crop_size * scale + MARGIN)) + io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN)) end else - io.stderr:write(string.format("%s: skip: reason: load error.\n", line)) + io.stderr:write(string.format("\n%s: skip: load error.\n", line)) end end - c = c + 1 - xlua.progress(c, count) - if c % 10 == 0 then + xlua.progress(i, #lines) + if i % 10 == 0 then collectgarbage() end end return x end + torch.manualSeed(settings.seed) print(settings) local x = load_images(settings.image_list) diff --git a/lib/DepthExpand2x.lua b/lib/DepthExpand2x.lua index 1a6348b..3f28dd5 100644 --- a/lib/DepthExpand2x.lua +++ b/lib/DepthExpand2x.lua @@ -1,7 +1,7 @@ -if mynn.DepthExpand2x then - return mynn.DepthExpand2x +if w2nn.DepthExpand2x then + return w2nn.DepthExpand2x end -local DepthExpand2x, parent = torch.class('mynn.DepthExpand2x','nn.Module') +local DepthExpand2x, parent = torch.class('w2nn.DepthExpand2x','nn.Module') function DepthExpand2x:__init() parent:__init() @@ -67,9 +67,11 @@ function DepthExpand2x.test() end show(x) - local de2x = mynn.DepthExpand2x() + local de2x = w2nn.DepthExpand2x() out = de2x:forward(x) show(out) out = de2x:updateGradInput(x, out) show(out) end + +return DepthExpand2x diff --git a/lib/LeakyReLU.lua b/lib/LeakyReLU.lua index 7096506..5b27bc9 100644 --- a/lib/LeakyReLU.lua +++ b/lib/LeakyReLU.lua @@ -1,8 +1,8 @@ -if mynn.LeakyReLU then - return mynn.LeakyReLU +if w2nn and w2nn.LeakyReLU then + return w2nn.LeakyReLU end -local LeakyReLU, parent = torch.class('mynn.LeakyReLU','nn.Module') +local LeakyReLU, parent = torch.class('w2nn.LeakyReLU','nn.Module') function LeakyReLU:__init(negative_scale) parent.__init(self) diff --git a/lib/RGBWeightedMSECriterion.lua b/lib/WeightedMSECriterion.lua similarity index 63% rename from lib/RGBWeightedMSECriterion.lua rename to lib/WeightedMSECriterion.lua index cd267f4..2e2d999 100644 --- a/lib/RGBWeightedMSECriterion.lua +++ b/lib/WeightedMSECriterion.lua @@ -1,13 +1,13 @@ -local RGBWeightedMSECriterion, parent = torch.class('mynn.RGBWeightedMSECriterion','nn.Criterion') +local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion') -function RGBWeightedMSECriterion:__init(w) +function WeightedMSECriterion:__init(w) parent.__init(self) self.weight = w:clone() self.diff = torch.Tensor() self.loss = torch.Tensor() end -function RGBWeightedMSECriterion:updateOutput(input, target) +function WeightedMSECriterion:updateOutput(input, target) self.diff:resizeAs(input):copy(input) for i = 1, input:size(1) do self.diff[i]:add(-1, target[i]):cmul(self.weight) @@ -18,8 +18,7 @@ function RGBWeightedMSECriterion:updateOutput(input, target) return self.output end -function RGBWeightedMSECriterion:updateGradInput(input, target) +function WeightedMSECriterion:updateGradInput(input, target) self.gradInput:resizeAs(input):copy(self.diff) return self.gradInput end - diff --git a/lib/compression.lua b/lib/compression.lua new file mode 100644 index 0000000..50c235e --- /dev/null +++ b/lib/compression.lua @@ -0,0 +1,17 @@ +-- snapply compression for ByteTensor +require 'snappy' + +local compression = {} +compression.compress = function (bt) + local enc = snappy.compress(bt:storage():string()) + return {bt:size(), torch.ByteStorage():string(enc)} +end +compression.decompress = function(data) + local size = data[1] + local dec = snappy.decompress(data[2]:string()) + local bt = torch.ByteTensor(unpack(torch.totable(size))) + bt:storage():string(dec) + return bt +end + +return compression diff --git a/lib/image_loader.lua b/lib/image_loader.lua index 7ca78cb..136758d 100644 --- a/lib/image_loader.lua +++ b/lib/image_loader.lua @@ -17,7 +17,7 @@ function image_loader.encode_png(rgb, alpha) end if alpha then if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then - alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW") + alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW") end local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3)) rgba[1]:copy(rgb[1]) @@ -50,8 +50,8 @@ function image_loader.decode_byte(blob) if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then -- split alpha channel im = im:toTensor('float', 'RGBA', 'DHW') - local sum_alpha = (im[4] - 1):sum() - if sum_alpha > 0 or sum_alpha < 0 then + local sum_alpha = (im[4] - 1.0):sum() + if sum_alpha < 0 then alpha = im[4]:reshape(1, im:size(2), im:size(3)) end local new_im = torch.FloatTensor(3, im:size(2), im:size(3)) diff --git a/lib/iproc.lua b/lib/iproc.lua index 9655771..b80b213 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -22,5 +22,4 @@ function iproc.padding(img, w1, w2, h1, h2) flow[2]:add(-w1) return image.warp(img, flow, "simple", false, "clamp") end - return iproc diff --git a/lib/mynn.lua b/lib/mynn.lua deleted file mode 100644 index dc81c39..0000000 --- a/lib/mynn.lua +++ /dev/null @@ -1,20 +0,0 @@ -local function load_cunn() - require 'nn' - require 'cunn' -end -local function load_cudnn() - require 'cudnn' - cudnn.fastest = true -end -if mynn then - return mynn -else - load_cunn() - --load_cudnn() - mynn = {} - require './LeakyReLU' - require './LeakyReLU_deprecated' - require './DepthExpand2x' - require './RGBWeightedMSECriterion' - return mynn -end diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index 5abfc76..a6de6f6 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -1,7 +1,7 @@ require 'image' local gm = require 'graphicsmagick' -local iproc = require './iproc' -local reconstruct = require './reconstruct' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' local pairwise_transform = {} local function random_half(src, p) @@ -81,6 +81,11 @@ local function color_noise(src) return x:mul(255):byte() end +local function shift_1px(src) + -- reducing the even/odd issue in nearest neighbor. + local r = torch.random(1, 4) + +end local function flip_augment(x, y) local flip = torch.random(1, 4) if y then @@ -138,17 +143,16 @@ local function data_augment(y, options) return y end - local INTERPOLATION_PADDING = 16 function pairwise_transform.scale(src, scale, size, offset, n, options) local filters = { - "Box","Box","Box", -- 0.012756949974688 + "Box","Box", -- 0.012756949974688 "Blackman", -- 0.013191924552285 --"Cartom", -- 0.013753536746706 --"Hanning", -- 0.013761314529647 --"Hermite", -- 0.013850225205266 "SincFast", -- 0.014095824314306 - "Jinc", -- 0.014244299255442 + --"Jinc", -- 0.014244299255442 } if options.random_half then src = random_half(src) @@ -176,26 +180,14 @@ function pairwise_transform.scale(src, scale, size, offset, n, options) return batch end function pairwise_transform.jpeg_(src, quality, size, offset, n, options) - if options.random_half then - src = random_half(src) - end - src = crop_if_large(src, math.max(size * 4, 512)) - local y = src - local x - - if options.color_noise then - y = color_noise(y) - end - if options.overlay then - y = overlay_augment(y) - end - x = y + local y = data_augment(crop_if_large(src, math.max(size * 4, 512)), options) + local x = y for i = 1, #quality do x = gm.Image(x, "RGB", "DHW") x:format("jpeg") if options.jpeg_sampling_factors == 444 then x:samplingFactors({1.0, 1.0, 1.0}) - else -- 422 + else -- 420 x:samplingFactors({2.0, 1.0, 1.0}) end local blob, len = x:toBlob(quality[i]) diff --git a/lib/portable.lua b/lib/portable.lua deleted file mode 100644 index d79b98b..0000000 --- a/lib/portable.lua +++ /dev/null @@ -1,17 +0,0 @@ -require 'torch' -require 'nn' - -local function load_cuda() - require 'cutorch' - require 'cunn' -end -local function load_cudnn() - require 'cudnn' - --cudnn.fastest = true -end - -if pcall(load_cuda) then -else -end -if pcall(load_cudnn) then -end diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index fbd1f51..94683f2 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -1,5 +1,5 @@ require 'image' -local iproc = require './iproc' +local iproc = require 'iproc' local function reconstruct_y(model, x, offset, block_size) if x:dim() == 2 then diff --git a/lib/settings.lua b/lib/settings.lua index eb8d8ce..4031f6d 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -35,7 +35,7 @@ cmd:option("-crop_size", 128, 'crop size') cmd:option("-max_size", -1, 'crop if image size larger then this value.') cmd:option("-batch_size", 2, 'mini batch size') cmd:option("-epoch", 200, 'epoch') -cmd:option("-core", 2, 'cpu core') +cmd:option("-thread", -1, 'number of CPU threads') cmd:option("-jpeg_sampling_factors", 444, '(444|422)') cmd:option("-validation_ratio", 0.1, 'validation ratio') cmd:option("-validation_crops", 40, 'number of crop region in validation') @@ -84,7 +84,9 @@ else settings.overlay = false end -torch.setnumthreads(settings.core) +if settings.thread > 0 then + torch.setnumthreads(tonumber(settings.thread)) +end settings.images = string.format("%s/images.t7", settings.data_dir) settings.image_list = string.format("%s/image_list.txt", settings.data_dir) diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 6dbdebd..6e273c9 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -1,5 +1,4 @@ - -require './mynn' +require 'w2nn' -- ref: http://arxiv.org/abs/1502.01852 -- ref: http://arxiv.org/abs/1501.00092 @@ -7,17 +6,17 @@ local srcnn = {} function srcnn.waifu2x_cunn(ch) local model = nn.Sequential() model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0)) model:add(nn.View(-1):setNumInputDims(3)) --model:cuda() @@ -28,17 +27,17 @@ end function srcnn.waifu2x_cudnn(ch) local model = nn.Sequential() model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0)) - model:add(mynn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0)) model:add(nn.View(-1):setNumInputDims(3)) --model:cuda() diff --git a/lib/w2nn.lua b/lib/w2nn.lua new file mode 100644 index 0000000..c47e21e --- /dev/null +++ b/lib/w2nn.lua @@ -0,0 +1,24 @@ +local function load_nn() + require 'torch' + require 'nn' +end +local function load_cunn() + require 'cutorch' + require 'cunn' +end +local function load_cudnn() + require 'cudnn' + cudnn.fastest = true +end +if w2nn then + return w2nn +else + pcall(load_cunn) + pcall(load_cudnn) + w2nn = {} + require 'LeakyReLU' + require 'LeakyReLU_deprecated' + require 'DepthExpand2x' + require 'WeightedMSECriterion' + return w2nn +end diff --git a/tools/benchmark.lua b/tools/benchmark.lua new file mode 100644 index 0000000..dd89ad7 --- /dev/null +++ b/tools/benchmark.lua @@ -0,0 +1,148 @@ +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path +require 'xlua' +require 'pl' + +require 'w2nn' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' +local image_loader = require 'image_loader' +local gm = require 'graphicsmagick' + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("waifu2x-benchmark") +cmd:text("Options:") + +cmd:option("-seed", 11, 'fixed input seed') +cmd:option("-dir", "./data/test", 'test image directory') +cmd:option("-model1_dir", "./models/anime_style_art", 'model1 directory') +cmd:option("-model2_dir", "./models/anime_style_art_rgb", 'model2 directory') +cmd:option("-method", "scale", '(scale|noise)') +cmd:option("-noise_level", 1, '(1|2)') +cmd:option("-color_weight", "y", '(y|rgb)') +cmd:option("-jpeg_quality", 75, 'jpeg quality') +cmd:option("-jpeg_times", 1, 'jpeg compression times') +cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times') + +local opt = cmd:parse(arg) +torch.setdefaulttensortype('torch.FloatTensor') + +local function MSE(x1, x2) + return (x1 - x2):pow(2):mean() +end +local function YMSE(x1, x2) + local x1_2 = x1:clone() + local x2_2 = x2:clone() + + x1_2[1]:mul(0.299 * 3) + x1_2[2]:mul(0.587 * 3) + x1_2[3]:mul(0.114 * 3) + + x2_2[1]:mul(0.299 * 3) + x2_2[2]:mul(0.587 * 3) + x2_2[3]:mul(0.114 * 3) + + return (x1_2 - x2_2):pow(2):mean() +end +local function PSNR(x1, x2) + local mse = MSE(x1, x2) + return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10)) +end +local function YPSNR(x1, x2) + local mse = YMSE(x1, x2) + return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10)) +end + +local function transform_jpeg(x) + for i = 1, opt.jpeg_times do + jpeg = gm.Image(x, "RGB", "DHW") + jpeg:format("jpeg") + jpeg:samplingFactors({1.0, 1.0, 1.0}) + blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down) + jpeg:fromBlob(blob, len) + x = jpeg:toTensor("byte", "RGB", "DHW") + end + return x +end +local function transform_scale(x) + return iproc.scale(x, + x:size(3) * 0.5, + x:size(2) * 0.5, + "Box") +end + +local function benchmark(color_weight, x, input_func, v1_noise, v2_noise) + local v1_mse = 0 + local v2_mse = 0 + local v1_psnr = 0 + local v2_psnr = 0 + + for i = 1, #x do + local ground_truth = x[i] + local input, v1_output, v2_output + + input = input_func(ground_truth) + input = input:float():div(255) + ground_truth = ground_truth:float():div(255) + + t = sys.clock() + if input:size(3) == ground_truth:size(3) then + v1_output = reconstruct.image(v1_noise, input) + v2_output = reconstruct.image(v2_noise, input) + else + v1_output = reconstruct.scale(v1_noise, 2.0, input) + v2_output = reconstruct.scale(v2_noise, 2.0, input) + end + if color_weight == "y" then + v1_mse = v1_mse + YMSE(ground_truth, v1_output) + v1_psnr = v1_psnr + YPSNR(ground_truth, v1_output) + v2_mse = v2_mse + YMSE(ground_truth, v2_output) + v2_psnr = v2_psnr + YPSNR(ground_truth, v2_output) + elseif color_weight == "rgb" then + v1_mse = v1_mse + MSE(ground_truth, v1_output) + v1_psnr = v1_psnr + PSNR(ground_truth, v1_output) + v2_mse = v2_mse + MSE(ground_truth, v2_output) + v2_psnr = v2_psnr + PSNR(ground_truth, v2_output) + end + + io.stdout:write( + string.format("%d/%d; v1_mse=%f, v2_mse=%f, v1_psnr=%f, v2_psnr=%f \r", + i, #x, + v1_mse / i, v2_mse / i, + v1_psnr / i, v2_psnr / i + ) + ) + io.stdout:flush() + end + io.stdout:write("\n") +end +local function crop_4x(x) + local w = x:size(3) % 4 + local h = x:size(2) % 4 + return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h) +end +local function load_data(test_dir) + local test_x = {} + local files = dir.getfiles(test_dir, "*.*") + for i = 1, #files do + table.insert(test_x, crop_4x(image_loader.load_byte(files[i]))) + xlua.progress(i, #files) + end + return test_x +end + +print(opt) +torch.manualSeed(opt.seed) +cutorch.manualSeed(opt.seed) +if opt.method == "scale" then + local v1 = torch.load(path.join(opt.model1_dir, "scale2.0x_model.t7"), "ascii") + local v2 = torch.load(path.join(opt.model2_dir, "scale2.0x_model.t7"), "ascii") + local test_x = load_data(opt.dir) + benchmark(opt.color_weight, test_x, transform_scale, v1, v2) +elseif opt.method == "noise" then + local v1 = torch.load(path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii") + local v2 = torch.load(path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii") + local test_x = load_data(opt.dir) + benchmark(opt.color_weight, test_x, transform_jpeg, v1, v2) +end diff --git a/cleanup_model.lua b/tools/cleanup_model.lua similarity index 85% rename from cleanup_model.lua rename to tools/cleanup_model.lua index 7a77ec1..408ae5d 100644 --- a/cleanup_model.lua +++ b/tools/cleanup_model.lua @@ -1,6 +1,7 @@ -require './lib/portable' -require './lib/mynn' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path +require 'w2nn' torch.setdefaulttensortype("torch.FloatTensor") -- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049 @@ -27,7 +28,7 @@ local function cleanupModel(node) if node.finput ~= nil then node.finput = zeroDataSize(node.finput) end - if tostring(node) == "nn.LeakyReLU" then + if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then if node.negative ~= nil then node.negative = zeroDataSize(node.negative) end diff --git a/export_model.lua b/tools/export_model.lua similarity index 76% rename from export_model.lua rename to tools/export_model.lua index bb91a0d..b89e4b4 100644 --- a/export_model.lua +++ b/tools/export_model.lua @@ -1,6 +1,7 @@ -- adapted from https://github.com/marcan/cl-waifu2x -require './lib/portable' -require './lib/LeakyReLU' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path +require 'w2nn' local cjson = require "cjson" local model = torch.load(arg[1], "ascii") diff --git a/train.lua b/train.lua index 22b817f..43fb63d 100644 --- a/train.lua +++ b/train.lua @@ -1,17 +1,18 @@ -require './lib/portable' -require './lib/mynn' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path require 'optim' require 'xlua' require 'pl' -require 'snappy' -local settings = require './lib/settings' -local srcnn = require './lib/srcnn' -local minibatch_adam = require './lib/minibatch_adam' -local iproc = require './lib/iproc' -local reconstruct = require './lib/reconstruct' -local pairwise_transform = require './lib/pairwise_transform' -local image_loader = require './lib/image_loader' +require 'w2nn' +local settings = require 'settings' +local srcnn = require 'srcnn' +local minibatch_adam = require 'minibatch_adam' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' +local compression = require 'compression' +local pairwise_transform = require 'pairwise_transform' +local image_loader = require 'image_loader' local function save_test_scale(model, rgb, file) local up = reconstruct.scale(model, settings.scale, rgb) @@ -73,17 +74,13 @@ local function create_criterion(model) weight[1]:fill(0.299 * 3) -- R weight[2]:fill(0.587 * 3) -- G weight[3]:fill(0.114 * 3) -- B - return mynn.RGBWeightedMSECriterion(weight):cuda() + return w2nn.WeightedMSECriterion(weight):cuda() else return nn.MSECriterion():cuda() end end local function transformer(x, is_validation, n, offset) - local size = x[1] - local dec = snappy.decompress(x[2]:string()) - x = torch.ByteTensor(size[1], size[2], size[3]) - x:storage():string(dec) - + x = compression.decompress(x) n = n or settings.batch_size; if is_validation == nil then is_validation = false end local color_noise = nil diff --git a/waifu2x.lua b/waifu2x.lua index 3548302..ea1bbe5 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -1,11 +1,11 @@ -require './lib/portable' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path require 'sys' require 'pl' -require './lib/mynn' - -local iproc = require './lib/iproc' -local reconstruct = require './lib/reconstruct' -local image_loader = require './lib/image_loader' +require 'w2nn' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' +local image_loader = require 'image_loader' torch.setdefaulttensortype('torch.FloatTensor') @@ -111,8 +111,12 @@ local function waifu2x() cmd:option("-noise_level", 1, '(1|2)') cmd:option("-crop_size", 128, 'patch size per process') cmd:option("-resume", 0, "skip existing files (0|1)") - + cmd:option("-thread", -1, "number of CPU threads") + local opt = cmd:parse(arg) + if opt.thread > 0 then + torch.setnumthreads(opt.thread) + end if string.len(opt.l) == 0 then convert_image(opt) else diff --git a/web.lua b/web.lua index e2a06a4..9d1fa00 100644 --- a/web.lua +++ b/web.lua @@ -1,11 +1,16 @@ +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path _G.TURBO_SSL = true + +require 'pl' +require 'w2nn' local turbo = require 'turbo' local uuid = require 'uuid' local ffi = require 'ffi' local md5 = require 'md5' -require 'pl' -require 'lib.portable' -require 'lib.mynn' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' +local image_loader = require 'image_loader' local cmd = torch.CmdLine() cmd:text() @@ -13,18 +18,15 @@ cmd:text("waifu2x-api") cmd:text("Options:") cmd:option("-port", 8812, 'listen port') cmd:option("-gpu", 1, 'Device ID') -cmd:option("-core", 2, 'number of CPU cores') +cmd:option("-thread", -1, 'number of CPU threads') local opt = cmd:parse(arg) cutorch.setDevice(opt.gpu) torch.setdefaulttensortype('torch.FloatTensor') -torch.setnumthreads(opt.core) - -local iproc = require './lib/iproc' -local reconstruct = require './lib/reconstruct' -local image_loader = require './lib/image_loader' - -local MODEL_DIR = "./models/anime_style_art_rgb3" +if opt.thread > 0 then + torch.setnumthreads(opt.thread) +end +local MODEL_DIR = "./models/anime_style_art_rgb" local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii") local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii") local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")