1
0
Fork 0
mirror of synced 2024-07-01 04:31:14 +12:00

Don't use cudnn.benchmark mode when predicting

This commit is contained in:
nagadomi 2015-10-31 22:09:21 +09:00
parent 490eb33a6b
commit 425898a3aa
3 changed files with 15 additions and 1 deletions

View file

@ -27,6 +27,11 @@ cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each time
local opt = cmd:parse(arg) local opt = cmd:parse(arg)
torch.setdefaulttensortype('torch.FloatTensor') torch.setdefaulttensortype('torch.FloatTensor')
if cudnn then
cudnn.fastest = true
cudnn.benchmark = false
end
local function MSE(x1, x2) local function MSE(x1, x2)
return (x1 - x2):pow(2):mean() return (x1 - x2):pow(2):mean()

View file

@ -117,6 +117,11 @@ local function waifu2x()
if opt.thread > 0 then if opt.thread > 0 then
torch.setnumthreads(opt.thread) torch.setnumthreads(opt.thread)
end end
if cudnn then
cudnn.fastest = true
cudnn.benchmark = false
end
if string.len(opt.l) == 0 then if string.len(opt.l) == 0 then
convert_image(opt) convert_image(opt)
else else

View file

@ -25,6 +25,10 @@ torch.setdefaulttensortype('torch.FloatTensor')
if opt.thread > 0 then if opt.thread > 0 then
torch.setnumthreads(opt.thread) torch.setnumthreads(opt.thread)
end end
if cudnn then
cudnn.fastest = true
cudnn.benchmark = false
end
local MODEL_DIR = "./models/anime_style_art_rgb" local MODEL_DIR = "./models/anime_style_art_rgb"
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii") local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")