diff --git a/waifu2x.lua b/waifu2x.lua index ce70107..4a29dae 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -17,23 +17,34 @@ local function convert_image(opt) local name = path.basename(opt.i) local e = path.extension(name) local base = name:sub(0, name:len() - e:len()) - opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m)) + opt.o = path.join(path.dirname(opt.i), string.format("%s_%s.png", base, opt.m)) end if opt.m == "noise" then - local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii") - --local srcnn = require 'lib/srcnn' - --local model = srcnn.waifu2x("rgb"):cuda() - model:evaluate() + local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)) + local model = torch.load(model_path, "ascii") + if not model then + error("Load Error: " .. model_path) + end new_x = reconstruct.image(model, x, opt.crop_size) elseif opt.m == "scale" then - local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii") - model:evaluate() + local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) + local model = torch.load(model_path, "ascii") + if not model then + error("Load Error: " .. model_path) + end new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size) elseif opt.m == "noise_scale" then - local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii") - local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii") - noise_model:evaluate() - scale_model:evaluate() + local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)) + local noise_model = torch.load(noise_model_path, "ascii") + local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) + local scale_model = torch.load(scale_model_path, "ascii") + + if not noise_model then + error("Load Error: " .. noise_model_path) + end + if not scale_model then + error("Load Error: " .. scale_model_path) + end x = reconstruct.image(noise_model, x) new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size) else @@ -43,15 +54,30 @@ local function convert_image(opt) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") end local function convert_frames(opt) - local noise1_model = torch.load(path.join(opt.model_dir, "noise1_model.t7"), "ascii") - local noise2_model = torch.load(path.join(opt.model_dir, "noise2_model.t7"), "ascii") - local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii") - - noise1_model:evaluate() - noise2_model:evaluate() - scale_model:evaluate() - + local noise1_model, noise2_model, scale_model + if opt.m == "scale" then + local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) + scale_model = torch.load(model_path, "ascii") + if not scale_model then + error("Load Error: " .. model_path) + end + elseif opt.m == "noise" and opt.noise_level == 1 then + local model_path = path.join(opt.model_dir, "noise1_model.t7") + noise1_model = torch.load(model_path, "ascii") + if not noise1_model then + error("Load Error: " .. model_path) + end + elseif opt.m == "noise" and opt.noise_level == 2 then + local model_path = path.join(opt.model_dir, "noise2_model.t7") + noise2_model = torch.load(model_path, "ascii") + if not noise2_model then + error("Load Error: " .. model_path) + end + end local fp = io.open(opt.l) + if not fp then + error("Open Error: " .. opt.l) + end local count = 0 local lines = {} for line in fp:lines() do