1
0
Fork 0
mirror of synced 2024-06-02 11:04:31 +12:00

Use PSNR for evaluation

This commit is contained in:
nagadomi 2016-03-12 06:53:42 +09:00
parent 41581a0d55
commit 1900ac7500
4 changed files with 29 additions and 7 deletions

19
lib/PSNRCriterion.lua Normal file
View file

@ -0,0 +1,19 @@
local PSNRCriterion, parent = torch.class('w2nn.PSNRCriterion','nn.Criterion')
function PSNRCriterion:__init()
parent.__init(self)
self.image = torch.Tensor()
self.diff = torch.Tensor()
end
function PSNRCriterion:updateOutput(input, target)
self.image:resizeAs(input):copy(input)
self.image:clamp(0.0, 1.0)
self.diff:resizeAs(self.image):copy(self.image)
local mse = self.diff:add(-1, target):pow(2):mean()
self.output = 10 * math.log10(1.0 / mse)
return self.output
end
function PSNRCriterion:updateGradInput(input, target)
error("PSNRCriterion does not support backward")
end

View file

@ -2,12 +2,13 @@ require 'optim'
require 'cutorch'
require 'xlua'
local function minibatch_adam(model, criterion,
local function minibatch_adam(model, criterion, eval_metric,
train_x, train_y,
config)
local parameters, gradParameters = model:getParameters()
config = config or {}
local sum_loss = 0
local sum_eval = 0
local count_loss = 0
local batch_size = config.xBatchSize or 32
local shuffle = torch.randperm(train_x:size(1))
@ -39,6 +40,7 @@ local function minibatch_adam(model, criterion,
gradParameters:zero()
local output = model:forward(inputs)
local f = criterion:forward(output, targets)
sum_eval = sum_eval + eval_metric:forward(output, targets)
sum_loss = sum_loss + f
count_loss = count_loss + 1
model:backward(inputs, criterion:backward(output, targets))
@ -52,7 +54,7 @@ local function minibatch_adam(model, criterion,
end
xlua.progress(train_x:size(1), train_x:size(1))
return { loss = sum_loss / count_loss}
return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
end
return minibatch_adam

View file

@ -19,7 +19,7 @@ else
require 'LeakyReLU'
require 'LeakyReLU_deprecated'
require 'DepthExpand2x'
require 'WeightedMSECriterion'
require 'PSNRCriterion'
require 'ClippedWeightedHuberCriterion'
return w2nn
end

View file

@ -166,6 +166,7 @@ local function train()
return transformer(x, is_validation, n, offset)
end
local criterion = create_criterion(model)
local eval_metric = w2nn.PSNRCriterion():cuda()
local x = torch.load(settings.images)
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
local adam_config = {
@ -179,7 +180,7 @@ local function train()
elseif settings.color == "rgb" then
ch = 3
end
local best_score = 100000.0
local best_score = 0.0
print("# make validation-set")
local valid_xy = make_validation_set(valid_x, pairwise_func,
settings.validation_crops,
@ -200,11 +201,11 @@ local function train()
print("# " .. epoch)
resampling(x, y, train_x, pairwise_func)
for i = 1, settings.inner_epoch do
print(minibatch_adam(model, criterion, x, y, adam_config))
print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
model:evaluate()
print("# validation")
local score = validate(model, criterion, valid_xy)
if score < best_score then
local score = validate(model, eval_metric, valid_xy)
if score > best_score then
local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0
best_score = score