1
0
Fork 0
mirror of synced 2024-05-16 19:02:21 +12:00

Merge branch 'dev' of into dev

This commit is contained in:
nagadomi 2017-02-12 17:48:44 +09:00
commit 8b5ccbed08
3 changed files with 18 additions and 8 deletions

View file

@ -11,7 +11,7 @@ local function minibatch_adam(model, criterion, eval_metric,
config.xEvalCount = 0
config.learningRate = config.xLearningRate
end
local sum_psnr = 0
local sum_loss = 0
local sum_eval = 0
local count_loss = 0
@ -55,6 +55,7 @@ local function minibatch_adam(model, criterion, eval_metric,
else
se = eval_metric:forward(output, targets)
end
sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6)))
sum_eval = sum_eval + se
sum_loss = sum_loss + f
count_loss = count_loss + 1
@ -69,10 +70,9 @@ local function minibatch_adam(model, criterion, eval_metric,
collectgarbage()
xlua.progress(t, train_x:size(1))
end
end
xlua.progress(train_x:size(1), train_x:size(1))
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = sum_psnr / count_loss}, instance_loss
end
return minibatch_adam

View file

@ -76,6 +76,7 @@ cmd:option("-resume", "", 'resume model file')
cmd:option("-name", "user", 'model name for user method')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
cmd:option("-update_criterion", "mse", 'mse|loss')
local function to_bool(settings, name)
if settings[name] == 1 then

View file

@ -262,6 +262,7 @@ local function make_validation_set(x, n, patches)
return data
end
local function validate(model, criterion, eval_metric, data, batch_size)
local psnr = 0
local loss = 0
local mse = 0
local loss_count = 0
@ -286,8 +287,10 @@ local function validate(model, criterion, eval_metric, data, batch_size)
inputs:copy(inputs_tmp)
targets:copy(targets_tmp)
local z = model:forward(inputs)
local batch_mse = eval_metric:forward(z, targets)
loss = loss + criterion:forward(z, targets)
mse = mse + eval_metric:forward(z, targets)
mse = mse + batch_mse
psnr = psnr + (10 * math.log10(1 / batch_mse))
loss_count = loss_count + 1
if loss_count % 10 == 0 then
xlua.progress(t, #data)
@ -295,7 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
end
end
xlua.progress(#data, #data)
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = psnr / loss_count}
end
local function create_criterion(model)
@ -540,9 +543,15 @@ local function train()
if settings.plot then
plot(hist_train, hist_valid)
end
if score.loss < best_score then
local score_for_update
if settings.update_criterion == "mse" then
score_for_update = score.MSE
else
score_for_update = score.loss
end
if score_for_update < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
best_score = score.loss
best_score = score_for_update
print("* model has updated")
if settings.save_history then
torch.save(settings.model_file_best, model:clearState(), "ascii")
@ -591,7 +600,7 @@ local function train()
end
end
end
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score)
collectgarbage()
end
end