aux_huber and lbp loss
This commit is contained in:
parent
5365890fa8
commit
f317545732
|
@ -1,9 +1,11 @@
|
|||
local Print, parent = torch.class('w2nn.Print','nn.Module')
|
||||
|
||||
function Print:__init()
|
||||
function Print:__init(id)
|
||||
parent.__init(self)
|
||||
self.id = id
|
||||
end
|
||||
function Print:updateOutput(input)
|
||||
print("----", self.id)
|
||||
print(input:size())
|
||||
self.output:resizeAs(input)
|
||||
self.output:copy(input)
|
||||
|
|
29
train.lua
29
train.lua
|
@ -369,6 +369,31 @@ local function create_criterion(model)
|
|||
local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion)
|
||||
aux.sizeAverage = true
|
||||
return aux:cuda()
|
||||
elseif settings.loss == "aux_huber" then
|
||||
local args = {}
|
||||
if reconstruct.is_rgb(model) then
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(3, output_w * output_w)
|
||||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
args = {weight, 0.1, {0.0, 1.0}}
|
||||
else
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(1, output_w * output_w)
|
||||
weight[1]:fill(1.0)
|
||||
args = {weight, 0.1, {0.0, 1.0}}
|
||||
end
|
||||
local aux = w2nn.AuxiliaryLossCriterion(w2nn.ClippedWeightedHuberCriterion, args)
|
||||
return aux:cuda()
|
||||
elseif settings.loss == "lbp" then
|
||||
if reconstruct.is_rgb(model) then
|
||||
return w2nn.RandomBinaryCriterion(3, 512):cuda()
|
||||
else
|
||||
return w2nn.RandomBinaryCriterion(1, 512):cuda()
|
||||
end
|
||||
else
|
||||
error("unsupported loss .." .. settings.loss)
|
||||
end
|
||||
|
@ -506,9 +531,9 @@ local function train()
|
|||
local criterion = create_criterion(model)
|
||||
local eval_metric = nil
|
||||
if settings.loss:find("aux_") ~= nil then
|
||||
eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion)
|
||||
eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion):cuda()
|
||||
else
|
||||
eval_metric = w2nn.ClippedMSECriterion()
|
||||
eval_metric = w2nn.ClippedMSECriterion():cuda()
|
||||
end
|
||||
local adam_config = {
|
||||
xLearningRate = settings.learning_rate,
|
||||
|
|
Loading…
Reference in a new issue