diff --git a/train.lua b/train.lua index 4b97ce9..b779d82 100644 --- a/train.lua +++ b/train.lua @@ -41,31 +41,49 @@ local function make_validation_set(x, transformer, n, patches) for i = 1, #x do for k = 1, math.max(n / patches, 1) do local xy = transformer(x[i], true, patches) - local tx = torch.Tensor(patches, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3)) - local ty = torch.Tensor(patches, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3)) for j = 1, #xy do - tx[j]:copy(xy[j][1]) - ty[j]:copy(xy[j][2]) + table.insert(data, {x = xy[j][1], y = xy[j][2]}) end - table.insert(data, {x = tx, y = ty}) end xlua.progress(i, #x) collectgarbage() end return data end -local function validate(model, criterion, data) +local function validate(model, criterion, data, batch_size) local loss = 0 - for i = 1, #data do - local z = model:forward(data[i].x:cuda()) - loss = loss + criterion:forward(z, data[i].y:cuda()) - if i % 100 == 0 then - xlua.progress(i, #data) + local loss_count = 0 + local inputs_tmp = torch.Tensor(batch_size, + data[1].x:size(1), + data[1].x:size(2), + data[1].x:size(3)):zero() + local targets_tmp = torch.Tensor(batch_size, + data[1].y:size(1), + data[1].y:size(2), + data[1].y:size(3)):zero() + local inputs = inputs_tmp:clone():cuda() + local targets = targets_tmp:clone():cuda() + + for t = 1, #data, batch_size do + if t + batch_size -1 > #data then + break + end + for i = 1, batch_size do + inputs_tmp[i]:copy(data[t + i - 1].x) + targets_tmp[i]:copy(data[t + i - 1].y) + end + inputs:copy(inputs_tmp) + targets:copy(targets_tmp) + local z = model:forward(inputs) + loss = loss + criterion:forward(z, targets) + loss_count = loss_count + 1 + if t % 10 == 0 then + xlua.progress(t, #data) collectgarbage() end end xlua.progress(#data, #data) - return loss / #data + return loss / loss_count end local function create_criterion(model) @@ -214,8 +232,7 @@ local function train() print(train_score) model:evaluate() print("# validation") - local score = validate(model, eval_metric, valid_xy) - + local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize) table.insert(hist_train, train_score.PSNR) table.insert(hist_valid, score) if settings.plot then