Stop calculate the instance loss when oracle_rate=0
This commit is contained in:
parent
33e6bc888e
commit
451ee1407f
|
@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric,
|
|||
local output = model:forward(inputs)
|
||||
local f = criterion:forward(output, targets)
|
||||
local se = 0
|
||||
for i = 1, batch_size do
|
||||
local el = eval_metric:forward(output[i], targets[i])
|
||||
se = se + el
|
||||
instance_loss[shuffle[t + i - 1]] = el
|
||||
if config.xInstanceLoss then
|
||||
for i = 1, batch_size do
|
||||
local el = eval_metric:forward(output[i], targets[i])
|
||||
se = se + el
|
||||
instance_loss[shuffle[t + i - 1]] = el
|
||||
end
|
||||
se = (se / batch_size)
|
||||
else
|
||||
se = eval_metric:forward(output, targets)
|
||||
end
|
||||
sum_eval = sum_eval + (se / batch_size)
|
||||
sum_eval = sum_eval + se
|
||||
sum_loss = sum_loss + f
|
||||
count_loss = count_loss + 1
|
||||
model:backward(inputs, criterion:backward(output, targets))
|
||||
|
|
|
@ -374,7 +374,8 @@ local function train()
|
|||
local adam_config = {
|
||||
xLearningRate = settings.learning_rate,
|
||||
xBatchSize = settings.batch_size,
|
||||
xLearningRateDecay = settings.learning_rate_decay
|
||||
xLearningRateDecay = settings.learning_rate_decay,
|
||||
xInstanceLoss = (settings.oracle_rate > 0)
|
||||
}
|
||||
local ch = nil
|
||||
if settings.color == "y" then
|
||||
|
|
Loading…
Reference in a new issue