lbp loss: multilayer support
This commit is contained in:
parent
804896276e
commit
b3fb286258
|
@ -1,21 +1,35 @@
|
|||
local RandomBinaryCriterion, parent = torch.class('w2nn.RandomBinaryCriterion','nn.Criterion')
|
||||
|
||||
local function create_filters(ch, n, k)
|
||||
local filter = w2nn.RandomBinaryConvolution(ch, n, k, k)
|
||||
-- channel identify
|
||||
for i = 1, ch do
|
||||
filter.weight[i]:fill(0)
|
||||
filter.weight[i][i][math.floor(k/2)+1][math.floor(k/2)+1] = 1
|
||||
local function create_filters(ch, n, k, layers)
|
||||
local model = nn.Sequential()
|
||||
for i = 1, layers do
|
||||
local n_input = ch
|
||||
if i > 1 then
|
||||
n_input = n
|
||||
end
|
||||
local filter = w2nn.RandomBinaryConvolution(n_input, n, k, k)
|
||||
if i == 1 then
|
||||
-- channel identity
|
||||
for j = 1, ch do
|
||||
filter.weight[i]:fill(0)
|
||||
filter.weight[i][i][math.floor(k/2)+1][math.floor(k/2)+1] = 1
|
||||
end
|
||||
end
|
||||
model:add(filter)
|
||||
--if layers > 1 and i ~= layers then
|
||||
-- model:add(nn.Sigmoid(true))
|
||||
--end
|
||||
end
|
||||
return filter
|
||||
return model
|
||||
end
|
||||
function RandomBinaryCriterion:__init(ch, n, k)
|
||||
function RandomBinaryCriterion:__init(ch, n, k, layers)
|
||||
parent.__init(self)
|
||||
self.layers = layers or 1
|
||||
self.gamma = 0.1
|
||||
self.n = n or 32
|
||||
self.k = k or 3
|
||||
self.ch = ch
|
||||
self.filter1 = create_filters(self.ch, self.n, self.k)
|
||||
self.filter1 = create_filters(self.ch, self.n, self.k, self.layers)
|
||||
self.filter2 = self.filter1:clone()
|
||||
self.diff = torch.Tensor()
|
||||
self.diff_abs = torch.Tensor()
|
||||
|
|
12
train.lua
12
train.lua
|
@ -394,12 +394,24 @@ local function create_criterion(model)
|
|||
else
|
||||
return w2nn.RandomBinaryCriterion(1, 128):cuda()
|
||||
end
|
||||
elseif settings.loss == "lbp2" then
|
||||
if reconstruct.is_rgb(model) then
|
||||
return w2nn.RandomBinaryCriterion(3, 128, 3, 2):cuda()
|
||||
else
|
||||
return w2nn.RandomBinaryCriterion(1, 128, 3, 2):cuda()
|
||||
end
|
||||
elseif settings.loss == "aux_lbp" then
|
||||
if reconstruct.is_rgb(model) then
|
||||
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128}):cuda()
|
||||
else
|
||||
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128}):cuda()
|
||||
end
|
||||
elseif settings.loss == "aux_lbp2" then
|
||||
if reconstruct.is_rgb(model) then
|
||||
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128, 3, 2}):cuda()
|
||||
else
|
||||
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128, 3, 2}):cuda()
|
||||
end
|
||||
else
|
||||
error("unsupported loss .." .. settings.loss)
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue