2015-05-16 17:48:05 +12:00
|
|
|
require './LeakyReLU'
|
|
|
|
|
2015-07-11 17:52:51 +12:00
|
|
|
-- ref: http://arxiv.org/abs/1502.01852
|
2015-06-13 18:02:02 +12:00
|
|
|
function nn.SpatialConvolutionMM:reset(stdv)
|
2015-07-11 17:52:51 +12:00
|
|
|
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
|
2015-05-16 17:48:05 +12:00
|
|
|
self.weight:normal(0, stdv)
|
2015-07-11 17:52:51 +12:00
|
|
|
self.bias:zero()
|
2015-05-16 17:48:05 +12:00
|
|
|
end
|
2015-07-11 17:52:51 +12:00
|
|
|
|
|
|
|
-- ref: http://arxiv.org/abs/1501.00092
|
2015-06-13 18:02:02 +12:00
|
|
|
local srcnn = {}
|
2015-06-23 05:27:28 +12:00
|
|
|
function srcnn.waifu2x(color)
|
2015-06-13 18:02:02 +12:00
|
|
|
local model = nn.Sequential()
|
2015-06-23 05:27:28 +12:00
|
|
|
local ch = nil
|
|
|
|
if color == "rgb" then
|
|
|
|
ch = 3
|
|
|
|
elseif color == "y" then
|
|
|
|
ch = 1
|
|
|
|
else
|
|
|
|
if color then
|
|
|
|
error("unknown color: " .. color)
|
|
|
|
else
|
|
|
|
error("unknown color: nil")
|
|
|
|
end
|
|
|
|
end
|
2015-07-11 17:52:51 +12:00
|
|
|
-- very deep model
|
2015-06-23 05:27:28 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
|
2015-05-16 17:48:05 +12:00
|
|
|
model:add(nn.LeakyReLU(0.1))
|
2015-06-13 18:02:02 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
2015-05-16 17:48:05 +12:00
|
|
|
model:add(nn.LeakyReLU(0.1))
|
2015-06-13 18:02:02 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
|
2015-05-16 17:48:05 +12:00
|
|
|
model:add(nn.LeakyReLU(0.1))
|
2015-06-13 18:02:02 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
|
2015-05-16 17:48:05 +12:00
|
|
|
model:add(nn.LeakyReLU(0.1))
|
2015-06-13 18:02:02 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
|
2015-05-16 17:48:05 +12:00
|
|
|
model:add(nn.LeakyReLU(0.1))
|
2015-06-13 18:02:02 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
|
|
|
model:add(nn.LeakyReLU(0.1))
|
2015-06-23 05:27:28 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
|
2015-05-16 17:48:05 +12:00
|
|
|
model:add(nn.View(-1):setNumInputDims(3))
|
|
|
|
--model:cuda()
|
|
|
|
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
|
|
|
|
|
|
|
|
return model, 7
|
|
|
|
end
|
2015-06-13 18:02:02 +12:00
|
|
|
|
|
|
|
-- current 4x is worse then 2x * 2
|
2015-06-23 05:27:28 +12:00
|
|
|
function srcnn.waifu4x(color)
|
2015-06-13 18:02:02 +12:00
|
|
|
local model = nn.Sequential()
|
2015-06-23 05:27:28 +12:00
|
|
|
|
|
|
|
local ch = nil
|
|
|
|
if color == "rgb" then
|
|
|
|
ch = 3
|
|
|
|
elseif color == "y" then
|
|
|
|
ch = 1
|
|
|
|
else
|
|
|
|
error("unknown color: " .. color)
|
|
|
|
end
|
2015-06-13 18:02:02 +12:00
|
|
|
|
2015-06-23 05:27:28 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
|
2015-06-13 18:02:02 +12:00
|
|
|
model:add(nn.LeakyReLU(0.1))
|
|
|
|
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
|
|
|
model:add(nn.LeakyReLU(0.1))
|
|
|
|
model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
|
|
|
|
model:add(nn.LeakyReLU(0.1))
|
|
|
|
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
|
|
|
|
model:add(nn.LeakyReLU(0.1))
|
|
|
|
model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
|
|
|
|
model:add(nn.LeakyReLU(0.1))
|
|
|
|
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
|
|
|
model:add(nn.LeakyReLU(0.1))
|
2015-06-23 05:27:28 +12:00
|
|
|
model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
|
2015-06-13 18:02:02 +12:00
|
|
|
model:add(nn.View(-1):setNumInputDims(3))
|
|
|
|
|
|
|
|
return model, 13
|
|
|
|
end
|
|
|
|
return srcnn
|