From 451ee1407fa826b154681d1cf508f439194999b5 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 11 Sep 2016 20:59:32 +0900 Subject: [PATCH] Stop calculate the instance loss when oracle_rate=0 --- lib/minibatch_adam.lua | 15 ++++++++++----- train.lua | 3 ++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index dbe70f4..96f189f 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric, local output = model:forward(inputs) local f = criterion:forward(output, targets) local se = 0 - for i = 1, batch_size do - local el = eval_metric:forward(output[i], targets[i]) - se = se + el - instance_loss[shuffle[t + i - 1]] = el + if config.xInstanceLoss then + for i = 1, batch_size do + local el = eval_metric:forward(output[i], targets[i]) + se = se + el + instance_loss[shuffle[t + i - 1]] = el + end + se = (se / batch_size) + else + se = eval_metric:forward(output, targets) end - sum_eval = sum_eval + (se / batch_size) + sum_eval = sum_eval + se sum_loss = sum_loss + f count_loss = count_loss + 1 model:backward(inputs, criterion:backward(output, targets)) diff --git a/train.lua b/train.lua index c5fe873..ba3ee82 100644 --- a/train.lua +++ b/train.lua @@ -374,7 +374,8 @@ local function train() local adam_config = { xLearningRate = settings.learning_rate, xBatchSize = settings.batch_size, - xLearningRateDecay = settings.learning_rate_decay + xLearningRateDecay = settings.learning_rate_decay, + xInstanceLoss = (settings.oracle_rate > 0) } local ch = nil if settings.color == "y" then