From afac4b52ab687f8d26dd9d84e773383d458acfca Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 9 Jun 2016 14:03:18 +0900 Subject: [PATCH] Add -batch_size option to waifu2x.lua/web.lua --- lib/reconstruct.lua | 51 ++++++++++++++++++++++++++++++++------------- waifu2x.lua | 38 ++++++++++++++++----------------- web.lua | 21 ++++++++++--------- 3 files changed, 67 insertions(+), 43 deletions(-) diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 7ff406c..d93b9da 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -2,7 +2,8 @@ require 'image' local iproc = require 'iproc' local srcnn = require 'srcnn' -local function reconstruct_nn(model, x, inner_scale, offset, block_size) +local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_size) + batch_size = batch_size or 1 if x:dim() == 2 then x = x:reshape(1, x:size(1), x:size(2)) end @@ -12,24 +13,46 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size) local output_block_size = block_size local output_size = output_block_size - offset * 2 local output_size_in_input = input_block_size - math.ceil(offset / inner_scale) * 2 - local input = torch.CudaTensor(1, ch, input_block_size, input_block_size) + local input_indexes = {} + local output_indexes = {} for i = 1, x:size(2), output_size_in_input do for j = 1, x:size(3), output_size_in_input do if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then local index = {{}, {i, i + input_block_size - 1}, {j, j + input_block_size - 1}} - input:copy(x[index]) - local output = model:forward(input) - output = output:view(ch, output_size, output_size) local ii = (i - 1) * inner_scale + 1 local jj = (j - 1) * inner_scale + 1 local output_index = {{}, { ii , ii + output_size - 1 }, { jj, jj + output_size - 1}} - new_x[output_index]:copy(output) + table.insert(input_indexes, index) + table.insert(output_indexes, output_index) end end end + local input = torch.Tensor(batch_size, ch, input_block_size, input_block_size) + local input_cuda = torch.CudaTensor(batch_size, ch, input_block_size, input_block_size) + for i = 1, #input_indexes, batch_size do + local c = 0 + local output + for j = 0, batch_size - 1 do + if i + j > #input_indexes then + break + end + input[j+1]:copy(x[input_indexes[i + j]]) + c = c + 1 + end + input_cuda:copy(input) + if c == batch_size then + output = model:forward(input_cuda) + else + output = model:forward(input_cuda:narrow(1, 1, c)) + end + --output = output:view(batch_size, ch, output_size, output_size) + for j = 0, c - 1 do + new_x[output_indexes[i + j]]:copy(output[j+1]) + end + end return new_x end local reconstruct = {} @@ -72,11 +95,11 @@ local function padding_params(x, model, block_size) p.pad_w2 = (w - input_offset) - p.x_w return p end -function reconstruct.image_y(model, x, offset, block_size) +function reconstruct.image_y(model, x, offset, block_size, batch_size) block_size = block_size or 128 local p = padding_params(x, model, block_size) x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)) - local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size) + local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size) x = iproc.crop(x, p.pad_w1, p.pad_w2, p.pad_w1 + p.x_w, p.pad_w2 + p.x_h) y = iproc.crop(y, 0, 0, p.x_w, p.x_h) y[torch.lt(y, 0)] = 0 @@ -91,7 +114,7 @@ function reconstruct.image_y(model, x, offset, block_size) return output end -function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter) +function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size, upsampling_filter) upsampling_filter = upsampling_filter or "Box" block_size = block_size or 128 local x_lanczos @@ -107,7 +130,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil end x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)) x_lanczos = image.rgb2yuv(x_lanczos) - local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size) + local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size) y = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale) y[torch.lt(y, 0)] = 0 y[torch.gt(y, 1)] = 1 @@ -122,14 +145,14 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil return output end -function reconstruct.image_rgb(model, x, offset, block_size) +function reconstruct.image_rgb(model, x, offset, block_size, batch_size) block_size = block_size or 128 local p = padding_params(x, model, block_size) x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2) if p.x_w * p.x_h > 2048*2048 then collectgarbage() end - local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size) + local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size) local output = iproc.crop(y, 0, 0, p.x_w, p.x_h) output[torch.lt(output, 0)] = 0 output[torch.gt(output, 1)] = 1 @@ -139,7 +162,7 @@ function reconstruct.image_rgb(model, x, offset, block_size) return output end -function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter) +function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size, upsampling_filter) upsampling_filter = upsampling_filter or "Box" block_size = block_size or 128 if not reconstruct.has_resize(model) then @@ -151,7 +174,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_f collectgarbage() end local y - y = reconstruct_nn(model, x, p.inner_scale, offset, block_size) + y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size) local output = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale) output[torch.lt(output, 0)] = 0 output[torch.gt(output, 1)] = 1 diff --git a/waifu2x.lua b/waifu2x.lua index fb58079..c3bb77f 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -44,13 +44,13 @@ local function convert_image(opt) local scale_f, image_f if opt.tta == 1 then - scale_f = function(model, scale, x, block_size, upsampling_filter) + scale_f = function(model, scale, x, block_size, batch_size, batupsampling_filter) return reconstruct.scale_tta(model, opt.tta_level, - scale, x, block_size, upsampling_filter) + scale, x, block_size, batch_size, upsampling_filter) end - image_f = function(model, x, block_size) + image_f = function(model, x, block_size, batch_size) return reconstruct.image_tta(model, opt.tta_level, - x, block_size) + x, block_size, batch_size) end else scale_f = reconstruct.scale @@ -64,7 +64,7 @@ local function convert_image(opt) error("Load Error: " .. model_path) end local t = sys.clock() - new_x = image_f(model, x, opt.crop_size) + new_x = image_f(model, x, opt.crop_size, opt.batch_size) new_x = alpha_util.composite(new_x, alpha) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") elseif opt.m == "scale" then @@ -75,7 +75,7 @@ local function convert_image(opt) end local t = sys.clock() x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model)) - new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.upsampling_filter) + new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.batch_size, opt.batch_size, opt.upsampling_filter) new_x = alpha_util.composite(new_x, alpha, model) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") elseif opt.m == "noise_scale" then @@ -92,7 +92,7 @@ local function convert_image(opt) end local t = sys.clock() x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) - new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.upsampling_filter) + new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter) new_x = alpha_util.composite(new_x, alpha, scale_model) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") else @@ -109,8 +109,8 @@ local function convert_image(opt) end local t = sys.clock() x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) - x = image_f(noise_model, x, opt.crop_size) - new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter) + x = image_f(noise_model, x, opt.crop_size, opt.batch_size) + new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter) new_x = alpha_util.composite(new_x, alpha, scale_model) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") end @@ -125,13 +125,13 @@ local function convert_frames(opt) local noise_model = {} local scale_f, image_f if opt.tta == 1 then - scale_f = function(model, scale, x, block_size, upsampling_filter) + scale_f = function(model, scale, x, block_size, batch_size, upsampling_filter) return reconstruct.scale_tta(model, opt.tta_level, - scale, x, block_size, upsampling_filter) + scale, x, block_size, batch_size, upsampling_filter) end - image_f = function(model, x, block_size) + image_f = function(model, x, block_size, batch_size) return reconstruct.image_tta(model, opt.tta_level, - x, block_size) + x, block_size, batch_size) end else scale_f = reconstruct.scale @@ -191,19 +191,19 @@ local function convert_frames(opt) local alpha = meta.alpha local new_x = nil if opt.m == "noise" then - new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size) + new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size, opt.batch_size) new_x = alpha_util.composite(new_x, alpha) elseif opt.m == "scale" then x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) - new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter) + new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter) new_x = alpha_util.composite(new_x, alpha, scale_model) elseif opt.m == "noise_scale" then x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) if noise_scale_model[opt.noise_level] then - new_x = scale_f(noise_scale_model[opt.noise_level], opt.scale, x, opt.crop_size, upsampling_filter) + new_x = scale_f(noise_scale_model[opt.noise_level], opt.scale, x, opt.crop_size, opt.batch_size, upsampling_filter) else - x = image_f(noise_model[opt.noise_level], x, opt.crop_size) - new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, upsampling_filter) + x = image_f(noise_model[opt.noise_level], x, opt.crop_size, opt.batch_size) + new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, upsampling_filter) end new_x = alpha_util.composite(new_x, alpha, scale_model) else @@ -220,7 +220,6 @@ local function convert_frames(opt) end end end - local function waifu2x() local cmd = torch.CmdLine() cmd:text() @@ -235,6 +234,7 @@ local function waifu2x() cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)') cmd:option("-noise_level", 1, '(1|2|3)') cmd:option("-crop_size", 128, 'patch size per process') + cmd:option("-batch_size", 1, 'batch_size') cmd:option("-resume", 0, "skip existing files (0|1)") cmd:option("-thread", -1, "number of CPU threads") cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)') diff --git a/web.lua b/web.lua index c5aca15..96335a1 100644 --- a/web.lua +++ b/web.lua @@ -27,6 +27,7 @@ cmd:option("-port", 8812, 'listen port') cmd:option("-gpu", 1, 'Device ID') cmd:option("-upsampling_filter", "Box", 'Upsampling filter (for dev)') cmd:option("-crop_size", 128, 'patch size per process') +cmd:option("-batch_size", 1, 'batch size') cmd:option("-thread", -1, 'number of CPU threads') local opt = cmd:parse(arg) cutorch.setDevice(opt.gpu) @@ -148,23 +149,23 @@ local function convert(x, meta, options) end if options.method == "scale" then x = reconstruct.scale(art_scale2_model, 2.0, x, - opt.crop_size, opt.upsampling_filter) + opt.crop_size, opt.batch_size, opt.upsampling_filter) if alpha then if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then alpha = reconstruct.scale(art_scale2_model, 2.0, alpha, - opt.crop_size, opt.upsampling_filter) + opt.crop_size, opt.batch_size, opt.upsampling_filter) image_loader.save_png(alpha_cache_file, alpha) end end cleanup_model(art_scale2_model) elseif options.method == "noise1" then - x = reconstruct.image(art_noise1_model, x) + x = reconstruct.image(art_noise1_model, x, opt.crop_size, opt.batch_size) cleanup_model(art_noise1_model) elseif options.method == "noise2" then - x = reconstruct.image(art_noise2_model, x) + x = reconstruct.image(art_noise2_model, x, opt.crop_size, opt.batch_size) cleanup_model(art_noise2_model) elseif options.method == "noise3" then - x = reconstruct.image(art_noise3_model, x) + x = reconstruct.image(art_noise3_model, x, opt.crop_size, opt.batch_size) cleanup_model(art_noise3_model) end else -- photo @@ -173,23 +174,23 @@ local function convert(x, meta, options) end if options.method == "scale" then x = reconstruct.scale(photo_scale2_model, 2.0, x, - opt.crop_size, opt.upsampling_filter) + opt.crop_size, opt.batch_size, opt.upsampling_filter) if alpha then if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha, - opt.crop_size, opt.upsampling_filter) + opt.crop_size, opt.batch_size, opt.upsampling_filter) image_loader.save_png(alpha_cache_file, alpha) end end cleanup_model(photo_scale2_model) elseif options.method == "noise1" then - x = reconstruct.image(photo_noise1_model, x) + x = reconstruct.image(photo_noise1_model, x, opt.crop_size, opt.batch_size) cleanup_model(photo_noise1_model) elseif options.method == "noise2" then - x = reconstruct.image(photo_noise2_model, x) + x = reconstruct.image(photo_noise2_model, x, opt.crop_size, opt.batch_size) cleanup_model(photo_noise2_model) elseif options.method == "noise3" then - x = reconstruct.image(photo_noise3_model, x) + x = reconstruct.image(photo_noise3_model, x, opt.crop_size, opt.batch_size) cleanup_model(photo_noise3_model) end end