Don't use cudnn.benchmark mode when predicting
This commit is contained in:
parent
490eb33a6b
commit
425898a3aa
|
@ -27,6 +27,11 @@ cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each time
|
||||||
|
|
||||||
local opt = cmd:parse(arg)
|
local opt = cmd:parse(arg)
|
||||||
torch.setdefaulttensortype('torch.FloatTensor')
|
torch.setdefaulttensortype('torch.FloatTensor')
|
||||||
|
if cudnn then
|
||||||
|
cudnn.fastest = true
|
||||||
|
cudnn.benchmark = false
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
local function MSE(x1, x2)
|
local function MSE(x1, x2)
|
||||||
return (x1 - x2):pow(2):mean()
|
return (x1 - x2):pow(2):mean()
|
||||||
|
|
|
@ -112,11 +112,16 @@ local function waifu2x()
|
||||||
cmd:option("-crop_size", 128, 'patch size per process')
|
cmd:option("-crop_size", 128, 'patch size per process')
|
||||||
cmd:option("-resume", 0, "skip existing files (0|1)")
|
cmd:option("-resume", 0, "skip existing files (0|1)")
|
||||||
cmd:option("-thread", -1, "number of CPU threads")
|
cmd:option("-thread", -1, "number of CPU threads")
|
||||||
|
|
||||||
local opt = cmd:parse(arg)
|
local opt = cmd:parse(arg)
|
||||||
if opt.thread > 0 then
|
if opt.thread > 0 then
|
||||||
torch.setnumthreads(opt.thread)
|
torch.setnumthreads(opt.thread)
|
||||||
end
|
end
|
||||||
|
if cudnn then
|
||||||
|
cudnn.fastest = true
|
||||||
|
cudnn.benchmark = false
|
||||||
|
end
|
||||||
|
|
||||||
if string.len(opt.l) == 0 then
|
if string.len(opt.l) == 0 then
|
||||||
convert_image(opt)
|
convert_image(opt)
|
||||||
else
|
else
|
||||||
|
|
4
web.lua
4
web.lua
|
@ -25,6 +25,10 @@ torch.setdefaulttensortype('torch.FloatTensor')
|
||||||
if opt.thread > 0 then
|
if opt.thread > 0 then
|
||||||
torch.setnumthreads(opt.thread)
|
torch.setnumthreads(opt.thread)
|
||||||
end
|
end
|
||||||
|
if cudnn then
|
||||||
|
cudnn.fastest = true
|
||||||
|
cudnn.benchmark = false
|
||||||
|
end
|
||||||
|
|
||||||
local MODEL_DIR = "./models/anime_style_art_rgb"
|
local MODEL_DIR = "./models/anime_style_art_rgb"
|
||||||
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
|
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||||
|
|
Loading…
Reference in a new issue