add support for transparent png (retain the alpha channel)
This commit is contained in:
parent
f8112367e5
commit
bbc3cac273
3 changed files with 86 additions and 59 deletions
|
@ -1,44 +1,72 @@
|
||||||
local gm = require 'graphicsmagick'
|
local gm = require 'graphicsmagick'
|
||||||
|
local ffi = require 'ffi'
|
||||||
require 'pl'
|
require 'pl'
|
||||||
|
|
||||||
local image_loader = {}
|
local image_loader = {}
|
||||||
|
|
||||||
function image_loader.decode_float(blob)
|
function image_loader.decode_float(blob)
|
||||||
local im = image_loader.decode_byte(blob)
|
local im, alpha = image_loader.decode_byte(blob)
|
||||||
if im then
|
if im then
|
||||||
im = im:float():div(255)
|
im = im:float():div(255)
|
||||||
end
|
end
|
||||||
return im
|
return im, alpha
|
||||||
end
|
end
|
||||||
function image_loader.encode_png(tensor)
|
function image_loader.encode_png(rgb, alpha)
|
||||||
local im = gm.Image(tensor, "RGB", "DHW")
|
if rgb:type() == "torch.ByteTensor" then
|
||||||
im:format("png")
|
error("expect FloatTensor")
|
||||||
return im:toBlob()
|
end
|
||||||
|
if alpha then
|
||||||
|
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
||||||
|
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
|
||||||
|
end
|
||||||
|
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
|
||||||
|
rgba[1]:copy(rgb[1])
|
||||||
|
rgba[2]:copy(rgb[2])
|
||||||
|
rgba[3]:copy(rgb[3])
|
||||||
|
rgba[4]:copy(alpha)
|
||||||
|
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
||||||
|
im:format("png")
|
||||||
|
return im:toBlob()
|
||||||
|
else
|
||||||
|
local im = gm.Image(rgb, "RGB", "DHW")
|
||||||
|
im:format("png")
|
||||||
|
return im:toBlob()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function image_loader.save_png(filename, rgb, alpha)
|
||||||
|
local blob, len = image_loader.encode_png(rgb, alpha)
|
||||||
|
local fp = io.open(filename, "wb")
|
||||||
|
fp:write(ffi.string(blob, len))
|
||||||
|
fp:close()
|
||||||
|
return true
|
||||||
end
|
end
|
||||||
function image_loader.decode_byte(blob)
|
function image_loader.decode_byte(blob)
|
||||||
local load_image = function()
|
local load_image = function()
|
||||||
local im = gm.Image()
|
local im = gm.Image()
|
||||||
|
local alpha = nil
|
||||||
|
|
||||||
im:fromBlob(blob, #blob)
|
im:fromBlob(blob, #blob)
|
||||||
-- FIXME: How to detect that a image has an alpha channel?
|
-- FIXME: How to detect that a image has an alpha channel?
|
||||||
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
||||||
-- merge alpha channel
|
-- split alpha channel
|
||||||
im = im:toTensor('float', 'RGBA', 'DHW')
|
im = im:toTensor('float', 'RGBA', 'DHW')
|
||||||
local w2 = im[4]
|
local sum_alpha = (im[4] - 1):sum()
|
||||||
local w1 = im[4] * -1 + 1
|
if sum_alpha > 0 or sum_alpha < 0 then
|
||||||
|
alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
||||||
|
end
|
||||||
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
|
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
|
||||||
-- apply the white background
|
new_im[1]:copy(im[1])
|
||||||
new_im[1]:copy(im[1]):cmul(w2):add(w1)
|
new_im[2]:copy(im[2])
|
||||||
new_im[2]:copy(im[2]):cmul(w2):add(w1)
|
new_im[3]:copy(im[3])
|
||||||
new_im[3]:copy(im[3]):cmul(w2):add(w1)
|
|
||||||
im = new_im:mul(255):byte()
|
im = new_im:mul(255):byte()
|
||||||
else
|
else
|
||||||
im = im:toTensor('byte', 'RGB', 'DHW')
|
im = im:toTensor('byte', 'RGB', 'DHW')
|
||||||
end
|
end
|
||||||
return im
|
return {im, alpha}
|
||||||
end
|
end
|
||||||
local state, ret = pcall(load_image)
|
local state, ret = pcall(load_image)
|
||||||
if state then
|
if state then
|
||||||
return ret
|
return ret[1], ret[2]
|
||||||
else
|
else
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
75
waifu2x.lua
75
waifu2x.lua
|
@ -6,13 +6,12 @@ require './lib/LeakyReLU'
|
||||||
local iproc = require './lib/iproc'
|
local iproc = require './lib/iproc'
|
||||||
local reconstruct = require './lib/reconstruct'
|
local reconstruct = require './lib/reconstruct'
|
||||||
local image_loader = require './lib/image_loader'
|
local image_loader = require './lib/image_loader'
|
||||||
|
|
||||||
local BLOCK_OFFSET = 7
|
local BLOCK_OFFSET = 7
|
||||||
|
|
||||||
torch.setdefaulttensortype('torch.FloatTensor')
|
torch.setdefaulttensortype('torch.FloatTensor')
|
||||||
|
|
||||||
local function convert_image(opt)
|
local function convert_image(opt)
|
||||||
local x = image_loader.load_float(opt.i)
|
local x, alpha = image_loader.load_float(opt.i)
|
||||||
local new_x = nil
|
local new_x = nil
|
||||||
local t = sys.clock()
|
local t = sys.clock()
|
||||||
if opt.o == "(auto)" then
|
if opt.o == "(auto)" then
|
||||||
|
@ -39,7 +38,7 @@ local function convert_image(opt)
|
||||||
else
|
else
|
||||||
error("undefined method:" .. opt.method)
|
error("undefined method:" .. opt.method)
|
||||||
end
|
end
|
||||||
image.save(opt.o, new_x)
|
image_loader.save_png(opt.o, new_x, alpha)
|
||||||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||||
end
|
end
|
||||||
local function convert_frames(opt)
|
local function convert_frames(opt)
|
||||||
|
@ -59,41 +58,41 @@ local function convert_frames(opt)
|
||||||
end
|
end
|
||||||
fp:close()
|
fp:close()
|
||||||
for i = 1, #lines do
|
for i = 1, #lines do
|
||||||
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
|
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
|
||||||
local x = image_loader.load_float(lines[i])
|
local x, alpha = image_loader.load_float(lines[i])
|
||||||
local new_x = nil
|
local new_x = nil
|
||||||
if opt.m == "noise" and opt.noise_level == 1 then
|
if opt.m == "noise" and opt.noise_level == 1 then
|
||||||
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
|
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
|
||||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||||
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
||||||
elseif opt.m == "scale" then
|
elseif opt.m == "scale" then
|
||||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||||
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
|
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
|
||||||
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
|
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
|
||||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||||
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
|
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
|
||||||
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
||||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||||
else
|
else
|
||||||
error("undefined method:" .. opt.method)
|
error("undefined method:" .. opt.method)
|
||||||
end
|
end
|
||||||
local output = nil
|
local output = nil
|
||||||
if opt.o == "(auto)" then
|
if opt.o == "(auto)" then
|
||||||
local name = path.basename(lines[i])
|
local name = path.basename(lines[i])
|
||||||
local e = path.extension(name)
|
local e = path.extension(name)
|
||||||
local base = name:sub(0, name:len() - e:len())
|
local base = name:sub(0, name:len() - e:len())
|
||||||
output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
|
output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
|
||||||
else
|
else
|
||||||
output = string.format(opt.o, i)
|
output = string.format(opt.o, i)
|
||||||
end
|
end
|
||||||
image.save(output, new_x)
|
image_loader.save_png(output, new_x, alpha)
|
||||||
xlua.progress(i, #lines)
|
xlua.progress(i, #lines)
|
||||||
if i % 10 == 0 then
|
if i % 10 == 0 then
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
xlua.progress(i, #lines)
|
xlua.progress(i, #lines)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
12
web.lua
12
web.lua
|
@ -47,10 +47,10 @@ local function get_image(req)
|
||||||
local url = req:get_argument("url", "")
|
local url = req:get_argument("url", "")
|
||||||
local blob = nil
|
local blob = nil
|
||||||
local img = nil
|
local img = nil
|
||||||
|
local alpha = nil
|
||||||
if file and file:len() > 0 then
|
if file and file:len() > 0 then
|
||||||
blob = file
|
blob = file
|
||||||
img = image_loader.decode_float(blob)
|
img, alpha = image_loader.decode_float(blob)
|
||||||
elseif url and url:len() > 0 then
|
elseif url and url:len() > 0 then
|
||||||
local res = coroutine.yield(
|
local res = coroutine.yield(
|
||||||
turbo.async.HTTPClient({verify_ca=false},
|
turbo.async.HTTPClient({verify_ca=false},
|
||||||
|
@ -64,11 +64,11 @@ local function get_image(req)
|
||||||
end
|
end
|
||||||
if content_type and content_type:find("image") then
|
if content_type and content_type:find("image") then
|
||||||
blob = res.body
|
blob = res.body
|
||||||
img = image_loader.decode_float(blob)
|
img, alpha = image_loader.decode_float(blob)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return img, blob
|
return img, blob, alpha
|
||||||
end
|
end
|
||||||
|
|
||||||
local function apply_denoise1(x)
|
local function apply_denoise1(x)
|
||||||
|
@ -104,7 +104,7 @@ function APIHandler:post()
|
||||||
self:write("client disconnected")
|
self:write("client disconnected")
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local x, src = get_image(self)
|
local x, src, alpha = get_image(self)
|
||||||
local scale = tonumber(self:get_argument("scale", "0"))
|
local scale = tonumber(self:get_argument("scale", "0"))
|
||||||
local noise = tonumber(self:get_argument("noise", "0"))
|
local noise = tonumber(self:get_argument("noise", "0"))
|
||||||
if x and valid_size(x, scale) then
|
if x and valid_size(x, scale) then
|
||||||
|
@ -151,7 +151,7 @@ function APIHandler:post()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local name = uuid() .. ".png"
|
local name = uuid() .. ".png"
|
||||||
local blob, len = image_loader.encode_png(x)
|
local blob, len = image_loader.encode_png(x, alpha)
|
||||||
|
|
||||||
self:set_header("Content-Disposition", string.format('filename="%s"', name))
|
self:set_header("Content-Disposition", string.format('filename="%s"', name))
|
||||||
self:set_header("Content-Type", "image/png")
|
self:set_header("Content-Type", "image/png")
|
||||||
|
|
Loading…
Reference in a new issue