diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index 4c819fd..af05f5a 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -21,7 +21,6 @@ local function minibatch_adam(model, criterion, input_size[1], input_size[2], input_size[3]) local targets_tmp = torch.Tensor(batch_size, target_size[1] * target_size[2] * target_size[3]) - for t = 1, #train_x do xlua.progress(t, #train_x) local xy = transformer(train_x[shuffle[t]], false, batch_size) @@ -31,7 +30,6 @@ local function minibatch_adam(model, criterion, end inputs:copy(inputs_tmp) targets:copy(targets_tmp) - local feval = function(x) if x ~= parameters then parameters:copy(x) @@ -53,7 +51,7 @@ local function minibatch_adam(model, criterion, end xlua.progress(#train_x, #train_x) - return { mse = sum_loss / count_loss} + return { loss = sum_loss / count_loss} end return minibatch_adam diff --git a/train.lua b/train.lua index b4da992..978dc92 100644 --- a/train.lua +++ b/train.lua @@ -35,18 +35,19 @@ local function split_data(x, test_size) end return train_x, valid_x end -local function make_validation_set(x, transformer, n) +local function make_validation_set(x, transformer, n, batch_size) n = n or 4 local data = {} for i = 1, #x do - for k = 1, math.max(n / 8, 1) do - local xy = transformer(x[i], true, 8) + for k = 1, math.max(n / batch_size, 1) do + local xy = transformer(x[i], true, batch_size) + local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3)) + local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3)) for j = 1, #xy do - local x = xy[j][1] - local y = xy[j][2] - table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)), - y = y:reshape(1, y:size(1), y:size(2), y:size(3))}) + tx[j]:copy(xy[j][1]) + ty[j]:copy(xy[j][2]) end + table.insert(data, {x = tx, y = ty}) end xlua.progress(i, #x) collectgarbage() @@ -58,11 +59,12 @@ local function validate(model, criterion, data) for i = 1, #data do local z = model:forward(data[i].x:cuda()) loss = loss + criterion:forward(z, data[i].y:cuda()) - xlua.progress(i, #data) - if i % 10 == 0 then + if i % 100 == 0 then + xlua.progress(i, #data) collectgarbage() end end + xlua.progress(#data, #data) return loss / #data end @@ -71,10 +73,10 @@ local function create_criterion(model) local offset = reconstruct.offset_size(model) local output_w = settings.crop_size - offset * 2 local weight = torch.Tensor(3, output_w * output_w) - weight[1]:fill(0.299 * 3) -- R - weight[2]:fill(0.587 * 3) -- G - weight[3]:fill(0.114 * 3) -- B - return w2nn.WeightedMSECriterion(weight):cuda() + weight[1]:fill(0.29891 * 3) -- R + weight[2]:fill(0.58661 * 3) -- G + weight[3]:fill(0.11448 * 3) -- B + return w2nn.WeightedHuberCriterion(weight, 0.1):cuda() else return nn.MSECriterion():cuda() end @@ -151,7 +153,9 @@ local function train() end local best_score = 100000.0 print("# make validation-set") - local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops) + local valid_xy = make_validation_set(valid_x, pairwise_func, + settings.validation_crops, + settings.batch_size) valid_x = nil collectgarbage()