1
0
Fork 0
mirror of synced 2024-06-18 19:04:30 +12:00

Add support for tta_level=1; Add support for TTA to web.lua

This commit is contained in:
nagadomi 2016-06-12 16:55:05 +09:00
parent af74a67bd1
commit 25e293202a
2 changed files with 41 additions and 21 deletions

View file

@ -269,7 +269,10 @@ local augmented_patterns = {
}
}
local function get_augmented_patterns(n)
if n == 2 then
if n == 1 then
-- no tta
return {augmented_patterns[1]}
elseif n == 2 then
return {augmented_patterns[1], augmented_patterns[5]}
elseif n == 4 then
return {augmented_patterns[1], augmented_patterns[5],

57
web.lua
View file

@ -63,6 +63,7 @@ local CURL_OPTIONS = {
max_redirects = 2
}
local CURL_MAX_SIZE = 3 * 1024 * 1024
local TTA_SUPPORT = false
local function valid_size(x, scale)
if scale == 0 then
@ -151,8 +152,8 @@ local function convert(x, meta, options)
x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
end
if options.method == "scale" then
x = reconstruct.scale(art_scale2_model, 2.0, x,
opt.crop_size, opt.batch_size)
x = reconstruct.scale_tta(art_scale2_model, options.tta_level, 2.0, x,
opt.crop_size, opt.batch_size)
if alpha then
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
@ -162,13 +163,16 @@ local function convert(x, meta, options)
end
cleanup_model(art_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(art_noise1_model, x, opt.crop_size, opt.batch_size)
x = reconstruct.image_tta(art_noise1_model, options.tta_level,
x, opt.crop_size, opt.batch_size)
cleanup_model(art_noise1_model)
elseif options.method == "noise2" then
x = reconstruct.image(art_noise2_model, x, opt.crop_size, opt.batch_size)
x = reconstruct.image_tta(art_noise2_model, options.tta_level,
x, opt.crop_size, opt.batch_size)
cleanup_model(art_noise2_model)
elseif options.method == "noise3" then
x = reconstruct.image(art_noise3_model, x, opt.crop_size, opt.batch_size)
x = reconstruct.image_tta(art_noise3_model, options.tta_level,
x, opt.crop_size, opt.batch_size)
cleanup_model(art_noise3_model)
end
else -- photo
@ -176,7 +180,7 @@ local function convert(x, meta, options)
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
end
if options.method == "scale" then
x = reconstruct.scale(photo_scale2_model, 2.0, x,
x = reconstruct.scale_tta(photo_scale2_model, options.tta_level, 2.0, x,
opt.crop_size, opt.batch_size)
if alpha then
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
@ -187,13 +191,16 @@ local function convert(x, meta, options)
end
cleanup_model(photo_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(photo_noise1_model, x, opt.crop_size, opt.batch_size)
x = reconstruct.image_tta(photo_noise1_model, options.tta_level,
x, opt.crop_size, opt.batch_size)
cleanup_model(photo_noise1_model)
elseif options.method == "noise2" then
x = reconstruct.image(photo_noise2_model, x, opt.crop_size, opt.batch_size)
x = reconstruct.image_tta(photo_noise2_model, options.tta_level,
x, opt.crop_size, opt.batch_size)
cleanup_model(photo_noise2_model)
elseif options.method == "noise3" then
x = reconstruct.image(photo_noise3_model, x, opt.crop_size, opt.batch_size)
x = reconstruct.image_tta(photo_noise3_model, options.tta_level,
x, opt.crop_size, opt.batch_size)
cleanup_model(photo_noise3_model)
end
end
@ -230,9 +237,18 @@ function APIHandler:post()
local x, meta, filename = get_image(self)
local scale = tonumber(self:get_argument("scale", "0"))
local noise = tonumber(self:get_argument("noise", "0"))
local tta_level = tonumber(self:get_argument("noise", "1"))
local style = self:get_argument("style", "art")
local download = (self:get_argument("download", "")):len()
if not TTA_SUPPORT then
tta_level = 1 -- disable TTA mode
else
if not (tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
tta_level = 1
end
end
if style ~= "art" then
style = "photo" -- style must be art or photo
end
@ -246,35 +262,36 @@ function APIHandler:post()
border = true
end
if noise == 1 then
prefix = style .. "_noise1_"
x = convert(x, meta, {method = "noise1", style = style,
prefix = style .. "_noise1_tta_" .. tta_level .. "_"
x = convert(x, meta, {method = "noise1", style = style, tta_level = tta_level,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
elseif noise == 2 then
prefix = style .. "_noise2_"
x = convert(x, meta, {method = "noise2", style = style,
prefix = style .. "_noise2_tta_" .. tta_level .. "_"
x = convert(x, meta, {method = "noise2", style = style, tta_level = tta_level,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
elseif noise == 3 then
prefix = style .. "_noise3_"
x = convert(x, meta, {method = "noise3", style = style,
prefix = style .. "_noise3_tta_" .. tta_level .. "_"
x = convert(x, meta, {method = "noise3", style = style, tta_level = tta_level,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
end
if scale == 1 or scale == 2 then
if noise == 1 then
prefix = style .. "_noise1_scale_"
prefix = style .. "_noise1_scale_tta_" .. tta_level .. "_"
elseif noise == 2 then
prefix = style .. "_noise2_scale_"
prefix = style .. "_noise2_scale_tta_" .. tta_level .. "_"
elseif noise == 3 then
prefix = style .. "_noise3_scale_"
prefix = style .. "_noise3_scale_tta_" .. tta_level .. "_"
else
prefix = style .. "_scale_"
prefix = style .. "_scale_tta_" .. tta_level .. "_"
end
x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
x, meta = convert(x, meta, {method = "scale", style = style, tta_level = tta_level,
prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
if scale == 1 then
x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
end