2015-05-16 17:48:05 +12:00
|
|
|
local gm = require 'graphicsmagick'
|
2015-06-16 23:41:48 +12:00
|
|
|
local ffi = require 'ffi'
|
2015-05-16 17:48:05 +12:00
|
|
|
require 'pl'
|
|
|
|
|
|
|
|
local image_loader = {}
|
|
|
|
|
|
|
|
function image_loader.decode_float(blob)
|
2015-06-16 23:41:48 +12:00
|
|
|
local im, alpha = image_loader.decode_byte(blob)
|
2015-05-16 17:48:05 +12:00
|
|
|
if im then
|
|
|
|
im = im:float():div(255)
|
|
|
|
end
|
2015-06-16 23:41:48 +12:00
|
|
|
return im, alpha
|
2015-05-16 17:48:05 +12:00
|
|
|
end
|
2015-06-16 23:41:48 +12:00
|
|
|
function image_loader.encode_png(rgb, alpha)
|
|
|
|
if rgb:type() == "torch.ByteTensor" then
|
2015-10-26 13:23:52 +13:00
|
|
|
rgb = rgb:float():div(255)
|
2015-06-16 23:41:48 +12:00
|
|
|
end
|
|
|
|
if alpha then
|
|
|
|
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
2015-10-28 19:30:47 +13:00
|
|
|
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
|
2015-06-16 23:41:48 +12:00
|
|
|
end
|
|
|
|
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
|
|
|
|
rgba[1]:copy(rgb[1])
|
|
|
|
rgba[2]:copy(rgb[2])
|
|
|
|
rgba[3]:copy(rgb[3])
|
|
|
|
rgba[4]:copy(alpha)
|
|
|
|
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
|
|
|
im:format("png")
|
2015-10-26 13:23:52 +13:00
|
|
|
return im:toBlob(9)
|
2015-06-16 23:41:48 +12:00
|
|
|
else
|
|
|
|
local im = gm.Image(rgb, "RGB", "DHW")
|
|
|
|
im:format("png")
|
2015-10-26 13:23:52 +13:00
|
|
|
return im:toBlob(9)
|
2015-06-16 23:41:48 +12:00
|
|
|
end
|
|
|
|
end
|
|
|
|
function image_loader.save_png(filename, rgb, alpha)
|
|
|
|
local blob, len = image_loader.encode_png(rgb, alpha)
|
|
|
|
local fp = io.open(filename, "wb")
|
2015-11-05 18:59:51 +13:00
|
|
|
if not fp then
|
|
|
|
error("IO error: " .. filename)
|
|
|
|
end
|
2015-06-16 23:41:48 +12:00
|
|
|
fp:write(ffi.string(blob, len))
|
|
|
|
fp:close()
|
|
|
|
return true
|
2015-05-16 17:48:05 +12:00
|
|
|
end
|
|
|
|
function image_loader.decode_byte(blob)
|
|
|
|
local load_image = function()
|
|
|
|
local im = gm.Image()
|
2015-06-16 23:41:48 +12:00
|
|
|
local alpha = nil
|
2015-11-05 18:59:51 +13:00
|
|
|
local gamma_lcd = 0.454545
|
2015-06-16 23:41:48 +12:00
|
|
|
|
2015-05-16 17:48:05 +12:00
|
|
|
im:fromBlob(blob, #blob)
|
2015-11-05 18:59:51 +13:00
|
|
|
|
2015-11-04 03:20:21 +13:00
|
|
|
if im:colorspace() == "CMYK" then
|
|
|
|
im:colorspace("RGB")
|
|
|
|
end
|
2015-11-05 18:59:51 +13:00
|
|
|
local gamma = math.floor(im:gamma() * 1000000) / 1000000
|
|
|
|
if gamma ~= 0 and gamma ~= gamma_lcd then
|
|
|
|
im:gammaCorrection(gamma / gamma_lcd)
|
|
|
|
end
|
2015-05-16 17:48:05 +12:00
|
|
|
-- FIXME: How to detect that a image has an alpha channel?
|
|
|
|
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
2015-06-16 23:41:48 +12:00
|
|
|
-- split alpha channel
|
2015-05-16 17:48:05 +12:00
|
|
|
im = im:toTensor('float', 'RGBA', 'DHW')
|
2015-10-28 19:30:47 +13:00
|
|
|
local sum_alpha = (im[4] - 1.0):sum()
|
|
|
|
if sum_alpha < 0 then
|
2015-06-16 23:41:48 +12:00
|
|
|
alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
|
|
|
end
|
2015-05-16 17:48:05 +12:00
|
|
|
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
|
2015-06-16 23:41:48 +12:00
|
|
|
new_im[1]:copy(im[1])
|
|
|
|
new_im[2]:copy(im[2])
|
|
|
|
new_im[3]:copy(im[3])
|
2015-05-16 17:48:05 +12:00
|
|
|
im = new_im:mul(255):byte()
|
|
|
|
else
|
|
|
|
im = im:toTensor('byte', 'RGB', 'DHW')
|
|
|
|
end
|
2015-06-16 23:41:48 +12:00
|
|
|
return {im, alpha}
|
2015-05-16 17:48:05 +12:00
|
|
|
end
|
2015-10-26 13:23:52 +13:00
|
|
|
load_image()
|
2015-05-16 17:48:05 +12:00
|
|
|
local state, ret = pcall(load_image)
|
|
|
|
if state then
|
2015-06-16 23:41:48 +12:00
|
|
|
return ret[1], ret[2]
|
2015-05-16 17:48:05 +12:00
|
|
|
else
|
|
|
|
return nil
|
|
|
|
end
|
|
|
|
end
|
|
|
|
function image_loader.load_float(file)
|
|
|
|
local fp = io.open(file, "rb")
|
2015-06-23 02:01:56 +12:00
|
|
|
if not fp then
|
|
|
|
error(file .. ": failed to load image")
|
|
|
|
end
|
2015-05-16 17:48:05 +12:00
|
|
|
local buff = fp:read("*a")
|
|
|
|
fp:close()
|
|
|
|
return image_loader.decode_float(buff)
|
|
|
|
end
|
|
|
|
function image_loader.load_byte(file)
|
|
|
|
local fp = io.open(file, "rb")
|
2015-06-23 02:01:56 +12:00
|
|
|
if not fp then
|
|
|
|
error(file .. ": failed to load image")
|
|
|
|
end
|
2015-05-16 17:48:05 +12:00
|
|
|
local buff = fp:read("*a")
|
|
|
|
fp:close()
|
|
|
|
return image_loader.decode_byte(buff)
|
|
|
|
end
|
|
|
|
local function test()
|
|
|
|
require 'image'
|
|
|
|
local img
|
|
|
|
img = image_loader.load_float("./a.jpg")
|
|
|
|
if img then
|
|
|
|
print(img:min())
|
|
|
|
print(img:max())
|
|
|
|
image.display(img)
|
|
|
|
end
|
|
|
|
img = image_loader.load_float("./b.png")
|
|
|
|
if img then
|
|
|
|
image.display(img)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
--test()
|
|
|
|
return image_loader
|