From 425898a3aaaa36c88144516e34e1fb2a0a9b8c0c Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 31 Oct 2015 22:09:21 +0900 Subject: [PATCH] Don't use cudnn.benchmark mode when predicting --- tools/benchmark.lua | 5 +++++ waifu2x.lua | 7 ++++++- web.lua | 4 ++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tools/benchmark.lua b/tools/benchmark.lua index 83a44c3..4e8ab9d 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -27,6 +27,11 @@ cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each time local opt = cmd:parse(arg) torch.setdefaulttensortype('torch.FloatTensor') +if cudnn then + cudnn.fastest = true + cudnn.benchmark = false +end + local function MSE(x1, x2) return (x1 - x2):pow(2):mean() diff --git a/waifu2x.lua b/waifu2x.lua index ea1bbe5..ce70107 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -112,11 +112,16 @@ local function waifu2x() cmd:option("-crop_size", 128, 'patch size per process') cmd:option("-resume", 0, "skip existing files (0|1)") cmd:option("-thread", -1, "number of CPU threads") - + local opt = cmd:parse(arg) if opt.thread > 0 then torch.setnumthreads(opt.thread) end + if cudnn then + cudnn.fastest = true + cudnn.benchmark = false + end + if string.len(opt.l) == 0 then convert_image(opt) else diff --git a/web.lua b/web.lua index 9d1fa00..d888989 100644 --- a/web.lua +++ b/web.lua @@ -25,6 +25,10 @@ torch.setdefaulttensortype('torch.FloatTensor') if opt.thread > 0 then torch.setnumthreads(opt.thread) end +if cudnn then + cudnn.fastest = true + cudnn.benchmark = false +end local MODEL_DIR = "./models/anime_style_art_rgb" local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")