1
0
Fork 0
mirror of synced 2024-06-24 01:00:41 +12:00

Stop calculate the instance loss when oracle_rate=0

This commit is contained in:
nagadomi 2016-09-11 20:59:32 +09:00
parent 33e6bc888e
commit 451ee1407f
2 changed files with 12 additions and 6 deletions

View file

@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric,
local output = model:forward(inputs)
local f = criterion:forward(output, targets)
local se = 0
for i = 1, batch_size do
local el = eval_metric:forward(output[i], targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
if config.xInstanceLoss then
for i = 1, batch_size do
local el = eval_metric:forward(output[i], targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
end
se = (se / batch_size)
else
se = eval_metric:forward(output, targets)
end
sum_eval = sum_eval + (se / batch_size)
sum_eval = sum_eval + se
sum_loss = sum_loss + f
count_loss = count_loss + 1
model:backward(inputs, criterion:backward(output, targets))

View file

@ -374,7 +374,8 @@ local function train()
local adam_config = {
xLearningRate = settings.learning_rate,
xBatchSize = settings.batch_size,
xLearningRateDecay = settings.learning_rate_decay
xLearningRateDecay = settings.learning_rate_decay,
xInstanceLoss = (settings.oracle_rate > 0)
}
local ch = nil
if settings.color == "y" then