diff --git a/train.lua b/train.lua index ded3f16..c4ee633 100644 --- a/train.lua +++ b/train.lua @@ -48,6 +48,12 @@ local function make_validation_set(x, transformer, n, patches) xlua.progress(i, #x) collectgarbage() end + local new_data = {} + local perm = torch.randperm(#data) + for i = 1, perm:size(1) do + new_data[i] = data[perm[i]] + end + data = new_data return data end local function validate(model, criterion, data, batch_size)