diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index d7119c2..6382108 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -269,7 +269,10 @@ local augmented_patterns = { } } local function get_augmented_patterns(n) - if n == 2 then + if n == 1 then + -- no tta + return {augmented_patterns[1]} + elseif n == 2 then return {augmented_patterns[1], augmented_patterns[5]} elseif n == 4 then return {augmented_patterns[1], augmented_patterns[5], diff --git a/web.lua b/web.lua index a518055..89716e3 100644 --- a/web.lua +++ b/web.lua @@ -63,6 +63,7 @@ local CURL_OPTIONS = { max_redirects = 2 } local CURL_MAX_SIZE = 3 * 1024 * 1024 +local TTA_SUPPORT = false local function valid_size(x, scale) if scale == 0 then @@ -151,8 +152,8 @@ local function convert(x, meta, options) x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model)) end if options.method == "scale" then - x = reconstruct.scale(art_scale2_model, 2.0, x, - opt.crop_size, opt.batch_size) + x = reconstruct.scale_tta(art_scale2_model, options.tta_level, 2.0, x, + opt.crop_size, opt.batch_size) if alpha then if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then alpha = reconstruct.scale(art_scale2_model, 2.0, alpha, @@ -162,13 +163,16 @@ local function convert(x, meta, options) end cleanup_model(art_scale2_model) elseif options.method == "noise1" then - x = reconstruct.image(art_noise1_model, x, opt.crop_size, opt.batch_size) + x = reconstruct.image_tta(art_noise1_model, options.tta_level, + x, opt.crop_size, opt.batch_size) cleanup_model(art_noise1_model) elseif options.method == "noise2" then - x = reconstruct.image(art_noise2_model, x, opt.crop_size, opt.batch_size) + x = reconstruct.image_tta(art_noise2_model, options.tta_level, + x, opt.crop_size, opt.batch_size) cleanup_model(art_noise2_model) elseif options.method == "noise3" then - x = reconstruct.image(art_noise3_model, x, opt.crop_size, opt.batch_size) + x = reconstruct.image_tta(art_noise3_model, options.tta_level, + x, opt.crop_size, opt.batch_size) cleanup_model(art_noise3_model) end else -- photo @@ -176,7 +180,7 @@ local function convert(x, meta, options) x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model)) end if options.method == "scale" then - x = reconstruct.scale(photo_scale2_model, 2.0, x, + x = reconstruct.scale_tta(photo_scale2_model, options.tta_level, 2.0, x, opt.crop_size, opt.batch_size) if alpha then if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then @@ -187,13 +191,16 @@ local function convert(x, meta, options) end cleanup_model(photo_scale2_model) elseif options.method == "noise1" then - x = reconstruct.image(photo_noise1_model, x, opt.crop_size, opt.batch_size) + x = reconstruct.image_tta(photo_noise1_model, options.tta_level, + x, opt.crop_size, opt.batch_size) cleanup_model(photo_noise1_model) elseif options.method == "noise2" then - x = reconstruct.image(photo_noise2_model, x, opt.crop_size, opt.batch_size) + x = reconstruct.image_tta(photo_noise2_model, options.tta_level, + x, opt.crop_size, opt.batch_size) cleanup_model(photo_noise2_model) elseif options.method == "noise3" then - x = reconstruct.image(photo_noise3_model, x, opt.crop_size, opt.batch_size) + x = reconstruct.image_tta(photo_noise3_model, options.tta_level, + x, opt.crop_size, opt.batch_size) cleanup_model(photo_noise3_model) end end @@ -230,9 +237,18 @@ function APIHandler:post() local x, meta, filename = get_image(self) local scale = tonumber(self:get_argument("scale", "0")) local noise = tonumber(self:get_argument("noise", "0")) + local tta_level = tonumber(self:get_argument("noise", "1")) local style = self:get_argument("style", "art") local download = (self:get_argument("download", "")):len() + if not TTA_SUPPORT then + tta_level = 1 -- disable TTA mode + else + if not (tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then + tta_level = 1 + end + end + if style ~= "art" then style = "photo" -- style must be art or photo end @@ -246,35 +262,36 @@ function APIHandler:post() border = true end if noise == 1 then - prefix = style .. "_noise1_" - x = convert(x, meta, {method = "noise1", style = style, + prefix = style .. "_noise1_tta_" .. tta_level .. "_" + x = convert(x, meta, {method = "noise1", style = style, tta_level = tta_level, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) border = false elseif noise == 2 then - prefix = style .. "_noise2_" - x = convert(x, meta, {method = "noise2", style = style, + prefix = style .. "_noise2_tta_" .. tta_level .. "_" + x = convert(x, meta, {method = "noise2", style = style, tta_level = tta_level, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) border = false elseif noise == 3 then - prefix = style .. "_noise3_" - x = convert(x, meta, {method = "noise3", style = style, + prefix = style .. "_noise3_tta_" .. tta_level .. "_" + x = convert(x, meta, {method = "noise3", style = style, tta_level = tta_level, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) border = false end if scale == 1 or scale == 2 then if noise == 1 then - prefix = style .. "_noise1_scale_" + prefix = style .. "_noise1_scale_tta_" .. tta_level .. "_" elseif noise == 2 then - prefix = style .. "_noise2_scale_" + prefix = style .. "_noise2_scale_tta_" .. tta_level .. "_" elseif noise == 3 then - prefix = style .. "_noise3_scale_" + prefix = style .. "_noise3_scale_tta_" .. tta_level .. "_" else - prefix = style .. "_scale_" + prefix = style .. "_scale_tta_" .. tta_level .. "_" end - x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) + x, meta = convert(x, meta, {method = "scale", style = style, tta_level = tta_level, + prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) if scale == 1 then x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc") end