Add support for tta_level=1; Add support for TTA to web.lua
This commit is contained in:
parent
af74a67bd1
commit
25e293202a
|
@ -269,7 +269,10 @@ local augmented_patterns = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
local function get_augmented_patterns(n)
|
local function get_augmented_patterns(n)
|
||||||
if n == 2 then
|
if n == 1 then
|
||||||
|
-- no tta
|
||||||
|
return {augmented_patterns[1]}
|
||||||
|
elseif n == 2 then
|
||||||
return {augmented_patterns[1], augmented_patterns[5]}
|
return {augmented_patterns[1], augmented_patterns[5]}
|
||||||
elseif n == 4 then
|
elseif n == 4 then
|
||||||
return {augmented_patterns[1], augmented_patterns[5],
|
return {augmented_patterns[1], augmented_patterns[5],
|
||||||
|
|
57
web.lua
57
web.lua
|
@ -63,6 +63,7 @@ local CURL_OPTIONS = {
|
||||||
max_redirects = 2
|
max_redirects = 2
|
||||||
}
|
}
|
||||||
local CURL_MAX_SIZE = 3 * 1024 * 1024
|
local CURL_MAX_SIZE = 3 * 1024 * 1024
|
||||||
|
local TTA_SUPPORT = false
|
||||||
|
|
||||||
local function valid_size(x, scale)
|
local function valid_size(x, scale)
|
||||||
if scale == 0 then
|
if scale == 0 then
|
||||||
|
@ -151,8 +152,8 @@ local function convert(x, meta, options)
|
||||||
x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
|
x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
|
||||||
end
|
end
|
||||||
if options.method == "scale" then
|
if options.method == "scale" then
|
||||||
x = reconstruct.scale(art_scale2_model, 2.0, x,
|
x = reconstruct.scale_tta(art_scale2_model, options.tta_level, 2.0, x,
|
||||||
opt.crop_size, opt.batch_size)
|
opt.crop_size, opt.batch_size)
|
||||||
if alpha then
|
if alpha then
|
||||||
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
||||||
alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
|
alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
|
||||||
|
@ -162,13 +163,16 @@ local function convert(x, meta, options)
|
||||||
end
|
end
|
||||||
cleanup_model(art_scale2_model)
|
cleanup_model(art_scale2_model)
|
||||||
elseif options.method == "noise1" then
|
elseif options.method == "noise1" then
|
||||||
x = reconstruct.image(art_noise1_model, x, opt.crop_size, opt.batch_size)
|
x = reconstruct.image_tta(art_noise1_model, options.tta_level,
|
||||||
|
x, opt.crop_size, opt.batch_size)
|
||||||
cleanup_model(art_noise1_model)
|
cleanup_model(art_noise1_model)
|
||||||
elseif options.method == "noise2" then
|
elseif options.method == "noise2" then
|
||||||
x = reconstruct.image(art_noise2_model, x, opt.crop_size, opt.batch_size)
|
x = reconstruct.image_tta(art_noise2_model, options.tta_level,
|
||||||
|
x, opt.crop_size, opt.batch_size)
|
||||||
cleanup_model(art_noise2_model)
|
cleanup_model(art_noise2_model)
|
||||||
elseif options.method == "noise3" then
|
elseif options.method == "noise3" then
|
||||||
x = reconstruct.image(art_noise3_model, x, opt.crop_size, opt.batch_size)
|
x = reconstruct.image_tta(art_noise3_model, options.tta_level,
|
||||||
|
x, opt.crop_size, opt.batch_size)
|
||||||
cleanup_model(art_noise3_model)
|
cleanup_model(art_noise3_model)
|
||||||
end
|
end
|
||||||
else -- photo
|
else -- photo
|
||||||
|
@ -176,7 +180,7 @@ local function convert(x, meta, options)
|
||||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
|
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
|
||||||
end
|
end
|
||||||
if options.method == "scale" then
|
if options.method == "scale" then
|
||||||
x = reconstruct.scale(photo_scale2_model, 2.0, x,
|
x = reconstruct.scale_tta(photo_scale2_model, options.tta_level, 2.0, x,
|
||||||
opt.crop_size, opt.batch_size)
|
opt.crop_size, opt.batch_size)
|
||||||
if alpha then
|
if alpha then
|
||||||
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
||||||
|
@ -187,13 +191,16 @@ local function convert(x, meta, options)
|
||||||
end
|
end
|
||||||
cleanup_model(photo_scale2_model)
|
cleanup_model(photo_scale2_model)
|
||||||
elseif options.method == "noise1" then
|
elseif options.method == "noise1" then
|
||||||
x = reconstruct.image(photo_noise1_model, x, opt.crop_size, opt.batch_size)
|
x = reconstruct.image_tta(photo_noise1_model, options.tta_level,
|
||||||
|
x, opt.crop_size, opt.batch_size)
|
||||||
cleanup_model(photo_noise1_model)
|
cleanup_model(photo_noise1_model)
|
||||||
elseif options.method == "noise2" then
|
elseif options.method == "noise2" then
|
||||||
x = reconstruct.image(photo_noise2_model, x, opt.crop_size, opt.batch_size)
|
x = reconstruct.image_tta(photo_noise2_model, options.tta_level,
|
||||||
|
x, opt.crop_size, opt.batch_size)
|
||||||
cleanup_model(photo_noise2_model)
|
cleanup_model(photo_noise2_model)
|
||||||
elseif options.method == "noise3" then
|
elseif options.method == "noise3" then
|
||||||
x = reconstruct.image(photo_noise3_model, x, opt.crop_size, opt.batch_size)
|
x = reconstruct.image_tta(photo_noise3_model, options.tta_level,
|
||||||
|
x, opt.crop_size, opt.batch_size)
|
||||||
cleanup_model(photo_noise3_model)
|
cleanup_model(photo_noise3_model)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -230,9 +237,18 @@ function APIHandler:post()
|
||||||
local x, meta, filename = get_image(self)
|
local x, meta, filename = get_image(self)
|
||||||
local scale = tonumber(self:get_argument("scale", "0"))
|
local scale = tonumber(self:get_argument("scale", "0"))
|
||||||
local noise = tonumber(self:get_argument("noise", "0"))
|
local noise = tonumber(self:get_argument("noise", "0"))
|
||||||
|
local tta_level = tonumber(self:get_argument("noise", "1"))
|
||||||
local style = self:get_argument("style", "art")
|
local style = self:get_argument("style", "art")
|
||||||
local download = (self:get_argument("download", "")):len()
|
local download = (self:get_argument("download", "")):len()
|
||||||
|
|
||||||
|
if not TTA_SUPPORT then
|
||||||
|
tta_level = 1 -- disable TTA mode
|
||||||
|
else
|
||||||
|
if not (tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
|
||||||
|
tta_level = 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
if style ~= "art" then
|
if style ~= "art" then
|
||||||
style = "photo" -- style must be art or photo
|
style = "photo" -- style must be art or photo
|
||||||
end
|
end
|
||||||
|
@ -246,35 +262,36 @@ function APIHandler:post()
|
||||||
border = true
|
border = true
|
||||||
end
|
end
|
||||||
if noise == 1 then
|
if noise == 1 then
|
||||||
prefix = style .. "_noise1_"
|
prefix = style .. "_noise1_tta_" .. tta_level .. "_"
|
||||||
x = convert(x, meta, {method = "noise1", style = style,
|
x = convert(x, meta, {method = "noise1", style = style, tta_level = tta_level,
|
||||||
prefix = prefix .. hash,
|
prefix = prefix .. hash,
|
||||||
alpha_prefix = alpha_prefix, border = border})
|
alpha_prefix = alpha_prefix, border = border})
|
||||||
border = false
|
border = false
|
||||||
elseif noise == 2 then
|
elseif noise == 2 then
|
||||||
prefix = style .. "_noise2_"
|
prefix = style .. "_noise2_tta_" .. tta_level .. "_"
|
||||||
x = convert(x, meta, {method = "noise2", style = style,
|
x = convert(x, meta, {method = "noise2", style = style, tta_level = tta_level,
|
||||||
prefix = prefix .. hash,
|
prefix = prefix .. hash,
|
||||||
alpha_prefix = alpha_prefix, border = border})
|
alpha_prefix = alpha_prefix, border = border})
|
||||||
border = false
|
border = false
|
||||||
elseif noise == 3 then
|
elseif noise == 3 then
|
||||||
prefix = style .. "_noise3_"
|
prefix = style .. "_noise3_tta_" .. tta_level .. "_"
|
||||||
x = convert(x, meta, {method = "noise3", style = style,
|
x = convert(x, meta, {method = "noise3", style = style, tta_level = tta_level,
|
||||||
prefix = prefix .. hash,
|
prefix = prefix .. hash,
|
||||||
alpha_prefix = alpha_prefix, border = border})
|
alpha_prefix = alpha_prefix, border = border})
|
||||||
border = false
|
border = false
|
||||||
end
|
end
|
||||||
if scale == 1 or scale == 2 then
|
if scale == 1 or scale == 2 then
|
||||||
if noise == 1 then
|
if noise == 1 then
|
||||||
prefix = style .. "_noise1_scale_"
|
prefix = style .. "_noise1_scale_tta_" .. tta_level .. "_"
|
||||||
elseif noise == 2 then
|
elseif noise == 2 then
|
||||||
prefix = style .. "_noise2_scale_"
|
prefix = style .. "_noise2_scale_tta_" .. tta_level .. "_"
|
||||||
elseif noise == 3 then
|
elseif noise == 3 then
|
||||||
prefix = style .. "_noise3_scale_"
|
prefix = style .. "_noise3_scale_tta_" .. tta_level .. "_"
|
||||||
else
|
else
|
||||||
prefix = style .. "_scale_"
|
prefix = style .. "_scale_tta_" .. tta_level .. "_"
|
||||||
end
|
end
|
||||||
x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
|
x, meta = convert(x, meta, {method = "scale", style = style, tta_level = tta_level,
|
||||||
|
prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
|
||||||
if scale == 1 then
|
if scale == 1 then
|
||||||
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
|
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue