Add model option and 12 layers net
This commit is contained in:
parent
68a6d4cef5
commit
7af5c9443d
|
@ -1,5 +1,6 @@
|
||||||
require 'image'
|
require 'image'
|
||||||
local iproc = require 'iproc'
|
local iproc = require 'iproc'
|
||||||
|
local srcnn = require 'srcnn'
|
||||||
|
|
||||||
local function reconstruct_y(model, x, offset, block_size)
|
local function reconstruct_y(model, x, offset, block_size)
|
||||||
if x:dim() == 2 then
|
if x:dim() == 2 then
|
||||||
|
@ -50,7 +51,7 @@ local function reconstruct_rgb(model, x, offset, block_size)
|
||||||
end
|
end
|
||||||
local reconstruct = {}
|
local reconstruct = {}
|
||||||
function reconstruct.is_rgb(model)
|
function reconstruct.is_rgb(model)
|
||||||
if model:get(model:size() - 1).weight:size(1) == 3 then
|
if srcnn.channels(model) == 3 then
|
||||||
-- 3ch RGB
|
-- 3ch RGB
|
||||||
return true
|
return true
|
||||||
else
|
else
|
||||||
|
@ -59,21 +60,7 @@ function reconstruct.is_rgb(model)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
function reconstruct.offset_size(model)
|
function reconstruct.offset_size(model)
|
||||||
local conv = model:findModules("nn.SpatialConvolutionMM")
|
return srcnn.offset_size(model)
|
||||||
if #conv > 0 then
|
|
||||||
local offset = 0
|
|
||||||
for i = 1, #conv do
|
|
||||||
offset = offset + (conv[i].kW - 1) / 2
|
|
||||||
end
|
|
||||||
return math.floor(offset)
|
|
||||||
else
|
|
||||||
conv = model:findModules("cudnn.SpatialConvolution")
|
|
||||||
local offset = 0
|
|
||||||
for i = 1, #conv do
|
|
||||||
offset = offset + (conv[i].kW - 1) / 2
|
|
||||||
end
|
|
||||||
return math.floor(offset)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
function reconstruct.image_y(model, x, offset, block_size)
|
function reconstruct.image_y(model, x, offset, block_size)
|
||||||
block_size = block_size or 128
|
block_size = block_size or 128
|
||||||
|
|
|
@ -24,6 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
||||||
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
||||||
cmd:option("-model_dir", "./models", 'model directory')
|
cmd:option("-model_dir", "./models", 'model directory')
|
||||||
cmd:option("-method", "scale", 'method to training (noise|scale)')
|
cmd:option("-method", "scale", 'method to training (noise|scale)')
|
||||||
|
cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)')
|
||||||
cmd:option("-noise_level", 1, '(1|2|3)')
|
cmd:option("-noise_level", 1, '(1|2|3)')
|
||||||
cmd:option("-style", "art", '(art|photo)')
|
cmd:option("-style", "art", '(art|photo)')
|
||||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||||
|
|
113
lib/srcnn.lua
113
lib/srcnn.lua
|
@ -30,63 +30,132 @@ end
|
||||||
function srcnn.channels(model)
|
function srcnn.channels(model)
|
||||||
return model:get(model:size() - 1).weight:size(1)
|
return model:get(model:size() - 1).weight:size(1)
|
||||||
end
|
end
|
||||||
function srcnn.waifu2x_cunn(ch)
|
function srcnn.backend(model)
|
||||||
|
local conv = model:findModules("cudnn.SpatialConvolution")
|
||||||
|
if #conv > 0 then
|
||||||
|
return "cudnn"
|
||||||
|
else
|
||||||
|
return "cunn"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function srcnn.color(model)
|
||||||
|
local ch = srcnn.channels(model)
|
||||||
|
if ch == 3 then
|
||||||
|
return "rgb"
|
||||||
|
else
|
||||||
|
return "y"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function srcnn.name(model)
|
||||||
|
local backend_cudnn = false
|
||||||
|
local conv = model:findModules("nn.SpatialConvolutionMM")
|
||||||
|
if #conv == 0 then
|
||||||
|
backend_cudnn = true
|
||||||
|
conv = model:findModules("cudnn.SpatialConvolution")
|
||||||
|
end
|
||||||
|
if #conv == 7 then
|
||||||
|
return "vgg_7"
|
||||||
|
elseif #conv == 12 then
|
||||||
|
return "vgg_12"
|
||||||
|
else
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function srcnn.offset_size(model)
|
||||||
|
local conv = model:findModules("nn.SpatialConvolutionMM")
|
||||||
|
if #conv == 0 then
|
||||||
|
conv = model:findModules("cudnn.SpatialConvolution")
|
||||||
|
end
|
||||||
|
local offset = 0
|
||||||
|
for i = 1, #conv do
|
||||||
|
offset = offset + (conv[i].kW - 1) / 2
|
||||||
|
end
|
||||||
|
return math.floor(offset)
|
||||||
|
end
|
||||||
|
|
||||||
|
local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
|
||||||
|
if backend == "cunn" then
|
||||||
|
return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
|
||||||
|
elseif backend == "cudnn" then
|
||||||
|
return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
|
||||||
|
else
|
||||||
|
error("unsupported backend:" .. backend)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- VGG style net(7 layers)
|
||||||
|
function srcnn.vgg_7(backend, ch)
|
||||||
local model = nn.Sequential()
|
local model = nn.Sequential()
|
||||||
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(nn.View(-1):setNumInputDims(3))
|
model:add(nn.View(-1):setNumInputDims(3))
|
||||||
--model:cuda()
|
--model:cuda()
|
||||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||||
|
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
function srcnn.waifu2x_cudnn(ch)
|
-- VGG style net(12 layers)
|
||||||
|
function srcnn.vgg_12(backend, ch)
|
||||||
local model = nn.Sequential()
|
local model = nn.Sequential()
|
||||||
model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(w2nn.LeakyReLU(0.1))
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||||
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
|
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||||
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
|
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||||
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
|
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
|
||||||
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
|
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
|
||||||
|
model:add(w2nn.LeakyReLU(0.1))
|
||||||
|
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(nn.View(-1):setNumInputDims(3))
|
model:add(nn.View(-1):setNumInputDims(3))
|
||||||
--model:cuda()
|
--model:cuda()
|
||||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||||
|
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
function srcnn.create(model_name, backend, color)
|
function srcnn.create(model_name, backend, color)
|
||||||
|
model_name = model_name or "vgg_7"
|
||||||
|
backend = backend or "cunn"
|
||||||
|
color = color or "rgb"
|
||||||
local ch = 3
|
local ch = 3
|
||||||
if color == "rgb" then
|
if color == "rgb" then
|
||||||
ch = 3
|
ch = 3
|
||||||
elseif color == "y" then
|
elseif color == "y" then
|
||||||
ch = 1
|
ch = 1
|
||||||
else
|
else
|
||||||
error("unsupported color: " + color)
|
error("unsupported color: " .. color)
|
||||||
end
|
end
|
||||||
if backend == "cunn" then
|
if model_name == "vgg_7" then
|
||||||
return srcnn.waifu2x_cunn(ch)
|
return srcnn.vgg_7(backend, ch)
|
||||||
elseif backend == "cudnn" then
|
elseif model_name == "vgg_12" then
|
||||||
return srcnn.waifu2x_cudnn(ch)
|
return srcnn.vgg_12(backend, ch)
|
||||||
else
|
else
|
||||||
error("unsupported backend: " + backend)
|
error("unsupported model_name: " .. model_name)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return srcnn
|
return srcnn
|
||||||
|
|
|
@ -5,8 +5,8 @@ require 'os'
|
||||||
require 'w2nn'
|
require 'w2nn'
|
||||||
local srcnn = require 'srcnn'
|
local srcnn = require 'srcnn'
|
||||||
|
|
||||||
local function rebuild(old_model)
|
local function rebuild(old_model, model)
|
||||||
local new_model = srcnn.waifu2x_cunn(srcnn.channels(old_model))
|
local new_model = srcnn.create(model, srcnn.backend(old_model), srcnn.color(old_model))
|
||||||
local weight_from = old_model:findModules("nn.SpatialConvolutionMM")
|
local weight_from = old_model:findModules("nn.SpatialConvolutionMM")
|
||||||
local weight_to = new_model:findModules("nn.SpatialConvolutionMM")
|
local weight_to = new_model:findModules("nn.SpatialConvolutionMM")
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ cmd:text("waifu2x rebuild cunn model")
|
||||||
cmd:text("Options:")
|
cmd:text("Options:")
|
||||||
cmd:option("-i", "", 'Specify the input model')
|
cmd:option("-i", "", 'Specify the input model')
|
||||||
cmd:option("-o", "", 'Specify the output model')
|
cmd:option("-o", "", 'Specify the output model')
|
||||||
|
cmd:option("-model", "vgg_7", 'Specify the model architecture (vgg_7|vgg_12)')
|
||||||
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
|
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
|
||||||
cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
|
cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
|
||||||
|
|
||||||
|
@ -39,5 +40,5 @@ if not path.isfile(opt.i) then
|
||||||
os.exit(-1)
|
os.exit(-1)
|
||||||
end
|
end
|
||||||
local old_model = torch.load(opt.i, opt.iformat)
|
local old_model = torch.load(opt.i, opt.iformat)
|
||||||
local new_model = rebuild(old_model)
|
local new_model = rebuild(old_model, opt.model)
|
||||||
torch.save(opt.o, new_model, opt.oformat)
|
torch.save(opt.o, new_model, opt.oformat)
|
||||||
|
|
|
@ -192,7 +192,7 @@ local function train()
|
||||||
local hist_train = {}
|
local hist_train = {}
|
||||||
local hist_valid = {}
|
local hist_valid = {}
|
||||||
local LR_MIN = 1.0e-5
|
local LR_MIN = 1.0e-5
|
||||||
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
local model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||||
local offset = reconstruct.offset_size(model)
|
local offset = reconstruct.offset_size(model)
|
||||||
local pairwise_func = function(x, is_validation, n)
|
local pairwise_func = function(x, is_validation, n)
|
||||||
return transformer(x, is_validation, n, offset)
|
return transformer(x, is_validation, n, offset)
|
||||||
|
@ -200,7 +200,7 @@ local function train()
|
||||||
local criterion = create_criterion(model)
|
local criterion = create_criterion(model)
|
||||||
local eval_metric = nn.MSECriterion():cuda()
|
local eval_metric = nn.MSECriterion():cuda()
|
||||||
local x = torch.load(settings.images)
|
local x = torch.load(settings.images)
|
||||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||||
local adam_config = {
|
local adam_config = {
|
||||||
learningRate = settings.learning_rate,
|
learningRate = settings.learning_rate,
|
||||||
xBatchSize = settings.batch_size,
|
xBatchSize = settings.batch_size,
|
||||||
|
|
Loading…
Reference in a new issue