diff --git a/tools/benchmark.lua b/tools/benchmark.lua index 233e542..41119cc 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -17,7 +17,7 @@ cmd:text("Options:") cmd:option("-dir", "./data/test", 'test image directory') cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory') cmd:option("-model2_dir", "", 'model2 directory (optional)') -cmd:option("-method", "scale", '(scale|noise)') +cmd:option("-method", "scale", '(scale|noise|noise_scale)') cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))") cmd:option("-resize_blur", 1.0, 'blur parameter for resize') cmd:option("-color", "y", '(rgb|y)') @@ -129,6 +129,22 @@ local function transform_scale(x, opt) opt.filter, opt.resize_blur) end +local function transform_scale_jpeg(x, opt) + x = iproc.scale(x, + x:size(3) * 0.5, + x:size(2) * 0.5, + opt.filter, opt.resize_blur) + for i = 1, opt.jpeg_times do + jpeg = gm.Image(x, "RGB", "DHW") + jpeg:format("jpeg") + jpeg:samplingFactors({1.0, 1.0, 1.0}) + blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down) + jpeg:fromBlob(blob, len) + x = jpeg:toTensor("byte", "RGB", "DHW") + end + return iproc.byte2float(x) +end + local function benchmark(opt, x, input_func, model1, model2) local model1_mse = 0 local model2_mse = 0 @@ -157,15 +173,45 @@ local function benchmark(opt, x, input_func, model1, model2) input = input_func(ground_truth, opt) t = sys.clock() - if input:size(3) == ground_truth:size(3) then + if opt.method == "scale" then + model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size) + if model2 then + model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size) + end + baseline_output = baseline_scale(input, opt.baseline_filter) + elseif opt.method == "noise" then model1_output = image_f(model1, input, opt.crop_size, opt.batch_size) if model2 then model2_output = image_f(model2, input, opt.crop_size, opt.batch_size) end - else - model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size) + baseline_output = input + elseif opt.method == "noise_scale" then + if model1.noise_scale_model then + model1_output = scale_f(model1.noise_scale_model, 2.0, + input, opt.crop_size, opt.batch_size) + else + if model1.noise_model then + model1_output = image_f(model1.noise_model, input, opt.crop_size, opt.batch_size) + else + model1_output = input + end + model1_output = scale_f(model1.scale_model, 2.0, model1_output, + opt.crop_size, opt.batch_size) + end if model2 then - model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size) + if model2.noise_scale_model then + model2_output = scale_f(model2.noise_scale_model, 2.0, + input, opt.crop_size, opt.batch_size) + else + if model2.noise_model then + model2_output = image_f(model2.noise_model, input, + opt.crop_size, opt.batch_size) + else + model2_output = input + end + model2_output = scale_f(model2.scale_model, 2.0, model2_output, + opt.crop_size, opt.batch_size) + end end baseline_output = baseline_scale(input, opt.baseline_filter) end @@ -271,9 +317,36 @@ end function load_model(filename) return torch.load(filename, "ascii") end +function load_noise_scale_model(model_dir, noise_level) + local f = path.join(model_dir, string.format("noise%d_scale2.0x_model.t7", opt.noise_level)) + local s1, noise_scale = pcall(load_model, f) + local model = {} + if not s1 then + f = path.join(model_dir, string.format("noise%d_model.t7", opt.noise_level)) + local noise + s1, noise = pcall(load_model, f) + if not s1 then + model.noise_model = nil + print(model_dir .. "'s noise model is not found. benchmark will use only scale model.") + else + model.noise_model = noise + end + f = path.join(model_dir, "scale2.0x_model.t7") + local scale + s1, scale = pcall(load_model, f) + if not s1 then + return nil + end + model.scale_model = scale + else + model.noise_scale_model = noise_scale + end + return model +end if opt.show_progress then print(opt) end + if opt.method == "scale" then local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7") local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7") @@ -300,4 +373,12 @@ elseif opt.method == "noise" then end local test_x = load_data(opt.dir) benchmark(opt, test_x, transform_jpeg, model1, model2) +elseif opt.method == "noise_scale" then + local model2 = nil + local model1 = load_noise_scale_model(opt.model1_dir, opt.noise_level) + if opt.model2_dir:len() > 0 then + model2 = load_noise_scale_model(opt.model2_dir, opt.noise_level) + end + local test_x = load_data(opt.dir) + benchmark(opt, test_x, transform_scale_jpeg, model1, model2) end