1
0
Fork 0
mirror of synced 2024-06-28 19:20:32 +12:00

add support for transparent png (retain the alpha channel)

This commit is contained in:
nagadomi 2015-06-16 20:41:48 +09:00
parent f8112367e5
commit bbc3cac273
3 changed files with 86 additions and 59 deletions

View file

@ -1,44 +1,72 @@
local gm = require 'graphicsmagick' local gm = require 'graphicsmagick'
local ffi = require 'ffi'
require 'pl' require 'pl'
local image_loader = {} local image_loader = {}
function image_loader.decode_float(blob) function image_loader.decode_float(blob)
local im = image_loader.decode_byte(blob) local im, alpha = image_loader.decode_byte(blob)
if im then if im then
im = im:float():div(255) im = im:float():div(255)
end end
return im return im, alpha
end end
function image_loader.encode_png(tensor) function image_loader.encode_png(rgb, alpha)
local im = gm.Image(tensor, "RGB", "DHW") if rgb:type() == "torch.ByteTensor" then
im:format("png") error("expect FloatTensor")
return im:toBlob() end
if alpha then
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
end
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
rgba[1]:copy(rgb[1])
rgba[2]:copy(rgb[2])
rgba[3]:copy(rgb[3])
rgba[4]:copy(alpha)
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
im:format("png")
return im:toBlob()
else
local im = gm.Image(rgb, "RGB", "DHW")
im:format("png")
return im:toBlob()
end
end
function image_loader.save_png(filename, rgb, alpha)
local blob, len = image_loader.encode_png(rgb, alpha)
local fp = io.open(filename, "wb")
fp:write(ffi.string(blob, len))
fp:close()
return true
end end
function image_loader.decode_byte(blob) function image_loader.decode_byte(blob)
local load_image = function() local load_image = function()
local im = gm.Image() local im = gm.Image()
local alpha = nil
im:fromBlob(blob, #blob) im:fromBlob(blob, #blob)
-- FIXME: How to detect that a image has an alpha channel? -- FIXME: How to detect that a image has an alpha channel?
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
-- merge alpha channel -- split alpha channel
im = im:toTensor('float', 'RGBA', 'DHW') im = im:toTensor('float', 'RGBA', 'DHW')
local w2 = im[4] local sum_alpha = (im[4] - 1):sum()
local w1 = im[4] * -1 + 1 if sum_alpha > 0 or sum_alpha < 0 then
alpha = im[4]:reshape(1, im:size(2), im:size(3))
end
local new_im = torch.FloatTensor(3, im:size(2), im:size(3)) local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
-- apply the white background new_im[1]:copy(im[1])
new_im[1]:copy(im[1]):cmul(w2):add(w1) new_im[2]:copy(im[2])
new_im[2]:copy(im[2]):cmul(w2):add(w1) new_im[3]:copy(im[3])
new_im[3]:copy(im[3]):cmul(w2):add(w1)
im = new_im:mul(255):byte() im = new_im:mul(255):byte()
else else
im = im:toTensor('byte', 'RGB', 'DHW') im = im:toTensor('byte', 'RGB', 'DHW')
end end
return im return {im, alpha}
end end
local state, ret = pcall(load_image) local state, ret = pcall(load_image)
if state then if state then
return ret return ret[1], ret[2]
else else
return nil return nil
end end

View file

@ -6,13 +6,12 @@ require './lib/LeakyReLU'
local iproc = require './lib/iproc' local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct' local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader' local image_loader = require './lib/image_loader'
local BLOCK_OFFSET = 7 local BLOCK_OFFSET = 7
torch.setdefaulttensortype('torch.FloatTensor') torch.setdefaulttensortype('torch.FloatTensor')
local function convert_image(opt) local function convert_image(opt)
local x = image_loader.load_float(opt.i) local x, alpha = image_loader.load_float(opt.i)
local new_x = nil local new_x = nil
local t = sys.clock() local t = sys.clock()
if opt.o == "(auto)" then if opt.o == "(auto)" then
@ -39,7 +38,7 @@ local function convert_image(opt)
else else
error("undefined method:" .. opt.method) error("undefined method:" .. opt.method)
end end
image.save(opt.o, new_x) image_loader.save_png(opt.o, new_x, alpha)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec") print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end end
local function convert_frames(opt) local function convert_frames(opt)
@ -59,41 +58,41 @@ local function convert_frames(opt)
end end
fp:close() fp:close()
for i = 1, #lines do for i = 1, #lines do
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
local x = image_loader.load_float(lines[i]) local x, alpha = image_loader.load_float(lines[i])
local new_x = nil local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then if opt.m == "noise" and opt.noise_level == 1 then
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size) new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise" and opt.noise_level == 2 then elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
elseif opt.m == "scale" then elseif opt.m == "scale" then
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then elseif opt.m == "noise_scale" and opt.noise_level == 1 then
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET) x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
else else
error("undefined method:" .. opt.method) error("undefined method:" .. opt.method)
end end
local output = nil local output = nil
if opt.o == "(auto)" then if opt.o == "(auto)" then
local name = path.basename(lines[i]) local name = path.basename(lines[i])
local e = path.extension(name) local e = path.extension(name)
local base = name:sub(0, name:len() - e:len()) local base = name:sub(0, name:len() - e:len())
output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m)) output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
else else
output = string.format(opt.o, i) output = string.format(opt.o, i)
end end
image.save(output, new_x) image_loader.save_png(output, new_x, alpha)
xlua.progress(i, #lines) xlua.progress(i, #lines)
if i % 10 == 0 then if i % 10 == 0 then
collectgarbage() collectgarbage()
end end
else else
xlua.progress(i, #lines) xlua.progress(i, #lines)
end end
end end
end end

12
web.lua
View file

@ -47,10 +47,10 @@ local function get_image(req)
local url = req:get_argument("url", "") local url = req:get_argument("url", "")
local blob = nil local blob = nil
local img = nil local img = nil
local alpha = nil
if file and file:len() > 0 then if file and file:len() > 0 then
blob = file blob = file
img = image_loader.decode_float(blob) img, alpha = image_loader.decode_float(blob)
elseif url and url:len() > 0 then elseif url and url:len() > 0 then
local res = coroutine.yield( local res = coroutine.yield(
turbo.async.HTTPClient({verify_ca=false}, turbo.async.HTTPClient({verify_ca=false},
@ -64,11 +64,11 @@ local function get_image(req)
end end
if content_type and content_type:find("image") then if content_type and content_type:find("image") then
blob = res.body blob = res.body
img = image_loader.decode_float(blob) img, alpha = image_loader.decode_float(blob)
end end
end end
end end
return img, blob return img, blob, alpha
end end
local function apply_denoise1(x) local function apply_denoise1(x)
@ -104,7 +104,7 @@ function APIHandler:post()
self:write("client disconnected") self:write("client disconnected")
return return
end end
local x, src = get_image(self) local x, src, alpha = get_image(self)
local scale = tonumber(self:get_argument("scale", "0")) local scale = tonumber(self:get_argument("scale", "0"))
local noise = tonumber(self:get_argument("noise", "0")) local noise = tonumber(self:get_argument("noise", "0"))
if x and valid_size(x, scale) then if x and valid_size(x, scale) then
@ -151,7 +151,7 @@ function APIHandler:post()
end end
end end
local name = uuid() .. ".png" local name = uuid() .. ".png"
local blob, len = image_loader.encode_png(x) local blob, len = image_loader.encode_png(x, alpha)
self:set_header("Content-Disposition", string.format('filename="%s"', name)) self:set_header("Content-Disposition", string.format('filename="%s"', name))
self:set_header("Content-Type", "image/png") self:set_header("Content-Type", "image/png")