From fbad30c0311eaaac8caddd259d180bbc782ebc72 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Fri, 15 Apr 2016 09:13:37 +0900 Subject: [PATCH] Fix embed gamma handling --- convert_data.lua | 4 ++-- lib/image_loader.lua | 53 ++++++++++++++++++++++---------------------- waifu2x.lua | 11 +++++---- web.lua | 48 ++++++++++++++++++++------------------- 4 files changed, 61 insertions(+), 55 deletions(-) diff --git a/convert_data.lua b/convert_data.lua index 6efd9a5..4cebfbd 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -33,8 +33,8 @@ local function load_images(list) local x = {} for i = 1, #lines do local line = lines[i] - local im, alpha = image_loader.load_byte(line) - if alpha then + local im, meta = image_loader.load_byte(line) + if meta and meta.alpha then io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line)) else if settings.max_training_image_size > 0 then diff --git a/lib/image_loader.lua b/lib/image_loader.lua index 3b2be8f..e164539 100644 --- a/lib/image_loader.lua +++ b/lib/image_loader.lua @@ -9,14 +9,15 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5) local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5) local background_color = 0.5 -function image_loader.encode_png(rgb, depth, inplace) - if inplace == nil then - inplace = false +function image_loader.encode_png(rgb, options) + options = options or {} + options.depth = options.depth or 8 + if options.inplace == nil then + options.inplace = false end - depth = depth or 8 rgb = iproc.byte2float(rgb) - if depth < 16 then - if inplace then + if options.depth < 16 then + if options.inplace then rgb:add(clip_eps8) else rgb = rgb:clone():add(clip_eps8) @@ -25,7 +26,7 @@ function image_loader.encode_png(rgb, depth, inplace) rgb[torch.gt(rgb, 1.0)] = 1.0 rgb = rgb:mul(255):floor():div(255) else - if inplace then + if options.inplace then rgb:add(clip_eps16) else rgb = rgb:clone():add(clip_eps16) @@ -43,10 +44,13 @@ function image_loader.encode_png(rgb, depth, inplace) im = gm.Image(rgb, "I", "DHW") -- im:colorspace("GRAY") -- it does not work end - return im:depth(depth):format("PNG"):toString(9) + if options.gamma then + im:gamma(options.gamma) + end + return im:depth(options.depth):format("PNG"):toString(9) end -function image_loader.save_png(filename, rgb, depth, inplace) - local blob = image_loader.encode_png(rgb, depth, inplace) +function image_loader.save_png(filename, rgb, options) + local blob = image_loader.encode_png(rgb, options) local fp = io.open(filename, "wb") if not fp then error("IO error: " .. filename) @@ -57,8 +61,8 @@ function image_loader.save_png(filename, rgb, depth, inplace) end function image_loader.decode_float(blob) local load_image = function() + local meta = {} local im = gm.Image() - local alpha = nil local gamma_lcd = 0.454545 im:fromBlob(blob, #blob) @@ -66,12 +70,8 @@ function image_loader.decode_float(blob) if im:colorspace() == "CMYK" then im:colorspace("RGB") end - local gamma = math.floor(im:gamma() * 1000000) / 1000000 - if gamma ~= 0 and gamma ~= gamma_lcd then - local cg = gamma / gamma_lcd - im:gammaCorrection(cg, "Red") - im:gammaCorrection(cg, "Blue") - im:gammaCorrection(cg, "Green") + if gamma ~= 0 and math.floor(im:gamma() * 1000000) / 1000000 ~= gamma_lcd then + meta.gamma = im:gamma() end -- FIXME: How to detect that a image has an alpha channel? if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then @@ -79,9 +79,9 @@ function image_loader.decode_float(blob) im = im:toTensor('float', 'RGBA', 'DHW') local sum_alpha = (im[4] - 1.0):sum() if sum_alpha < 0 then - alpha = im[4]:reshape(1, im:size(2), im:size(3)) + meta.alpha = im[4]:reshape(1, im:size(2), im:size(3)) -- drop full transparent background - local mask = torch.le(alpha, 0.0) + local mask = torch.le(meta.alpha, 0.0) im[1][mask] = background_color im[2][mask] = background_color im[3][mask] = background_color @@ -94,25 +94,26 @@ function image_loader.decode_float(blob) else im = im:toTensor('float', 'RGB', 'DHW') end - return {im, alpha, blob} + meta.blob = blob + return {im, meta} end local state, ret = pcall(load_image) if state then - return ret[1], ret[2], ret[3] + return ret[1], ret[2] else - return nil, nil, nil + return nil, nil end end function image_loader.decode_byte(blob) - local im, alpha - im, alpha, blob = image_loader.decode_float(blob) + local im, meta + im, meta = image_loader.decode_float(blob) if im then im = iproc.float2byte(im) -- hmm, alpha does not convert here - return im, alpha, blob + return im, meta else - return nil, nil, nil + return nil, nil end end function image_loader.load_float(file) diff --git a/waifu2x.lua b/waifu2x.lua index f2a8a7f..1138ef7 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -11,7 +11,8 @@ local alpha_util = require 'alpha_util' torch.setdefaulttensortype('torch.FloatTensor') local function convert_image(opt) - local x, alpha = image_loader.load_float(opt.i) + local x, meta = image_loader.load_float(opt.i) + local alpha = meta.alpha local new_x = nil local t = sys.clock() local scale_f, image_f @@ -65,7 +66,7 @@ local function convert_image(opt) else error("undefined method:" .. opt.method) end - image_loader.save_png(opt.o, new_x, opt.depth, true) + image_loader.save_png(opt.o, new_x, {depth = opt.depth, inplace = true, gamma = meta.gamma}) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") end local function convert_frames(opt) @@ -115,7 +116,8 @@ local function convert_frames(opt) fp:close() for i = 1, #lines do if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then - local x, alpha = image_loader.load_float(lines[i]) + local x, meta = image_loader.load_float(lines[i]) + local alpha = meta.alpha local new_x = nil if opt.m == "noise" then new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size) @@ -141,7 +143,8 @@ local function convert_frames(opt) else output = string.format(opt.o, i) end - image_loader.save_png(output, new_x, opt.depth, true) + image_loader.save_png(output, new_x, + {depth = opt.depth, inplace = true, gamma = meta.gamma}) xlua.progress(i, #lines) if i % 10 == 0 then collectgarbage() diff --git a/web.lua b/web.lua index 559b20a..8d16803 100644 --- a/web.lua +++ b/web.lua @@ -93,7 +93,7 @@ local function cache_url(url) end end end - return nil, nil, nil + return nil, nil end local function get_image(req) local file_info = req:get_arguments("file") @@ -108,22 +108,23 @@ local function get_image(req) end end if file and file:len() > 0 then - local x, alpha, blob = image_loader.decode_float(file) - return x, alpha, blob, filename + local x, meta = image_loader.decode_float(file) + return x, meta, filename elseif url and url:len() > 0 then - local x, alpha, blob = cache_url(url) - return x, alpha, blob, filename + local x, meta = cache_url(url) + return x, meta, filename end - return nil, nil, nil, nil + return nil, nil, nil end local function cleanup_model(model) if CLEANUP_MODEL then model:clearState() -- release GPU memory end end -local function convert(x, alpha, options) +local function convert(x, meta, options) local cache_file = path.join(CACHE_DIR, options.prefix .. ".png") local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png") + local alpha = meta.alpha local alpha_orig = alpha if path.exists(alpha_cache_file) then @@ -137,7 +138,7 @@ local function convert(x, alpha, options) end if path.exists(cache_file) then x = image_loader.load_float(cache_file) - return x, alpha + return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob} else if options.style == "art" then if options.border then @@ -192,7 +193,7 @@ local function convert(x, alpha, options) end image_loader.save_png(cache_file, x) - return x, alpha + return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob} end end local function client_disconnected(handler) @@ -218,7 +219,7 @@ function APIHandler:post() self:write("client disconnected") return end - local x, alpha, blob, filename = get_image(self) + local x, meta, filename = get_image(self) local scale = tonumber(self:get_argument("scale", "0")) local noise = tonumber(self:get_argument("noise", "0")) local style = self:get_argument("style", "art") @@ -230,29 +231,29 @@ function APIHandler:post() if x and valid_size(x, scale) then local prefix = nil if (noise ~= 0 or scale ~= 0) then - local hash = md5.sumhexa(blob) + local hash = md5.sumhexa(meta.blob) local alpha_prefix = style .. "_" .. hash .. "_alpha" local border = false - if scale ~= 0 and alpha then + if scale ~= 0 and meta.alpha then border = true end if noise == 1 then prefix = style .. "_noise1_" - x = convert(x, alpha, {method = "noise1", style = style, - prefix = prefix .. hash, - alpha_prefix = alpha_prefix, border = border}) + x = convert(x, meta, {method = "noise1", style = style, + prefix = prefix .. hash, + alpha_prefix = alpha_prefix, border = border}) border = false elseif noise == 2 then prefix = style .. "_noise2_" - x = convert(x, alpha, {method = "noise2", style = style, - prefix = prefix .. hash, - alpha_prefix = alpha_prefix, border = border}) + x = convert(x, meta, {method = "noise2", style = style, + prefix = prefix .. hash, + alpha_prefix = alpha_prefix, border = border}) border = false elseif noise == 3 then prefix = style .. "_noise3_" - x = convert(x, alpha, {method = "noise3", style = style, - prefix = prefix .. hash, - alpha_prefix = alpha_prefix, border = border}) + x = convert(x, meta, {method = "noise3", style = style, + prefix = prefix .. hash, + alpha_prefix = alpha_prefix, border = border}) border = false end if scale == 1 or scale == 2 then @@ -265,7 +266,7 @@ function APIHandler:post() else prefix = style .. "_scale_" end - x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) + x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) if scale == 1 then x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc") end @@ -281,7 +282,8 @@ function APIHandler:post() else name = uuid() .. ".png" end - local blob = image_loader.encode_png(alpha_util.composite(x, alpha), 8, true) + local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha), + { depth = 8, inplace = true, gamma = meta.gamma}) self:set_header("Content-Length", string.format("%d", #blob)) if download > 0 then