diff --git a/web.lua b/web.lua index 1e58433..f894a81 100644 --- a/web.lua +++ b/web.lua @@ -26,6 +26,7 @@ cmd:text("waifu2x-api") cmd:text("Options:") cmd:option("-port", 8812, 'listen port') cmd:option("-gpu", 1, 'Device ID') +cmd:option("-enable_tta", 0, 'enable TTA query(0|1)') cmd:option("-crop_size", 128, 'patch size per process') cmd:option("-batch_size", 1, 'batch size') cmd:option("-thread", -1, 'number of CPU threads') @@ -48,6 +49,7 @@ if cudnn then cudnn.benchmark = true end opt.force_cudnn = opt.force_cudnn == 1 +opt.enable_tta = opt.enable_tta == 1 local ART_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "art") local PHOTO_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "photo") local art_model = { @@ -313,11 +315,14 @@ function APIHandler:post() self:write("client disconnected") return end - - if tta_level == 0 then - tta_level = auto_tta_level(x, scale) - end - if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then + if opt.enable_tta then + if tta_level == 0 then + tta_level = auto_tta_level(x, scale) + end + if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then + tta_level = 1 + end + else tta_level = 1 end if style ~= "art" then