diff --git a/train.lua b/train.lua index b779d82..ded3f16 100644 --- a/train.lua +++ b/train.lua @@ -63,7 +63,6 @@ local function validate(model, criterion, data, batch_size) data[1].y:size(3)):zero() local inputs = inputs_tmp:clone():cuda() local targets = targets_tmp:clone():cuda() - for t = 1, #data, batch_size do if t + batch_size -1 > #data then break @@ -77,7 +76,7 @@ local function validate(model, criterion, data, batch_size) local z = model:forward(inputs) loss = loss + criterion:forward(z, targets) loss_count = loss_count + 1 - if t % 10 == 0 then + if loss_count % 10 == 0 then xlua.progress(t, #data) collectgarbage() end