1
0
Fork 0
mirror of synced 2024-06-18 19:04:30 +12:00

Add -batch_size option to waifu2x.lua/web.lua

This commit is contained in:
nagadomi 2016-06-09 14:03:18 +09:00
parent 0b949c05a7
commit afac4b52ab
3 changed files with 67 additions and 43 deletions

View file

@ -2,7 +2,8 @@ require 'image'
local iproc = require 'iproc'
local srcnn = require 'srcnn'
local function reconstruct_nn(model, x, inner_scale, offset, block_size)
local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_size)
batch_size = batch_size or 1
if x:dim() == 2 then
x = x:reshape(1, x:size(1), x:size(2))
end
@ -12,24 +13,46 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size)
local output_block_size = block_size
local output_size = output_block_size - offset * 2
local output_size_in_input = input_block_size - math.ceil(offset / inner_scale) * 2
local input = torch.CudaTensor(1, ch, input_block_size, input_block_size)
local input_indexes = {}
local output_indexes = {}
for i = 1, x:size(2), output_size_in_input do
for j = 1, x:size(3), output_size_in_input do
if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
local index = {{},
{i, i + input_block_size - 1},
{j, j + input_block_size - 1}}
input:copy(x[index])
local output = model:forward(input)
output = output:view(ch, output_size, output_size)
local ii = (i - 1) * inner_scale + 1
local jj = (j - 1) * inner_scale + 1
local output_index = {{}, { ii , ii + output_size - 1 },
{ jj, jj + output_size - 1}}
new_x[output_index]:copy(output)
table.insert(input_indexes, index)
table.insert(output_indexes, output_index)
end
end
end
local input = torch.Tensor(batch_size, ch, input_block_size, input_block_size)
local input_cuda = torch.CudaTensor(batch_size, ch, input_block_size, input_block_size)
for i = 1, #input_indexes, batch_size do
local c = 0
local output
for j = 0, batch_size - 1 do
if i + j > #input_indexes then
break
end
input[j+1]:copy(x[input_indexes[i + j]])
c = c + 1
end
input_cuda:copy(input)
if c == batch_size then
output = model:forward(input_cuda)
else
output = model:forward(input_cuda:narrow(1, 1, c))
end
--output = output:view(batch_size, ch, output_size, output_size)
for j = 0, c - 1 do
new_x[output_indexes[i + j]]:copy(output[j+1])
end
end
return new_x
end
local reconstruct = {}
@ -72,11 +95,11 @@ local function padding_params(x, model, block_size)
p.pad_w2 = (w - input_offset) - p.x_w
return p
end
function reconstruct.image_y(model, x, offset, block_size)
function reconstruct.image_y(model, x, offset, block_size, batch_size)
block_size = block_size or 128
local p = padding_params(x, model, block_size)
x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size)
local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
x = iproc.crop(x, p.pad_w1, p.pad_w2, p.pad_w1 + p.x_w, p.pad_w2 + p.x_h)
y = iproc.crop(y, 0, 0, p.x_w, p.x_h)
y[torch.lt(y, 0)] = 0
@ -91,7 +114,7 @@ function reconstruct.image_y(model, x, offset, block_size)
return output
end
function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size, upsampling_filter)
upsampling_filter = upsampling_filter or "Box"
block_size = block_size or 128
local x_lanczos
@ -107,7 +130,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil
end
x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
x_lanczos = image.rgb2yuv(x_lanczos)
local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size)
local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
y = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
y[torch.lt(y, 0)] = 0
y[torch.gt(y, 1)] = 1
@ -122,14 +145,14 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil
return output
end
function reconstruct.image_rgb(model, x, offset, block_size)
function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
block_size = block_size or 128
local p = padding_params(x, model, block_size)
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
if p.x_w * p.x_h > 2048*2048 then
collectgarbage()
end
local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size)
local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
local output = iproc.crop(y, 0, 0, p.x_w, p.x_h)
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
@ -139,7 +162,7 @@ function reconstruct.image_rgb(model, x, offset, block_size)
return output
end
function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size, upsampling_filter)
upsampling_filter = upsampling_filter or "Box"
block_size = block_size or 128
if not reconstruct.has_resize(model) then
@ -151,7 +174,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_f
collectgarbage()
end
local y
y = reconstruct_nn(model, x, p.inner_scale, offset, block_size)
y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
local output = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1

View file

