Add L1 criterion. Change the criterion of updating model
This commit is contained in:
parent
3c28debad9
commit
d2cfb8f104
27
lib/L1Criterion.lua
Normal file
27
lib/L1Criterion.lua
Normal file
|
@ -0,0 +1,27 @@
|
|||
-- ref: https://en.wikipedia.org/wiki/L1_loss
|
||||
local L1Criterion, parent = torch.class('w2nn.L1Criterion','nn.Criterion')
|
||||
|
||||
function L1Criterion:__init()
|
||||
parent.__init(self)
|
||||
self.diff = torch.Tensor()
|
||||
self.linear_loss_buff = torch.Tensor()
|
||||
end
|
||||
function L1Criterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
if input:dim() == 1 then
|
||||
self.diff[1] = input[1] - target
|
||||
else
|
||||
for i = 1, input:size(1) do
|
||||
self.diff[i]:add(-1, target[i])
|
||||
end
|
||||
end
|
||||
local linear_targets = self.diff
|
||||
local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():sum()
|
||||
self.output = (linear_loss) / input:nElement()
|
||||
return self.output
|
||||
end
|
||||
function L1Criterion:updateGradInput(input, target)
|
||||
local norm = 1.0 / input:nElement()
|
||||
self.gradInput:resizeAs(self.diff):copy(self.diff):sign():mul(norm)
|
||||
return self.gradInput
|
||||
end
|
|
@ -75,6 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate *
|
|||
cmd:option("-resume", "", 'resume model file')
|
||||
cmd:option("-name", "user", 'model name for user method')
|
||||
cmd:option("-gpu", 1, 'Device ID')
|
||||
cmd:option("-loss", "huber", 'loss function (huber|l1)')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
|
|
@ -32,5 +32,6 @@ else
|
|||
require 'ClippedMSECriterion'
|
||||
require 'SSIMCriterion'
|
||||
require 'InplaceClip01'
|
||||
require 'L1Criterion'
|
||||
return w2nn
|
||||
end
|
||||
|
|
12
train.lua
12
train.lua
|
@ -298,6 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|||
end
|
||||
|
||||
local function create_criterion(model)
|
||||
if settings.loss == "huber" then
|
||||
if reconstruct.is_rgb(model) then
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
|
@ -313,6 +314,11 @@ local function create_criterion(model)
|
|||
weight[1]:fill(1.0)
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
end
|
||||
elseif settings.loss == "l1" then
|
||||
return w2nn.L1Criterion():cuda()
|
||||
else
|
||||
error("unsupported loss .." .. settings.loss)
|
||||
end
|
||||
end
|
||||
|
||||
local function resampling(x, y, train_x)
|
||||
|
@ -518,9 +524,9 @@ local function train()
|
|||
if settings.plot then
|
||||
plot(hist_train, hist_valid)
|
||||
end
|
||||
if score.MSE < best_score then
|
||||
if score.loss < best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
best_score = score.MSE
|
||||
best_score = score.loss
|
||||
print("* model has updated")
|
||||
if settings.save_history then
|
||||
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
||||
|
@ -569,7 +575,7 @@ local function train()
|
|||
end
|
||||
end
|
||||
end
|
||||
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
|
||||
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue