Add support for tta_level=1; Add support for TTA to web.lua
This commit is contained in:
parent
af74a67bd1
commit
25e293202a
2 changed files with 41 additions and 21 deletions
|
@ -269,7 +269,10 @@ local augmented_patterns = {
|
|||
}
|
||||
}
|
||||
local function get_augmented_patterns(n)
|
||||
if n == 2 then
|
||||
if n == 1 then
|
||||
-- no tta
|
||||
return {augmented_patterns[1]}
|
||||
elseif n == 2 then
|
||||
return {augmented_patterns[1], augmented_patterns[5]}
|
||||
elseif n == 4 then
|
||||
return {augmented_patterns[1], augmented_patterns[5],
|
||||
|
|
55
web.lua
55
web.lua
|
@ -63,6 +63,7 @@ local CURL_OPTIONS = {
|
|||
max_redirects = 2
|
||||
}
|
||||
local CURL_MAX_SIZE = 3 * 1024 * 1024
|
||||
local TTA_SUPPORT = false
|
||||
|
||||
local function valid_size(x, scale)
|
||||
if scale == 0 then
|
||||
|
@ -151,7 +152,7 @@ local function convert(x, meta, options)
|
|||
x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
|
||||
end
|
||||
if options.method == "scale" then
|
||||
x = reconstruct.scale(art_scale2_model, 2.0, x,
|
||||
x = reconstruct.scale_tta(art_scale2_model, options.tta_level, 2.0, x,
|
||||
opt.crop_size, opt.batch_size)
|
||||
if alpha then
|
||||
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
||||
|
@ -162,13 +163,16 @@ local function convert(x, meta, options)
|
|||
end
|
||||
cleanup_model(art_scale2_model)
|
||||
elseif options.method == "noise1" then
|
||||
x = reconstruct.image(art_noise1_model, x, opt.crop_size, opt.batch_size)
|
||||
x = reconstruct.image_tta(art_noise1_model, options.tta_level,
|
||||
x, opt.crop_size, opt.batch_size)
|
||||
cleanup_model(art_noise1_model)
|
||||
elseif options.method == "noise2" then
|
||||
x = reconstruct.image(art_noise2_model, x, opt.crop_size, opt.batch_size)
|
||||
x = reconstruct.image_tta(art_noise2_model, options.tta_level,
|
||||
x, opt.crop_size, opt.batch_size)
|
||||
cleanup_model(art_noise2_model)
|
||||
elseif options.method == "noise3" then
|
||||
x = reconstruct.image(art_noise3_model, x, opt.crop_size, opt.batch_size)
|
||||
x = reconstruct.image_tta(art_noise3_model, options.tta_level,
|
||||
x, opt.crop_size, opt.batch_size)
|
||||
cleanup_model(art_noise3_model)
|
||||
end
|
||||
else -- photo
|
||||
|
@ -176,7 +180,7 @@ local function convert(x, meta, options)
|
|||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
|
||||
end
|
||||
if options.method == "scale" then
|
||||
x = reconstruct.scale(photo_scale2_model, 2.0, x,
|
||||
x = reconstruct.scale_tta(photo_scale2_model, options.tta_level, 2.0, x,
|
||||
opt.crop_size, opt.batch_size)
|
||||
if alpha then
|
||||
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
||||
|
@ -187,13 +191,16 @@ local function convert(x, meta, options)
|
|||
end
|
||||
cleanup_model(photo_scale2_model)
|
||||
elseif options.method == "noise1" then
|
||||
x = reconstruct.image(photo_noise1_model, x, opt.crop_size, opt.batch_size)
|
||||
x = reconstruct.image_tta(photo_noise1_model, options.tta_level,
|
||||
x, opt.crop_size, opt.batch_size)
|
||||
cleanup_model(photo_noise1_model)
|
||||
elseif options.method == "noise2" then
|
||||
x = reconstruct.image(photo_noise2_model, x, opt.crop_size, opt.batch_size)
|
||||
x = reconstruct.image_tta(photo_noise2_model, options.tta_level,
|
||||
x, opt.crop_size, opt.batch_size)
|
||||
cleanup_model(photo_noise2_model)
|
||||
elseif options.method == "noise3" then
|
||||
x = reconstruct.image(photo_noise3_model, x, opt.crop_size, opt.batch_size)
|
||||
x = reconstruct.image_tta(photo_noise3_model, options.tta_level,
|
||||
x, opt.crop_size, opt.batch_size)
|
||||
cleanup_model(photo_noise3_model)
|
||||
end
|
||||
end
|
||||
|
@ -230,9 +237,18 @@ function APIHandler:post()
|
|||
local x, meta, filename = get_image(self)
|
||||
local scale = tonumber(self:get_argument("scale", "0"))
|
||||
local noise = tonumber(self:get_argument("noise", "0"))
|
||||
local tta_level = tonumber(self:get_argument("noise", "1"))
|
||||
local style = self:get_argument("style", "art")
|
||||
local download = (self:get_argument("download", "")):len()
|
||||
|
||||
if not TTA_SUPPORT then
|
||||
tta_level = 1 -- disable TTA mode
|
||||
else
|
||||
if not (tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
|
||||
tta_level = 1
|
||||
end
|
||||
end
|
||||
|
||||
if style ~= "art" then
|
||||
style = "photo" -- style must be art or photo
|
||||
end
|
||||
|
@ -246,35 +262,36 @@ function APIHandler:post()
|
|||
border = true
|
||||
end
|
||||
if noise == 1 then
|
||||
prefix = style .. "_noise1_"
|
||||
x = convert(x, meta, {method = "noise1", style = style,
|
||||
prefix = style .. "_noise1_tta_" .. tta_level .. "_"
|
||||
x = convert(x, meta, {method = "noise1", style = style, tta_level = tta_level,
|
||||
prefix = prefix .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
elseif noise == 2 then
|
||||
prefix = style .. "_noise2_"
|
||||
x = convert(x, meta, {method = "noise2", style = style,
|
||||
prefix = style .. "_noise2_tta_" .. tta_level .. "_"
|
||||
x = convert(x, meta, {method = "noise2", style = style, tta_level = tta_level,
|
||||
prefix = prefix .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
elseif noise == 3 then
|
||||
prefix = style .. "_noise3_"
|
||||
x = convert(x, meta, {method = "noise3", style = style,
|
||||
prefix = style .. "_noise3_tta_" .. tta_level .. "_"
|
||||
x = convert(x, meta, {method = "noise3", style = style, tta_level = tta_level,
|
||||
prefix = prefix .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
end
|
||||
if scale == 1 or scale == 2 then
|
||||
if noise == 1 then
|
||||
prefix = style .. "_noise1_scale_"
|
||||
prefix = style .. "_noise1_scale_tta_" .. tta_level .. "_"
|
||||
elseif noise == 2 then
|
||||
prefix = style .. "_noise2_scale_"
|
||||
prefix = style .. "_noise2_scale_tta_" .. tta_level .. "_"
|
||||
elseif noise == 3 then
|
||||
prefix = style .. "_noise3_scale_"
|
||||
prefix = style .. "_noise3_scale_tta_" .. tta_level .. "_"
|
||||
else
|
||||
prefix = style .. "_scale_"
|
||||
prefix = style .. "_scale_tta_" .. tta_level .. "_"
|
||||
end
|
||||
x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
|
||||
x, meta = convert(x, meta, {method = "scale", style = style, tta_level = tta_level,
|
||||
prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
|
||||
if scale == 1 then
|
||||
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue