Fix embed gamma handling
This commit is contained in:
parent
cafc579627
commit
fbad30c031
|
@ -33,8 +33,8 @@ local function load_images(list)
|
||||||
local x = {}
|
local x = {}
|
||||||
for i = 1, #lines do
|
for i = 1, #lines do
|
||||||
local line = lines[i]
|
local line = lines[i]
|
||||||
local im, alpha = image_loader.load_byte(line)
|
local im, meta = image_loader.load_byte(line)
|
||||||
if alpha then
|
if meta and meta.alpha then
|
||||||
io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
|
io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
|
||||||
else
|
else
|
||||||
if settings.max_training_image_size > 0 then
|
if settings.max_training_image_size > 0 then
|
||||||
|
|
|
@ -9,14 +9,15 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
||||||
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
|
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
|
||||||
local background_color = 0.5
|
local background_color = 0.5
|
||||||
|
|
||||||
function image_loader.encode_png(rgb, depth, inplace)
|
function image_loader.encode_png(rgb, options)
|
||||||
if inplace == nil then
|
options = options or {}
|
||||||
inplace = false
|
options.depth = options.depth or 8
|
||||||
|
if options.inplace == nil then
|
||||||
|
options.inplace = false
|
||||||
end
|
end
|
||||||
depth = depth or 8
|
|
||||||
rgb = iproc.byte2float(rgb)
|
rgb = iproc.byte2float(rgb)
|
||||||
if depth < 16 then
|
if options.depth < 16 then
|
||||||
if inplace then
|
if options.inplace then
|
||||||
rgb:add(clip_eps8)
|
rgb:add(clip_eps8)
|
||||||
else
|
else
|
||||||
rgb = rgb:clone():add(clip_eps8)
|
rgb = rgb:clone():add(clip_eps8)
|
||||||
|
@ -25,7 +26,7 @@ function image_loader.encode_png(rgb, depth, inplace)
|
||||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||||
rgb = rgb:mul(255):floor():div(255)
|
rgb = rgb:mul(255):floor():div(255)
|
||||||
else
|
else
|
||||||
if inplace then
|
if options.inplace then
|
||||||
rgb:add(clip_eps16)
|
rgb:add(clip_eps16)
|
||||||
else
|
else
|
||||||
rgb = rgb:clone():add(clip_eps16)
|
rgb = rgb:clone():add(clip_eps16)
|
||||||
|
@ -43,10 +44,13 @@ function image_loader.encode_png(rgb, depth, inplace)
|
||||||
im = gm.Image(rgb, "I", "DHW")
|
im = gm.Image(rgb, "I", "DHW")
|
||||||
-- im:colorspace("GRAY") -- it does not work
|
-- im:colorspace("GRAY") -- it does not work
|
||||||
end
|
end
|
||||||
return im:depth(depth):format("PNG"):toString(9)
|
if options.gamma then
|
||||||
|
im:gamma(options.gamma)
|
||||||
|
end
|
||||||
|
return im:depth(options.depth):format("PNG"):toString(9)
|
||||||
end
|
end
|
||||||
function image_loader.save_png(filename, rgb, depth, inplace)
|
function image_loader.save_png(filename, rgb, options)
|
||||||
local blob = image_loader.encode_png(rgb, depth, inplace)
|
local blob = image_loader.encode_png(rgb, options)
|
||||||
local fp = io.open(filename, "wb")
|
local fp = io.open(filename, "wb")
|
||||||
if not fp then
|
if not fp then
|
||||||
error("IO error: " .. filename)
|
error("IO error: " .. filename)
|
||||||
|
@ -57,8 +61,8 @@ function image_loader.save_png(filename, rgb, depth, inplace)
|
||||||
end
|
end
|
||||||
function image_loader.decode_float(blob)
|
function image_loader.decode_float(blob)
|
||||||
local load_image = function()
|
local load_image = function()
|
||||||
|
local meta = {}
|
||||||
local im = gm.Image()
|
local im = gm.Image()
|
||||||
local alpha = nil
|
|
||||||
local gamma_lcd = 0.454545
|
local gamma_lcd = 0.454545
|
||||||
|
|
||||||
im:fromBlob(blob, #blob)
|
im:fromBlob(blob, #blob)
|
||||||
|
@ -66,12 +70,8 @@ function image_loader.decode_float(blob)
|
||||||
if im:colorspace() == "CMYK" then
|
if im:colorspace() == "CMYK" then
|
||||||
im:colorspace("RGB")
|
im:colorspace("RGB")
|
||||||
end
|
end
|
||||||
local gamma = math.floor(im:gamma() * 1000000) / 1000000
|
if gamma ~= 0 and math.floor(im:gamma() * 1000000) / 1000000 ~= gamma_lcd then
|
||||||
if gamma ~= 0 and gamma ~= gamma_lcd then
|
meta.gamma = im:gamma()
|
||||||
local cg = gamma / gamma_lcd
|
|
||||||
im:gammaCorrection(cg, "Red")
|
|
||||||
im:gammaCorrection(cg, "Blue")
|
|
||||||
im:gammaCorrection(cg, "Green")
|
|
||||||
end
|
end
|
||||||
-- FIXME: How to detect that a image has an alpha channel?
|
-- FIXME: How to detect that a image has an alpha channel?
|
||||||
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
||||||
|
@ -79,9 +79,9 @@ function image_loader.decode_float(blob)
|
||||||
im = im:toTensor('float', 'RGBA', 'DHW')
|
im = im:toTensor('float', 'RGBA', 'DHW')
|
||||||
local sum_alpha = (im[4] - 1.0):sum()
|
local sum_alpha = (im[4] - 1.0):sum()
|
||||||
if sum_alpha < 0 then
|
if sum_alpha < 0 then
|
||||||
alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
meta.alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
||||||
-- drop full transparent background
|
-- drop full transparent background
|
||||||
local mask = torch.le(alpha, 0.0)
|
local mask = torch.le(meta.alpha, 0.0)
|
||||||
im[1][mask] = background_color
|
im[1][mask] = background_color
|
||||||
im[2][mask] = background_color
|
im[2][mask] = background_color
|
||||||
im[3][mask] = background_color
|
im[3][mask] = background_color
|
||||||
|
@ -94,25 +94,26 @@ function image_loader.decode_float(blob)
|
||||||
else
|
else
|
||||||
im = im:toTensor('float', 'RGB', 'DHW')
|
im = im:toTensor('float', 'RGB', 'DHW')
|
||||||
end
|
end
|
||||||
return {im, alpha, blob}
|
meta.blob = blob
|
||||||
|
return {im, meta}
|
||||||
end
|
end
|
||||||
local state, ret = pcall(load_image)
|
local state, ret = pcall(load_image)
|
||||||
if state then
|
if state then
|
||||||
return ret[1], ret[2], ret[3]
|
return ret[1], ret[2]
|
||||||
else
|
else
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
function image_loader.decode_byte(blob)
|
function image_loader.decode_byte(blob)
|
||||||
local im, alpha
|
local im, meta
|
||||||
im, alpha, blob = image_loader.decode_float(blob)
|
im, meta = image_loader.decode_float(blob)
|
||||||
|
|
||||||
if im then
|
if im then
|
||||||
im = iproc.float2byte(im)
|
im = iproc.float2byte(im)
|
||||||
-- hmm, alpha does not convert here
|
-- hmm, alpha does not convert here
|
||||||
return im, alpha, blob
|
return im, meta
|
||||||
else
|
else
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
function image_loader.load_float(file)
|
function image_loader.load_float(file)
|
||||||
|
|
11
waifu2x.lua
11
waifu2x.lua
|
@ -11,7 +11,8 @@ local alpha_util = require 'alpha_util'
|
||||||
torch.setdefaulttensortype('torch.FloatTensor')
|
torch.setdefaulttensortype('torch.FloatTensor')
|
||||||
|
|
||||||
local function convert_image(opt)
|
local function convert_image(opt)
|
||||||
local x, alpha = image_loader.load_float(opt.i)
|
local x, meta = image_loader.load_float(opt.i)
|
||||||
|
local alpha = meta.alpha
|
||||||
local new_x = nil
|
local new_x = nil
|
||||||
local t = sys.clock()
|
local t = sys.clock()
|
||||||
local scale_f, image_f
|
local scale_f, image_f
|
||||||
|
@ -65,7 +66,7 @@ local function convert_image(opt)
|
||||||
else
|
else
|
||||||
error("undefined method:" .. opt.method)
|
error("undefined method:" .. opt.method)
|
||||||
end
|
end
|
||||||
image_loader.save_png(opt.o, new_x, opt.depth, true)
|
image_loader.save_png(opt.o, new_x, {depth = opt.depth, inplace = true, gamma = meta.gamma})
|
||||||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||||
end
|
end
|
||||||
local function convert_frames(opt)
|
local function convert_frames(opt)
|
||||||
|
@ -115,7 +116,8 @@ local function convert_frames(opt)
|
||||||
fp:close()
|
fp:close()
|
||||||
for i = 1, #lines do
|
for i = 1, #lines do
|
||||||
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
|
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
|
||||||
local x, alpha = image_loader.load_float(lines[i])
|
local x, meta = image_loader.load_float(lines[i])
|
||||||
|
local alpha = meta.alpha
|
||||||
local new_x = nil
|
local new_x = nil
|
||||||
if opt.m == "noise" then
|
if opt.m == "noise" then
|
||||||
new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
|
new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
|
||||||
|
@ -141,7 +143,8 @@ local function convert_frames(opt)
|
||||||
else
|
else
|
||||||
output = string.format(opt.o, i)
|
output = string.format(opt.o, i)
|
||||||
end
|
end
|
||||||
image_loader.save_png(output, new_x, opt.depth, true)
|
image_loader.save_png(output, new_x,
|
||||||
|
{depth = opt.depth, inplace = true, gamma = meta.gamma})
|
||||||
xlua.progress(i, #lines)
|
xlua.progress(i, #lines)
|
||||||
if i % 10 == 0 then
|
if i % 10 == 0 then
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
|
|
48
web.lua
48
web.lua
|
@ -93,7 +93,7 @@ local function cache_url(url)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
end
|
end
|
||||||
local function get_image(req)
|
local function get_image(req)
|
||||||
local file_info = req:get_arguments("file")
|
local file_info = req:get_arguments("file")
|
||||||
|
@ -108,22 +108,23 @@ local function get_image(req)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if file and file:len() > 0 then
|
if file and file:len() > 0 then
|
||||||
local x, alpha, blob = image_loader.decode_float(file)
|
local x, meta = image_loader.decode_float(file)
|
||||||
return x, alpha, blob, filename
|
return x, meta, filename
|
||||||
elseif url and url:len() > 0 then
|
elseif url and url:len() > 0 then
|
||||||
local x, alpha, blob = cache_url(url)
|
local x, meta = cache_url(url)
|
||||||
return x, alpha, blob, filename
|
return x, meta, filename
|
||||||
end
|
end
|
||||||
return nil, nil, nil, nil
|
return nil, nil, nil
|
||||||
end
|
end
|
||||||
local function cleanup_model(model)
|
local function cleanup_model(model)
|
||||||
if CLEANUP_MODEL then
|
if CLEANUP_MODEL then
|
||||||
model:clearState() -- release GPU memory
|
model:clearState() -- release GPU memory
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local function convert(x, alpha, options)
|
local function convert(x, meta, options)
|
||||||
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
|
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
|
||||||
local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
|
local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
|
||||||
|
local alpha = meta.alpha
|
||||||
local alpha_orig = alpha
|
local alpha_orig = alpha
|
||||||
|
|
||||||
if path.exists(alpha_cache_file) then
|
if path.exists(alpha_cache_file) then
|
||||||
|
@ -137,7 +138,7 @@ local function convert(x, alpha, options)
|
||||||
end
|
end
|
||||||
if path.exists(cache_file) then
|
if path.exists(cache_file) then
|
||||||
x = image_loader.load_float(cache_file)
|
x = image_loader.load_float(cache_file)
|
||||||
return x, alpha
|
return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
|
||||||
else
|
else
|
||||||
if options.style == "art" then
|
if options.style == "art" then
|
||||||
if options.border then
|
if options.border then
|
||||||
|
@ -192,7 +193,7 @@ local function convert(x, alpha, options)
|
||||||
end
|
end
|
||||||
image_loader.save_png(cache_file, x)
|
image_loader.save_png(cache_file, x)
|
||||||
|
|
||||||
return x, alpha
|
return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local function client_disconnected(handler)
|
local function client_disconnected(handler)
|
||||||
|
@ -218,7 +219,7 @@ function APIHandler:post()
|
||||||
self:write("client disconnected")
|
self:write("client disconnected")
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local x, alpha, blob, filename = get_image(self)
|
local x, meta, filename = get_image(self)
|
||||||
local scale = tonumber(self:get_argument("scale", "0"))
|
local scale = tonumber(self:get_argument("scale", "0"))
|
||||||
local noise = tonumber(self:get_argument("noise", "0"))
|
local noise = tonumber(self:get_argument("noise", "0"))
|
||||||
local style = self:get_argument("style", "art")
|
local style = self:get_argument("style", "art")
|
||||||
|
@ -230,29 +231,29 @@ function APIHandler:post()
|
||||||
if x and valid_size(x, scale) then
|
if x and valid_size(x, scale) then
|
||||||
local prefix = nil
|
local prefix = nil
|
||||||
if (noise ~= 0 or scale ~= 0) then
|
if (noise ~= 0 or scale ~= 0) then
|
||||||
local hash = md5.sumhexa(blob)
|
local hash = md5.sumhexa(meta.blob)
|
||||||
local alpha_prefix = style .. "_" .. hash .. "_alpha"
|
local alpha_prefix = style .. "_" .. hash .. "_alpha"
|
||||||
local border = false
|
local border = false
|
||||||
if scale ~= 0 and alpha then
|
if scale ~= 0 and meta.alpha then
|
||||||
border = true
|
border = true
|
||||||
end
|
end
|
||||||
if noise == 1 then
|
if noise == 1 then
|
||||||
prefix = style .. "_noise1_"
|
prefix = style .. "_noise1_"
|
||||||
x = convert(x, alpha, {method = "noise1", style = style,
|
x = convert(x, meta, {method = "noise1", style = style,
|
||||||
prefix = prefix .. hash,
|
prefix = prefix .. hash,
|
||||||
alpha_prefix = alpha_prefix, border = border})
|
alpha_prefix = alpha_prefix, border = border})
|
||||||
border = false
|
border = false
|
||||||
elseif noise == 2 then
|
elseif noise == 2 then
|
||||||
prefix = style .. "_noise2_"
|
prefix = style .. "_noise2_"
|
||||||
x = convert(x, alpha, {method = "noise2", style = style,
|
x = convert(x, meta, {method = "noise2", style = style,
|
||||||
prefix = prefix .. hash,
|
prefix = prefix .. hash,
|
||||||
alpha_prefix = alpha_prefix, border = border})
|
alpha_prefix = alpha_prefix, border = border})
|
||||||
border = false
|
border = false
|
||||||
elseif noise == 3 then
|
elseif noise == 3 then
|
||||||
prefix = style .. "_noise3_"
|
prefix = style .. "_noise3_"
|
||||||
x = convert(x, alpha, {method = "noise3", style = style,
|
x = convert(x, meta, {method = "noise3", style = style,
|
||||||
prefix = prefix .. hash,
|
prefix = prefix .. hash,
|
||||||
alpha_prefix = alpha_prefix, border = border})
|
alpha_prefix = alpha_prefix, border = border})
|
||||||
border = false
|
border = false
|
||||||
end
|
end
|
||||||
if scale == 1 or scale == 2 then
|
if scale == 1 or scale == 2 then
|
||||||
|
@ -265,7 +266,7 @@ function APIHandler:post()
|
||||||
else
|
else
|
||||||
prefix = style .. "_scale_"
|
prefix = style .. "_scale_"
|
||||||
end
|
end
|
||||||
x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
|
x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
|
||||||
if scale == 1 then
|
if scale == 1 then
|
||||||
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
|
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
|
||||||
end
|
end
|
||||||
|
@ -281,7 +282,8 @@ function APIHandler:post()
|
||||||
else
|
else
|
||||||
name = uuid() .. ".png"
|
name = uuid() .. ".png"
|
||||||
end
|
end
|
||||||
local blob = image_loader.encode_png(alpha_util.composite(x, alpha), 8, true)
|
local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
|
||||||
|
{ depth = 8, inplace = true, gamma = meta.gamma})
|
||||||
|
|
||||||
self:set_header("Content-Length", string.format("%d", #blob))
|
self:set_header("Content-Length", string.format("%d", #blob))
|
||||||
if download > 0 then
|
if download > 0 then
|
||||||
|
|
Loading…
Reference in a new issue