Merge branch 'dev' of into dev
This commit is contained in:
commit
8b5ccbed08
|
@ -11,7 +11,7 @@ local function minibatch_adam(model, criterion, eval_metric,
|
||||||
config.xEvalCount = 0
|
config.xEvalCount = 0
|
||||||
config.learningRate = config.xLearningRate
|
config.learningRate = config.xLearningRate
|
||||||
end
|
end
|
||||||
|
local sum_psnr = 0
|
||||||
local sum_loss = 0
|
local sum_loss = 0
|
||||||
local sum_eval = 0
|
local sum_eval = 0
|
||||||
local count_loss = 0
|
local count_loss = 0
|
||||||
|
@ -55,6 +55,7 @@ local function minibatch_adam(model, criterion, eval_metric,
|
||||||
else
|
else
|
||||||
se = eval_metric:forward(output, targets)
|
se = eval_metric:forward(output, targets)
|
||||||
end
|
end
|
||||||
|
sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6)))
|
||||||
sum_eval = sum_eval + se
|
sum_eval = sum_eval + se
|
||||||
sum_loss = sum_loss + f
|
sum_loss = sum_loss + f
|
||||||
count_loss = count_loss + 1
|
count_loss = count_loss + 1
|
||||||
|
@ -69,10 +70,9 @@ local function minibatch_adam(model, criterion, eval_metric,
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
xlua.progress(t, train_x:size(1))
|
xlua.progress(t, train_x:size(1))
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
xlua.progress(train_x:size(1), train_x:size(1))
|
xlua.progress(train_x:size(1), train_x:size(1))
|
||||||
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss
|
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = sum_psnr / count_loss}, instance_loss
|
||||||
end
|
end
|
||||||
|
|
||||||
return minibatch_adam
|
return minibatch_adam
|
||||||
|
|
|
@ -76,6 +76,7 @@ cmd:option("-resume", "", 'resume model file')
|
||||||
cmd:option("-name", "user", 'model name for user method')
|
cmd:option("-name", "user", 'model name for user method')
|
||||||
cmd:option("-gpu", 1, 'Device ID')
|
cmd:option("-gpu", 1, 'Device ID')
|
||||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
|
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
|
||||||
|
cmd:option("-update_criterion", "mse", 'mse|loss')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
if settings[name] == 1 then
|
if settings[name] == 1 then
|
||||||
|
|
19
train.lua
19
train.lua
|
@ -262,6 +262,7 @@ local function make_validation_set(x, n, patches)
|
||||||
return data
|
return data
|
||||||
end
|
end
|
||||||
local function validate(model, criterion, eval_metric, data, batch_size)
|
local function validate(model, criterion, eval_metric, data, batch_size)
|
||||||
|
local psnr = 0
|
||||||
local loss = 0
|
local loss = 0
|
||||||
local mse = 0
|
local mse = 0
|
||||||
local loss_count = 0
|
local loss_count = 0
|
||||||
|
@ -286,8 +287,10 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
||||||
inputs:copy(inputs_tmp)
|
inputs:copy(inputs_tmp)
|
||||||
targets:copy(targets_tmp)
|
targets:copy(targets_tmp)
|
||||||
local z = model:forward(inputs)
|
local z = model:forward(inputs)
|
||||||
|
local batch_mse = eval_metric:forward(z, targets)
|
||||||
loss = loss + criterion:forward(z, targets)
|
loss = loss + criterion:forward(z, targets)
|
||||||
mse = mse + eval_metric:forward(z, targets)
|
mse = mse + batch_mse
|
||||||
|
psnr = psnr + (10 * math.log10(1 / batch_mse))
|
||||||
loss_count = loss_count + 1
|
loss_count = loss_count + 1
|
||||||
if loss_count % 10 == 0 then
|
if loss_count % 10 == 0 then
|
||||||
xlua.progress(t, #data)
|
xlua.progress(t, #data)
|
||||||
|
@ -295,7 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
xlua.progress(#data, #data)
|
xlua.progress(#data, #data)
|
||||||
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
|
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = psnr / loss_count}
|
||||||
end
|
end
|
||||||
|
|
||||||
local function create_criterion(model)
|
local function create_criterion(model)
|
||||||
|
@ -540,9 +543,15 @@ local function train()
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
plot(hist_train, hist_valid)
|
plot(hist_train, hist_valid)
|
||||||
end
|
end
|
||||||
if score.loss < best_score then
|
local score_for_update
|
||||||
|
if settings.update_criterion == "mse" then
|
||||||
|
score_for_update = score.MSE
|
||||||
|
else
|
||||||
|
score_for_update = score.loss
|
||||||
|
end
|
||||||
|
if score_for_update < best_score then
|
||||||
local test_image = image_loader.load_float(settings.test) -- reload
|
local test_image = image_loader.load_float(settings.test) -- reload
|
||||||
best_score = score.loss
|
best_score = score_for_update
|
||||||
print("* model has updated")
|
print("* model has updated")
|
||||||
if settings.save_history then
|
if settings.save_history then
|
||||||
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
||||||
|
@ -591,7 +600,7 @@ local function train()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
|
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score)
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue