1
0
Fork 0
mirror of synced 2024-05-03 12:32:26 +12:00
This commit is contained in:
nagadomi 2018-11-14 17:15:27 +09:00
parent dd8cb71601
commit d5c2277e0e
8 changed files with 551 additions and 12 deletions

204
appendix/arch/cunet.txt Normal file
View file

@ -0,0 +1,204 @@
nn.Sequential {
[input -> (1) -> (2) -> (3) -> (4) -> output]
(1): nn.Sequential {
[input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
(1): nn.Sequential {
[input -> (1) -> (2) -> (3) -> (4) -> output]
(1): nn.SpatialConvolutionMM(3 -> 32, 3x3)
(2): nn.LeakyReLU(0.1)
(3): nn.SpatialConvolutionMM(32 -> 64, 3x3)
(4): nn.LeakyReLU(0.1)
}
(2): nn.Sequential {
[input -> (1) -> (2) -> output]
(1): nn.ConcatTable {
input
|`-> (1): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| (1): nn.SpatialConvolutionMM(64 -> 64, 2x2, 2,2)
| (2): nn.LeakyReLU(0.1)
| (3): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
| (1): nn.SpatialConvolutionMM(64 -> 128, 3x3)
| (2): nn.LeakyReLU(0.1)
| (3): nn.SpatialConvolutionMM(128 -> 64, 3x3)
| (4): nn.LeakyReLU(0.1)
| (5): nn.ConcatTable {
| input
| |`-> (1): nn.Identity
| `-> (2): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| (1): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> output]
| (1): nn.Mean
| (2): nn.Mean
| (3): nn.View(-1, 64, 1, 1)
| }
| (2): nn.SpatialConvolutionMM(64 -> 8, 1x1)
| (3): nn.ReLU
| (4): nn.SpatialConvolutionMM(8 -> 64, 1x1)
| (5): nn.Sigmoid
| }
| ... -> output
| }
| (6): w2nn.ScaleTable
| }
| (4): nn.SpatialFullConvolution(64 -> 64, 2x2, 2,2)
| (5): nn.LeakyReLU(0.1)
| }
`-> (2): nn.SpatialZeroPadding(l=-4, r=-4, t=-4, b=-4)
... -> output
}
(2): nn.CAddTable
}
(3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
(4): nn.LeakyReLU(0.1)
(5): nn.SpatialConvolutionMM(64 -> 3, 3x3)
}
(2): nn.ConcatTable {
input
|`-> (1): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| (1): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> output]
| (1): nn.SpatialConvolutionMM(3 -> 32, 3x3)
| (2): nn.LeakyReLU(0.1)
| (3): nn.SpatialConvolutionMM(32 -> 64, 3x3)
| (4): nn.LeakyReLU(0.1)
| }
| (2): nn.Sequential {
| [input -> (1) -> (2) -> output]
| (1): nn.ConcatTable {
| input
| |`-> (1): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | (1): nn.SpatialConvolutionMM(64 -> 64, 2x2, 2,2)
| | (2): nn.LeakyReLU(0.1)
| | (3): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> output]
| | (1): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
| | (1): nn.SpatialConvolutionMM(64 -> 64, 3x3)
| | (2): nn.LeakyReLU(0.1)
| | (3): nn.SpatialConvolutionMM(64 -> 128, 3x3)
| | (4): nn.LeakyReLU(0.1)
| | (5): nn.ConcatTable {
| | input
| | |`-> (1): nn.Identity
| | `-> (2): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | (1): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> output]
| | (1): nn.Mean
| | (2): nn.Mean
| | (3): nn.View(-1, 128, 1, 1)
| | }
| | (2): nn.SpatialConvolutionMM(128 -> 16, 1x1)
| | (3): nn.ReLU
| | (4): nn.SpatialConvolutionMM(16 -> 128, 1x1)
| | (5): nn.Sigmoid
| | }
| | ... -> output
| | }
| | (6): w2nn.ScaleTable
| | }
| | (2): nn.Sequential {
| | [input -> (1) -> (2) -> output]
| | (1): nn.ConcatTable {
| | input
| | |`-> (1): nn.Sequential {
| | | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | | (1): nn.SpatialConvolutionMM(128 -> 128, 2x2, 2,2)
| | | (2): nn.LeakyReLU(0.1)
| | | (3): nn.Sequential {
| | | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
| | | (1): nn.SpatialConvolutionMM(128 -> 256, 3x3)
| | | (2): nn.LeakyReLU(0.1)
| | | (3): nn.SpatialConvolutionMM(256 -> 128, 3x3)
| | | (4): nn.LeakyReLU(0.1)
| | | (5): nn.ConcatTable {
| | | input
| | | |`-> (1): nn.Identity
| | | `-> (2): nn.Sequential {
| | | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | | (1): nn.Sequential {
| | | [input -> (1) -> (2) -> (3) -> output]
| | | (1): nn.Mean
| | | (2): nn.Mean
| | | (3): nn.View(-1, 128, 1, 1)
| | | }
| | | (2): nn.SpatialConvolutionMM(128 -> 16, 1x1)
| | | (3): nn.ReLU
| | | (4): nn.SpatialConvolutionMM(16 -> 128, 1x1)
| | | (5): nn.Sigmoid
| | | }
| | | ... -> output
| | | }
| | | (6): w2nn.ScaleTable
| | | }
| | | (4): nn.SpatialFullConvolution(128 -> 128, 2x2, 2,2)
| | | (5): nn.LeakyReLU(0.1)
| | | }
| | `-> (2): nn.SpatialZeroPadding(l=-4, r=-4, t=-4, b=-4)
| | ... -> output
| | }
| | (2): nn.CAddTable
| | }
| | (3): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
| | (1): nn.SpatialConvolutionMM(128 -> 64, 3x3)
| | (2): nn.LeakyReLU(0.1)
| | (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
| | (4): nn.LeakyReLU(0.1)
| | (5): nn.ConcatTable {
| | input
| | |`-> (1): nn.Identity
| | `-> (2): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | (1): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> output]
| | (1): nn.Mean
| | (2): nn.Mean
| | (3): nn.View(-1, 64, 1, 1)
| | }
| | (2): nn.SpatialConvolutionMM(64 -> 8, 1x1)
| | (3): nn.ReLU
| | (4): nn.SpatialConvolutionMM(8 -> 64, 1x1)
| | (5): nn.Sigmoid
| | }
| | ... -> output
| | }
| | (6): w2nn.ScaleTable
| | }
| | }
| | (4): nn.SpatialFullConvolution(64 -> 64, 2x2, 2,2)
| | (5): nn.LeakyReLU(0.1)
| | }
| `-> (2): nn.SpatialZeroPadding(l=-16, r=-16, t=-16, b=-16)
| ... -> output
| }
| (2): nn.CAddTable
| }
| (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
| (4): nn.LeakyReLU(0.1)
| (5): nn.SpatialConvolutionMM(64 -> 3, 3x3)
| }
`-> (2): nn.SpatialZeroPadding(l=-20, r=-20, t=-20, b=-20)
... -> output
}
(3): nn.ConcatTable {
input
|`-> (1): nn.Sequential {
| [input -> (1) -> (2) -> output]
| (1): nn.CAddTable
| (2): w2nn.InplaceClip01
| }
`-> (2): nn.Sequential {
[input -> (1) -> (2) -> output]
(1): nn.SelectTable(2)
(2): w2nn.InplaceClip01
}
... -> output
}
(4): w2nn.AuxiliaryLossTable
}

