1
0
Fork 0
mirror of synced 2024-05-16 19:02:21 +12:00

aux_huber and lbp loss

This commit is contained in:
nagadomi 2018-10-14 01:20:03 +09:00
parent 5365890fa8
commit f317545732
2 changed files with 30 additions and 3 deletions

View file

@ -1,9 +1,11 @@
local Print, parent = torch.class('w2nn.Print','nn.Module')
function Print:__init()
function Print:__init(id)
parent.__init(self)
self.id = id
end
function Print:updateOutput(input)
print("----", self.id)
print(input:size())
self.output:resizeAs(input)
self.output:copy(input)

View file

@ -369,6 +369,31 @@ local function create_criterion(model)
local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion)
aux.sizeAverage = true
return aux:cuda()
elseif settings.loss == "aux_huber" then
local args = {}
if reconstruct.is_rgb(model) then
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
args = {weight, 0.1, {0.0, 1.0}}
else
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(1, output_w * output_w)
weight[1]:fill(1.0)
args = {weight, 0.1, {0.0, 1.0}}
end
local aux = w2nn.AuxiliaryLossCriterion(w2nn.ClippedWeightedHuberCriterion, args)
return aux:cuda()
elseif settings.loss == "lbp" then
if reconstruct.is_rgb(model) then
return w2nn.RandomBinaryCriterion(3, 512):cuda()
else
return w2nn.RandomBinaryCriterion(1, 512):cuda()
end
else
error("unsupported loss .." .. settings.loss)
end
@ -506,9 +531,9 @@ local function train()
local criterion = create_criterion(model)
local eval_metric = nil
if settings.loss:find("aux_") ~= nil then
eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion)
eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion):cuda()
else
eval_metric = w2nn.ClippedMSECriterion()
eval_metric = w2nn.ClippedMSECriterion():cuda()
end
local adam_config = {
xLearningRate = settings.learning_rate,