1
0
Fork 0
mirror of synced 2024-06-01 10:39:30 +12:00

Merge branch 'master' into photo

This commit is contained in:
nagadomi 2015-12-02 06:54:37 +09:00
commit 2305e31616
6 changed files with 235 additions and 76 deletions

80
lib/alpha_util.lua Normal file
View file

@ -0,0 +1,80 @@
local w2nn = require 'w2nn'
local reconstruct = require 'reconstruct'
local image = require 'image'
local iproc = require 'iproc'
local gm = require 'graphicsmagick'
alpha_util = {}
function alpha_util.make_border(rgb, alpha, offset)
if not alpha then
return rgb
end
local sum2d = nn.SpatialConvolutionMM(1, 1, 3, 3, 1, 1, 1, 1):cuda()
sum2d.weight:fill(1)
sum2d.bias:zero()
local mask = alpha:clone()
mask[torch.gt(mask, 0.0)] = 1
mask[torch.eq(mask, 0.0)] = 0
local mask_nega = (mask - 1):abs():byte()
local eps = 1.0e-7
rgb = rgb:clone()
rgb[1][mask_nega] = 0
rgb[2][mask_nega] = 0
rgb[3][mask_nega] = 0
for i = 1, offset do
local mask_weight = sum2d:forward(mask:cuda()):float()
local border = rgb:clone()
for j = 1, 3 do
border[j]:copy(sum2d:forward(rgb[j]:reshape(1, rgb:size(2), rgb:size(3)):cuda()))
border[j]:cdiv((mask_weight + eps))
rgb[j][mask_nega] = border[j][mask_nega]
end
mask = mask_weight:clone()
mask[torch.gt(mask_weight, 0.0)] = 1
mask_nega = (mask - 1):abs():byte()
end
rgb[torch.gt(rgb, 1.0)] = 1.0
rgb[torch.lt(rgb, 0.0)] = 0.0
return rgb
end
function alpha_util.composite(rgb, alpha, model2x)
if not alpha then
return rgb
end
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
if model2x then
alpha = reconstruct.scale(model2x, 2.0, alpha)
else
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
end
end
local out = torch.Tensor(4, rgb:size(2), rgb:size(3))
out[1]:copy(rgb[1])
out[2]:copy(rgb[2])
out[3]:copy(rgb[3])
out[4]:copy(alpha)
return out
end
local function test()
require 'sys'
require 'trepl'
torch.setdefaulttensortype("torch.FloatTensor")
local image_loader = require 'image_loader'
local rgb, alpha = image_loader.load_float("alpha.png")
local t = sys.clock()
rgb = alpha_util.make_border(rgb, alpha, 7)
print(sys.clock() - t)
print(rgb:min(), rgb:max())
image.display({image = rgb, min = 0, max = 1})
image.save("out.png", rgb)
end
--test()
return alpha_util

View file

@ -9,47 +9,34 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
local background_color = 0.5
function image_loader.encode_png(rgb, alpha, depth)
function image_loader.encode_png(rgb, depth)
depth = depth or 8
rgb = iproc.byte2float(rgb)
if alpha then
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
end
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
rgba[1]:copy(rgb[1])
rgba[2]:copy(rgb[2])
rgba[3]:copy(rgb[3])
rgba[4]:copy(alpha)
if depth < 16 then
rgba:add(clip_eps8)
rgba[torch.lt(rgba, 0.0)] = 0.0
rgba[torch.gt(rgba, 1.0)] = 1.0
else
rgba:add(clip_eps16)
rgba[torch.lt(rgba, 0.0)] = 0.0
rgba[torch.gt(rgba, 1.0)] = 1.0
end
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
return im:depth(depth):format("PNG"):toString(9)
if depth < 16 then
rgb = rgb:clone():add(clip_eps8)
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
rgb = rgb:mul(255):long():float():div(255)
else
if depth < 16 then
rgb = rgb:clone():add(clip_eps8)
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
else
rgb = rgb:clone():add(clip_eps16)
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
end
local im = gm.Image(rgb, "RGB", "DHW")
return im:depth(depth):format("PNG"):toString(9)
rgb = rgb:clone():add(clip_eps16)
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
rgb = rgb:mul(65535):long():float():div(65535)
end
local im
if rgb:size(1) == 4 then -- RGBA
im = gm.Image(rgb, "RGBA", "DHW")
elseif rgb:size(1) == 3 then -- RGB
im = gm.Image(rgb, "RGB", "DHW")
elseif rgb:size(1) == 1 then -- Y
im = gm.Image(rgb, "I", "DHW")
-- im:colorspace("GRAY") -- it does not work
end
return im:depth(depth):format("PNG"):toString(9)
end
function image_loader.save_png(filename, rgb, alpha, depth)
function image_loader.save_png(filename, rgb, depth)
depth = depth or 8
local blob = image_loader.encode_png(rgb, alpha, depth)
local blob = image_loader.encode_png(rgb, depth)
local fp = io.open(filename, "wb")
if not fp then
error("IO error: " .. filename)
@ -71,7 +58,10 @@ function image_loader.decode_float(blob)
end
local gamma = math.floor(im:gamma() * 1000000) / 1000000
if gamma ~= 0 and gamma ~= gamma_lcd then
im:gammaCorrection(gamma / gamma_lcd)
local cg = gamma / gamma_lcd
im:gammaCorrection(cg, "Red")
im:gammaCorrection(cg, "Blue")
im:gammaCorrection(cg, "Green")
end
-- FIXME: How to detect that a image has an alpha channel?
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then