204
appendix/arch/upcunet.txt Normal file
View file

@ -0,0 +1,204 @@
nn.Sequential {
[input -> (1) -> (2) -> (3) -> (4) -> output]
(1): nn.Sequential {
[input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
(1): nn.Sequential {
[input -> (1) -> (2) -> (3) -> (4) -> output]
(1): nn.SpatialConvolutionMM(3 -> 32, 3x3)
(2): nn.LeakyReLU(0.1)
(3): nn.SpatialConvolutionMM(32 -> 64, 3x3)
(4): nn.LeakyReLU(0.1)
}
(2): nn.Sequential {
[input -> (1) -> (2) -> output]
(1): nn.ConcatTable {
input
|`-> (1): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| (1): nn.SpatialConvolutionMM(64 -> 64, 2x2, 2,2)
| (2): nn.LeakyReLU(0.1)
| (3): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
| (1): nn.SpatialConvolutionMM(64 -> 128, 3x3)
| (2): nn.LeakyReLU(0.1)
| (3): nn.SpatialConvolutionMM(128 -> 64, 3x3)
| (4): nn.LeakyReLU(0.1)
| (5): nn.ConcatTable {
| input
| |`-> (1): nn.Identity
| `-> (2): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| (1): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> output]
| (1): nn.Mean
| (2): nn.Mean
| (3): nn.View(-1, 64, 1, 1)
| }
| (2): nn.SpatialConvolutionMM(64 -> 8, 1x1)
| (3): nn.ReLU
| (4): nn.SpatialConvolutionMM(8 -> 64, 1x1)
| (5): nn.Sigmoid
| }
| ... -> output
| }
| (6): w2nn.ScaleTable
| }
| (4): nn.SpatialFullConvolution(64 -> 64, 2x2, 2,2)
| (5): nn.LeakyReLU(0.1)
| }
`-> (2): nn.SpatialZeroPadding(l=-4, r=-4, t=-4, b=-4)
... -> output
}
(2): nn.CAddTable
}
(3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
(4): nn.LeakyReLU(0.1)
(5): nn.SpatialFullConvolution(64 -> 3, 4x4, 2,2, 3,3)
}
(2): nn.ConcatTable {
input
|`-> (1): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| (1): nn.Sequential {
| [input -> (1) -> (2) -> (3) -> (4) -> output]
| (1): nn.SpatialConvolutionMM(3 -> 32, 3x3)
| (2): nn.LeakyReLU(0.1)
| (3): nn.SpatialConvolutionMM(32 -> 64, 3x3)
| (4): nn.LeakyReLU(0.1)
| }
| (2): nn.Sequential {
| [input -> (1) -> (2) -> output]
| (1): nn.ConcatTable {
| input
| |`-> (1): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | (1): nn.SpatialConvolutionMM(64 -> 64, 2x2, 2,2)
| | (2): nn.LeakyReLU(0.1)
| | (3): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> output]
| | (1): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
| | (1): nn.SpatialConvolutionMM(64 -> 64, 3x3)
| | (2): nn.LeakyReLU(0.1)
| | (3): nn.SpatialConvolutionMM(64 -> 128, 3x3)
| | (4): nn.LeakyReLU(0.1)
| | (5): nn.ConcatTable {
| | input
| | |`-> (1): nn.Identity
| | `-> (2): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | (1): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> output]
| | (1): nn.Mean
| | (2): nn.Mean
| | (3): nn.View(-1, 128, 1, 1)
| | }
| | (2): nn.SpatialConvolutionMM(128 -> 16, 1x1)
| | (3): nn.ReLU
| | (4): nn.SpatialConvolutionMM(16 -> 128, 1x1)
| | (5): nn.Sigmoid
| | }
| | ... -> output
| | }
| | (6): w2nn.ScaleTable
| | }
| | (2): nn.Sequential {
| | [input -> (1) -> (2) -> output]
| | (1): nn.ConcatTable {
| | input
| | |`-> (1): nn.Sequential {
| | | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | | (1): nn.SpatialConvolutionMM(128 -> 128, 2x2, 2,2)
| | | (2): nn.LeakyReLU(0.1)
| | | (3): nn.Sequential {
| | | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
| | | (1): nn.SpatialConvolutionMM(128 -> 256, 3x3)
| | | (2): nn.LeakyReLU(0.1)
| | | (3): nn.SpatialConvolutionMM(256 -> 128, 3x3)
| | | (4): nn.LeakyReLU(0.1)
| | | (5): nn.ConcatTable {
| | | input
| | | |`-> (1): nn.Identity
| | | `-> (2): nn.Sequential {
| | | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | | (1): nn.Sequential {
| | | [input -> (1) -> (2) -> (3) -> output]
| | | (1): nn.Mean
| | | (2): nn.Mean
| | | (3): nn.View(-1, 128, 1, 1)
| | | }
| | | (2): nn.SpatialConvolutionMM(128 -> 16, 1x1)
| | | (3): nn.ReLU
| | | (4): nn.SpatialConvolutionMM(16 -> 128, 1x1)
| | | (5): nn.Sigmoid
| | | }
| | | ... -> output
| | | }
| | | (6): w2nn.ScaleTable
| | | }
| | | (4): nn.SpatialFullConvolution(128 -> 128, 2x2, 2,2)
| | | (5): nn.LeakyReLU(0.1)
| | | }
| | `-> (2): nn.SpatialZeroPadding(l=-4, r=-4, t=-4, b=-4)
| | ... -> output
| | }
| | (2): nn.CAddTable
| | }
| | (3): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
| | (1): nn.SpatialConvolutionMM(128 -> 64, 3x3)
| | (2): nn.LeakyReLU(0.1)
| | (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
| | (4): nn.LeakyReLU(0.1)
| | (5): nn.ConcatTable {
| | input
| | |`-> (1): nn.Identity
| | `-> (2): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
| | (1): nn.Sequential {
| | [input -> (1) -> (2) -> (3) -> output]
| | (1): nn.Mean
| | (2): nn.Mean
| | (3): nn.View(-1, 64, 1, 1)
| | }
| | (2): nn.SpatialConvolutionMM(64 -> 8, 1x1)
| | (3): nn.ReLU
| | (4): nn.SpatialConvolutionMM(8 -> 64, 1x1)
| | (5): nn.Sigmoid
| | }
| | ... -> output
| | }
| | (6): w2nn.ScaleTable
| | }
| | }
| | (4): nn.SpatialFullConvolution(64 -> 64, 2x2, 2,2)
| | (5): nn.LeakyReLU(0.1)
| | }
| `-> (2): nn.SpatialZeroPadding(l=-16, r=-16, t=-16, b=-16)
| ... -> output
| }
| (2): nn.CAddTable
| }
| (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
| (4): nn.LeakyReLU(0.1)
| (5): nn.SpatialConvolutionMM(64 -> 3, 3x3)
| }
`-> (2): nn.SpatialZeroPadding(l=-20, r=-20, t=-20, b=-20)
... -> output
}
(3): nn.ConcatTable {
input
|`-> (1): nn.Sequential {
| [input -> (1) -> (2) -> output]
| (1): nn.CAddTable
| (2): w2nn.InplaceClip01
| }
`-> (2): nn.Sequential {
[input -> (1) -> (2) -> output]
(1): nn.SelectTable(2)
(2): w2nn.InplaceClip01
}
... -> output
}
(4): w2nn.AuxiliaryLossTable
}

