From 17b8de2d366a731b6ce972de7ceabf85416edb61 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 27 Oct 2018 20:38:04 +0900 Subject: [PATCH] Make the performance benchmark practical --- lib/srcnn.lua | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 68f0b14..c064256 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -890,18 +890,31 @@ function srcnn.upcunet_v2(backend, ch) end local function bench() local sys = require 'sys' - cudnn.benchmark = false + cudnn.benchmark = true local model = nil local arch = {"upconv_7", "upcunet", "upcunet_v2"} - local backend = "cunn" + local backend = "cudnn" for k = 1, #arch do model = srcnn[arch[k]](backend, 3):cuda() - model:training() + model:evaluate() + local dummy = nil + -- warn + for i = 1, 20 do + local x = torch.Tensor(4, 3, 172, 172):uniform():cuda() + model:forward(x) + end t = sys.clock() - for i = 1, 10 do - model:forward(torch.Tensor(1, 3, 172, 172):zero():cuda()) + for i = 1, 20 do + local x = torch.Tensor(4, 3, 172, 172):uniform():cuda() + local z = model:forward(x) + if dummy == nil then + dummy = z:clone() + else + dummy:add(z) + end end print(arch[k], sys.clock() - t) + model:clearState() end end function srcnn.create(model_name, backend, color) @@ -935,4 +948,5 @@ model:training() print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda())) os.exit() --]] + return srcnn