1
0
Fork 0
mirror of synced 2024-05-18 11:52:17 +12:00

Use MSE instead of PSNR

PSNR depends on the minibatch size and those group.
This commit is contained in:
nagadomi 2016-04-17 02:08:07 +09:00
parent fa9355be7c
commit 68a6d4cef5
2 changed files with 6 additions and 8 deletions

View file

@ -52,8 +52,7 @@ local function minibatch_adam(model, criterion, eval_metric,
end
end
xlua.progress(train_x:size(1), train_x:size(1))
return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}
end
return minibatch_adam

View file

@ -198,7 +198,7 @@ local function train()
return transformer(x, is_validation, n, offset)
end
local criterion = create_criterion(model)
local eval_metric = w2nn.PSNRCriterion():cuda()
local eval_metric = nn.MSECriterion():cuda()
local x = torch.load(settings.images)
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
local adam_config = {
@ -212,7 +212,7 @@ local function train()
elseif settings.color == "rgb" then
ch = 3
end
local best_score = 0.0
local best_score = 1000.0
print("# make validation-set")
local valid_xy = make_validation_set(valid_x, pairwise_func,
settings.validation_crops,
@ -227,7 +227,6 @@ local function train()
ch, settings.crop_size, settings.crop_size)
local y = torch.Tensor(settings.patches * #train_x,
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
for epoch = 1, settings.epoch do
model:training()
print("# " .. epoch)
@ -238,12 +237,12 @@ local function train()
model:evaluate()
print("# validation")
local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
table.insert(hist_train, train_score.PSNR)
table.insert(hist_train, train_score.MSE)
table.insert(hist_valid, score)
if settings.plot then
plot(hist_train, hist_valid)
end
if score > best_score then
if score < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0
best_score = score
@ -281,7 +280,7 @@ local function train()
lrd_count = 0
end
end
print("current: " .. score .. ", best: " .. best_score)
print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score)
collectgarbage()
end
end