8
appendix/cudnn2cunn.sh Executable file
View file

@ -0,0 +1,8 @@
#!/bin/bash
th tools/cudnn2cunn.lua -i models/test/test/upcunet_release/scale2.0x_model.t7 -o models/cunet/art/scale2.0x_model.t7
for i in 0 1 2 3
do
th tools/cudnn2cunn.lua -i models/test/cunet_release/noise${i}_model.t7 -o models/cunet/art/noise${i}_model.t7
th tools/cudnn2cunn.lua -i models/test/cunet_release/noise${i}_scale2.0x_model.t7 -o models/cunet/art/noise${i}_scale2.0x_model.t7
done

View file

@ -1,3 +1,4 @@
-- Random Generated Local Binary Pattern Loss
local LBPCriterion, parent = torch.class('w2nn.LBPCriterion','nn.Criterion')
local function create_filters(ch, n, k, layers)
@ -26,7 +27,7 @@ function LBPCriterion:__init(ch, n, k, layers)
parent.__init(self)
self.layers = layers or 1
self.gamma = 0.1
self.n = n or 32
self.n = n or 128
self.k = k or 3
self.ch = ch
self.filter1 = create_filters(self.ch, self.n, self.k, self.layers)

View file

@ -550,6 +550,7 @@ end
-- Cascaded Residual U-Net with SEBlock
-- unet utils adapted from https://gist.github.com/toshi-k/ca75e614f1ac12fa44f62014ac1d6465
local function unet_conv(backend, n_input, n_middle, n_output, se)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))

