1
0
Fork 0
mirror of synced 2024-06-02 02:54:31 +12:00

Minimize the weighted huber loss instead of the weighted mean square error

Huber loss is less sensitive to outliers(i.e. noise) in data than the squared error loss.
This commit is contained in:
nagadomi 2015-10-31 21:56:20 +09:00
parent 243d8821be
commit 490eb33a6b
2 changed files with 19 additions and 17 deletions

View file

@ -21,7 +21,6 @@ local function minibatch_adam(model, criterion,
input_size[1], input_size[2], input_size[3]) input_size[1], input_size[2], input_size[3])
local targets_tmp = torch.Tensor(batch_size, local targets_tmp = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3]) target_size[1] * target_size[2] * target_size[3])
for t = 1, #train_x do for t = 1, #train_x do
xlua.progress(t, #train_x) xlua.progress(t, #train_x)
local xy = transformer(train_x[shuffle[t]], false, batch_size) local xy = transformer(train_x[shuffle[t]], false, batch_size)
@ -31,7 +30,6 @@ local function minibatch_adam(model, criterion,
end end
inputs:copy(inputs_tmp) inputs:copy(inputs_tmp)
targets:copy(targets_tmp) targets:copy(targets_tmp)
local feval = function(x) local feval = function(x)
if x ~= parameters then if x ~= parameters then
parameters:copy(x) parameters:copy(x)
@ -53,7 +51,7 @@ local function minibatch_adam(model, criterion,
end end
xlua.progress(#train_x, #train_x) xlua.progress(#train_x, #train_x)
return { mse = sum_loss / count_loss} return { loss = sum_loss / count_loss}
end end
return minibatch_adam return minibatch_adam

View file

@ -35,18 +35,19 @@ local function split_data(x, test_size)
end end
return train_x, valid_x return train_x, valid_x
end end
local function make_validation_set(x, transformer, n) local function make_validation_set(x, transformer, n, batch_size)
n = n or 4 n = n or 4
local data = {} local data = {}
for i = 1, #x do for i = 1, #x do
for k = 1, math.max(n / 8, 1) do for k = 1, math.max(n / batch_size, 1) do
local xy = transformer(x[i], true, 8) local xy = transformer(x[i], true, batch_size)
local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
for j = 1, #xy do for j = 1, #xy do
local x = xy[j][1] tx[j]:copy(xy[j][1])
local y = xy[j][2] ty[j]:copy(xy[j][2])
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
end end
table.insert(data, {x = tx, y = ty})
end end
xlua.progress(i, #x) xlua.progress(i, #x)
collectgarbage() collectgarbage()
@ -58,11 +59,12 @@ local function validate(model, criterion, data)
for i = 1, #data do for i = 1, #data do
local z = model:forward(data[i].x:cuda()) local z = model:forward(data[i].x:cuda())
loss = loss + criterion:forward(z, data[i].y:cuda()) loss = loss + criterion:forward(z, data[i].y:cuda())
xlua.progress(i, #data) if i % 100 == 0 then
if i % 10 == 0 then xlua.progress(i, #data)
collectgarbage() collectgarbage()
end end
end end
xlua.progress(#data, #data)
return loss / #data return loss / #data
end end
@ -71,10 +73,10 @@ local function create_criterion(model)
local offset = reconstruct.offset_size(model) local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2 local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w) local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.299 * 3) -- R weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.587 * 3) -- G weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.114 * 3) -- B weight[3]:fill(0.11448 * 3) -- B
return w2nn.WeightedMSECriterion(weight):cuda() return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
else else
return nn.MSECriterion():cuda() return nn.MSECriterion():cuda()
end end
@ -151,7 +153,9 @@ local function train()
end end
local best_score = 100000.0 local best_score = 100000.0
print("# make validation-set") print("# make validation-set")
local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops) local valid_xy = make_validation_set(valid_x, pairwise_func,
settings.validation_crops,
settings.batch_size)
valid_x = nil valid_x = nil
collectgarbage() collectgarbage()