From bbc3cac273c76e5f9423b9d98d940cc44f90f2ea Mon Sep 17 00:00:00 2001 From: nagadomi Date: Tue, 16 Jun 2015 20:41:48 +0900 Subject: [PATCH] add support for transparent png (retain the alpha channel) --- lib/image_loader.lua | 58 +++++++++++++++++++++++++--------- waifu2x.lua | 75 ++++++++++++++++++++++---------------------- web.lua | 12 +++---- 3 files changed, 86 insertions(+), 59 deletions(-) diff --git a/lib/image_loader.lua b/lib/image_loader.lua index 0b1009d..bdca38a 100644 --- a/lib/image_loader.lua +++ b/lib/image_loader.lua @@ -1,44 +1,72 @@ local gm = require 'graphicsmagick' +local ffi = require 'ffi' require 'pl' local image_loader = {} function image_loader.decode_float(blob) - local im = image_loader.decode_byte(blob) + local im, alpha = image_loader.decode_byte(blob) if im then im = im:float():div(255) end - return im + return im, alpha end -function image_loader.encode_png(tensor) - local im = gm.Image(tensor, "RGB", "DHW") - im:format("png") - return im:toBlob() +function image_loader.encode_png(rgb, alpha) + if rgb:type() == "torch.ByteTensor" then + error("expect FloatTensor") + end + if alpha then + if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then + alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW") + end + local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3)) + rgba[1]:copy(rgb[1]) + rgba[2]:copy(rgb[2]) + rgba[3]:copy(rgb[3]) + rgba[4]:copy(alpha) + local im = gm.Image():fromTensor(rgba, "RGBA", "DHW") + im:format("png") + return im:toBlob() + else + local im = gm.Image(rgb, "RGB", "DHW") + im:format("png") + return im:toBlob() + end +end +function image_loader.save_png(filename, rgb, alpha) + local blob, len = image_loader.encode_png(rgb, alpha) + local fp = io.open(filename, "wb") + fp:write(ffi.string(blob, len)) + fp:close() + return true end function image_loader.decode_byte(blob) local load_image = function() local im = gm.Image() + local alpha = nil + im:fromBlob(blob, #blob) -- FIXME: How to detect that a image has an alpha channel? if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then - -- merge alpha channel + -- split alpha channel im = im:toTensor('float', 'RGBA', 'DHW') - local w2 = im[4] - local w1 = im[4] * -1 + 1 + local sum_alpha = (im[4] - 1):sum() + if sum_alpha > 0 or sum_alpha < 0 then + alpha = im[4]:reshape(1, im:size(2), im:size(3)) + end local new_im = torch.FloatTensor(3, im:size(2), im:size(3)) - -- apply the white background - new_im[1]:copy(im[1]):cmul(w2):add(w1) - new_im[2]:copy(im[2]):cmul(w2):add(w1) - new_im[3]:copy(im[3]):cmul(w2):add(w1) + new_im[1]:copy(im[1]) + new_im[2]:copy(im[2]) + new_im[3]:copy(im[3]) im = new_im:mul(255):byte() else im = im:toTensor('byte', 'RGB', 'DHW') end - return im + return {im, alpha} end local state, ret = pcall(load_image) if state then - return ret + return ret[1], ret[2] else return nil end diff --git a/waifu2x.lua b/waifu2x.lua index da04feb..dde08ad 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -6,13 +6,12 @@ require './lib/LeakyReLU' local iproc = require './lib/iproc' local reconstruct = require './lib/reconstruct' local image_loader = require './lib/image_loader' - local BLOCK_OFFSET = 7 torch.setdefaulttensortype('torch.FloatTensor') local function convert_image(opt) - local x = image_loader.load_float(opt.i) + local x, alpha = image_loader.load_float(opt.i) local new_x = nil local t = sys.clock() if opt.o == "(auto)" then @@ -39,7 +38,7 @@ local function convert_image(opt) else error("undefined method:" .. opt.method) end - image.save(opt.o, new_x) + image_loader.save_png(opt.o, new_x, alpha) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") end local function convert_frames(opt) @@ -59,41 +58,41 @@ local function convert_frames(opt) end fp:close() for i = 1, #lines do - if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then - local x = image_loader.load_float(lines[i]) - local new_x = nil - if opt.m == "noise" and opt.noise_level == 1 then - new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size) - elseif opt.m == "noise" and opt.noise_level == 2 then - new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) - elseif opt.m == "scale" then - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) - elseif opt.m == "noise_scale" and opt.noise_level == 1 then - x = reconstruct.image(noise1_model, x, BLOCK_OFFSET) - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) - elseif opt.m == "noise_scale" and opt.noise_level == 2 then - x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) - else - error("undefined method:" .. opt.method) - end - local output = nil - if opt.o == "(auto)" then - local name = path.basename(lines[i]) - local e = path.extension(name) - local base = name:sub(0, name:len() - e:len()) - output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m)) - else - output = string.format(opt.o, i) - end - image.save(output, new_x) - xlua.progress(i, #lines) - if i % 10 == 0 then - collectgarbage() - end - else - xlua.progress(i, #lines) - end + if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then + local x, alpha = image_loader.load_float(lines[i]) + local new_x = nil + if opt.m == "noise" and opt.noise_level == 1 then + new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size) + elseif opt.m == "noise" and opt.noise_level == 2 then + new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) + elseif opt.m == "scale" then + new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) + elseif opt.m == "noise_scale" and opt.noise_level == 1 then + x = reconstruct.image(noise1_model, x, BLOCK_OFFSET) + new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) + elseif opt.m == "noise_scale" and opt.noise_level == 2 then + x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) + new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) + else + error("undefined method:" .. opt.method) + end + local output = nil + if opt.o == "(auto)" then + local name = path.basename(lines[i]) + local e = path.extension(name) + local base = name:sub(0, name:len() - e:len()) + output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m)) + else + output = string.format(opt.o, i) + end + image_loader.save_png(output, new_x, alpha) + xlua.progress(i, #lines) + if i % 10 == 0 then + collectgarbage() + end + else + xlua.progress(i, #lines) + end end end diff --git a/web.lua b/web.lua index 4077935..51ce9b2 100644 --- a/web.lua +++ b/web.lua @@ -47,10 +47,10 @@ local function get_image(req) local url = req:get_argument("url", "") local blob = nil local img = nil - + local alpha = nil if file and file:len() > 0 then blob = file - img = image_loader.decode_float(blob) + img, alpha = image_loader.decode_float(blob) elseif url and url:len() > 0 then local res = coroutine.yield( turbo.async.HTTPClient({verify_ca=false}, @@ -64,11 +64,11 @@ local function get_image(req) end if content_type and content_type:find("image") then blob = res.body - img = image_loader.decode_float(blob) + img, alpha = image_loader.decode_float(blob) end end end - return img, blob + return img, blob, alpha end local function apply_denoise1(x) @@ -104,7 +104,7 @@ function APIHandler:post() self:write("client disconnected") return end - local x, src = get_image(self) + local x, src, alpha = get_image(self) local scale = tonumber(self:get_argument("scale", "0")) local noise = tonumber(self:get_argument("noise", "0")) if x and valid_size(x, scale) then @@ -151,7 +151,7 @@ function APIHandler:post() end end local name = uuid() .. ".png" - local blob, len = image_loader.encode_png(x) + local blob, len = image_loader.encode_png(x, alpha) self:set_header("Content-Disposition", string.format('filename="%s"', name)) self:set_header("Content-Type", "image/png")