1
0
Fork 0
mirror of synced 2024-06-02 11:04:31 +12:00
waifu2x/lib/RandomBinaryConvolution.lua
2018-10-14 01:18:32 +09:00

29 lines
1.1 KiB
Lua

-- RandomBinaryConvolution.lua
-- from https://github.com/juefeix/lbcnn.torch
local THNN = require 'nn.THNN'
local RandomBinaryConvolution, parent = torch.class('w2nn.RandomBinaryConvolution', 'nn.SpatialConvolution')
function RandomBinaryConvolution:__init(nInputPlane, nOutputPlane, kW, kH, kSparsity)
self.kSparsity = kSparsity or 0.9
parent.__init(self, nInputPlane, nOutputPlane, kW, kH, 1, 1, 0, 0)
self:reset()
end
function RandomBinaryConvolution:reset()
local numElements = self.nInputPlane*self.nOutputPlane*self.kW*self.kH
self.weight:fill(0)
self.weight = torch.reshape(self.weight,numElements)
local index = torch.Tensor(torch.floor(self.kSparsity*numElements)):random(numElements)
for i = 1, index:numel() do
self.weight[index[i]] = torch.bernoulli(0.5)*2-1
end
self.weight = torch.reshape(self.weight,self.nOutputPlane,self.nInputPlane,self.kW,self.kH)
self.bias = nil
self.gradBias = nil
self.gradWeight:fill(0)
end
function RandomBinaryConvolution:accGradParameters(input, gradOutput, scale)
end
function RandomBinaryConvolution:updateParameters(learningRate)
end