1
0
Fork 0
mirror of synced 2024-05-16 19:02:21 +12:00

add -update_criterion option for back compatible

This commit is contained in:
nagadomi 2017-02-12 01:56:03 +09:00
parent b65bed9418
commit b111901cbb
2 changed files with 10 additions and 3 deletions

View file

@ -76,6 +76,7 @@ cmd:option("-resume", "", 'resume model file')
cmd:option("-name", "user", 'model name for user method')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
cmd:option("-update_criterion", "mse", 'mse|loss')
local function to_bool(settings, name)
if settings[name] == 1 then

View file

@ -532,9 +532,15 @@ local function train()
if settings.plot then
plot(hist_train, hist_valid)
end
if score.loss < best_score then
local score_for_update
if settings.update_criterion == "mse" then
score_for_update = score.MSE
else
score_for_update = score.loss
end
if score_for_update < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
best_score = score.loss
best_score = score_for_update
print("* model has updated")
if settings.save_history then
torch.save(settings.model_file_best, model:clearState(), "ascii")
@ -583,7 +589,7 @@ local function train()
end
end
end
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score)
collectgarbage()
end
end