View file

@ -6,18 +6,29 @@ require 'w2nn'
local srcnn = require 'srcnn'
local function cudnn2cunn(cudnn_model)
local cunn_model = srcnn.waifu2x_cunn(srcnn.channels(cudnn_model))
local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution")
local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM")
local name = srcnn.name(cudnn_model)
local cunn_model = srcnn[name]('cunn', srcnn.channels(cudnn_model))
local param_layers = {
{cunn="nn.SpatialConvolutionMM", cudnn="cudnn.SpatialConvolution", attr={"bias", "weight"}},
{cunn="nn.SpatialDilatedConvolution", cudnn="cudnn.SpatialDilatedConvolution", attr={"bias", "weight"}},
{cunn="nn.SpatialFullConvolution", cudnn="cudnn.SpatialFullConvolution", attr={"bias", "weight"}},
{cunn="nn.Linear", cudnn="nn.Linear", attr={"bias", "weight"}}
}
for i = 1, #param_layers do
local p = param_layers[i]
local weight_from = cudnn_model:findModules(p.cudnn)
local weight_to = cunn_model:findModules(p.cunn)
print(p.cudnn, #weight_from)
assert(#weight_from == #weight_to)
assert(#weight_from == #weight_to)
for i = 1, #weight_from do
local from = weight_from[i]
local to = weight_to[i]
to.weight:copy(from.weight)
to.bias:copy(from.bias)
for i = 1, #weight_from do
local from = weight_from[i]
local to = weight_to[i]
to.weight:copy(from.weight)
if to.bias then
to.bias:copy(from.bias)
end
end
end
cunn_model:cuda()
cunn_model:evaluate()

View file

@ -0,0 +1,58 @@
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'xlua'
local iproc = require 'iproc'
local image_loader = require 'image_loader'
local gm = require 'graphicsmagick'
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x-make benchmark data")
cmd:text("Options:")
cmd:option("-i", "./data/test", 'input dir')
cmd:option("-lr", "hr", 'highres output dir')
cmd:option("-hr", "lr", 'lowres output dir')
cmd:option("-filter", "Sinc", 'dowsampling filter')
local opt = cmd:parse(arg)
torch.setdefaulttensortype('torch.FloatTensor')
local function transform_scale(x, opt)
return iproc.scale(x,
x:size(3) * 0.5,
x:size(2) * 0.5,
opt.filter, 1)
end
local function load_data_from_dir(test_dir)
local test_x = {}
local files = dir.getfiles(test_dir, "*.*")
for i = 1, #files do
local name = path.basename(files[i])
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
local img = image_loader.load_byte(files[i])
if img then
table.insert(test_x, {y = iproc.crop_mod4(img),
basename = base})
end
if i % 10 == 0 then
if opt.show_progress then
xlua.progress(i, #files)
end
collectgarbage()
end
end
return test_x
end
dir.makepath(opt.lr)
dir.makepath(opt.hr)
local files = load_data_from_dir(opt.i)
for i = 1, #files do
local y = files[i].y
local x = transform_scale(y, opt)
local hr_path = path.join(opt.hr, files[i].basename .. ".png")
local lr_path = path.join(opt.lr, files[i].basename .. ".png")
image.save(hr_path, y)
image.save(lr_path, x)
end

View file

@ -0,0 +1,52 @@
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'os'
require 'w2nn'
local srcnn = require 'srcnn'
local function find_aux(seq)
for k = 1, #seq.modules do
local mod = seq.modules[k]
local name = torch.typename(mod)
if name == "nn.Sequential" or name == "nn.ConcatTable" then
local aux = find_aux(mod)
if aux ~= nil then
return aux
end
elseif name == "w2nn.AuxiliaryLossTable" then
return mod
end
end
return nil
end
local cmd = torch.CmdLine()
cmd:text()
cmd:text("switch the output pass of auxiliary loss")
cmd:text("Options:")
cmd:option("-j", 1, 'Specify the output path index (1|2)')
cmd:option("-i", "", 'Specify the input model')
cmd:option("-o", "", 'Specify the output model')
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
local opt = cmd:parse(arg)
if not path.isfile(opt.i) then
cmd:help()
os.exit(-1)
end
local model = torch.load(opt.i, opt.iformat)
if model == nil then
print("load error")
os.exit(-1)
end
local aux = find_aux(model)
if aux == nil then
print("AuxiliaryLossTable not found")
else
print(aux)
aux.i = opt.j
torch.save(opt.o, model, opt.oformat)
end