1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Add L1 criterion. Change the criterion of updating model

This commit is contained in:
nagadomi 2016-12-05 10:32:26 +09:00
parent 3c28debad9
commit d2cfb8f104
4 changed files with 51 additions and 16 deletions

27
lib/L1Criterion.lua Normal file
View file

@ -0,0 +1,27 @@
-- ref: https://en.wikipedia.org/wiki/L1_loss
local L1Criterion, parent = torch.class('w2nn.L1Criterion','nn.Criterion')
function L1Criterion:__init()
parent.__init(self)
self.diff = torch.Tensor()
self.linear_loss_buff = torch.Tensor()
end
function L1Criterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
if input:dim() == 1 then
self.diff[1] = input[1] - target
else
for i = 1, input:size(1) do
self.diff[i]:add(-1, target[i])
end
end
local linear_targets = self.diff
local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():sum()
self.output = (linear_loss) / input:nElement()
return self.output
end
function L1Criterion:updateGradInput(input, target)
local norm = 1.0 / input:nElement()
self.gradInput:resizeAs(self.diff):copy(self.diff):sign():mul(norm)
return self.gradInput
end

View file

@ -75,6 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate *
cmd:option("-resume", "", 'resume model file')
cmd:option("-name", "user", 'model name for user method')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-loss", "huber", 'loss function (huber|l1)')
local function to_bool(settings, name)
if settings[name] == 1 then

View file

@ -32,5 +32,6 @@ else
require 'ClippedMSECriterion'
require 'SSIMCriterion'
require 'InplaceClip01'
require 'L1Criterion'
return w2nn
end

View file

@ -298,20 +298,26 @@ local function validate(model, criterion, eval_metric, data, batch_size)
end
local function create_criterion(model)
if reconstruct.is_rgb(model) then
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
if settings.loss == "huber" then
if reconstruct.is_rgb(model) then
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
else
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(1, output_w * output_w)
weight[1]:fill(1.0)
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
end
elseif settings.loss == "l1" then
return w2nn.L1Criterion():cuda()
else
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(1, output_w * output_w)
weight[1]:fill(1.0)
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
error("unsupported loss .." .. settings.loss)
end
end
@ -518,9 +524,9 @@ local function train()
if settings.plot then
plot(hist_train, hist_valid)
end
if score.MSE < best_score then
if score.loss < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
best_score = score.MSE
best_score = score.loss
print("* model has updated")
if settings.save_history then
torch.save(settings.model_file_best, model:clearState(), "ascii")
@ -569,7 +575,7 @@ local function train()
end
end
end
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
collectgarbage()
end
end