From 599da6a6650502c15f87a64c261f7533669bcdc3 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 12 Jun 2016 15:56:44 +0900 Subject: [PATCH] refactor --- waifu2x.lua | 39 +++++++-------------------------------- 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/waifu2x.lua b/waifu2x.lua index 28170d0..24b7ca8 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -82,13 +82,10 @@ local function convert_image(opt) local model_path = path.join(opt.model_dir, ("noise%d_scale%.1fx_model.t7"):format(opt.noise_level, opt.scale)) if path.exists(model_path) then local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) - local scale_model = torch.load(scale_model_path, "ascii") + local t, scale_model = pcall(torch.load, scale_model_path, "ascii") local model = torch.load(model_path, "ascii") - if not model then - error("Load Error: " .. model_path) - end - if not scale_model then - error("Load Error: " .. model_path) + if not t then + scale_model = model end local t = sys.clock() x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) @@ -100,13 +97,6 @@ local function convert_image(opt) local noise_model = torch.load(noise_model_path, "ascii") local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) local scale_model = torch.load(scale_model_path, "ascii") - - if not noise_model then - error("Load Error: " .. noise_model_path) - end - if not scale_model then - error("Load Error: " .. scale_model_path) - end local t = sys.clock() x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model)) x = image_f(noise_model, x, opt.crop_size, opt.batch_size) @@ -120,7 +110,7 @@ local function convert_image(opt) image_loader.save_png(opt.o, new_x, tablex.update({depth = opt.depth, inplace = true}, meta)) end local function convert_frames(opt) - local model_path, scale_model + local model_path, scale_model, t local noise_scale_model = {} local noise_model = {} local scale_f, image_f @@ -140,38 +130,23 @@ local function convert_frames(opt) if opt.m == "scale" then model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) scale_model = torch.load(model_path, "ascii") - if not scale_model then - error("Load Error: " .. model_path) - end elseif opt.m == "noise" then model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level)) noise_model[opt.noise_level] = torch.load(model_path, "ascii") - if not noise_model[opt.noise_level] then - error("Load Error: " .. model_path) - end elseif opt.m == "noise_scale" then local model_path = path.join(opt.model_dir, ("noise%d_scale%.1fx_model.t7"):format(opt.noise_level, opt.scale)) if path.exists(model_path) then noise_scale_model[opt.noise_level] = torch.load(model_path, "ascii") - if not noise_scale_model[opt.noise_level] then - error("Load Error: " .. model_path) - end model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) - scale_model = torch.load(model_path, "ascii") - if not scale_model then - error("Load Error: " .. model_path) + t, scale_model = pcall(torch.load, model_path, "ascii") + if not t then + scale_model = noise_scale_model[opt.noise_level] end else model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) scale_model = torch.load(model_path, "ascii") - if not scale_model then - error("Load Error: " .. model_path) - end model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level)) noise_model[opt.noise_level] = torch.load(model_path, "ascii") - if not noise_model[opt.noise_level] then - error("Load Error: " .. model_path) - end end end local fp = io.open(opt.l)