1
0
Fork 0
mirror of synced 2024-05-16 10:52:20 +12:00

lbp loss: multilayer support

This commit is contained in:
nagadomi 2018-11-01 00:27:11 +09:00
parent 804896276e
commit b3fb286258
2 changed files with 35 additions and 9 deletions

View file

@ -1,21 +1,35 @@
local RandomBinaryCriterion, parent = torch.class('w2nn.RandomBinaryCriterion','nn.Criterion')
local function create_filters(ch, n, k)
local filter = w2nn.RandomBinaryConvolution(ch, n, k, k)
-- channel identify
for i = 1, ch do
filter.weight[i]:fill(0)
filter.weight[i][i][math.floor(k/2)+1][math.floor(k/2)+1] = 1
local function create_filters(ch, n, k, layers)
local model = nn.Sequential()
for i = 1, layers do
local n_input = ch
if i > 1 then
n_input = n
end
local filter = w2nn.RandomBinaryConvolution(n_input, n, k, k)
if i == 1 then
-- channel identity
for j = 1, ch do
filter.weight[i]:fill(0)
filter.weight[i][i][math.floor(k/2)+1][math.floor(k/2)+1] = 1
end
end
model:add(filter)
--if layers > 1 and i ~= layers then
-- model:add(nn.Sigmoid(true))
--end
end
return filter
return model
end
function RandomBinaryCriterion:__init(ch, n, k)
function RandomBinaryCriterion:__init(ch, n, k, layers)
parent.__init(self)
self.layers = layers or 1
self.gamma = 0.1
self.n = n or 32
self.k = k or 3
self.ch = ch
self.filter1 = create_filters(self.ch, self.n, self.k)
self.filter1 = create_filters(self.ch, self.n, self.k, self.layers)
self.filter2 = self.filter1:clone()
self.diff = torch.Tensor()
self.diff_abs = torch.Tensor()

View file

@ -394,12 +394,24 @@ local function create_criterion(model)
else
return w2nn.RandomBinaryCriterion(1, 128):cuda()
end
elseif settings.loss == "lbp2" then
if reconstruct.is_rgb(model) then
return w2nn.RandomBinaryCriterion(3, 128, 3, 2):cuda()
else
return w2nn.RandomBinaryCriterion(1, 128, 3, 2):cuda()
end
elseif settings.loss == "aux_lbp" then
if reconstruct.is_rgb(model) then
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128}):cuda()
else
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128}):cuda()
end
elseif settings.loss == "aux_lbp2" then
if reconstruct.is_rgb(model) then
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128, 3, 2}):cuda()
else
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128, 3, 2}):cuda()
end
else
error("unsupported loss .." .. settings.loss)
end