Fix model loading error when using ukbench model; Add error handling
This commit is contained in:
parent
1a20ec501b
commit
1f91548c6e
1 changed files with 45 additions and 19 deletions
64
waifu2x.lua
64
waifu2x.lua
|
@ -17,23 +17,34 @@ local function convert_image(opt)
|
||||||
local name = path.basename(opt.i)
|
local name = path.basename(opt.i)
|
||||||
local e = path.extension(name)
|
local e = path.extension(name)
|
||||||
local base = name:sub(0, name:len() - e:len())
|
local base = name:sub(0, name:len() - e:len())
|
||||||
opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
|
opt.o = path.join(path.dirname(opt.i), string.format("%s_%s.png", base, opt.m))
|
||||||
end
|
end
|
||||||
if opt.m == "noise" then
|
if opt.m == "noise" then
|
||||||
local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
|
local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
|
||||||
--local srcnn = require 'lib/srcnn'
|
local model = torch.load(model_path, "ascii")
|
||||||
--local model = srcnn.waifu2x("rgb"):cuda()
|
if not model then
|
||||||
model:evaluate()
|
error("Load Error: " .. model_path)
|
||||||
|
end
|
||||||
new_x = reconstruct.image(model, x, opt.crop_size)
|
new_x = reconstruct.image(model, x, opt.crop_size)
|
||||||
elseif opt.m == "scale" then
|
elseif opt.m == "scale" then
|
||||||
local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
|
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||||
model:evaluate()
|
local model = torch.load(model_path, "ascii")
|
||||||
|
if not model then
|
||||||
|
error("Load Error: " .. model_path)
|
||||||
|
end
|
||||||
new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
|
new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
|
||||||
elseif opt.m == "noise_scale" then
|
elseif opt.m == "noise_scale" then
|
||||||
local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
|
local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
|
||||||
local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
|
local noise_model = torch.load(noise_model_path, "ascii")
|
||||||
noise_model:evaluate()
|
local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||||
scale_model:evaluate()
|
local scale_model = torch.load(scale_model_path, "ascii")
|
||||||
|
|
||||||
|
if not noise_model then
|
||||||
|
error("Load Error: " .. noise_model_path)
|
||||||
|
end
|
||||||
|
if not scale_model then
|
||||||
|
error("Load Error: " .. scale_model_path)
|
||||||
|
end
|
||||||
x = reconstruct.image(noise_model, x)
|
x = reconstruct.image(noise_model, x)
|
||||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||||
else
|
else
|
||||||
|
@ -43,15 +54,30 @@ local function convert_image(opt)
|
||||||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||||
end
|
end
|
||||||
local function convert_frames(opt)
|
local function convert_frames(opt)
|
||||||
local noise1_model = torch.load(path.join(opt.model_dir, "noise1_model.t7"), "ascii")
|
local noise1_model, noise2_model, scale_model
|
||||||
local noise2_model = torch.load(path.join(opt.model_dir, "noise2_model.t7"), "ascii")
|
if opt.m == "scale" then
|
||||||
local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
|
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||||
|
scale_model = torch.load(model_path, "ascii")
|
||||||
noise1_model:evaluate()
|
if not scale_model then
|
||||||
noise2_model:evaluate()
|
error("Load Error: " .. model_path)
|
||||||
scale_model:evaluate()
|
end
|
||||||
|
elseif opt.m == "noise" and opt.noise_level == 1 then
|
||||||
|
local model_path = path.join(opt.model_dir, "noise1_model.t7")
|
||||||
|
noise1_model = torch.load(model_path, "ascii")
|
||||||
|
if not noise1_model then
|
||||||
|
error("Load Error: " .. model_path)
|
||||||
|
end
|
||||||
|
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||||
|
local model_path = path.join(opt.model_dir, "noise2_model.t7")
|
||||||
|
noise2_model = torch.load(model_path, "ascii")
|
||||||
|
if not noise2_model then
|
||||||
|
error("Load Error: " .. model_path)
|
||||||
|
end
|
||||||
|
end
|
||||||
local fp = io.open(opt.l)
|
local fp = io.open(opt.l)
|
||||||
|
if not fp then
|
||||||
|
error("Open Error: " .. opt.l)
|
||||||
|
end
|
||||||
local count = 0
|
local count = 0
|
||||||
local lines = {}
|
local lines = {}
|
||||||
for line in fp:lines() do
|
for line in fp:lines() do
|
||||||
|
|
Loading…
Reference in a new issue