From 1d463b7e5f69c47f3de9dae26ed8fd1eeb5a2fc3 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Mon, 14 Mar 2016 11:17:38 +0900 Subject: [PATCH] Change output filename in WebUI --- web.lua | 61 +++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/web.lua b/web.lua index 86e985c..8f07052 100644 --- a/web.lua +++ b/web.lua @@ -92,14 +92,25 @@ local function cache_url(url) return nil, nil, nil end local function get_image(req) - local file = req:get_argument("file", "") + local file_info = req:get_arguments("file") local url = req:get_argument("url", "") - if file and file:len() > 0 then - return image_loader.decode_float(file) - elseif url and url:len() > 0 then - return cache_url(url) + local file = nil + local filename = nil + if file_info and #file_info == 1 then + file = file_info[1][1] + local disp = file_info[1]["content-disposition"] + if disp and disp["filename"] then + filename = path.basename(disp["filename"]) + end end - return nil, nil, nil + if file and file:len() > 0 then + local x, alpha, blob = image_loader.decode_float(file) + return x, alpha, blob, filename + elseif url and url:len() > 0 then + local x, alpha, blob = cache_url(url) + return x, alpha, blob, filename + end + return nil, nil, nil, nil end local function cleanup_model(model) if CLEANUP_MODEL then @@ -176,6 +187,15 @@ local function client_disconnected(handler) handler.request.connection.stream and (not handler.request.connection.stream:closed())) end +local function make_output_filename(filename, mode) + local e = path.extension(filename) + local base = filename:sub(0, filename:len() - e:len()) + if mode then + return base .. "_waifu2x_" .. mode .. ".png" + else + return base .. ".png" + end +end local APIHandler = class("APIHandler", turbo.web.RequestHandler) function APIHandler:post() @@ -184,7 +204,7 @@ function APIHandler:post() self:write("client disconnected") return end - local x, alpha, blob = get_image(self) + local x, alpha, blob, filename = get_image(self) local scale = tonumber(self:get_argument("scale", "0")) local noise = tonumber(self:get_argument("noise", "0")) local style = self:get_argument("style", "art") @@ -194,6 +214,7 @@ function APIHandler:post() style = "photo" -- style must be art or photo end if x and valid_size(x, scale) then + local prefix = nil if (noise ~= 0 or scale ~= 0) then local hash = md5.sumhexa(blob) local alpha_prefix = style .. "_" .. hash .. "_alpha" @@ -202,32 +223,42 @@ function APIHandler:post() border = true end if noise == 1 then + prefix = style .. "_noise1_" x = convert(x, alpha, {method = "noise1", style = style, - prefix = style .. "_noise1_" .. hash, + prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) border = false elseif noise == 2 then + prefix = style .. "_noise1_" x = convert(x, alpha, {method = "noise2", style = style, - prefix = style .. "_noise2_" .. hash, + prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) border = false end if scale == 1 or scale == 2 then - local prefix if noise == 1 then - prefix = style .. "_noise1_scale_" .. hash + prefix = style .. "_noise1_scale_" elseif noise == 2 then - prefix = style .. "_noise2_scale_" .. hash + prefix = style .. "_noise2_scale_" else - prefix = style .. "_scale_" .. hash + prefix = style .. "_scale_" end - x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix, alpha_prefix = alpha_prefix, border = border}) + x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) if scale == 1 then x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc") end end end - local name = uuid() .. ".png" + local name = nil + if filename then + if prefix then + name = make_output_filename(filename, prefix:sub(0, prefix:len()-1)) + else + name = make_output_filename(filename, nil) + end + else + name = uuid() .. ".png" + end local blob = image_loader.encode_png(alpha_util.composite(x, alpha)) self:set_header("Content-Length", string.format("%d", #blob))