View file

@ -49,12 +49,17 @@ function iproc.float2byte(src)
return dest, conversion
end
function iproc.scale(src, width, height, filter)
local conversion
local conversion, color
src, conversion = iproc.byte2float(src)
filter = filter or "Box"
local im = gm.Image(src, "RGB", "DHW")
if src:size(1) == 3 then
color = "RGB"
else
color = "I"
end
local im = gm.Image(src, color, "DHW")
im:size(math.ceil(width), math.ceil(height), filter)
local dest = im:toTensor("float", "RGB", "DHW")
local dest = im:toTensor("float", color, "DHW")
if conversion then
dest = iproc.float2byte(dest)
end
@ -84,6 +89,16 @@ function iproc.padding(img, w1, w2, h1, h2)
flow[2]:add(-w1)
return image.warp(img, flow, "simple", false, "clamp")
end
function iproc.zero_padding(img, w1, w2, h1, h2)
local dst_height = img:size(2) + h1 + h2
local dst_width = img:size(3) + w1 + w2
local flow = torch.Tensor(2, dst_height, dst_width)
flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width))
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
flow[1]:add(-h1)
flow[2]:add(-w1)
return image.warp(img, flow, "simple", false, "pad", 0)
end
function iproc.white_noise(src, std, rgb_weights, gamma)
gamma = gamma or 0.454545
local conversion

View file

@ -189,22 +189,48 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
end
function reconstruct.image(model, x, block_size)
if reconstruct.is_rgb(model) then
return reconstruct.image_rgb(model, x,
reconstruct.offset_size(model), block_size)
else
return reconstruct.image_y(model, x,
reconstruct.offset_size(model), block_size)
local i2rgb = false
if x:size(1) == 1 then
local new_x = torch.Tensor(3, x:size(2), x:size(3))
new_x[1]:copy(x)
new_x[2]:copy(x)
new_x[3]:copy(x)
x = new_x
i2rgb = true
end
if reconstruct.is_rgb(model) then
x = reconstruct.image_rgb(model, x,
reconstruct.offset_size(model), block_size)
else
x = reconstruct.image_y(model, x,
reconstruct.offset_size(model), block_size)
end
if i2rgb then
x = image.rgb2y(x)
end
return x
end
function reconstruct.scale(model, scale, x, block_size)
if reconstruct.is_rgb(model) then
return reconstruct.scale_rgb(model, scale, x,
reconstruct.offset_size(model), block_size)
else
return reconstruct.scale_y(model, scale, x,
reconstruct.offset_size(model), block_size)
local i2rgb = false
if x:size(1) == 1 then
local new_x = torch.Tensor(3, x:size(2), x:size(3))
new_x[1]:copy(x)
new_x[2]:copy(x)
new_x[3]:copy(x)
x = new_x
i2rgb = true
end
if reconstruct.is_rgb(model) then
x = reconstruct.scale_rgb(model, scale, x,
reconstruct.offset_size(model), block_size)
else
x = reconstruct.scale_y(model, scale, x,
reconstruct.offset_size(model), block_size)
end
if i2rgb then
x = image.rgb2y(x)
end
return x
end
local function tta(f, model, x, block_size)
local average = nil

View file

