Stop calculate the instance loss when oracle_rate=0
This commit is contained in:
parent
33e6bc888e
commit
451ee1407f
2 changed files with 12 additions and 6 deletions
|
@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric,
|
||||||
local output = model:forward(inputs)
|
local output = model:forward(inputs)
|
||||||
local f = criterion:forward(output, targets)
|
local f = criterion:forward(output, targets)
|
||||||
local se = 0
|
local se = 0
|
||||||
|
if config.xInstanceLoss then
|
||||||
for i = 1, batch_size do
|
for i = 1, batch_size do
|
||||||
local el = eval_metric:forward(output[i], targets[i])
|
local el = eval_metric:forward(output[i], targets[i])
|
||||||
se = se + el
|
se = se + el
|
||||||
instance_loss[shuffle[t + i - 1]] = el
|
instance_loss[shuffle[t + i - 1]] = el
|
||||||
end
|
end
|
||||||
sum_eval = sum_eval + (se / batch_size)
|
se = (se / batch_size)
|
||||||
|
else
|
||||||
|
se = eval_metric:forward(output, targets)
|
||||||
|
end
|
||||||
|
sum_eval = sum_eval + se
|
||||||
sum_loss = sum_loss + f
|
sum_loss = sum_loss + f
|
||||||
count_loss = count_loss + 1
|
count_loss = count_loss + 1
|
||||||
model:backward(inputs, criterion:backward(output, targets))
|
model:backward(inputs, criterion:backward(output, targets))
|
||||||
|
|
|
@ -374,7 +374,8 @@ local function train()
|
||||||
local adam_config = {
|
local adam_config = {
|
||||||
xLearningRate = settings.learning_rate,
|
xLearningRate = settings.learning_rate,
|
||||||
xBatchSize = settings.batch_size,
|
xBatchSize = settings.batch_size,
|
||||||
xLearningRateDecay = settings.learning_rate_decay
|
xLearningRateDecay = settings.learning_rate_decay,
|
||||||
|
xInstanceLoss = (settings.oracle_rate > 0)
|
||||||
}
|
}
|
||||||
local ch = nil
|
local ch = nil
|
||||||
if settings.color == "y" then
|
if settings.color == "y" then
|
||||||
|
|
Loading…
Reference in a new issue