diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index dbe70f4..96f189f 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric, local output = model:forward(inputs) local f = criterion:forward(output, targets) local se = 0 - for i = 1, batch_size do - local el = eval_metric:forward(output[i], targets[i]) - se = se + el - instance_loss[shuffle[t + i - 1]] = el + if config.xInstanceLoss then + for i = 1, batch_size do + local el = eval_metric:forward(output[i], targets[i]) + se = se + el + instance_loss[shuffle[t + i - 1]] = el + end + se = (se / batch_size) + else + se = eval_metric:forward(output, targets) end - sum_eval = sum_eval + (se / batch_size) + sum_eval = sum_eval + se sum_loss = sum_loss + f count_loss = count_loss + 1 model:backward(inputs, criterion:backward(output, targets)) diff --git a/train.lua b/train.lua index c5fe873..ba3ee82 100644 --- a/train.lua +++ b/train.lua @@ -374,7 +374,8 @@ local function train() local adam_config = { xLearningRate = settings.learning_rate, xBatchSize = settings.batch_size, - xLearningRateDecay = settings.learning_rate_decay + xLearningRateDecay = settings.learning_rate_decay, + xInstanceLoss = (settings.oracle_rate > 0) } local ch = nil if settings.color == "y" then