From 962bdcf300fe5212ce2ead602a0168f50c3de06a Mon Sep 17 00:00:00 2001 From: nagadomi Date: Fri, 22 Jul 2016 02:15:00 +0900 Subject: [PATCH] Add support for user method in benchmark --- tools/benchmark.lua | 121 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 109 insertions(+), 12 deletions(-) diff --git a/tools/benchmark.lua b/tools/benchmark.lua index 29b8a67..4a60007 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -17,7 +17,7 @@ cmd:text("Options:") cmd:option("-dir", "./data/test", 'test image directory') cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory') cmd:option("-model2_dir", "", 'model2 directory (optional)') -cmd:option("-method", "scale", '(scale|noise|noise_scale)') +cmd:option("-method", "scale", '(scale|noise|noise_scale|user)') cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))") cmd:option("-resize_blur", 1.0, 'blur parameter for resize') cmd:option("-color", "y", '(rgb|y)') @@ -40,6 +40,9 @@ cmd:option("-crop_size", 128, 'patch size per process') cmd:option("-batch_size", 1, 'batch_size') cmd:option("-force_cudnn", 0, 'use cuDNN backend') cmd:option("-yuv420", 0, 'use yuv420 jpeg') +cmd:option("-name", "", 'model name for user method') +cmd:option("-x_dir", "", 'input image for user method') +cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be the same as x_dir') local function to_bool(settings, name) if settings[name] == 1 then @@ -112,7 +115,6 @@ end local function MSE2PSNR(mse) return 10 * math.log10((255.0 * 255.0) / math.max(mse, 1)) end - local function transform_jpeg(x, opt) for i = 1, opt.jpeg_times do jpeg = gm.Image(x, "RGB", "DHW") @@ -161,7 +163,7 @@ local function transform_scale_jpeg(x, opt) return iproc.byte2float(x) end -local function benchmark(opt, x, input_func, model1, model2) +local function benchmark(opt, x, model1, model2) local mse local model1_mse = 0 local model2_mse = 0 @@ -185,12 +187,13 @@ local function benchmark(opt, x, input_func, model1, model2) end for i = 1, #x do - local ground_truth = x[i].image local basename = x[i].basename - local input, model1_output, model2_output, baseline_output + local input, model1_output, model2_output, baseline_output, ground_truth - input = input_func(ground_truth, opt) if opt.method == "scale" then + input = transform_scale(x[i].y, opt) + ground_truth = x[i].y + if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size) if model2 then @@ -207,7 +210,10 @@ local function benchmark(opt, x, input_func, model1, model2) end baseline_output = baseline_scale(input, opt.baseline_filter) elseif opt.method == "noise" then - if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first + input = transform_jpeg(x[i].y, opt) + ground_truth = x[i].y + + if opt.force_cudnn and i == 1 then model1_output = image_f(model1, input, opt.crop_size, opt.batch_size) if model2 then model2_output = image_f(model2, input, opt.crop_size, opt.batch_size) @@ -223,7 +229,10 @@ local function benchmark(opt, x, input_func, model1, model2) end baseline_output = input elseif opt.method == "noise_scale" then - if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first + input = transform_scale_jpeg(x[i].y, opt) + ground_truth = x[i].y + + if opt.force_cudnn and i == 1 then if model1.noise_scale_model then model1_output = scale_f(model1.noise_scale_model, 2.0, input, opt.crop_size, opt.batch_size) @@ -285,6 +294,37 @@ local function benchmark(opt, x, input_func, model1, model2) model2_time = model2_time + (sys.clock() - t) end baseline_output = baseline_scale(input, opt.baseline_filter) + elseif opt.method == "user" then + input = x[i].x + ground_truth = x[i].y + local y_scale = ground_truth:size(2) / input:size(2) + if y_scale > 1 then + if opt.force_cudnn and i == 1 then + model1_output = scale_f(model1, y_scale, input, opt.crop_size, opt.batch_size) + if model2 then + model2_output = scale_f(model2, y_scale, input, opt.crop_size, opt.batch_size) + end + end + t = sys.clock() + model1_output = scale_f(model1, y_scale, input, opt.crop_size, opt.batch_size) + model1_time = model1_time + (sys.clock() - t) + if model2 then + t = sys.clock() + model2_output = scale_f(model2, y_scale, input, opt.crop_size, opt.batch_size) + model2_time = model2_time + (sys.clock() - t) + end + else + if opt.force_cudnn and i == 1 then + model1_output = image_f(model1, input, opt.crop_size, opt.batch_size) + if model2 then + model2_output = image_f(model2, input, opt.crop_size, opt.batch_size) + end + end + model1_output = image_f(model1, input, opt.crop_size, opt.batch_size) + if model2 then + model2_output = image_f(model2, input, opt.crop_size, opt.batch_size) + end + end end mse = MSE(ground_truth, model1_output, opt.color) model1_mse = model1_mse + mse @@ -385,7 +425,7 @@ local function load_data(test_dir) local base = name:sub(0, name:len() - e:len()) local img = image_loader.load_float(files[i]) if img then - table.insert(test_x, {image = iproc.crop_mod4(img), + table.insert(test_x, {y = iproc.crop_mod4(img), basename = base}) end if opt.show_progress then @@ -394,6 +434,50 @@ local function load_data(test_dir) end return test_x end +local function get_basename(f) + local name = path.basename(f) + local e = path.extension(name) + local base = name:sub(0, name:len() - e:len()) + return base +end +local function load_user_data(y_dir, x_dir) + local test = {} + local y_files = dir.getfiles(y_dir, "*.*") + local x_files = dir.getfiles(x_dir, "*.*") + local basename_db = {} + for i = 1, #y_files do + basename_db[get_basename(y_files[i])] = {y = y_files[i]} + end + for i = 1, #x_files do + local key = get_basename(x_files[i]) + if basename_db[key] then + basename_db[key].x = x_files[i] + else + error(string.format("%s is not found in %s", key, y_dir)) + end + end + for i = 1, #y_files do + local key = get_basename(y_files[i]) + local d = basename_db[key] + if not (d.x and d.y) then + error(string.format("%s is not found in %s", key, x_dir)) + end + end + for i = 1, #y_files do + local key = get_basename(y_files[i]) + local x = image_loader.load_float(basename_db[key].x) + local y = image_loader.load_float(basename_db[key].y) + if x and y then + table.insert(test, {y = y, + x = x, + basename = base}) + end + if opt.show_progress then + xlua.progress(i, #y_files) + end + end + return test +end function load_noise_scale_model(model_dir, noise_level, force_cudnn) local f = path.join(model_dir, string.format("noise%d_scale2.0x_model.t7", opt.noise_level)) local s1, noise_scale = pcall(w2nn.load_model, f, force_cudnn) @@ -437,7 +521,7 @@ if opt.method == "scale" then model2 = nil end local test_x = load_data(opt.dir) - benchmark(opt, test_x, transform_scale, model1, model2) + benchmark(opt, test_x, model1, model2) elseif opt.method == "noise" then local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)) local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)) @@ -450,7 +534,7 @@ elseif opt.method == "noise" then model2 = nil end local test_x = load_data(opt.dir) - benchmark(opt, test_x, transform_jpeg, model1, model2) + benchmark(opt, test_x, model1, model2) elseif opt.method == "noise_scale" then local model2 = nil local model1 = load_noise_scale_model(opt.model1_dir, opt.noise_level, opt.force_cudnn) @@ -458,5 +542,18 @@ elseif opt.method == "noise_scale" then model2 = load_noise_scale_model(opt.model2_dir, opt.noise_level, opt.force_cudnn) end local test_x = load_data(opt.dir) - benchmark(opt, test_x, transform_scale_jpeg, model1, model2) + benchmark(opt, test_x, model1, model2) +elseif opt.method == "user" then + local f1 = path.join(opt.model1_dir, string.format("%s_model.t7", opt.name)) + local f2 = path.join(opt.model2_dir, string.format("%s_model.t7", opt.name)) + local s1, model1 = pcall(w2nn.load_model, f1, opt.force_cudnn) + local s2, model2 = pcall(w2nn.load_model, f2, opt.force_cudnn) + if not s1 then + error("Load error: " .. f1) + end + if not s2 then + model2 = nil + end + local test = load_user_data(opt.y_dir, opt.x_dir) + benchmark(opt, test, model1, model2) end