Minimize the weighted huber loss instead of the weighted mean square error
Huber loss is less sensitive to outliers(i.e. noise) in data than the squared error loss.
This commit is contained in:
parent
243d8821be
commit
490eb33a6b
|
@ -21,7 +21,6 @@ local function minibatch_adam(model, criterion,
|
|||
input_size[1], input_size[2], input_size[3])
|
||||
local targets_tmp = torch.Tensor(batch_size,
|
||||
target_size[1] * target_size[2] * target_size[3])
|
||||
|
||||
for t = 1, #train_x do
|
||||
xlua.progress(t, #train_x)
|
||||
local xy = transformer(train_x[shuffle[t]], false, batch_size)
|
||||
|
@ -31,7 +30,6 @@ local function minibatch_adam(model, criterion,
|
|||
end
|
||||
inputs:copy(inputs_tmp)
|
||||
targets:copy(targets_tmp)
|
||||
|
||||
local feval = function(x)
|
||||
if x ~= parameters then
|
||||
parameters:copy(x)
|
||||
|
@ -53,7 +51,7 @@ local function minibatch_adam(model, criterion,
|
|||
end
|
||||
xlua.progress(#train_x, #train_x)
|
||||
|
||||
return { mse = sum_loss / count_loss}
|
||||
return { loss = sum_loss / count_loss}
|
||||
end
|
||||
|
||||
return minibatch_adam
|
||||
|
|
30
train.lua
30
train.lua
|
@ -35,18 +35,19 @@ local function split_data(x, test_size)
|
|||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
local function make_validation_set(x, transformer, n)
|
||||
local function make_validation_set(x, transformer, n, batch_size)
|
||||
n = n or 4
|
||||
local data = {}
|
||||
for i = 1, #x do
|
||||
for k = 1, math.max(n / 8, 1) do
|
||||
local xy = transformer(x[i], true, 8)
|
||||
for k = 1, math.max(n / batch_size, 1) do
|
||||
local xy = transformer(x[i], true, batch_size)
|
||||
local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
|
||||
local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
|
||||
for j = 1, #xy do
|
||||
local x = xy[j][1]
|
||||
local y = xy[j][2]
|
||||
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
|
||||
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
|
||||
tx[j]:copy(xy[j][1])
|
||||
ty[j]:copy(xy[j][2])
|
||||
end
|
||||
table.insert(data, {x = tx, y = ty})
|
||||
end
|
||||
xlua.progress(i, #x)
|
||||
collectgarbage()
|
||||
|
@ -58,11 +59,12 @@ local function validate(model, criterion, data)
|
|||
for i = 1, #data do
|
||||
local z = model:forward(data[i].x:cuda())
|
||||
loss = loss + criterion:forward(z, data[i].y:cuda())
|
||||
if i % 100 == 0 then
|
||||
xlua.progress(i, #data)
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
xlua.progress(#data, #data)
|
||||
return loss / #data
|
||||
end
|
||||
|
||||
|
@ -71,10 +73,10 @@ local function create_criterion(model)
|
|||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(3, output_w * output_w)
|
||||
weight[1]:fill(0.299 * 3) -- R
|
||||
weight[2]:fill(0.587 * 3) -- G
|
||||
weight[3]:fill(0.114 * 3) -- B
|
||||
return w2nn.WeightedMSECriterion(weight):cuda()
|
||||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
|
||||
else
|
||||
return nn.MSECriterion():cuda()
|
||||
end
|
||||
|
@ -151,7 +153,9 @@ local function train()
|
|||
end
|
||||
local best_score = 100000.0
|
||||
print("# make validation-set")
|
||||
local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
|
||||
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
||||
settings.validation_crops,
|
||||
settings.batch_size)
|
||||
valid_x = nil
|
||||
|
||||
collectgarbage()
|
||||
|
|
Loading…
Reference in a new issue