1
0
Fork 0
mirror of synced 2024-06-18 19:04:30 +12:00

Fix embed gamma handling

This commit is contained in:
nagadomi 2016-04-15 09:13:37 +09:00
parent cafc579627
commit fbad30c031
4 changed files with 61 additions and 55 deletions

View file

@ -33,8 +33,8 @@ local function load_images(list)
local x = {}
for i = 1, #lines do
local line = lines[i]
local im, alpha = image_loader.load_byte(line)
if alpha then
local im, meta = image_loader.load_byte(line)
if meta and meta.alpha then
io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
else
if settings.max_training_image_size > 0 then

View file

@ -9,14 +9,15 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
local background_color = 0.5
function image_loader.encode_png(rgb, depth, inplace)
if inplace == nil then
inplace = false
function image_loader.encode_png(rgb, options)
options = options or {}
options.depth = options.depth or 8
if options.inplace == nil then
options.inplace = false
end
depth = depth or 8
rgb = iproc.byte2float(rgb)
if depth < 16 then
if inplace then
if options.depth < 16 then
if options.inplace then
rgb:add(clip_eps8)
else
rgb = rgb:clone():add(clip_eps8)
@ -25,7 +26,7 @@ function image_loader.encode_png(rgb, depth, inplace)
rgb[torch.gt(rgb, 1.0)] = 1.0
rgb = rgb:mul(255):floor():div(255)
else
if inplace then
if options.inplace then
rgb:add(clip_eps16)
else
rgb = rgb:clone():add(clip_eps16)
@ -43,10 +44,13 @@ function image_loader.encode_png(rgb, depth, inplace)
im = gm.Image(rgb, "I", "DHW")
-- im:colorspace("GRAY") -- it does not work
end
return im:depth(depth):format("PNG"):toString(9)
if options.gamma then
im:gamma(options.gamma)
end
return im:depth(options.depth):format("PNG"):toString(9)
end
function image_loader.save_png(filename, rgb, depth, inplace)
local blob = image_loader.encode_png(rgb, depth, inplace)
function image_loader.save_png(filename, rgb, options)
local blob = image_loader.encode_png(rgb, options)
local fp = io.open(filename, "wb")
if not fp then
error("IO error: " .. filename)
@ -57,8 +61,8 @@ function image_loader.save_png(filename, rgb, depth, inplace)
end
function image_loader.decode_float(blob)
local load_image = function()
local meta = {}
local im = gm.Image()
local alpha = nil
local gamma_lcd = 0.454545
im:fromBlob(blob, #blob)
@ -66,12 +70,8 @@ function image_loader.decode_float(blob)
if im:colorspace() == "CMYK" then
im:colorspace("RGB")
end
local gamma = math.floor(im:gamma() * 1000000) / 1000000
if gamma ~= 0 and gamma ~= gamma_lcd then
local cg = gamma / gamma_lcd
im:gammaCorrection(cg, "Red")
im:gammaCorrection(cg, "Blue")
im:gammaCorrection(cg, "Green")
if gamma ~= 0 and math.floor(im:gamma() * 1000000) / 1000000 ~= gamma_lcd then
meta.gamma = im:gamma()
end
-- FIXME: How to detect that a image has an alpha channel?
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
@ -79,9 +79,9 @@ function image_loader.decode_float(blob)
im = im:toTensor('float', 'RGBA', 'DHW')
local sum_alpha = (im[4] - 1.0):sum()
if sum_alpha < 0 then
alpha = im[4]:reshape(1, im:size(2), im:size(3))
meta.alpha = im[4]:reshape(1, im:size(2), im:size(3))
-- drop full transparent background
local mask = torch.le(alpha, 0.0)
local mask = torch.le(meta.alpha, 0.0)
im[1][mask] = background_color
im[2][mask] = background_color
im[3][mask] = background_color
@ -94,25 +94,26 @@ function image_loader.decode_float(blob)
else
im = im:toTensor('float', 'RGB', 'DHW')
end
return {im, alpha, blob}
meta.blob = blob
return {im, meta}
end
local state, ret = pcall(load_image)
if state then
return ret[1], ret[2], ret[3]
return ret[1], ret[2]
else
return nil, nil, nil
return nil, nil
end
end
function image_loader.decode_byte(blob)
local im, alpha
im, alpha, blob = image_loader.decode_float(blob)
local im, meta
im, meta = image_loader.decode_float(blob)
if im then
im = iproc.float2byte(im)
-- hmm, alpha does not convert here
return im, alpha, blob
return im, meta
else
return nil, nil, nil
return nil, nil
end
end
function image_loader.load_float(file)

View file

@ -11,7 +11,8 @@ local alpha_util = require 'alpha_util'
torch.setdefaulttensortype('torch.FloatTensor')
local function convert_image(opt)
local x, alpha = image_loader.load_float(opt.i)
local x, meta = image_loader.load_float(opt.i)
local alpha = meta.alpha
local new_x = nil
local t = sys.clock()
local scale_f, image_f
@ -65,7 +66,7 @@ local function convert_image(opt)
else
error("undefined method:" .. opt.method)
end
image_loader.save_png(opt.o, new_x, opt.depth, true)
image_loader.save_png(opt.o, new_x, {depth = opt.depth, inplace = true, gamma = meta.gamma})
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end
local function convert_frames(opt)
@ -115,7 +116,8 @@ local function convert_frames(opt)
fp:close()
for i = 1, #lines do
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
local x, alpha = image_loader.load_float(lines[i])
local x, meta = image_loader.load_float(lines[i])
local alpha = meta.alpha
local new_x = nil
if opt.m == "noise" then
new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
@ -141,7 +143,8 @@ local function convert_frames(opt)
else
output = string.format(opt.o, i)
end
image_loader.save_png(output, new_x, opt.depth, true)
image_loader.save_png(output, new_x,
{depth = opt.depth, inplace = true, gamma = meta.gamma})
xlua.progress(i, #lines)
if i % 10 == 0 then
collectgarbage()

48
web.lua
View file

@ -93,7 +93,7 @@ local function cache_url(url)
end
end
end
return nil, nil, nil
return nil, nil
end
local function get_image(req)
local file_info = req:get_arguments("file")
@ -108,22 +108,23 @@ local function get_image(req)
end
end
if file and file:len() > 0 then
local x, alpha, blob = image_loader.decode_float(file)
return x, alpha, blob, filename
local x, meta = image_loader.decode_float(file)
return x, meta, filename
elseif url and url:len() > 0 then
local x, alpha, blob = cache_url(url)
return x, alpha, blob, filename
local x, meta = cache_url(url)
return x, meta, filename
end
return nil, nil, nil, nil
return nil, nil, nil
end
local function cleanup_model(model)
if CLEANUP_MODEL then
model:clearState() -- release GPU memory
end
end
local function convert(x, alpha, options)
local function convert(x, meta, options)
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
local alpha = meta.alpha
local alpha_orig = alpha
if path.exists(alpha_cache_file) then
@ -137,7 +138,7 @@ local function convert(x, alpha, options)
end
if path.exists(cache_file) then
x = image_loader.load_float(cache_file)
return x, alpha
return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
else
if options.style == "art" then
if options.border then
@ -192,7 +193,7 @@ local function convert(x, alpha, options)
end
image_loader.save_png(cache_file, x)
return x, alpha
return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
end
end
local function client_disconnected(handler)
@ -218,7 +219,7 @@ function APIHandler:post()
self:write("client disconnected")
return
end
local x, alpha, blob, filename = get_image(self)
local x, meta, filename = get_image(self)
local scale = tonumber(self:get_argument("scale", "0"))
local noise = tonumber(self:get_argument("noise", "0"))
local style = self:get_argument("style", "art")
@ -230,29 +231,29 @@ function APIHandler:post()
if x and valid_size(x, scale) then
local prefix = nil
if (noise ~= 0 or scale ~= 0) then
local hash = md5.sumhexa(blob)
local hash = md5.sumhexa(meta.blob)
local alpha_prefix = style .. "_" .. hash .. "_alpha"
local border = false
if scale ~= 0 and alpha then
if scale ~= 0 and meta.alpha then
border = true
end
if noise == 1 then
prefix = style .. "_noise1_"
x = convert(x, alpha, {method = "noise1", style = style,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
x = convert(x, meta, {method = "noise1", style = style,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
elseif noise == 2 then
prefix = style .. "_noise2_"
x = convert(x, alpha, {method = "noise2", style = style,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
x = convert(x, meta, {method = "noise2", style = style,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
elseif noise == 3 then
prefix = style .. "_noise3_"
x = convert(x, alpha, {method = "noise3", style = style,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
x = convert(x, meta, {method = "noise3", style = style,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
end
if scale == 1 or scale == 2 then
@ -265,7 +266,7 @@ function APIHandler:post()
else
prefix = style .. "_scale_"
end
x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
if scale == 1 then
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
end
@ -281,7 +282,8 @@ function APIHandler:post()
else
name = uuid() .. ".png"
end
local blob = image_loader.encode_png(alpha_util.composite(x, alpha), 8, true)
local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
{ depth = 8, inplace = true, gamma = meta.gamma})
self:set_header("Content-Length", string.format("%d", #blob))
if download > 0 then