diff --git a/lib/alpha_util.lua b/lib/alpha_util.lua new file mode 100644 index 0000000..9105f06 --- /dev/null +++ b/lib/alpha_util.lua @@ -0,0 +1,80 @@ +local w2nn = require 'w2nn' +local reconstruct = require 'reconstruct' +local image = require 'image' +local iproc = require 'iproc' +local gm = require 'graphicsmagick' + +alpha_util = {} +alpha_util.sum2d = nn.SpatialConvolutionMM(1, 1, 3, 3, 1, 1, 1, 1):cuda() +alpha_util.sum2d.weight:fill(1) +alpha_util.sum2d.bias:zero() + +function alpha_util.make_border(rgb, alpha, offset) + if not alpha then + return rgb + end + + local mask = alpha:clone() + mask[torch.gt(mask, 0.0)] = 1 + mask[torch.eq(mask, 0.0)] = 0 + local mask_nega = (mask - 1):abs():byte() + local eps = 1.0e-7 + + rgb = rgb:clone() + rgb[1][mask_nega] = 0 + rgb[2][mask_nega] = 0 + rgb[3][mask_nega] = 0 + + for i = 1, offset do + local mask_weight = alpha_util.sum2d:forward(mask:cuda()):float() + local border = rgb:clone() + for j = 1, 3 do + border[j]:copy(alpha_util.sum2d:forward(rgb[j]:reshape(1, rgb:size(2), rgb:size(3)):cuda())) + border[j]:cdiv((mask_weight + eps)) + rgb[j][mask_nega] = border[j][mask_nega] + end + mask = mask_weight:clone() + mask[torch.gt(mask_weight, 0.0)] = 1 + mask_nega = (mask - 1):abs():byte() + end + rgb[torch.gt(rgb, 1.0)] = 1.0 + rgb[torch.lt(rgb, 0.0)] = 0.0 + + return rgb +end +function alpha_util.composite(rgb, alpha, model2x) + if not alpha then + return rgb + end + if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then + if model2x then + alpha = reconstruct.scale(model2x, 2.0, alpha) + else + alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW") + end + end + local out = torch.Tensor(4, rgb:size(2), rgb:size(3)) + out[1]:copy(rgb[1]) + out[2]:copy(rgb[2]) + out[3]:copy(rgb[3]) + out[4]:copy(alpha) + return out +end + +local function test() + require 'sys' + require 'trepl' + torch.setdefaulttensortype("torch.FloatTensor") + + local image_loader = require 'image_loader' + local rgb, alpha = image_loader.load_float("alpha.png") + local t = sys.clock() + rgb = alpha_util.make_border(rgb, alpha, 7) + print(sys.clock() - t) + print(rgb:min(), rgb:max()) + image.display({image = rgb, min = 0, max = 1}) + image.save("out.png", rgb) +end +--test() + +return alpha_util diff --git a/lib/image_loader.lua b/lib/image_loader.lua index 6be2ba4..b5717b8 100644 --- a/lib/image_loader.lua +++ b/lib/image_loader.lua @@ -9,47 +9,32 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5) local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5) local background_color = 0.5 -function image_loader.encode_png(rgb, alpha, depth) +function image_loader.encode_png(rgb, depth) depth = depth or 8 rgb = iproc.byte2float(rgb) - if alpha then - if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then - alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW") - end - local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3)) - rgba[1]:copy(rgb[1]) - rgba[2]:copy(rgb[2]) - rgba[3]:copy(rgb[3]) - rgba[4]:copy(alpha) - - if depth < 16 then - rgba:add(clip_eps8) - rgba[torch.lt(rgba, 0.0)] = 0.0 - rgba[torch.gt(rgba, 1.0)] = 1.0 - else - rgba:add(clip_eps16) - rgba[torch.lt(rgba, 0.0)] = 0.0 - rgba[torch.gt(rgba, 1.0)] = 1.0 - end - local im = gm.Image():fromTensor(rgba, "RGBA", "DHW") - return im:depth(depth):format("PNG"):toString(9) + if depth < 16 then + rgb = rgb:clone():add(clip_eps8) + rgb[torch.lt(rgb, 0.0)] = 0.0 + rgb[torch.gt(rgb, 1.0)] = 1.0 else - if depth < 16 then - rgb = rgb:clone():add(clip_eps8) - rgb[torch.lt(rgb, 0.0)] = 0.0 - rgb[torch.gt(rgb, 1.0)] = 1.0 - else - rgb = rgb:clone():add(clip_eps16) - rgb[torch.lt(rgb, 0.0)] = 0.0 - rgb[torch.gt(rgb, 1.0)] = 1.0 - end - local im = gm.Image(rgb, "RGB", "DHW") - return im:depth(depth):format("PNG"):toString(9) + rgb = rgb:clone():add(clip_eps16) + rgb[torch.lt(rgb, 0.0)] = 0.0 + rgb[torch.gt(rgb, 1.0)] = 1.0 end + local im + if rgb:size(1) == 4 then -- RGBA + im = gm.Image(rgb, "RGBA", "DHW") + elseif rgb:size(1) == 3 then -- RGB + im = gm.Image(rgb, "RGB", "DHW") + elseif rgb:size(1) == 1 then -- Y + im = gm.Image(rgb, "I", "DHW") + -- im:colorspace("GRAY") -- it does not work + end + return im:depth(depth):format("PNG"):toString(9) end -function image_loader.save_png(filename, rgb, alpha, depth) +function image_loader.save_png(filename, rgb, depth) depth = depth or 8 - local blob = image_loader.encode_png(rgb, alpha, depth) + local blob = image_loader.encode_png(rgb, depth) local fp = io.open(filename, "wb") if not fp then error("IO error: " .. filename) diff --git a/lib/iproc.lua b/lib/iproc.lua index 2ea20c4..3bec420 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -49,12 +49,17 @@ function iproc.float2byte(src) return dest, conversion end function iproc.scale(src, width, height, filter) - local conversion + local conversion, color src, conversion = iproc.byte2float(src) filter = filter or "Box" - local im = gm.Image(src, "RGB", "DHW") + if src:size(1) == 3 then + color = "RGB" + else + color = "I" + end + local im = gm.Image(src, color, "DHW") im:size(math.ceil(width), math.ceil(height), filter) - local dest = im:toTensor("float", "RGB", "DHW") + local dest = im:toTensor("float", color, "DHW") if conversion then dest = iproc.float2byte(dest) end @@ -84,6 +89,16 @@ function iproc.padding(img, w1, w2, h1, h2) flow[2]:add(-w1) return image.warp(img, flow, "simple", false, "clamp") end +function iproc.zero_padding(img, w1, w2, h1, h2) + local dst_height = img:size(2) + h1 + h2 + local dst_width = img:size(3) + w1 + w2 + local flow = torch.Tensor(2, dst_height, dst_width) + flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width)) + flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width)) + flow[1]:add(-h1) + flow[2]:add(-w1) + return image.warp(img, flow, "simple", false, "pad", 0) +end function iproc.white_noise(src, std, rgb_weights, gamma) gamma = gamma or 0.454545 local conversion diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 0aa03b6..1aeee12 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -189,22 +189,48 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size) end function reconstruct.image(model, x, block_size) - if reconstruct.is_rgb(model) then - return reconstruct.image_rgb(model, x, - reconstruct.offset_size(model), block_size) - else - return reconstruct.image_y(model, x, - reconstruct.offset_size(model), block_size) + local i2rgb = false + if x:size(1) == 1 then + local new_x = torch.Tensor(3, x:size(2), x:size(3)) + new_x[1]:copy(x) + new_x[2]:copy(x) + new_x[3]:copy(x) + x = new_x + i2rgb = true end + if reconstruct.is_rgb(model) then + x = reconstruct.image_rgb(model, x, + reconstruct.offset_size(model), block_size) + else + x = reconstruct.image_y(model, x, + reconstruct.offset_size(model), block_size) + end + if i2rgb then + x = image.rgb2y(x) + end + return x end function reconstruct.scale(model, scale, x, block_size) - if reconstruct.is_rgb(model) then - return reconstruct.scale_rgb(model, scale, x, - reconstruct.offset_size(model), block_size) - else - return reconstruct.scale_y(model, scale, x, - reconstruct.offset_size(model), block_size) + local i2rgb = false + if x:size(1) == 1 then + local new_x = torch.Tensor(3, x:size(2), x:size(3)) + new_x[1]:copy(x) + new_x[2]:copy(x) + new_x[3]:copy(x) + x = new_x + i2rgb = true end + if reconstruct.is_rgb(model) then + x = reconstruct.scale_rgb(model, scale, x, + reconstruct.offset_size(model), block_size) + else + x = reconstruct.scale_y(model, scale, x, + reconstruct.offset_size(model), block_size) + end + if i2rgb then + x = image.rgb2y(x) + end + return x end local function tta(f, model, x, block_size) local average = nil diff --git a/waifu2x.lua b/waifu2x.lua index a1faf7f..c61f325 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -6,6 +6,7 @@ require 'w2nn' local iproc = require 'iproc' local reconstruct = require 'reconstruct' local image_loader = require 'image_loader' +local alpha_util = require 'alpha_util' torch.setdefaulttensortype('torch.FloatTensor') @@ -14,6 +15,7 @@ local function convert_image(opt) local new_x = nil local t = sys.clock() local scale_f, image_f + if opt.tta == 1 then scale_f = reconstruct.scale_tta image_f = reconstruct.image_tta @@ -34,13 +36,16 @@ local function convert_image(opt) error("Load Error: " .. model_path) end new_x = image_f(model, x, opt.crop_size) + new_x = alpha_util.composite(new_x, alpha) elseif opt.m == "scale" then local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) local model = torch.load(model_path, "ascii") if not model then error("Load Error: " .. model_path) end + x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model)) new_x = scale_f(model, opt.scale, x, opt.crop_size) + new_x = alpha_util.composite(new_x, alpha, model) elseif opt.m == "noise_scale" then local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)) local noise_model = torch.load(noise_model_path, "ascii") @@ -53,15 +58,14 @@ local function convert_image(opt) if not scale_model then error("Load Error: " .. scale_model_path) end + x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) x = image_f(noise_model, x, opt.crop_size) new_x = scale_f(scale_model, opt.scale, x, opt.crop_size) + new_x = alpha_util.composite(new_x, alpha, scale_model) else error("undefined method:" .. opt.method) end - if opt.white_noise == 1 then - new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0}) - end - image_loader.save_png(opt.o, new_x, alpha, opt.depth) + image_loader.save_png(opt.o, new_x, opt.depth) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") end local function convert_frames(opt) @@ -128,23 +132,27 @@ local function convert_frames(opt) local new_x = nil if opt.m == "noise" and opt.noise_level == 1 then new_x = image_f(noise1_model, x, opt.crop_size) + new_x = alpha_util.composite(new_x, alpha) elseif opt.m == "noise" and opt.noise_level == 2 then new_x = image_func(noise2_model, x, opt.crop_size) + new_x = alpha_util.composite(new_x, alpha) elseif opt.m == "scale" then + x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) new_x = scale_f(scale_model, opt.scale, x, opt.crop_size) + new_x = alpha_util.composite(new_x, alpha, scale_model) elseif opt.m == "noise_scale" and opt.noise_level == 1 then + x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) x = image_f(noise1_model, x, opt.crop_size) new_x = scale_f(scale_model, opt.scale, x, opt.crop_size) + new_x = alpha_util.composite(new_x, alpha, scale_model) elseif opt.m == "noise_scale" and opt.noise_level == 2 then + x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) x = image_f(noise2_model, x, opt.crop_size) new_x = scale_f(scale_model, opt.scale, x, opt.crop_size) + new_x = alpha_util.composite(new_x, alpha, scale_model) else error("undefined method:" .. opt.method) end - if opt.white_noise == 1 then - new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0}) - end - local output = nil if opt.o == "(auto)" then local name = path.basename(lines[i]) @@ -154,7 +162,7 @@ local function convert_frames(opt) else output = string.format(opt.o, i) end - image_loader.save_png(output, new_x, alpha, opt.depth) + image_loader.save_png(output, new_x, opt.depth) xlua.progress(i, #lines) if i % 10 == 0 then collectgarbage() @@ -182,8 +190,6 @@ local function waifu2x() cmd:option("-resume", 0, "skip existing files (0|1)") cmd:option("-thread", -1, "number of CPU threads") cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)') - cmd:option("-white_noise", 0, 'adding white noise to output image (0|1)') - cmd:option("-white_noise_std", 0.0055, 'standard division of white noise') local opt = cmd:parse(arg) if opt.thread > 0 then diff --git a/web.lua b/web.lua index 918d5c0..7fb8076 100644 --- a/web.lua +++ b/web.lua @@ -11,6 +11,7 @@ local md5 = require 'md5' local iproc = require 'iproc' local reconstruct = require 'reconstruct' local image_loader = require 'image_loader' +local alpha_util = require 'alpha_util' -- Notes: turbo and xlua has different implementation of string:split(). -- Therefore, string:split() has conflict issue. @@ -104,14 +105,36 @@ local function cleanup_model(model) w2nn.cleanup_model(model) -- release GPU memory end end -local function convert(x, options) +local function convert(x, alpha, options) local cache_file = path.join(CACHE_DIR, options.prefix .. ".png") + local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png") + local alpha_orig = alpha + + if path.exists(alpha_cache_file) then + alpha = image_loader.load_float(alpha_cache_file) + if alpha:dim() == 2 then + alpha = alpha:reshape(1, alpha:size(1), alpha:size(2)) + end + if alpha:size(1) == 3 then + alpha = image.rgb2y(alpha) + end + end if path.exists(cache_file) then - return image.load(cache_file) + x = image_loader.load_float(cache_file) + return x, alpha else if options.style == "art" then + if options.border then + x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model)) + end if options.method == "scale" then x = reconstruct.scale(art_scale2_model, 2.0, x) + if alpha then + if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then + alpha = reconstruct.scale(art_scale2_model, 2.0, alpha) + image_loader.save_png(alpha_cache_file, alpha) + end + end cleanup_model(art_scale2_model) elseif options.method == "noise1" then x = reconstruct.image(art_noise1_model, x) @@ -121,8 +144,17 @@ local function convert(x, options) cleanup_model(art_noise2_model) end else --[[photo + if options.border then + x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model)) + end if options.method == "scale" then x = reconstruct.scale(photo_scale2_model, 2.0, x) + if alpha then + if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then + alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha) + image_loader.save_png(alpha_cache_file, alpha) + end + end cleanup_model(photo_scale2_model) elseif options.method == "noise1" then x = reconstruct.image(photo_noise1_model, x) @@ -133,8 +165,9 @@ local function convert(x, options) end --]] end - image.save(cache_file, x) - return x + image_loader.save_png(cache_file, x) + + return x, alpha end end local function client_disconnected(handler) @@ -154,7 +187,6 @@ function APIHandler:post() local x, alpha, blob = get_image(self) local scale = tonumber(self:get_argument("scale", "0")) local noise = tonumber(self:get_argument("noise", "0")) - local white_noise = tonumber(self:get_argument("white_noise", "0")) local style = self:get_argument("style", "art") local download = (self:get_argument("download", "")):len() @@ -164,29 +196,39 @@ function APIHandler:post() if x and valid_size(x, scale) then if (noise ~= 0 or scale ~= 0) then local hash = md5.sumhexa(blob) + local alpha_prefix = style .. "_" .. hash .. "_alpha" + local border = false + if scale ~= 0 and alpha then + border = true + end if noise == 1 then - x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash}) + x = convert(x, alpha, {method = "noise1", style = style, + prefix = style .. "_noise1_" .. hash, + alpha_prefix = alpha_prefix, border = border}) + border = false elseif noise == 2 then - x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash}) + x = convert(x, alpha, {method = "noise2", style = style, + prefix = style .. "_noise2_" .. hash, + alpha_prefix = alpha_prefix, border = border}) + border = false end if scale == 1 or scale == 2 then + local prefix if noise == 1 then - x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash}) + prefix = style .. "_noise1_scale_" .. hash elseif noise == 2 then - x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash}) + prefix = style .. "_noise2_scale_" .. hash else - x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash}) + prefix = style .. "_scale_" .. hash end + x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix, alpha_prefix = alpha_prefix, border = border}) if scale == 1 then x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc") end end - if white_noise == 1 then - x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0}) - end end local name = uuid() .. ".png" - local blob = image_loader.encode_png(x, alpha) + local blob = image_loader.encode_png(alpha_util.composite(x, alpha)) self:set_header("Content-Disposition", string.format('filename="%s"', name)) self:set_header("Content-Length", string.format("%d", #blob)) if download > 0 then