1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Add enable_tta option in web.lua

This commit is contained in:
nagadomi 2018-11-09 17:44:51 +09:00
parent 9589259812
commit 208043fd89

15
web.lua
View file

@ -26,6 +26,7 @@ cmd:text("waifu2x-api")
cmd:text("Options:")
cmd:option("-port", 8812, 'listen port')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-enable_tta", 0, 'enable TTA query(0|1)')
cmd:option("-crop_size", 128, 'patch size per process')
cmd:option("-batch_size", 1, 'batch size')
cmd:option("-thread", -1, 'number of CPU threads')
@ -48,6 +49,7 @@ if cudnn then
cudnn.benchmark = true
end
opt.force_cudnn = opt.force_cudnn == 1
opt.enable_tta = opt.enable_tta == 1
local ART_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "art")
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "photo")
local art_model = {
@ -313,11 +315,14 @@ function APIHandler:post()
self:write("client disconnected")
return
end
if tta_level == 0 then
tta_level = auto_tta_level(x, scale)
end
if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
if opt.enable_tta then
if tta_level == 0 then
tta_level = auto_tta_level(x, scale)
end
if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
tta_level = 1
end
else
tta_level = 1
end
if style ~= "art" then