Fix embed gamma handling
This commit is contained in:
parent
cafc579627
commit
fbad30c031
4 changed files with 61 additions and 55 deletions
|
@ -33,8 +33,8 @@ local function load_images(list)
|
|||
local x = {}
|
||||
for i = 1, #lines do
|
||||
local line = lines[i]
|
||||
local im, alpha = image_loader.load_byte(line)
|
||||
if alpha then
|
||||
local im, meta = image_loader.load_byte(line)
|
||||
if meta and meta.alpha then
|
||||
io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
|
||||
else
|
||||
if settings.max_training_image_size > 0 then
|
||||
|
|
|
@ -9,14 +9,15 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
|||
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
|
||||
local background_color = 0.5
|
||||
|
||||
function image_loader.encode_png(rgb, depth, inplace)
|
||||
if inplace == nil then
|
||||
inplace = false
|
||||
function image_loader.encode_png(rgb, options)
|
||||
options = options or {}
|
||||
options.depth = options.depth or 8
|
||||
if options.inplace == nil then
|
||||
options.inplace = false
|
||||
end
|
||||
depth = depth or 8
|
||||
rgb = iproc.byte2float(rgb)
|
||||
if depth < 16 then
|
||||
if inplace then
|
||||
if options.depth < 16 then
|
||||
if options.inplace then
|
||||
rgb:add(clip_eps8)
|
||||
else
|
||||
rgb = rgb:clone():add(clip_eps8)
|
||||
|
@ -25,7 +26,7 @@ function image_loader.encode_png(rgb, depth, inplace)
|
|||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
rgb = rgb:mul(255):floor():div(255)
|
||||
else
|
||||
if inplace then
|
||||
if options.inplace then
|
||||
rgb:add(clip_eps16)
|
||||
else
|
||||
rgb = rgb:clone():add(clip_eps16)
|
||||
|
@ -43,10 +44,13 @@ function image_loader.encode_png(rgb, depth, inplace)
|
|||
im = gm.Image(rgb, "I", "DHW")
|
||||
-- im:colorspace("GRAY") -- it does not work
|
||||
end
|
||||
return im:depth(depth):format("PNG"):toString(9)
|
||||
if options.gamma then
|
||||
im:gamma(options.gamma)
|
||||
end
|
||||
return im:depth(options.depth):format("PNG"):toString(9)
|
||||
end
|
||||
function image_loader.save_png(filename, rgb, depth, inplace)
|
||||
local blob = image_loader.encode_png(rgb, depth, inplace)
|
||||
function image_loader.save_png(filename, rgb, options)
|
||||
local blob = image_loader.encode_png(rgb, options)
|
||||
local fp = io.open(filename, "wb")
|
||||
if not fp then
|
||||
error("IO error: " .. filename)
|
||||
|
@ -57,8 +61,8 @@ function image_loader.save_png(filename, rgb, depth, inplace)
|
|||
end
|
||||
function image_loader.decode_float(blob)
|
||||
local load_image = function()
|
||||
local meta = {}
|
||||
local im = gm.Image()
|
||||
local alpha = nil
|
||||
local gamma_lcd = 0.454545
|
||||
|
||||
im:fromBlob(blob, #blob)
|
||||
|
@ -66,12 +70,8 @@ function image_loader.decode_float(blob)
|
|||
if im:colorspace() == "CMYK" then
|
||||
im:colorspace("RGB")
|
||||
end
|
||||
local gamma = math.floor(im:gamma() * 1000000) / 1000000
|
||||
if gamma ~= 0 and gamma ~= gamma_lcd then
|
||||
local cg = gamma / gamma_lcd
|
||||
im:gammaCorrection(cg, "Red")
|
||||
im:gammaCorrection(cg, "Blue")
|
||||
im:gammaCorrection(cg, "Green")
|
||||
if gamma ~= 0 and math.floor(im:gamma() * 1000000) / 1000000 ~= gamma_lcd then
|
||||
meta.gamma = im:gamma()
|
||||
end
|
||||
-- FIXME: How to detect that a image has an alpha channel?
|
||||
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
||||
|
@ -79,9 +79,9 @@ function image_loader.decode_float(blob)
|
|||
im = im:toTensor('float', 'RGBA', 'DHW')
|
||||
local sum_alpha = (im[4] - 1.0):sum()
|
||||
if sum_alpha < 0 then
|
||||
alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
||||
meta.alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
||||
-- drop full transparent background
|
||||
local mask = torch.le(alpha, 0.0)
|
||||
local mask = torch.le(meta.alpha, 0.0)
|
||||
im[1][mask] = background_color
|
||||
im[2][mask] = background_color
|
||||
im[3][mask] = background_color
|
||||
|
@ -94,25 +94,26 @@ function image_loader.decode_float(blob)
|
|||
else
|
||||
im = im:toTensor('float', 'RGB', 'DHW')
|
||||
end
|
||||
return {im, alpha, blob}
|
||||
meta.blob = blob
|
||||
return {im, meta}
|
||||
end
|
||||
local state, ret = pcall(load_image)
|
||||
if state then
|
||||
return ret[1], ret[2], ret[3]
|
||||
return ret[1], ret[2]
|
||||
else
|
||||
return nil, nil, nil
|
||||
return nil, nil
|
||||
end
|
||||
end
|
||||
function image_loader.decode_byte(blob)
|
||||
local im, alpha
|
||||
im, alpha, blob = image_loader.decode_float(blob)
|
||||
local im, meta
|
||||
im, meta = image_loader.decode_float(blob)
|
||||
|
||||
if im then
|
||||
im = iproc.float2byte(im)
|
||||
-- hmm, alpha does not convert here
|
||||
return im, alpha, blob
|
||||
return im, meta
|
||||
else
|
||||
return nil, nil, nil
|
||||
return nil, nil
|
||||
end
|
||||
end
|
||||
function image_loader.load_float(file)
|
||||
|
|
11
waifu2x.lua
11
waifu2x.lua
|
@ -11,7 +11,8 @@ local alpha_util = require 'alpha_util'
|
|||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
local function convert_image(opt)
|
||||
local x, alpha = image_loader.load_float(opt.i)
|
||||
local x, meta = image_loader.load_float(opt.i)
|
||||
local alpha = meta.alpha
|
||||
local new_x = nil
|
||||
local t = sys.clock()
|
||||
local scale_f, image_f
|
||||
|
@ -65,7 +66,7 @@ local function convert_image(opt)
|
|||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
image_loader.save_png(opt.o, new_x, opt.depth, true)
|
||||
image_loader.save_png(opt.o, new_x, {depth = opt.depth, inplace = true, gamma = meta.gamma})
|
||||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||
end
|
||||
local function convert_frames(opt)
|
||||
|
@ -115,7 +116,8 @@ local function convert_frames(opt)
|
|||
fp:close()
|
||||
for i = 1, #lines do
|
||||
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
|
||||
local x, alpha = image_loader.load_float(lines[i])
|
||||
local x, meta = image_loader.load_float(lines[i])
|
||||
local alpha = meta.alpha
|
||||
local new_x = nil
|
||||
if opt.m == "noise" then
|
||||
new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
|
||||
|
@ -141,7 +143,8 @@ local function convert_frames(opt)
|
|||
else
|
||||
output = string.format(opt.o, i)
|
||||
end
|
||||
image_loader.save_png(output, new_x, opt.depth, true)
|
||||
image_loader.save_png(output, new_x,
|
||||
{depth = opt.depth, inplace = true, gamma = meta.gamma})
|
||||
xlua.progress(i, #lines)
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
|
|
36
web.lua
36
web.lua
|
@ -93,7 +93,7 @@ local function cache_url(url)
|
|||
end
|
||||
end
|
||||
end
|
||||
return nil, nil, nil
|
||||
return nil, nil
|
||||
end
|
||||
local function get_image(req)
|
||||
local file_info = req:get_arguments("file")
|
||||
|
@ -108,22 +108,23 @@ local function get_image(req)
|
|||
end
|
||||
end
|
||||
if file and file:len() > 0 then
|
||||
local x, alpha, blob = image_loader.decode_float(file)
|
||||
return x, alpha, blob, filename
|
||||
local x, meta = image_loader.decode_float(file)
|
||||
return x, meta, filename
|
||||
elseif url and url:len() > 0 then
|
||||
local x, alpha, blob = cache_url(url)
|
||||
return x, alpha, blob, filename
|
||||
local x, meta = cache_url(url)
|
||||
return x, meta, filename
|
||||
end
|
||||
return nil, nil, nil, nil
|
||||
return nil, nil, nil
|
||||
end
|
||||
local function cleanup_model(model)
|
||||
if CLEANUP_MODEL then
|
||||
model:clearState() -- release GPU memory
|
||||
end
|
||||
end
|
||||
local function convert(x, alpha, options)
|
||||
local function convert(x, meta, options)
|
||||
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
|
||||
local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
|
||||
local alpha = meta.alpha
|
||||
local alpha_orig = alpha
|
||||
|
||||
if path.exists(alpha_cache_file) then
|
||||
|
@ -137,7 +138,7 @@ local function convert(x, alpha, options)
|
|||
end
|
||||
if path.exists(cache_file) then
|
||||
x = image_loader.load_float(cache_file)
|
||||
return x, alpha
|
||||
return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
|
||||
else
|
||||
if options.style == "art" then
|
||||
if options.border then
|
||||
|
@ -192,7 +193,7 @@ local function convert(x, alpha, options)
|
|||
end
|
||||
image_loader.save_png(cache_file, x)
|
||||
|
||||
return x, alpha
|
||||
return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
|
||||
end
|
||||
end
|
||||
local function client_disconnected(handler)
|
||||
|
@ -218,7 +219,7 @@ function APIHandler:post()
|
|||
self:write("client disconnected")
|
||||
return
|
||||
end
|
||||
local x, alpha, blob, filename = get_image(self)
|
||||
local x, meta, filename = get_image(self)
|
||||
local scale = tonumber(self:get_argument("scale", "0"))
|
||||
local noise = tonumber(self:get_argument("noise", "0"))
|
||||
local style = self:get_argument("style", "art")
|
||||
|
@ -230,27 +231,27 @@ function APIHandler:post()
|
|||
if x and valid_size(x, scale) then
|
||||
local prefix = nil
|
||||
if (noise ~= 0 or scale ~= 0) then
|
||||
local hash = md5.sumhexa(blob)
|
||||
local hash = md5.sumhexa(meta.blob)
|
||||
local alpha_prefix = style .. "_" .. hash .. "_alpha"
|
||||
local border = false
|
||||
if scale ~= 0 and alpha then
|
||||
if scale ~= 0 and meta.alpha then
|
||||
border = true
|
||||
end
|
||||
if noise == 1 then
|
||||
prefix = style .. "_noise1_"
|
||||
x = convert(x, alpha, {method = "noise1", style = style,
|
||||
x = convert(x, meta, {method = "noise1", style = style,
|
||||
prefix = prefix .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
elseif noise == 2 then
|
||||
prefix = style .. "_noise2_"
|
||||
x = convert(x, alpha, {method = "noise2", style = style,
|
||||
x = convert(x, meta, {method = "noise2", style = style,
|
||||
prefix = prefix .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
elseif noise == 3 then
|
||||
prefix = style .. "_noise3_"
|
||||
x = convert(x, alpha, {method = "noise3", style = style,
|
||||
x = convert(x, meta, {method = "noise3", style = style,
|
||||
prefix = prefix .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
|
@ -265,7 +266,7 @@ function APIHandler:post()
|
|||
else
|
||||
prefix = style .. "_scale_"
|
||||
end
|
||||
x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
|
||||
x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
|
||||
if scale == 1 then
|
||||
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
|
||||
end
|
||||
|
@ -281,7 +282,8 @@ function APIHandler:post()
|
|||
else
|
||||
name = uuid() .. ".png"
|
||||
end
|
||||
local blob = image_loader.encode_png(alpha_util.composite(x, alpha), 8, true)
|
||||
local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
|
||||
{ depth = 8, inplace = true, gamma = meta.gamma})
|
||||
|
||||
self:set_header("Content-Length", string.format("%d", #blob))
|
||||
if download > 0 then
|
||||
|
|
Loading…
Reference in a new issue