1
0
Fork 0
mirror of synced 2024-05-21 05:12:21 +12:00
waifu2x/web.lua

301 lines
10 KiB
Lua
Raw Normal View History

2015-11-08 02:54:29 +13:00
require 'pl'
2015-11-08 22:31:46 +13:00
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
2015-11-08 02:54:29 +13:00
local ROOT = path.dirname(__FILE__)
package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
2015-06-14 03:26:00 +12:00
_G.TURBO_SSL = true
2015-10-28 19:30:47 +13:00
require 'w2nn'
2015-05-16 17:48:05 +12:00
local uuid = require 'uuid'
local ffi = require 'ffi'
local md5 = require 'md5'
2015-10-28 19:30:47 +13:00
local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local image_loader = require 'image_loader'
local alpha_util = require 'alpha_util'
local gm = require 'graphicsmagick'
-- Note: turbo and xlua has different implementation of string:split().
2015-11-08 02:54:29 +13:00
-- Therefore, string:split() has conflict issue.
-- In this script, use turbo's string:split().
local turbo = require 'turbo'
2015-06-23 06:17:41 +12:00
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x-api")
cmd:text("Options:")
cmd:option("-port", 8812, 'listen port')
cmd:option("-gpu", 1, 'Device ID')
2015-10-28 19:30:47 +13:00
cmd:option("-thread", -1, 'number of CPU threads')
2015-06-23 06:17:41 +12:00
local opt = cmd:parse(arg)
cutorch.setDevice(opt.gpu)
torch.setdefaulttensortype('torch.FloatTensor')
2015-10-28 19:30:47 +13:00
if opt.thread > 0 then
torch.setnumthreads(opt.thread)
end
if cudnn then
cudnn.fastest = true
cudnn.benchmark = false
end
2015-11-08 02:54:29 +13:00
local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
2015-11-08 02:54:29 +13:00
local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
2015-11-08 02:54:29 +13:00
local CACHE_DIR = path.join(ROOT, "cache")
2015-05-16 17:48:05 +12:00
local MAX_NOISE_IMAGE = 2560 * 2560
local MAX_SCALE_IMAGE = 1280 * 1280
local CURL_OPTIONS = {
request_timeout = 60,
connect_timeout = 60,
2015-05-16 17:48:05 +12:00
allow_redirects = true,
max_redirects = 2
2015-05-16 17:48:05 +12:00
}
local CURL_MAX_SIZE = 3 * 1024 * 1024
2015-05-16 17:48:05 +12:00
local function valid_size(x, scale)
if scale == 0 then
return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
else
return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
end
end
2015-11-06 16:24:27 +13:00
local function cache_url(url)
local hash = md5.sumhexa(url)
local cache_file = path.join(CACHE_DIR, "url_" .. hash)
if path.exists(cache_file) then
return image_loader.load_float(cache_file)
else
2015-05-16 17:48:05 +12:00
local res = coroutine.yield(
turbo.async.HTTPClient({verify_ca=false},
2015-11-06 16:24:27 +13:00
nil,
CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
2015-05-16 17:48:05 +12:00
)
if res.code == 200 then
local content_type = res.headers:get("Content-Type", true)
if type(content_type) == "table" then
content_type = content_type[1]
end
if content_type and content_type:find("image") then
2015-11-06 16:24:27 +13:00
local fp = io.open(cache_file, "wb")
local blob = res.body
fp:write(blob)
fp:close()
return image_loader.decode_float(blob)
2015-05-16 17:48:05 +12:00
end
end
end
2015-11-06 16:24:27 +13:00
return nil, nil, nil
2015-05-16 17:48:05 +12:00
end
2015-11-06 16:24:27 +13:00
local function get_image(req)
local file = req:get_argument("file", "")
local url = req:get_argument("url", "")
if file and file:len() > 0 then
2015-11-11 14:53:23 +13:00
return image_loader.decode_float(file)
2015-11-06 16:24:27 +13:00
elseif url and url:len() > 0 then
return cache_url(url)
end
return nil, nil, nil
end
local function cleanup_model(model)
if CLEANUP_MODEL then
w2nn.cleanup_model(model) -- release GPU memory
end
end
local function convert(x, alpha, options)
2015-11-08 02:54:29 +13:00
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
local alpha_orig = alpha
if path.exists(alpha_cache_file) then
alpha = image_loader.load_float(alpha_cache_file)
if alpha:dim() == 2 then
alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
end
if alpha:size(1) == 3 then
alpha = image.rgb2y(alpha)
end
end
2015-11-08 02:54:29 +13:00
if path.exists(cache_file) then
x = image_loader.load_float(cache_file)
return x, alpha
2015-11-08 02:54:29 +13:00
else
if options.style == "art" then
if options.border then
x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
end
2015-11-08 02:54:29 +13:00
if options.method == "scale" then
x = reconstruct.scale(art_scale2_model, 2.0, x)
if alpha then
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
alpha = reconstruct.scale(art_scale2_model, 2.0, alpha)
image_loader.save_png(alpha_cache_file, alpha)
end
end
cleanup_model(art_scale2_model)
2015-11-08 02:54:29 +13:00
elseif options.method == "noise1" then
x = reconstruct.image(art_noise1_model, x)
cleanup_model(art_noise1_model)
2015-11-08 02:54:29 +13:00
else -- options.method == "noise2"
x = reconstruct.image(art_noise2_model, x)
cleanup_model(art_noise2_model)
2015-11-08 02:54:29 +13:00
end
else -- photo
if options.border then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
end
if options.method == "scale" then
x = reconstruct.scale(photo_scale2_model, 2.0, x)
if alpha then
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha)
image_loader.save_png(alpha_cache_file, alpha)
end
end
cleanup_model(photo_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(photo_noise1_model, x)
cleanup_model(photo_noise1_model)
elseif options.method == "noise2" then
x = reconstruct.image(photo_noise2_model, x)
cleanup_model(photo_noise2_model)
end
2015-11-08 02:54:29 +13:00
end
image_loader.save_png(cache_file, x)
return x, alpha
2015-11-08 02:54:29 +13:00
end
end
2015-05-16 17:48:05 +12:00
local function client_disconnected(handler)
return not(handler.request and
handler.request.connection and
handler.request.connection.stream and
(not handler.request.connection.stream:closed()))
end
local APIHandler = class("APIHandler", turbo.web.RequestHandler)
function APIHandler:post()
if client_disconnected(self) then
self:set_status(400)
self:write("client disconnected")
return
end
2015-11-08 02:54:29 +13:00
local x, alpha, blob = get_image(self)
2015-05-16 17:48:05 +12:00
local scale = tonumber(self:get_argument("scale", "0"))
local noise = tonumber(self:get_argument("noise", "0"))
2015-11-08 02:54:29 +13:00
local style = self:get_argument("style", "art")
2015-11-21 20:01:49 +13:00
local download = (self:get_argument("download", "")):len()
2015-11-08 02:54:29 +13:00
if style ~= "art" then
style = "photo" -- style must be art or photo
end
2015-05-16 17:48:05 +12:00
if x and valid_size(x, scale) then
2015-11-08 02:54:29 +13:00
if (noise ~= 0 or scale ~= 0) then
local hash = md5.sumhexa(blob)
local alpha_prefix = style .. "_" .. hash .. "_alpha"
local border = false
if scale ~= 0 and alpha then
border = true
end
2015-05-16 17:48:05 +12:00
if noise == 1 then
x = convert(x, alpha, {method = "noise1", style = style,
prefix = style .. "_noise1_" .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
2015-05-16 17:48:05 +12:00
elseif noise == 2 then
x = convert(x, alpha, {method = "noise2", style = style,
prefix = style .. "_noise2_" .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
2015-05-16 17:48:05 +12:00
end
if scale == 1 or scale == 2 then
local prefix
2015-05-16 17:48:05 +12:00
if noise == 1 then
prefix = style .. "_noise1_scale_" .. hash
2015-05-16 17:48:05 +12:00
elseif noise == 2 then
prefix = style .. "_noise2_scale_" .. hash
2015-05-16 17:48:05 +12:00
else
prefix = style .. "_scale_" .. hash
2015-05-16 17:48:05 +12:00
end
x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix, alpha_prefix = alpha_prefix, border = border})
2015-05-16 17:48:05 +12:00
if scale == 1 then
2015-11-23 00:22:46 +13:00
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
2015-05-16 17:48:05 +12:00
end
end
end
2016-02-14 15:51:30 +13:00
local name = uuid() .. ".png"
local blob = image_loader.encode_png(alpha_util.composite(x, alpha))
2016-02-14 15:13:53 +13:00
self:set_header("Content-Length", string.format("%d", #blob))
2015-11-21 20:01:49 +13:00
if download > 0 then
self:set_header("Content-Type", "application/octet-stream")
2016-02-14 15:13:53 +13:00
self:set_header("Content-Disposition", string.format('attachment; filename="%s"', name))
2015-11-21 20:01:49 +13:00
else
self:set_header("Content-Type", "image/png")
2016-02-14 15:13:53 +13:00
self:set_header("Content-Disposition", string.format('inline; filename="%s"', name))
2015-11-21 20:01:49 +13:00
end
self:write(blob)
2015-05-16 17:48:05 +12:00
else
if not x then
self:set_status(400)
self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
2015-05-16 17:48:05 +12:00
else
self:set_status(400)
2015-05-17 15:21:02 +12:00
self:write("ERROR: image size exceeds maximum allowable size.")
2015-05-16 17:48:05 +12:00
end
end
2015-05-17 07:00:40 +12:00
collectgarbage()
2015-05-16 17:48:05 +12:00
end
local FormHandler = class("FormHandler", turbo.web.RequestHandler)
2015-11-08 02:54:29 +13:00
local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
2016-01-23 15:12:58 +13:00
local index_pt = file.read(path.join(ROOT, "assets", "index.pt.html"))
2016-02-12 13:51:29 +13:00
local index_es = file.read(path.join(ROOT, "assets", "index.es.html"))
2016-02-23 04:00:01 +13:00
local index_fr = file.read(path.join(ROOT, "assets", "index.fr.html"))
2015-11-08 02:54:29 +13:00
local index_en = file.read(path.join(ROOT, "assets", "index.html"))
2015-05-16 17:48:05 +12:00
function FormHandler:get()
local lang = self.request.headers:get("Accept-Language")
if lang then
local langs = utils.split(lang, ",")
for i = 1, #langs do
langs[i] = utils.split(langs[i], ";")[1]
end
if langs[1] == "ja" then
2015-05-18 16:14:23 +12:00
self:write(index_ja)
elseif langs[1] == "ru" then
self:write(index_ru)
2016-01-23 15:12:58 +13:00
elseif langs[1] == "pt" or langs[1] == "pt-BR" then
self:write(index_pt)
2016-02-12 13:51:29 +13:00
elseif langs[1] == "es" or langs[1] == "es-ES" then
self:write(index_es)
2016-02-23 04:00:01 +13:00
elseif langs[1] == "fr" then
self:write(index_fr)
2015-05-16 17:48:05 +12:00
else
2015-05-18 16:14:23 +12:00
self:write(index_en)
2015-05-16 17:48:05 +12:00
end
else
2015-05-18 16:14:23 +12:00
self:write(index_en)
2015-05-16 17:48:05 +12:00
end
end
2015-06-29 11:56:12 +12:00
turbo.log.categories = {
["success"] = true,
["notice"] = false,
["warning"] = true,
["error"] = true,
["debug"] = false,
["development"] = false
}
2015-05-16 17:48:05 +12:00
local app = turbo.web.Application:new(
{
{"^/$", FormHandler},
{"^/api$", APIHandler},
2016-02-07 11:38:39 +13:00
{"^/([%a%d%.%-_]+)$", turbo.web.StaticFileHandler, path.join(ROOT, "assets/")},
2015-05-16 17:48:05 +12:00
}
)
2015-06-23 06:17:41 +12:00
app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
2015-05-16 17:48:05 +12:00
turbo.ioloop.instance():start()