1
0
Fork 0
mirror of synced 2024-05-16 10:52:20 +12:00

Add ShakeShakeTable

This commit is contained in:
nagadomi 2017-04-13 17:35:32 +09:00
parent f54dd37848
commit 88e3322296
2 changed files with 43 additions and 0 deletions

42
lib/ShakeShakeTable.lua Normal file
View file

@ -0,0 +1,42 @@
local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')
function ShakeShakeTable:__init()
parent.__init(self)
self.alpha = torch.Tensor()
self.beta = torch.Tensor()
self.first = torch.Tensor()
self.second = torch.Tensor()
self.train = true
end
function ShakeShakeTable:updateOutput(input)
local batch_size = input[1]:size(1)
if self.train then
self.alpha:resize(batch_size):uniform()
self.beta:resize(batch_size):uniform()
self.second:resizeAs(input[1]):copy(input[2])
for i = 1, batch_size do
self.second[i]:mul(self.alpha[i])
end
self.output:resizeAs(input[1]):copy(input[1])
for i = 1, batch_size do
self.output[i]:mul(1.0 - self.alpha[i])
end
self.output:add(self.second):mul(2)
else
self.output:resizeAs(input[1]):copy(input[1]):add(input[2])
end
return self.output
end
function ShakeShakeTable:updateGradInput(input, gradOutput)
local batch_size = input[1]:size(1)
self.first:resizeAs(gradOutput):copy(gradOutput)
for i = 1, batch_size do
self.first[i]:mul(self.beta[i])
end
self.second:resizeAs(gradOutput):copy(gradOutput)
for i = 1, batch_size do
self.second[i]:mul(1.0 - self.beta[i])
end
self.gradOutput = {self.first, self.second}
return self.gradOutput
end

View file

@ -74,5 +74,6 @@ else
require 'SSIMCriterion'
require 'InplaceClip01'
require 'L1Criterion'
require 'ShakeShakeTable'
return w2nn
end