Improve alpha channel handling #29
- make border - scale the alpha channel by waifu2x - composite
This commit is contained in:
parent
b829595c21
commit
d2c081bbcf
6 changed files with 229 additions and 75 deletions
80
lib/alpha_util.lua
Normal file
80
lib/alpha_util.lua
Normal file
|
@ -0,0 +1,80 @@
|
|||
local w2nn = require 'w2nn'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image = require 'image'
|
||||
local iproc = require 'iproc'
|
||||
local gm = require 'graphicsmagick'
|
||||
|
||||
alpha_util = {}
|
||||
alpha_util.sum2d = nn.SpatialConvolutionMM(1, 1, 3, 3, 1, 1, 1, 1):cuda()
|
||||
alpha_util.sum2d.weight:fill(1)
|
||||
alpha_util.sum2d.bias:zero()
|
||||
|
||||
function alpha_util.make_border(rgb, alpha, offset)
|
||||
if not alpha then
|
||||
return rgb
|
||||
end
|
||||
|
||||
local mask = alpha:clone()
|
||||
mask[torch.gt(mask, 0.0)] = 1
|
||||
mask[torch.eq(mask, 0.0)] = 0
|
||||
local mask_nega = (mask - 1):abs():byte()
|
||||
local eps = 1.0e-7
|
||||
|
||||
rgb = rgb:clone()
|
||||
rgb[1][mask_nega] = 0
|
||||
rgb[2][mask_nega] = 0
|
||||
rgb[3][mask_nega] = 0
|
||||
|
||||
for i = 1, offset do
|
||||
local mask_weight = alpha_util.sum2d:forward(mask:cuda()):float()
|
||||
local border = rgb:clone()
|
||||
for j = 1, 3 do
|
||||
border[j]:copy(alpha_util.sum2d:forward(rgb[j]:reshape(1, rgb:size(2), rgb:size(3)):cuda()))
|
||||
border[j]:cdiv((mask_weight + eps))
|
||||
rgb[j][mask_nega] = border[j][mask_nega]
|
||||
end
|
||||
mask = mask_weight:clone()
|
||||
mask[torch.gt(mask_weight, 0.0)] = 1
|
||||
mask_nega = (mask - 1):abs():byte()
|
||||
end
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
|
||||
return rgb
|
||||
end
|
||||
function alpha_util.composite(rgb, alpha, model2x)
|
||||
if not alpha then
|
||||
return rgb
|
||||
end
|
||||
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
||||
if model2x then
|
||||
alpha = reconstruct.scale(model2x, 2.0, alpha)
|
||||
else
|
||||
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
|
||||
end
|
||||
end
|
||||
local out = torch.Tensor(4, rgb:size(2), rgb:size(3))
|
||||
out[1]:copy(rgb[1])
|
||||
out[2]:copy(rgb[2])
|
||||
out[3]:copy(rgb[3])
|
||||
out[4]:copy(alpha)
|
||||
return out
|
||||
end
|
||||
|
||||
local function test()
|
||||
require 'sys'
|
||||
require 'trepl'
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
|
||||
local image_loader = require 'image_loader'
|
||||
local rgb, alpha = image_loader.load_float("alpha.png")
|
||||
local t = sys.clock()
|
||||
rgb = alpha_util.make_border(rgb, alpha, 7)
|
||||
print(sys.clock() - t)
|
||||
print(rgb:min(), rgb:max())
|
||||
image.display({image = rgb, min = 0, max = 1})
|
||||
image.save("out.png", rgb)
|
||||
end
|
||||
--test()
|
||||
|
||||
return alpha_util
|
|
@ -9,31 +9,9 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
|||
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
|
||||
local background_color = 0.5
|
||||
|
||||
function image_loader.encode_png(rgb, alpha, depth)
|
||||
function image_loader.encode_png(rgb, depth)
|
||||
depth = depth or 8
|
||||
rgb = iproc.byte2float(rgb)
|
||||
if alpha then
|
||||
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
||||
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
|
||||
end
|
||||
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
|
||||
rgba[1]:copy(rgb[1])
|
||||
rgba[2]:copy(rgb[2])
|
||||
rgba[3]:copy(rgb[3])
|
||||
rgba[4]:copy(alpha)
|
||||
|
||||
if depth < 16 then
|
||||
rgba:add(clip_eps8)
|
||||
rgba[torch.lt(rgba, 0.0)] = 0.0
|
||||
rgba[torch.gt(rgba, 1.0)] = 1.0
|
||||
else
|
||||
rgba:add(clip_eps16)
|
||||
rgba[torch.lt(rgba, 0.0)] = 0.0
|
||||
rgba[torch.gt(rgba, 1.0)] = 1.0
|
||||
end
|
||||
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
||||
return im:depth(depth):format("PNG"):toString(9)
|
||||
else
|
||||
if depth < 16 then
|
||||
rgb = rgb:clone():add(clip_eps8)
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
|
@ -43,13 +21,20 @@ function image_loader.encode_png(rgb, alpha, depth)
|
|||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
end
|
||||
local im = gm.Image(rgb, "RGB", "DHW")
|
||||
return im:depth(depth):format("PNG"):toString(9)
|
||||
local im
|
||||
if rgb:size(1) == 4 then -- RGBA
|
||||
im = gm.Image(rgb, "RGBA", "DHW")
|
||||
elseif rgb:size(1) == 3 then -- RGB
|
||||
im = gm.Image(rgb, "RGB", "DHW")
|
||||
elseif rgb:size(1) == 1 then -- Y
|
||||
im = gm.Image(rgb, "I", "DHW")
|
||||
-- im:colorspace("GRAY") -- it does not work
|
||||
end
|
||||
return im:depth(depth):format("PNG"):toString(9)
|
||||
end
|
||||
function image_loader.save_png(filename, rgb, alpha, depth)
|
||||
function image_loader.save_png(filename, rgb, depth)
|
||||
depth = depth or 8
|
||||
local blob = image_loader.encode_png(rgb, alpha, depth)
|
||||
local blob = image_loader.encode_png(rgb, depth)
|
||||
local fp = io.open(filename, "wb")
|
||||
if not fp then
|
||||
error("IO error: " .. filename)
|
||||
|
|
|
@ -49,12 +49,17 @@ function iproc.float2byte(src)
|
|||
return dest, conversion
|
||||
end
|
||||
function iproc.scale(src, width, height, filter)
|
||||
local conversion
|
||||
local conversion, color
|
||||
src, conversion = iproc.byte2float(src)
|
||||
filter = filter or "Box"
|
||||
local im = gm.Image(src, "RGB", "DHW")
|
||||
if src:size(1) == 3 then
|
||||
color = "RGB"
|
||||
else
|
||||
color = "I"
|
||||
end
|
||||
local im = gm.Image(src, color, "DHW")
|
||||
im:size(math.ceil(width), math.ceil(height), filter)
|
||||
local dest = im:toTensor("float", "RGB", "DHW")
|
||||
local dest = im:toTensor("float", color, "DHW")
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
|
@ -84,6 +89,16 @@ function iproc.padding(img, w1, w2, h1, h2)
|
|||
flow[2]:add(-w1)
|
||||
return image.warp(img, flow, "simple", false, "clamp")
|
||||
end
|
||||
function iproc.zero_padding(img, w1, w2, h1, h2)
|
||||
local dst_height = img:size(2) + h1 + h2
|
||||
local dst_width = img:size(3) + w1 + w2
|
||||
local flow = torch.Tensor(2, dst_height, dst_width)
|
||||
flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width))
|
||||
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
||||
flow[1]:add(-h1)
|
||||
flow[2]:add(-w1)
|
||||
return image.warp(img, flow, "simple", false, "pad", 0)
|
||||
end
|
||||
function iproc.white_noise(src, std, rgb_weights, gamma)
|
||||
gamma = gamma or 0.454545
|
||||
local conversion
|
||||
|
|
|
@ -189,22 +189,48 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
|||
end
|
||||
|
||||
function reconstruct.image(model, x, block_size)
|
||||
local i2rgb = false
|
||||
if x:size(1) == 1 then
|
||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||
new_x[1]:copy(x)
|
||||
new_x[2]:copy(x)
|
||||
new_x[3]:copy(x)
|
||||
x = new_x
|
||||
i2rgb = true
|
||||
end
|
||||
if reconstruct.is_rgb(model) then
|
||||
return reconstruct.image_rgb(model, x,
|
||||
x = reconstruct.image_rgb(model, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
else
|
||||
return reconstruct.image_y(model, x,
|
||||
x = reconstruct.image_y(model, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
end
|
||||
if i2rgb then
|
||||
x = image.rgb2y(x)
|
||||
end
|
||||
return x
|
||||
end
|
||||
function reconstruct.scale(model, scale, x, block_size)
|
||||
local i2rgb = false
|
||||
if x:size(1) == 1 then
|
||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||
new_x[1]:copy(x)
|
||||
new_x[2]:copy(x)
|
||||
new_x[3]:copy(x)
|
||||
x = new_x
|
||||
i2rgb = true
|
||||
end
|
||||
if reconstruct.is_rgb(model) then
|
||||
return reconstruct.scale_rgb(model, scale, x,
|
||||
x = reconstruct.scale_rgb(model, scale, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
else
|
||||
return reconstruct.scale_y(model, scale, x,
|
||||
x = reconstruct.scale_y(model, scale, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
end
|
||||
if i2rgb then
|
||||
x = image.rgb2y(x)
|
||||
end
|
||||
return x
|
||||
end
|
||||
local function tta(f, model, x, block_size)
|
||||
local average = nil
|
||||
|
|
28
waifu2x.lua
28
waifu2x.lua
|
@ -6,6 +6,7 @@ require 'w2nn'
|
|||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
local alpha_util = require 'alpha_util'
|
||||
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
|
@ -14,6 +15,7 @@ local function convert_image(opt)
|
|||
local new_x = nil
|
||||
local t = sys.clock()
|
||||
local scale_f, image_f
|
||||
|
||||
if opt.tta == 1 then
|
||||
scale_f = reconstruct.scale_tta
|
||||
image_f = reconstruct.image_tta
|
||||
|
@ -34,13 +36,16 @@ local function convert_image(opt)
|
|||
error("Load Error: " .. model_path)
|
||||
end
|
||||
new_x = image_f(model, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha)
|
||||
elseif opt.m == "scale" then
|
||||
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
local model = torch.load(model_path, "ascii")
|
||||
if not model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model))
|
||||
new_x = scale_f(model, opt.scale, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha, model)
|
||||
elseif opt.m == "noise_scale" then
|
||||
local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
|
||||
local noise_model = torch.load(noise_model_path, "ascii")
|
||||
|
@ -53,15 +58,14 @@ local function convert_image(opt)
|
|||
if not scale_model then
|
||||
error("Load Error: " .. scale_model_path)
|
||||
end
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
|
||||
x = image_f(noise_model, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
||||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
if opt.white_noise == 1 then
|
||||
new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
|
||||
end
|
||||
image_loader.save_png(opt.o, new_x, alpha, opt.depth)
|
||||
image_loader.save_png(opt.o, new_x, opt.depth)
|
||||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||
end
|
||||
local function convert_frames(opt)
|
||||
|
@ -128,23 +132,27 @@ local function convert_frames(opt)
|
|||
local new_x = nil
|
||||
if opt.m == "noise" and opt.noise_level == 1 then
|
||||
new_x = image_f(noise1_model, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha)
|
||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||
new_x = image_func(noise2_model, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha)
|
||||
elseif opt.m == "scale" then
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
|
||||
x = image_f(noise1_model, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
|
||||
x = image_f(noise2_model, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
||||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
if opt.white_noise == 1 then
|
||||
new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
|
||||
end
|
||||
|
||||
local output = nil
|
||||
if opt.o == "(auto)" then
|
||||
local name = path.basename(lines[i])
|
||||
|
@ -154,7 +162,7 @@ local function convert_frames(opt)
|
|||
else
|
||||
output = string.format(opt.o, i)
|
||||
end
|
||||
image_loader.save_png(output, new_x, alpha, opt.depth)
|
||||
image_loader.save_png(output, new_x, opt.depth)
|
||||
xlua.progress(i, #lines)
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
|
@ -182,8 +190,6 @@ local function waifu2x()
|
|||
cmd:option("-resume", 0, "skip existing files (0|1)")
|
||||
cmd:option("-thread", -1, "number of CPU threads")
|
||||
cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
|
||||
cmd:option("-white_noise", 0, 'adding white noise to output image (0|1)')
|
||||
cmd:option("-white_noise_std", 0.0055, 'standard division of white noise')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if opt.thread > 0 then
|
||||
|
|
70
web.lua
70
web.lua
|
@ -11,6 +11,7 @@ local md5 = require 'md5'
|
|||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
local alpha_util = require 'alpha_util'
|
||||
|
||||
-- Notes: turbo and xlua has different implementation of string:split().
|
||||
-- Therefore, string:split() has conflict issue.
|
||||
|
@ -104,14 +105,36 @@ local function cleanup_model(model)
|
|||
w2nn.cleanup_model(model) -- release GPU memory
|
||||
end
|
||||
end
|
||||
local function convert(x, options)
|
||||
local function convert(x, alpha, options)
|
||||
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
|
||||
local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
|
||||
local alpha_orig = alpha
|
||||
|
||||
if path.exists(alpha_cache_file) then
|
||||
alpha = image_loader.load_float(alpha_cache_file)
|
||||
if alpha:dim() == 2 then
|
||||
alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
|
||||
end
|
||||
if alpha:size(1) == 3 then
|
||||
alpha = image.rgb2y(alpha)
|
||||
end
|
||||
end
|
||||
if path.exists(cache_file) then
|
||||
return image.load(cache_file)
|
||||
x = image_loader.load_float(cache_file)
|
||||
return x, alpha
|
||||
else
|
||||
if options.style == "art" then
|
||||
if options.border then
|
||||
x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
|
||||
end
|
||||
if options.method == "scale" then
|
||||
x = reconstruct.scale(art_scale2_model, 2.0, x)
|
||||
if alpha then
|
||||
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
||||
alpha = reconstruct.scale(art_scale2_model, 2.0, alpha)
|
||||
image_loader.save_png(alpha_cache_file, alpha)
|
||||
end
|
||||
end
|
||||
cleanup_model(art_scale2_model)
|
||||
elseif options.method == "noise1" then
|
||||
x = reconstruct.image(art_noise1_model, x)
|
||||
|
@ -121,8 +144,17 @@ local function convert(x, options)
|
|||
cleanup_model(art_noise2_model)
|
||||
end
|
||||
else --[[photo
|
||||
if options.border then
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
|
||||
end
|
||||
if options.method == "scale" then
|
||||
x = reconstruct.scale(photo_scale2_model, 2.0, x)
|
||||
if alpha then
|
||||
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
||||
alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha)
|
||||
image_loader.save_png(alpha_cache_file, alpha)
|
||||
end
|
||||
end
|
||||
cleanup_model(photo_scale2_model)
|
||||
elseif options.method == "noise1" then
|
||||
x = reconstruct.image(photo_noise1_model, x)
|
||||
|
@ -133,8 +165,9 @@ local function convert(x, options)
|
|||
end
|
||||
--]]
|
||||
end
|
||||
image.save(cache_file, x)
|
||||
return x
|
||||
image_loader.save_png(cache_file, x)
|
||||
|
||||
return x, alpha
|
||||
end
|
||||
end
|
||||
local function client_disconnected(handler)
|
||||
|
@ -154,7 +187,6 @@ function APIHandler:post()
|
|||
local x, alpha, blob = get_image(self)
|
||||
local scale = tonumber(self:get_argument("scale", "0"))
|
||||
local noise = tonumber(self:get_argument("noise", "0"))
|
||||
local white_noise = tonumber(self:get_argument("white_noise", "0"))
|
||||
local style = self:get_argument("style", "art")
|
||||
local download = (self:get_argument("download", "")):len()
|
||||
|
||||
|
@ -164,29 +196,39 @@ function APIHandler:post()
|
|||
if x and valid_size(x, scale) then
|
||||
if (noise ~= 0 or scale ~= 0) then
|
||||
local hash = md5.sumhexa(blob)
|
||||
local alpha_prefix = style .. "_" .. hash .. "_alpha"
|
||||
local border = false
|
||||
if scale ~= 0 and alpha then
|
||||
border = true
|
||||
end
|
||||
if noise == 1 then
|
||||
x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
|
||||
x = convert(x, alpha, {method = "noise1", style = style,
|
||||
prefix = style .. "_noise1_" .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
elseif noise == 2 then
|
||||
x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
|
||||
x = convert(x, alpha, {method = "noise2", style = style,
|
||||
prefix = style .. "_noise2_" .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
end
|
||||
if scale == 1 or scale == 2 then
|
||||
local prefix
|
||||
if noise == 1 then
|
||||
x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
|
||||
prefix = style .. "_noise1_scale_" .. hash
|
||||
elseif noise == 2 then
|
||||
x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
|
||||
prefix = style .. "_noise2_scale_" .. hash
|
||||
else
|
||||
x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
|
||||
prefix = style .. "_scale_" .. hash
|
||||
end
|
||||
x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix, alpha_prefix = alpha_prefix, border = border})
|
||||
if scale == 1 then
|
||||
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
|
||||
end
|
||||
end
|
||||
if white_noise == 1 then
|
||||
x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0})
|
||||
end
|
||||
end
|
||||
local name = uuid() .. ".png"
|
||||
local blob = image_loader.encode_png(x, alpha)
|
||||
local blob = image_loader.encode_png(alpha_util.composite(x, alpha))
|
||||
self:set_header("Content-Disposition", string.format('filename="%s"', name))
|
||||
self:set_header("Content-Length", string.format("%d", #blob))
|
||||
if download > 0 then
|
||||
|
|
Loading…
Reference in a new issue