diff --git a/lib/Print.lua b/lib/Print.lua index 83a9ffd..c1f3ae2 100644 --- a/lib/Print.lua +++ b/lib/Print.lua @@ -1,9 +1,11 @@ local Print, parent = torch.class('w2nn.Print','nn.Module') -function Print:__init() +function Print:__init(id) parent.__init(self) + self.id = id end function Print:updateOutput(input) + print("----", self.id) print(input:size()) self.output:resizeAs(input) self.output:copy(input) diff --git a/train.lua b/train.lua index 2fe9091..641b06c 100644 --- a/train.lua +++ b/train.lua @@ -369,6 +369,31 @@ local function create_criterion(model) local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion) aux.sizeAverage = true return aux:cuda() + elseif settings.loss == "aux_huber" then + local args = {} + if reconstruct.is_rgb(model) then + local offset = reconstruct.offset_size(model) + local output_w = settings.crop_size - offset * 2 + local weight = torch.Tensor(3, output_w * output_w) + weight[1]:fill(0.29891 * 3) -- R + weight[2]:fill(0.58661 * 3) -- G + weight[3]:fill(0.11448 * 3) -- B + args = {weight, 0.1, {0.0, 1.0}} + else + local offset = reconstruct.offset_size(model) + local output_w = settings.crop_size - offset * 2 + local weight = torch.Tensor(1, output_w * output_w) + weight[1]:fill(1.0) + args = {weight, 0.1, {0.0, 1.0}} + end + local aux = w2nn.AuxiliaryLossCriterion(w2nn.ClippedWeightedHuberCriterion, args) + return aux:cuda() + elseif settings.loss == "lbp" then + if reconstruct.is_rgb(model) then + return w2nn.RandomBinaryCriterion(3, 512):cuda() + else + return w2nn.RandomBinaryCriterion(1, 512):cuda() + end else error("unsupported loss .." .. settings.loss) end @@ -506,9 +531,9 @@ local function train() local criterion = create_criterion(model) local eval_metric = nil if settings.loss:find("aux_") ~= nil then - eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion) + eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion):cuda() else - eval_metric = w2nn.ClippedMSECriterion() + eval_metric = w2nn.ClippedMSECriterion():cuda() end local adam_config = { xLearningRate = settings.learning_rate,