From ff63a28540f17c629f2fd51b0d7fa83927ca1117 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 23 Jan 2016 09:38:40 +0900 Subject: [PATCH] Add tools/rebuild_model --- tools/rebuild_model.lua | 43 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tools/rebuild_model.lua diff --git a/tools/rebuild_model.lua b/tools/rebuild_model.lua new file mode 100644 index 0000000..878475d --- /dev/null +++ b/tools/rebuild_model.lua @@ -0,0 +1,43 @@ +require 'pl' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path +require 'os' +require 'w2nn' +local srcnn = require 'srcnn' + +local function rebuild(old_model) + local new_model = srcnn.waifu2x_cunn(srcnn.channels(old_model)) + local weight_from = old_model:findModules("nn.SpatialConvolutionMM") + local weight_to = new_model:findModules("nn.SpatialConvolutionMM") + + assert(#weight_from == #weight_to) + + for i = 1, #weight_from do + local from = weight_from[i] + local to = weight_to[i] + + to.weight:copy(from.weight) + to.bias:copy(from.bias) + end + new_model:cuda() + new_model:evaluate() + return new_model +end + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("waifu2x rebuild cunn model") +cmd:text("Options:") +cmd:option("-i", "", 'Specify the input model') +cmd:option("-o", "", 'Specify the output model') +cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)') +cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)') + +local opt = cmd:parse(arg) +if not path.isfile(opt.i) then + cmd:help() + os.exit(-1) +end +local old_model = torch.load(opt.i, opt.iformat) +local new_model = rebuild(old_model) +torch.save(opt.o, new_model, opt.oformat)