2015-11-08 22:31:46 +13:00
|
|
|
require 'pl'
|
2015-10-28 19:30:47 +13:00
|
|
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
|
|
|
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
2015-05-16 17:48:05 +12:00
|
|
|
|
2015-10-28 19:30:47 +13:00
|
|
|
require 'w2nn'
|
2015-05-16 17:48:05 +12:00
|
|
|
torch.setdefaulttensortype("torch.FloatTensor")
|
|
|
|
|
|
|
|
local cmd = torch.CmdLine()
|
|
|
|
cmd:text()
|
|
|
|
cmd:text("cleanup model")
|
|
|
|
cmd:text("Options:")
|
|
|
|
cmd:option("-model", "./model.t7", 'path of model file')
|
|
|
|
cmd:option("-iformat", "binary", 'input format')
|
|
|
|
cmd:option("-oformat", "binary", 'output format')
|
|
|
|
|
|
|
|
local opt = cmd:parse(arg)
|
|
|
|
local model = torch.load(opt.model, opt.iformat)
|
|
|
|
if model then
|
2015-11-08 02:54:29 +13:00
|
|
|
w2nn.cleanup_model(model)
|
2015-11-09 07:45:03 +13:00
|
|
|
model:cuda()
|
|
|
|
model:evaluate()
|
2015-05-16 17:48:05 +12:00
|
|
|
torch.save(opt.model, model, opt.oformat)
|
|
|
|
else
|
|
|
|
error("model not found")
|
|
|
|
end
|