Use PSNR for evaluation
This commit is contained in:
parent
41581a0d55
commit
1900ac7500
19
lib/PSNRCriterion.lua
Normal file
19
lib/PSNRCriterion.lua
Normal file
|
@ -0,0 +1,19 @@
|
|||
local PSNRCriterion, parent = torch.class('w2nn.PSNRCriterion','nn.Criterion')
|
||||
|
||||
function PSNRCriterion:__init()
|
||||
parent.__init(self)
|
||||
self.image = torch.Tensor()
|
||||
self.diff = torch.Tensor()
|
||||
end
|
||||
function PSNRCriterion:updateOutput(input, target)
|
||||
self.image:resizeAs(input):copy(input)
|
||||
self.image:clamp(0.0, 1.0)
|
||||
self.diff:resizeAs(self.image):copy(self.image)
|
||||
|
||||
local mse = self.diff:add(-1, target):pow(2):mean()
|
||||
self.output = 10 * math.log10(1.0 / mse)
|
||||
return self.output
|
||||
end
|
||||
function PSNRCriterion:updateGradInput(input, target)
|
||||
error("PSNRCriterion does not support backward")
|
||||
end
|
|
@ -2,12 +2,13 @@ require 'optim'
|
|||
require 'cutorch'
|
||||
require 'xlua'
|
||||
|
||||
local function minibatch_adam(model, criterion,
|
||||
local function minibatch_adam(model, criterion, eval_metric,
|
||||
train_x, train_y,
|
||||
config)
|
||||
local parameters, gradParameters = model:getParameters()
|
||||
config = config or {}
|
||||
local sum_loss = 0
|
||||
local sum_eval = 0
|
||||
local count_loss = 0
|
||||
local batch_size = config.xBatchSize or 32
|
||||
local shuffle = torch.randperm(train_x:size(1))
|
||||
|
@ -39,6 +40,7 @@ local function minibatch_adam(model, criterion,
|
|||
gradParameters:zero()
|
||||
local output = model:forward(inputs)
|
||||
local f = criterion:forward(output, targets)
|
||||
sum_eval = sum_eval + eval_metric:forward(output, targets)
|
||||
sum_loss = sum_loss + f
|
||||
count_loss = count_loss + 1
|
||||
model:backward(inputs, criterion:backward(output, targets))
|
||||
|
@ -52,7 +54,7 @@ local function minibatch_adam(model, criterion,
|
|||
end
|
||||
xlua.progress(train_x:size(1), train_x:size(1))
|
||||
|
||||
return { loss = sum_loss / count_loss}
|
||||
return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
|
||||
end
|
||||
|
||||
return minibatch_adam
|
||||
|
|
|
@ -19,7 +19,7 @@ else
|
|||
require 'LeakyReLU'
|
||||
require 'LeakyReLU_deprecated'
|
||||
require 'DepthExpand2x'
|
||||
require 'WeightedMSECriterion'
|
||||
require 'PSNRCriterion'
|
||||
require 'ClippedWeightedHuberCriterion'
|
||||
return w2nn
|
||||
end
|
||||
|
|
|
@ -166,6 +166,7 @@ local function train()
|
|||
return transformer(x, is_validation, n, offset)
|
||||
end
|
||||
local criterion = create_criterion(model)
|
||||
local eval_metric = w2nn.PSNRCriterion():cuda()
|
||||
local x = torch.load(settings.images)
|
||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
||||
local adam_config = {
|
||||
|
@ -179,7 +180,7 @@ local function train()
|
|||
elseif settings.color == "rgb" then
|
||||
ch = 3
|
||||
end
|
||||
local best_score = 100000.0
|
||||
local best_score = 0.0
|
||||
print("# make validation-set")
|
||||
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
||||
settings.validation_crops,
|
||||
|
@ -200,11 +201,11 @@ local function train()
|
|||
print("# " .. epoch)
|
||||
resampling(x, y, train_x, pairwise_func)
|
||||
for i = 1, settings.inner_epoch do
|
||||
print(minibatch_adam(model, criterion, x, y, adam_config))
|
||||
print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
|
||||
model:evaluate()
|
||||
print("# validation")
|
||||
local score = validate(model, criterion, valid_xy)
|
||||
if score < best_score then
|
||||
local score = validate(model, eval_metric, valid_xy)
|
||||
if score > best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
lrd_count = 0
|
||||
best_score = score
|
||||
|
|
Loading…
Reference in a new issue