diff --git a/lib/L1Criterion.lua b/lib/L1Criterion.lua new file mode 100644 index 0000000..c706ece --- /dev/null +++ b/lib/L1Criterion.lua @@ -0,0 +1,27 @@ +-- ref: https://en.wikipedia.org/wiki/L1_loss +local L1Criterion, parent = torch.class('w2nn.L1Criterion','nn.Criterion') + +function L1Criterion:__init() + parent.__init(self) + self.diff = torch.Tensor() + self.linear_loss_buff = torch.Tensor() +end +function L1Criterion:updateOutput(input, target) + self.diff:resizeAs(input):copy(input) + if input:dim() == 1 then + self.diff[1] = input[1] - target + else + for i = 1, input:size(1) do + self.diff[i]:add(-1, target[i]) + end + end + local linear_targets = self.diff + local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():sum() + self.output = (linear_loss) / input:nElement() + return self.output +end +function L1Criterion:updateGradInput(input, target) + local norm = 1.0 / input:nElement() + self.gradInput:resizeAs(self.diff):copy(self.diff):sign():mul(norm) + return self.gradInput +end diff --git a/lib/settings.lua b/lib/settings.lua index 96ec7b7..ee3dcb3 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -75,6 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') cmd:option("-gpu", 1, 'Device ID') +cmd:option("-loss", "huber", 'loss function (huber|l1)') local function to_bool(settings, name) if settings[name] == 1 then diff --git a/lib/w2nn.lua b/lib/w2nn.lua index 6d047d7..a7f93e4 100644 --- a/lib/w2nn.lua +++ b/lib/w2nn.lua @@ -32,5 +32,6 @@ else require 'ClippedMSECriterion' require 'SSIMCriterion' require 'InplaceClip01' + require 'L1Criterion' return w2nn end diff --git a/train.lua b/train.lua index 43777b9..cf6850e 100644 --- a/train.lua +++ b/train.lua @@ -298,20 +298,26 @@ local function validate(model, criterion, eval_metric, data, batch_size) end local function create_criterion(model) - if reconstruct.is_rgb(model) then - local offset = reconstruct.offset_size(model) - local output_w = settings.crop_size - offset * 2 - local weight = torch.Tensor(3, output_w * output_w) - weight[1]:fill(0.29891 * 3) -- R - weight[2]:fill(0.58661 * 3) -- G - weight[3]:fill(0.11448 * 3) -- B - return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() + if settings.loss == "huber" then + if reconstruct.is_rgb(model) then + local offset = reconstruct.offset_size(model) + local output_w = settings.crop_size - offset * 2 + local weight = torch.Tensor(3, output_w * output_w) + weight[1]:fill(0.29891 * 3) -- R + weight[2]:fill(0.58661 * 3) -- G + weight[3]:fill(0.11448 * 3) -- B + return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() + else + local offset = reconstruct.offset_size(model) + local output_w = settings.crop_size - offset * 2 + local weight = torch.Tensor(1, output_w * output_w) + weight[1]:fill(1.0) + return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() + end + elseif settings.loss == "l1" then + return w2nn.L1Criterion():cuda() else - local offset = reconstruct.offset_size(model) - local output_w = settings.crop_size - offset * 2 - local weight = torch.Tensor(1, output_w * output_w) - weight[1]:fill(1.0) - return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() + error("unsupported loss .." .. settings.loss) end end @@ -518,9 +524,9 @@ local function train() if settings.plot then plot(hist_train, hist_valid) end - if score.MSE < best_score then + if score.loss < best_score then local test_image = image_loader.load_float(settings.test) -- reload - best_score = score.MSE + best_score = score.loss print("* model has updated") if settings.save_history then torch.save(settings.model_file_best, model:clearState(), "ascii") @@ -569,7 +575,7 @@ local function train() end end end - print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score) + print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE) collectgarbage() end end