1
0
Fork 0
mirror of synced 2024-05-25 15:19:33 +12:00

Add model option and 12 layers net

This commit is contained in:
nagadomi 2016-04-23 09:18:12 +09:00
parent 68a6d4cef5
commit 7af5c9443d
5 changed files with 101 additions and 43 deletions

View file

@ -1,5 +1,6 @@
require 'image'
local iproc = require 'iproc'
local srcnn = require 'srcnn'
local function reconstruct_y(model, x, offset, block_size)
if x:dim() == 2 then
@ -50,7 +51,7 @@ local function reconstruct_rgb(model, x, offset, block_size)
end
local reconstruct = {}
function reconstruct.is_rgb(model)
if model:get(model:size() - 1).weight:size(1) == 3 then
if srcnn.channels(model) == 3 then
-- 3ch RGB
return true
else
@ -59,21 +60,7 @@ function reconstruct.is_rgb(model)
end
end
function reconstruct.offset_size(model)
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv > 0 then
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
else
conv = model:findModules("cudnn.SpatialConvolution")
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
end
return srcnn.offset_size(model)
end
function reconstruct.image_y(model, x, offset, block_size)
block_size = block_size or 128

View file

@ -24,6 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
cmd:option("-test", "images/miku_small.png", 'path to test image')
cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", 'method to training (noise|scale)')
cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)')
cmd:option("-noise_level", 1, '(1|2|3)')
cmd:option("-style", "art", '(art|photo)')
cmd:option("-color", 'rgb', '(y|rgb)')

View file

@ -30,63 +30,132 @@ end
function srcnn.channels(model)
return model:get(model:size() - 1).weight:size(1)
end
function srcnn.waifu2x_cunn(ch)
function srcnn.backend(model)
local conv = model:findModules("cudnn.SpatialConvolution")
if #conv > 0 then
return "cudnn"
else
return "cunn"
end
end
function srcnn.color(model)
local ch = srcnn.channels(model)
if ch == 3 then
return "rgb"
else
return "y"
end
end
function srcnn.name(model)
local backend_cudnn = false
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv == 0 then
backend_cudnn = true
conv = model:findModules("cudnn.SpatialConvolution")
end
if #conv == 7 then
return "vgg_7"
elseif #conv == 12 then
return "vgg_12"
else
return nil
end
end
function srcnn.offset_size(model)
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv == 0 then
conv = model:findModules("cudnn.SpatialConvolution")
end
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
end
local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
elseif backend == "cudnn" then
return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
else
error("unsupported backend:" .. backend)
end
end
-- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch)
local model = nn.Sequential()
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
function srcnn.waifu2x_cudnn(ch)
-- VGG style net(12 layers)
function srcnn.vgg_12(backend, ch)
local model = nn.Sequential()
model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
function srcnn.create(model_name, backend, color)
model_name = model_name or "vgg_7"
backend = backend or "cunn"
color = color or "rgb"
local ch = 3
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
error("unsupported color: " + color)
error("unsupported color: " .. color)
end
if backend == "cunn" then
return srcnn.waifu2x_cunn(ch)
elseif backend == "cudnn" then
return srcnn.waifu2x_cudnn(ch)
if model_name == "vgg_7" then
return srcnn.vgg_7(backend, ch)
elseif model_name == "vgg_12" then
return srcnn.vgg_12(backend, ch)
else
error("unsupported backend: " + backend)
error("unsupported model_name: " .. model_name)
end
end
return srcnn

View file

@ -5,8 +5,8 @@ require 'os'
require 'w2nn'
local srcnn = require 'srcnn'
local function rebuild(old_model)
local new_model = srcnn.waifu2x_cunn(srcnn.channels(old_model))
local function rebuild(old_model, model)
local new_model = srcnn.create(model, srcnn.backend(old_model), srcnn.color(old_model))
local weight_from = old_model:findModules("nn.SpatialConvolutionMM")
local weight_to = new_model:findModules("nn.SpatialConvolutionMM")
@ -30,6 +30,7 @@ cmd:text("waifu2x rebuild cunn model")
cmd:text("Options:")
cmd:option("-i", "", 'Specify the input model')
cmd:option("-o", "", 'Specify the output model')
cmd:option("-model", "vgg_7", 'Specify the model architecture (vgg_7|vgg_12)')
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
@ -39,5 +40,5 @@ if not path.isfile(opt.i) then
os.exit(-1)
end
local old_model = torch.load(opt.i, opt.iformat)
local new_model = rebuild(old_model)
local new_model = rebuild(old_model, opt.model)
torch.save(opt.o, new_model, opt.oformat)

View file

@ -192,7 +192,7 @@ local function train()
local hist_train = {}
local hist_valid = {}
local LR_MIN = 1.0e-5
local model = srcnn.create(settings.method, settings.backend, settings.color)
local model = srcnn.create(settings.model, settings.backend, settings.color)
local offset = reconstruct.offset_size(model)
local pairwise_func = function(x, is_validation, n)
return transformer(x, is_validation, n, offset)
@ -200,7 +200,7 @@ local function train()
local criterion = create_criterion(model)
local eval_metric = nn.MSECriterion():cuda()
local x = torch.load(settings.images)
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local adam_config = {
learningRate = settings.learning_rate,
xBatchSize = settings.batch_size,