1
0
Fork 0
mirror of synced 2024-05-19 20:32:22 +12:00
waifu2x/cudnn2cunn.lua
nagadomi 5b4d692f03 add support for RGB color space reconstruction
- add new RGB model (models/anime_style_art_rgb).
- RGB model can reduce color noise.
- waifu2x uses this RGB model by default.

You can use Y model with:
$ th waifu2x.lua -model_dir models/anime_style_art -i input.png -o output.png
$ th train.lua -color y ...
2015-06-23 02:55:30 +09:00

35 lines
999 B
Lua

require 'cunn'
require 'cudnn'
require 'cutorch'
require './lib/LeakyReLU'
local srcnn = require 'lib/srcnn'
local function cudnn2cunn(cudnn_model)
local cunn_model = srcnn.waifu2x("y")
local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
for i = 1, #from_seq do
local from = from_seq[i]
local to = to_seq[i]
to.weight:copy(from.weight)
to.bias:copy(from.bias)
end
cunn_model:cuda()
cunn_model:evaluate()
return cunn_model
end
local cmd = torch.CmdLine()
cmd:text()
cmd:text("convert cudnn model to cunn model ")
cmd:text("Options:")
cmd:option("-model", "./model.t7", 'path of cudnn model file')
cmd:option("-iformat", "ascii", 'input format')
cmd:option("-oformat", "ascii", 'output format')
local opt = cmd:parse(arg)
local cudnn_model = torch.load(opt.model, opt.iformat)
local cunn_model = cudnn2cunn(cudnn_model)
torch.save(opt.model, cunn_model, opt.oformat)