Use MSE instead of PSNR
PSNR depends on the minibatch size and those group.
This commit is contained in:
parent
fa9355be7c
commit
68a6d4cef5
2 changed files with 6 additions and 8 deletions
|
@ -52,8 +52,7 @@ local function minibatch_adam(model, criterion, eval_metric,
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
xlua.progress(train_x:size(1), train_x:size(1))
|
xlua.progress(train_x:size(1), train_x:size(1))
|
||||||
|
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}
|
||||||
return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return minibatch_adam
|
return minibatch_adam
|
||||||
|
|
11
train.lua
11
train.lua
|
@ -198,7 +198,7 @@ local function train()
|
||||||
return transformer(x, is_validation, n, offset)
|
return transformer(x, is_validation, n, offset)
|
||||||
end
|
end
|
||||||
local criterion = create_criterion(model)
|
local criterion = create_criterion(model)
|
||||||
local eval_metric = w2nn.PSNRCriterion():cuda()
|
local eval_metric = nn.MSECriterion():cuda()
|
||||||
local x = torch.load(settings.images)
|
local x = torch.load(settings.images)
|
||||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
||||||
local adam_config = {
|
local adam_config = {
|
||||||
|
@ -212,7 +212,7 @@ local function train()
|
||||||
elseif settings.color == "rgb" then
|
elseif settings.color == "rgb" then
|
||||||
ch = 3
|
ch = 3
|
||||||
end
|
end
|
||||||
local best_score = 0.0
|
local best_score = 1000.0
|
||||||
print("# make validation-set")
|
print("# make validation-set")
|
||||||
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
||||||
settings.validation_crops,
|
settings.validation_crops,
|
||||||
|
@ -227,7 +227,6 @@ local function train()
|
||||||
ch, settings.crop_size, settings.crop_size)
|
ch, settings.crop_size, settings.crop_size)
|
||||||
local y = torch.Tensor(settings.patches * #train_x,
|
local y = torch.Tensor(settings.patches * #train_x,
|
||||||
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
|
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
|
||||||
|
|
||||||
for epoch = 1, settings.epoch do
|
for epoch = 1, settings.epoch do
|
||||||
model:training()
|
model:training()
|
||||||
print("# " .. epoch)
|
print("# " .. epoch)
|
||||||
|
@ -238,12 +237,12 @@ local function train()
|
||||||
model:evaluate()
|
model:evaluate()
|
||||||
print("# validation")
|
print("# validation")
|
||||||
local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
|
local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
|
||||||
table.insert(hist_train, train_score.PSNR)
|
table.insert(hist_train, train_score.MSE)
|
||||||
table.insert(hist_valid, score)
|
table.insert(hist_valid, score)
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
plot(hist_train, hist_valid)
|
plot(hist_train, hist_valid)
|
||||||
end
|
end
|
||||||
if score > best_score then
|
if score < best_score then
|
||||||
local test_image = image_loader.load_float(settings.test) -- reload
|
local test_image = image_loader.load_float(settings.test) -- reload
|
||||||
lrd_count = 0
|
lrd_count = 0
|
||||||
best_score = score
|
best_score = score
|
||||||
|
@ -281,7 +280,7 @@ local function train()
|
||||||
lrd_count = 0
|
lrd_count = 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
print("current: " .. score .. ", best: " .. best_score)
|
print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score)
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue