add -update_criterion option for back compatible
This commit is contained in:
parent
b65bed9418
commit
b111901cbb
|
@ -76,6 +76,7 @@ cmd:option("-resume", "", 'resume model file')
|
|||
cmd:option("-name", "user", 'model name for user method')
|
||||
cmd:option("-gpu", 1, 'Device ID')
|
||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
|
||||
cmd:option("-update_criterion", "mse", 'mse|loss')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
|
12
train.lua
12
train.lua
|
@ -532,9 +532,15 @@ local function train()
|
|||
if settings.plot then
|
||||
plot(hist_train, hist_valid)
|
||||
end
|
||||
if score.loss < best_score then
|
||||
local score_for_update
|
||||
if settings.update_criterion == "mse" then
|
||||
score_for_update = score.MSE
|
||||
else
|
||||
score_for_update = score.loss
|
||||
end
|
||||
if score_for_update < best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
best_score = score.loss
|
||||
best_score = score_for_update
|
||||
print("* model has updated")
|
||||
if settings.save_history then
|
||||
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
||||
|
@ -583,7 +589,7 @@ local function train()
|
|||
end
|
||||
end
|
||||
end
|
||||
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
|
||||
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score)
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue