1
0
Fork 0
mirror of synced 2024-06-02 19:14:30 +12:00

More clearState for nn.SpatialConvolutionMM

This commit is contained in:
nagadomi 2016-03-12 07:23:42 +09:00
parent 223dcead67
commit 4a1629d046

View file

@ -17,6 +17,16 @@ if cudnn and cudnn.SpatialConvolution then
end
end
function nn.SpatialConvolutionMM:clearState()
if self.gradWeight then
self.gradWeight = torch.Tensor(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):typeAs(self.gradWeight):zero()
end
if self.gradBias then
self.gradBias = torch.Tensor(self.nOutputPlane):typeAs(self.gradBias):zero()
end
return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
end
function srcnn.channels(model)
return model:get(model:size() - 1).weight:size(1)
end