@ -6,6 +6,7 @@ require 'w2nn'
local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local image_loader = require 'image_loader'
local alpha_util = require 'alpha_util'
torch.setdefaulttensortype('torch.FloatTensor')
@ -14,6 +15,7 @@ local function convert_image(opt)
local new_x = nil
local t = sys.clock()
local scale_f, image_f
if opt.tta == 1 then
scale_f = reconstruct.scale_tta
image_f = reconstruct.image_tta
@ -34,13 +36,16 @@ local function convert_image(opt)
error("Load Error: " .. model_path)
end
new_x = image_f(model, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha)
elseif opt.m == "scale" then
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
local model = torch.load(model_path, "ascii")
if not model then
error("Load Error: " .. model_path)
end
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model))
new_x = scale_f(model, opt.scale, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha, model)
elseif opt.m == "noise_scale" then
local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
local noise_model = torch.load(noise_model_path, "ascii")
@ -53,15 +58,14 @@ local function convert_image(opt)
if not scale_model then
error("Load Error: " .. scale_model_path)
end
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
x = image_f(noise_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha, scale_model)
else
error("undefined method:" .. opt.method)
end
if opt.white_noise == 1 then
new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
end
image_loader.save_png(opt.o, new_x, alpha, opt.depth)
image_loader.save_png(opt.o, new_x, opt.depth)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end
local function convert_frames(opt)
@ -128,23 +132,27 @@ local function convert_frames(opt)
local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then
new_x = image_f(noise1_model, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha)
elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = image_func(noise2_model, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha)
elseif opt.m == "scale" then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha, scale_model)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
x = image_f(noise1_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha, scale_model)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
x = image_f(noise2_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha, scale_model)
else
error("undefined method:" .. opt.method)
end
if opt.white_noise == 1 then
new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
end
local output = nil
if opt.o == "(auto)" then
local name = path.basename(lines[i])
@ -154,7 +162,7 @@ local function convert_frames(opt)
else
output = string.format(opt.o, i)
end
image_loader.save_png(output, new_x, alpha, opt.depth)
image_loader.save_png(output, new_x, opt.depth)
xlua.progress(i, #lines)
if i % 10 == 0 then
collectgarbage()
@ -182,8 +190,6 @@ local function waifu2x()
cmd:option("-resume", 0, "skip existing files (0|1)")
cmd:option("-thread", -1, "number of CPU threads")
cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
cmd:option("-white_noise", 0, 'adding white noise to output image (0|1)')
cmd:option("-white_noise_std", 0.0055, 'standard division of white noise')
local opt = cmd:parse(arg)
if opt.thread > 0 then

70
web.lua
View file

@ -11,6 +11,7 @@ local md5 = require 'md5'
local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local image_loader = require 'image_loader'
local alpha_util = require 'alpha_util'
-- Notes: turbo and xlua has different implementation of string:split().
-- Therefore, string:split() has conflict issue.
@ -104,14 +105,36 @@ local function cleanup_model(model)
w2nn.cleanup_model(model) -- release GPU memory
end
end
local function convert(x, options)
local function convert(x, alpha, options)
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
local alpha_orig = alpha
if path.exists(alpha_cache_file) then
alpha = image_loader.load_float(alpha_cache_file)
if alpha:dim() == 2 then
alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
end
if alpha:size(1) == 3 then
alpha = image.rgb2y(alpha)
end
end
if path.exists(cache_file) then
return image.load(cache_file)
x = image_loader.load_float(cache_file)
return x, alpha
else
if options.style == "art" then
if options.border then
x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
end
if options.method == "scale" then
x = reconstruct.scale(art_scale2_model, 2.0, x)
if alpha then
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
alpha = reconstruct.scale(art_scale2_model, 2.0, alpha)
image_loader.save_png(alpha_cache_file, alpha)
end
end
cleanup_model(art_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(art_noise1_model, x)
@ -121,8 +144,17 @@ local function convert(x, options)
cleanup_model(art_noise2_model)
end
else --[[photo
if options.border then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
end
if options.method == "scale" then
x = reconstruct.scale(photo_scale2_model, 2.0, x)
if alpha then
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha)
image_loader.save_png(alpha_cache_file, alpha)
end
end
cleanup_model(photo_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(photo_noise1_model, x)
@ -133,8 +165,9 @@ local function convert(x, options)
end
--]]
end
image.save(cache_file, x)
return x
image_loader.save_png(cache_file, x)
return x, alpha
end
end
local function client_disconnected(handler)
@ -154,7 +187,6 @@ function APIHandler:post()
local x, alpha, blob = get_image(self)
local scale = tonumber(self:get_argument("scale", "0"))
local noise = tonumber(self:get_argument("noise", "0"))
local white_noise = tonumber(self:get_argument("white_noise", "0"))
local style = self:get_argument("style", "art")
local download = (self:get_argument("download", "")):len()
@ -164,29 +196,39 @@ function APIHandler:post()
if x and valid_size(x, scale) then
if (noise ~= 0 or scale ~= 0) then
local hash = md5.sumhexa(blob)
local alpha_prefix = style .. "_" .. hash .. "_alpha"
local border = false
if scale ~= 0 and alpha then
border = true
end
if noise == 1 then
x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
x = convert(x, alpha, {method = "noise1", style = style,
prefix = style .. "_noise1_" .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
elseif noise == 2 then
x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
x = convert(x, alpha, {method = "noise2", style = style,
prefix = style .. "_noise2_" .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
end
if scale == 1 or scale == 2 then
local prefix
if noise == 1 then
x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
prefix = style .. "_noise1_scale_" .. hash
elseif noise == 2 then
x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
prefix = style .. "_noise2_scale_" .. hash
else
x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
prefix = style .. "_scale_" .. hash
end
x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix, alpha_prefix = alpha_prefix, border = border})
if scale == 1 then
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
end
end
if white_noise == 1 then
x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0})
end
end
local name = uuid() .. ".png"
local blob = image_loader.encode_png(x, alpha)
local blob = image_loader.encode_png(alpha_util.composite(x, alpha))
self:set_header("Content-Disposition", string.format('filename="%s"', name))
self:set_header("Content-Length", string.format("%d", #blob))
if download > 0 then