@ -44,13 +44,13 @@ local function convert_image(opt)
local scale_f, image_f
if opt.tta == 1 then
scale_f = function(model, scale, x, block_size, upsampling_filter)
scale_f = function(model, scale, x, block_size, batch_size, batupsampling_filter)
return reconstruct.scale_tta(model, opt.tta_level,
scale, x, block_size, upsampling_filter)
scale, x, block_size, batch_size, upsampling_filter)
end
image_f = function(model, x, block_size)
image_f = function(model, x, block_size, batch_size)
return reconstruct.image_tta(model, opt.tta_level,
x, block_size)
x, block_size, batch_size)
end
else
scale_f = reconstruct.scale
@ -64,7 +64,7 @@ local function convert_image(opt)
error("Load Error: " .. model_path)
end
local t = sys.clock()
new_x = image_f(model, x, opt.crop_size)
new_x = image_f(model, x, opt.crop_size, opt.batch_size)
new_x = alpha_util.composite(new_x, alpha)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
elseif opt.m == "scale" then
@ -75,7 +75,7 @@ local function convert_image(opt)
end
local t = sys.clock()
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model))
new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.batch_size, opt.batch_size, opt.upsampling_filter)
new_x = alpha_util.composite(new_x, alpha, model)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
elseif opt.m == "noise_scale" then
@ -92,7 +92,7 @@ local function convert_image(opt)
end
local t = sys.clock()
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter)
new_x = alpha_util.composite(new_x, alpha, scale_model)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
else
@ -109,8 +109,8 @@ local function convert_image(opt)
end
local t = sys.clock()
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
x = image_f(noise_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
x = image_f(noise_model, x, opt.crop_size, opt.batch_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter)
new_x = alpha_util.composite(new_x, alpha, scale_model)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end
@ -125,13 +125,13 @@ local function convert_frames(opt)
local noise_model = {}
local scale_f, image_f
if opt.tta == 1 then
scale_f = function(model, scale, x, block_size, upsampling_filter)
scale_f = function(model, scale, x, block_size, batch_size, upsampling_filter)
return reconstruct.scale_tta(model, opt.tta_level,
scale, x, block_size, upsampling_filter)
scale, x, block_size, batch_size, upsampling_filter)
end
image_f = function(model, x, block_size)
image_f = function(model, x, block_size, batch_size)
return reconstruct.image_tta(model, opt.tta_level,
x, block_size)
x, block_size, batch_size)
end
else
scale_f = reconstruct.scale
@ -191,19 +191,19 @@ local function convert_frames(opt)
local alpha = meta.alpha
local new_x = nil
if opt.m == "noise" then
new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size, opt.batch_size)
new_x = alpha_util.composite(new_x, alpha)
elseif opt.m == "scale" then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter)
new_x = alpha_util.composite(new_x, alpha, scale_model)
elseif opt.m == "noise_scale" then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
if noise_scale_model[opt.noise_level] then
new_x = scale_f(noise_scale_model[opt.noise_level], opt.scale, x, opt.crop_size, upsampling_filter)
new_x = scale_f(noise_scale_model[opt.noise_level], opt.scale, x, opt.crop_size, opt.batch_size, upsampling_filter)
else
x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, upsampling_filter)
x = image_f(noise_model[opt.noise_level], x, opt.crop_size, opt.batch_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, upsampling_filter)
end
new_x = alpha_util.composite(new_x, alpha, scale_model)
else
@ -220,7 +220,6 @@ local function convert_frames(opt)
end
end
end
local function waifu2x()
local cmd = torch.CmdLine()
cmd:text()
@ -235,6 +234,7 @@ local function waifu2x()
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
cmd:option("-noise_level", 1, '(1|2|3)')
cmd:option("-crop_size", 128, 'patch size per process')
cmd:option("-batch_size", 1, 'batch_size')
cmd:option("-resume", 0, "skip existing files (0|1)")
cmd:option("-thread", -1, "number of CPU threads")
cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')

21
web.lua
View file

@ -27,6 +27,7 @@ cmd:option("-port", 8812, 'listen port')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-upsampling_filter", "Box", 'Upsampling filter (for dev)')
cmd:option("-crop_size", 128, 'patch size per process')
cmd:option("-batch_size", 1, 'batch size')
cmd:option("-thread", -1, 'number of CPU threads')
local opt = cmd:parse(arg)
cutorch.setDevice(opt.gpu)
@ -148,23 +149,23 @@ local function convert(x, meta, options)
end
if options.method == "scale" then
x = reconstruct.scale(art_scale2_model, 2.0, x,
opt.crop_size, opt.upsampling_filter)
opt.crop_size, opt.batch_size, opt.upsampling_filter)
if alpha then
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
opt.crop_size, opt.upsampling_filter)
opt.crop_size, opt.batch_size, opt.upsampling_filter)
image_loader.save_png(alpha_cache_file, alpha)
end
end
cleanup_model(art_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(art_noise1_model, x)
x = reconstruct.image(art_noise1_model, x, opt.crop_size, opt.batch_size)
cleanup_model(art_noise1_model)
elseif options.method == "noise2" then
x = reconstruct.image(art_noise2_model, x)
x = reconstruct.image(art_noise2_model, x, opt.crop_size, opt.batch_size)
cleanup_model(art_noise2_model)
elseif options.method == "noise3" then
x = reconstruct.image(art_noise3_model, x)
x = reconstruct.image(art_noise3_model, x, opt.crop_size, opt.batch_size)
cleanup_model(art_noise3_model)
end
else -- photo
@ -173,23 +174,23 @@ local function convert(x, meta, options)
end
if options.method == "scale" then
x = reconstruct.scale(photo_scale2_model, 2.0, x,
opt.crop_size, opt.upsampling_filter)
opt.crop_size, opt.batch_size, opt.upsampling_filter)
if alpha then
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha,
opt.crop_size, opt.upsampling_filter)
opt.crop_size, opt.batch_size, opt.upsampling_filter)
image_loader.save_png(alpha_cache_file, alpha)
end
end
cleanup_model(photo_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(photo_noise1_model, x)
x = reconstruct.image(photo_noise1_model, x, opt.crop_size, opt.batch_size)
cleanup_model(photo_noise1_model)
elseif options.method == "noise2" then
x = reconstruct.image(photo_noise2_model, x)
x = reconstruct.image(photo_noise2_model, x, opt.crop_size, opt.batch_size)
cleanup_model(photo_noise2_model)
elseif options.method == "noise3" then
x = reconstruct.image(photo_noise3_model, x)
x = reconstruct.image(photo_noise3_model, x, opt.crop_size, opt.batch_size)
cleanup_model(photo_noise3_model)
end
end