add support for transparent png (retain the alpha channel)
This commit is contained in:
parent
f8112367e5
commit
bbc3cac273
3 changed files with 86 additions and 59 deletions
|
@ -1,44 +1,72 @@
|
|||
local gm = require 'graphicsmagick'
|
||||
local ffi = require 'ffi'
|
||||
require 'pl'
|
||||
|
||||
local image_loader = {}
|
||||
|
||||
function image_loader.decode_float(blob)
|
||||
local im = image_loader.decode_byte(blob)
|
||||
local im, alpha = image_loader.decode_byte(blob)
|
||||
if im then
|
||||
im = im:float():div(255)
|
||||
end
|
||||
return im
|
||||
return im, alpha
|
||||
end
|
||||
function image_loader.encode_png(tensor)
|
||||
local im = gm.Image(tensor, "RGB", "DHW")
|
||||
function image_loader.encode_png(rgb, alpha)
|
||||
if rgb:type() == "torch.ByteTensor" then
|
||||
error("expect FloatTensor")
|
||||
end
|
||||
if alpha then
|
||||
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
||||
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
|
||||
end
|
||||
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
|
||||
rgba[1]:copy(rgb[1])
|
||||
rgba[2]:copy(rgb[2])
|
||||
rgba[3]:copy(rgb[3])
|
||||
rgba[4]:copy(alpha)
|
||||
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
||||
im:format("png")
|
||||
return im:toBlob()
|
||||
else
|
||||
local im = gm.Image(rgb, "RGB", "DHW")
|
||||
im:format("png")
|
||||
return im:toBlob()
|
||||
end
|
||||
end
|
||||
function image_loader.save_png(filename, rgb, alpha)
|
||||
local blob, len = image_loader.encode_png(rgb, alpha)
|
||||
local fp = io.open(filename, "wb")
|
||||
fp:write(ffi.string(blob, len))
|
||||
fp:close()
|
||||
return true
|
||||
end
|
||||
function image_loader.decode_byte(blob)
|
||||
local load_image = function()
|
||||
local im = gm.Image()
|
||||
local alpha = nil
|
||||
|
||||
im:fromBlob(blob, #blob)
|
||||
-- FIXME: How to detect that a image has an alpha channel?
|
||||
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
||||
-- merge alpha channel
|
||||
-- split alpha channel
|
||||
im = im:toTensor('float', 'RGBA', 'DHW')
|
||||
local w2 = im[4]
|
||||
local w1 = im[4] * -1 + 1
|
||||
local sum_alpha = (im[4] - 1):sum()
|
||||
if sum_alpha > 0 or sum_alpha < 0 then
|
||||
alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
||||
end
|
||||
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
|
||||
-- apply the white background
|
||||
new_im[1]:copy(im[1]):cmul(w2):add(w1)
|
||||
new_im[2]:copy(im[2]):cmul(w2):add(w1)
|
||||
new_im[3]:copy(im[3]):cmul(w2):add(w1)
|
||||
new_im[1]:copy(im[1])
|
||||
new_im[2]:copy(im[2])
|
||||
new_im[3]:copy(im[3])
|
||||
im = new_im:mul(255):byte()
|
||||
else
|
||||
im = im:toTensor('byte', 'RGB', 'DHW')
|
||||
end
|
||||
return im
|
||||
return {im, alpha}
|
||||
end
|
||||
local state, ret = pcall(load_image)
|
||||
if state then
|
||||
return ret
|
||||
return ret[1], ret[2]
|
||||
else
|
||||
return nil
|
||||
end
|
||||
|
|
|
@ -6,13 +6,12 @@ require './lib/LeakyReLU'
|
|||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
|
||||
local BLOCK_OFFSET = 7
|
||||
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
local function convert_image(opt)
|
||||
local x = image_loader.load_float(opt.i)
|
||||
local x, alpha = image_loader.load_float(opt.i)
|
||||
local new_x = nil
|
||||
local t = sys.clock()
|
||||
if opt.o == "(auto)" then
|
||||
|
@ -39,7 +38,7 @@ local function convert_image(opt)
|
|||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
image.save(opt.o, new_x)
|
||||
image_loader.save_png(opt.o, new_x, alpha)
|
||||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||
end
|
||||
local function convert_frames(opt)
|
||||
|
@ -60,7 +59,7 @@ local function convert_frames(opt)
|
|||
fp:close()
|
||||
for i = 1, #lines do
|
||||
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
|
||||
local x = image_loader.load_float(lines[i])
|
||||
local x, alpha = image_loader.load_float(lines[i])
|
||||
local new_x = nil
|
||||
if opt.m == "noise" and opt.noise_level == 1 then
|
||||
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
|
||||
|
@ -86,7 +85,7 @@ local function convert_frames(opt)
|
|||
else
|
||||
output = string.format(opt.o, i)
|
||||
end
|
||||
image.save(output, new_x)
|
||||
image_loader.save_png(output, new_x, alpha)
|
||||
xlua.progress(i, #lines)
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
|
|
12
web.lua
12
web.lua
|
@ -47,10 +47,10 @@ local function get_image(req)
|
|||
local url = req:get_argument("url", "")
|
||||
local blob = nil
|
||||
local img = nil
|
||||
|
||||
local alpha = nil
|
||||
if file and file:len() > 0 then
|
||||
blob = file
|
||||
img = image_loader.decode_float(blob)
|
||||
img, alpha = image_loader.decode_float(blob)
|
||||
elseif url and url:len() > 0 then
|
||||
local res = coroutine.yield(
|
||||
turbo.async.HTTPClient({verify_ca=false},
|
||||
|
@ -64,11 +64,11 @@ local function get_image(req)
|
|||
end
|
||||
if content_type and content_type:find("image") then
|
||||
blob = res.body
|
||||
img = image_loader.decode_float(blob)
|
||||
img, alpha = image_loader.decode_float(blob)
|
||||
end
|
||||
end
|
||||
end
|
||||
return img, blob
|
||||
return img, blob, alpha
|
||||
end
|
||||
|
||||
local function apply_denoise1(x)
|
||||
|
@ -104,7 +104,7 @@ function APIHandler:post()
|
|||
self:write("client disconnected")
|
||||
return
|
||||
end
|
||||
local x, src = get_image(self)
|
||||
local x, src, alpha = get_image(self)
|
||||
local scale = tonumber(self:get_argument("scale", "0"))
|
||||
local noise = tonumber(self:get_argument("noise", "0"))
|
||||
if x and valid_size(x, scale) then
|
||||
|
@ -151,7 +151,7 @@ function APIHandler:post()
|
|||
end
|
||||
end
|
||||
local name = uuid() .. ".png"
|
||||
local blob, len = image_loader.encode_png(x)
|
||||
local blob, len = image_loader.encode_png(x, alpha)
|
||||
|
||||
self:set_header("Content-Disposition", string.format('filename="%s"', name))
|
||||
self:set_header("Content-Type", "image/png")
|
||||
|
|
Loading…
Reference in a new issue