diff --git a/train.lua b/train.lua index 215b4ce..b54d4a6 100644 --- a/train.lua +++ b/train.lua @@ -37,10 +37,11 @@ local function split_data(x, test_size) end local function make_validation_set(x, transformer, n, patches) n = n or 4 + local validation_patches = math.min(16, patches or 16) local data = {} for i = 1, #x do - for k = 1, math.max(n / patches, 1) do - local xy = transformer(x[i], true, patches) + for k = 1, math.max(n / validation_patches, 1) do + local xy = transformer(x[i], true, validation_patches) for j = 1, #xy do table.insert(data, {x = xy[j][1], y = xy[j][2]}) end