1
0
Fork 0
mirror of synced 2024-06-14 17:04:31 +12:00

Merge branch 'master' of github.com:nagadomi/waifu2x into rgb

This commit is contained in:
nagadomi 2015-06-22 23:04:17 +09:00
commit 628bd971c9
4 changed files with 91 additions and 60 deletions

View file

@ -1,6 +1,6 @@
# waifu2x
Image Super-Resolution for anime/fan-art using Deep Convolutional Neural Networks.
Image Super-Resolution for anime-style-art using Deep Convolutional Neural Networks.
Demo-Application can be found at http://waifu2x.udp.jp/ .
@ -20,6 +20,9 @@ waifu2x is inspired by SRCNN [1]. 2D character picture (HatsuneMiku) is licensed
## Public AMI
(maintenance)
## Third Party Software
[Third-Party](https://github.com/nagadomi/waifu2x/wiki/Third-Party)
## Dependencies
### Hardware

View file

@ -1,44 +1,72 @@
local gm = require 'graphicsmagick'
local ffi = require 'ffi'
require 'pl'
local image_loader = {}
function image_loader.decode_float(blob)
local im = image_loader.decode_byte(blob)
local im, alpha = image_loader.decode_byte(blob)
if im then
im = im:float():div(255)
end
return im
return im, alpha
end
function image_loader.encode_png(tensor)
local im = gm.Image(tensor, "RGB", "DHW")
im:format("png")
return im:toBlob()
function image_loader.encode_png(rgb, alpha)
if rgb:type() == "torch.ByteTensor" then
error("expect FloatTensor")
end
if alpha then
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
end
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
rgba[1]:copy(rgb[1])
rgba[2]:copy(rgb[2])
rgba[3]:copy(rgb[3])
rgba[4]:copy(alpha)
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
im:format("png")
return im:toBlob()
else
local im = gm.Image(rgb, "RGB", "DHW")
im:format("png")
return im:toBlob()
end
end
function image_loader.save_png(filename, rgb, alpha)
local blob, len = image_loader.encode_png(rgb, alpha)
local fp = io.open(filename, "wb")
fp:write(ffi.string(blob, len))
fp:close()
return true
end
function image_loader.decode_byte(blob)
local load_image = function()
local im = gm.Image()
local alpha = nil
im:fromBlob(blob, #blob)
-- FIXME: How to detect that a image has an alpha channel?
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
-- merge alpha channel
-- split alpha channel
im = im:toTensor('float', 'RGBA', 'DHW')
local w2 = im[4]
local w1 = im[4] * -1 + 1
local sum_alpha = (im[4] - 1):sum()
if sum_alpha > 0 or sum_alpha < 0 then
alpha = im[4]:reshape(1, im:size(2), im:size(3))
end
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
-- apply the white background
new_im[1]:copy(im[1]):cmul(w2):add(w1)
new_im[2]:copy(im[2]):cmul(w2):add(w1)
new_im[3]:copy(im[3]):cmul(w2):add(w1)
new_im[1]:copy(im[1])
new_im[2]:copy(im[2])
new_im[3]:copy(im[3])
im = new_im:mul(255):byte()
else
im = im:toTensor('byte', 'RGB', 'DHW')
end
return im
return {im, alpha}
end
local state, ret = pcall(load_image)
if state then
return ret
return ret[1], ret[2]
else
return nil
end

View file

@ -6,13 +6,12 @@ require './lib/LeakyReLU'
local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader'
local BLOCK_OFFSET = 7
torch.setdefaulttensortype('torch.FloatTensor')
local function convert_image(opt)
local x = image_loader.load_float(opt.i)
local x, alpha = image_loader.load_float(opt.i)
local new_x = nil
local t = sys.clock()
if opt.o == "(auto)" then
@ -39,7 +38,7 @@ local function convert_image(opt)
else
error("undefined method:" .. opt.method)
end
image.save(opt.o, new_x)
image_loader.save_png(opt.o, new_x, alpha)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end
local function convert_frames(opt)
@ -59,41 +58,41 @@ local function convert_frames(opt)
end
fp:close()
for i = 1, #lines do
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
local x = image_loader.load_float(lines[i])
local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
elseif opt.m == "scale" then
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
local output = nil
if opt.o == "(auto)" then
local name = path.basename(lines[i])
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
else
output = string.format(opt.o, i)
end
image.save(output, new_x)
xlua.progress(i, #lines)
if i % 10 == 0 then
collectgarbage()
end
else
xlua.progress(i, #lines)
end
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
local x, alpha = image_loader.load_float(lines[i])
local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
elseif opt.m == "scale" then
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
local output = nil
if opt.o == "(auto)" then
local name = path.basename(lines[i])
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
else
output = string.format(opt.o, i)
end
image_loader.save_png(output, new_x, alpha)
xlua.progress(i, #lines)
if i % 10 == 0 then
collectgarbage()
end
else
xlua.progress(i, #lines)
end
end
end

13
web.lua
View file

@ -1,3 +1,4 @@
_G.TURBO_SSL = true
local turbo = require 'turbo'
local uuid = require 'uuid'
local ffi = require 'ffi'
@ -46,10 +47,10 @@ local function get_image(req)
local url = req:get_argument("url", "")
local blob = nil
local img = nil
local alpha = nil
if file and file:len() > 0 then
blob = file
img = image_loader.decode_float(blob)
img, alpha = image_loader.decode_float(blob)
elseif url and url:len() > 0 then
local res = coroutine.yield(
turbo.async.HTTPClient({verify_ca=false},
@ -63,11 +64,11 @@ local function get_image(req)
end
if content_type and content_type:find("image") then
blob = res.body
img = image_loader.decode_float(blob)
img, alpha = image_loader.decode_float(blob)
end
end
end
return img, blob
return img, blob, alpha
end
local function apply_denoise1(x)
@ -103,7 +104,7 @@ function APIHandler:post()
self:write("client disconnected")
return
end
local x, src = get_image(self)
local x, src, alpha = get_image(self)
local scale = tonumber(self:get_argument("scale", "0"))
local noise = tonumber(self:get_argument("noise", "0"))
if x and valid_size(x, scale) then
@ -150,7 +151,7 @@ function APIHandler:post()
end
end
local name = uuid() .. ".png"
local blob, len = image_loader.encode_png(x)
local blob, len = image_loader.encode_png(x, alpha)
self:set_header("Content-Disposition", string.format('filename="%s"', name))
self:set_header("Content-Type", "image/png")