From 7af5c9443d6d49a17998209ed495cc77d9d28ea6 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 23 Apr 2016 09:18:12 +0900 Subject: [PATCH] Add model option and 12 layers net --- lib/reconstruct.lua | 19 ++----- lib/settings.lua | 1 + lib/srcnn.lua | 113 ++++++++++++++++++++++++++++++++-------- tools/rebuild_model.lua | 7 +-- train.lua | 4 +- 5 files changed, 101 insertions(+), 43 deletions(-) diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 3b23c5e..6a78926 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -1,5 +1,6 @@ require 'image' local iproc = require 'iproc' +local srcnn = require 'srcnn' local function reconstruct_y(model, x, offset, block_size) if x:dim() == 2 then @@ -50,7 +51,7 @@ local function reconstruct_rgb(model, x, offset, block_size) end local reconstruct = {} function reconstruct.is_rgb(model) - if model:get(model:size() - 1).weight:size(1) == 3 then + if srcnn.channels(model) == 3 then -- 3ch RGB return true else @@ -59,21 +60,7 @@ function reconstruct.is_rgb(model) end end function reconstruct.offset_size(model) - local conv = model:findModules("nn.SpatialConvolutionMM") - if #conv > 0 then - local offset = 0 - for i = 1, #conv do - offset = offset + (conv[i].kW - 1) / 2 - end - return math.floor(offset) - else - conv = model:findModules("cudnn.SpatialConvolution") - local offset = 0 - for i = 1, #conv do - offset = offset + (conv[i].kW - 1) / 2 - end - return math.floor(offset) - end + return srcnn.offset_size(model) end function reconstruct.image_y(model, x, offset, block_size) block_size = block_size or 128 diff --git a/lib/settings.lua b/lib/settings.lua index 858acba..15c9f26 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -24,6 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)') cmd:option("-test", "images/miku_small.png", 'path to test image') cmd:option("-model_dir", "./models", 'model directory') cmd:option("-method", "scale", 'method to training (noise|scale)') +cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)') cmd:option("-noise_level", 1, '(1|2|3)') cmd:option("-style", "art", '(art|photo)') cmd:option("-color", 'rgb', '(y|rgb)') diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 1f6a4b2..1c4ba9f 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -30,63 +30,132 @@ end function srcnn.channels(model) return model:get(model:size() - 1).weight:size(1) end -function srcnn.waifu2x_cunn(ch) +function srcnn.backend(model) + local conv = model:findModules("cudnn.SpatialConvolution") + if #conv > 0 then + return "cudnn" + else + return "cunn" + end +end +function srcnn.color(model) + local ch = srcnn.channels(model) + if ch == 3 then + return "rgb" + else + return "y" + end +end +function srcnn.name(model) + local backend_cudnn = false + local conv = model:findModules("nn.SpatialConvolutionMM") + if #conv == 0 then + backend_cudnn = true + conv = model:findModules("cudnn.SpatialConvolution") + end + if #conv == 7 then + return "vgg_7" + elseif #conv == 12 then + return "vgg_12" + else + return nil + end +end +function srcnn.offset_size(model) + local conv = model:findModules("nn.SpatialConvolutionMM") + if #conv == 0 then + conv = model:findModules("cudnn.SpatialConvolution") + end + local offset = 0 + for i = 1, #conv do + offset = offset + (conv[i].kW - 1) / 2 + end + return math.floor(offset) +end + +local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) + if backend == "cunn" then + return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) + elseif backend == "cudnn" then + return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) + else + error("unsupported backend:" .. backend) + end +end + +-- VGG style net(7 layers) +function srcnn.vgg_7(backend, ch) local model = nn.Sequential() - model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) model:add(nn.View(-1):setNumInputDims(3)) --model:cuda() --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) return model end -function srcnn.waifu2x_cudnn(ch) +-- VGG style net(12 layers) +function srcnn.vgg_12(backend, ch) local model = nn.Sequential() - model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) model:add(w2nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) model:add(nn.View(-1):setNumInputDims(3)) --model:cuda() --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) return model end + function srcnn.create(model_name, backend, color) + model_name = model_name or "vgg_7" + backend = backend or "cunn" + color = color or "rgb" local ch = 3 if color == "rgb" then ch = 3 elseif color == "y" then ch = 1 else - error("unsupported color: " + color) + error("unsupported color: " .. color) end - if backend == "cunn" then - return srcnn.waifu2x_cunn(ch) - elseif backend == "cudnn" then - return srcnn.waifu2x_cudnn(ch) + if model_name == "vgg_7" then + return srcnn.vgg_7(backend, ch) + elseif model_name == "vgg_12" then + return srcnn.vgg_12(backend, ch) else - error("unsupported backend: " + backend) + error("unsupported model_name: " .. model_name) end end return srcnn diff --git a/tools/rebuild_model.lua b/tools/rebuild_model.lua index 878475d..b46f433 100644 --- a/tools/rebuild_model.lua +++ b/tools/rebuild_model.lua @@ -5,8 +5,8 @@ require 'os' require 'w2nn' local srcnn = require 'srcnn' -local function rebuild(old_model) - local new_model = srcnn.waifu2x_cunn(srcnn.channels(old_model)) +local function rebuild(old_model, model) + local new_model = srcnn.create(model, srcnn.backend(old_model), srcnn.color(old_model)) local weight_from = old_model:findModules("nn.SpatialConvolutionMM") local weight_to = new_model:findModules("nn.SpatialConvolutionMM") @@ -30,6 +30,7 @@ cmd:text("waifu2x rebuild cunn model") cmd:text("Options:") cmd:option("-i", "", 'Specify the input model') cmd:option("-o", "", 'Specify the output model') +cmd:option("-model", "vgg_7", 'Specify the model architecture (vgg_7|vgg_12)') cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)') cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)') @@ -39,5 +40,5 @@ if not path.isfile(opt.i) then os.exit(-1) end local old_model = torch.load(opt.i, opt.iformat) -local new_model = rebuild(old_model) +local new_model = rebuild(old_model, opt.model) torch.save(opt.o, new_model, opt.oformat) diff --git a/train.lua b/train.lua index 077e1b0..14666a6 100644 --- a/train.lua +++ b/train.lua @@ -192,7 +192,7 @@ local function train() local hist_train = {} local hist_valid = {} local LR_MIN = 1.0e-5 - local model = srcnn.create(settings.method, settings.backend, settings.color) + local model = srcnn.create(settings.model, settings.backend, settings.color) local offset = reconstruct.offset_size(model) local pairwise_func = function(x, is_validation, n) return transformer(x, is_validation, n, offset) @@ -200,7 +200,7 @@ local function train() local criterion = create_criterion(model) local eval_metric = nn.MSECriterion():cuda() local x = torch.load(settings.images) - local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x)) + local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local adam_config = { learningRate = settings.learning_rate, xBatchSize = settings.batch_size,