Minimize the weighted huber loss instead of the weighted mean square error
Huber loss is less sensitive to outliers(i.e. noise) in data than the squared error loss.
This commit is contained in:
parent
243d8821be
commit
490eb33a6b
|
@ -21,7 +21,6 @@ local function minibatch_adam(model, criterion,
|
||||||
input_size[1], input_size[2], input_size[3])
|
input_size[1], input_size[2], input_size[3])
|
||||||
local targets_tmp = torch.Tensor(batch_size,
|
local targets_tmp = torch.Tensor(batch_size,
|
||||||
target_size[1] * target_size[2] * target_size[3])
|
target_size[1] * target_size[2] * target_size[3])
|
||||||
|
|
||||||
for t = 1, #train_x do
|
for t = 1, #train_x do
|
||||||
xlua.progress(t, #train_x)
|
xlua.progress(t, #train_x)
|
||||||
local xy = transformer(train_x[shuffle[t]], false, batch_size)
|
local xy = transformer(train_x[shuffle[t]], false, batch_size)
|
||||||
|
@ -31,7 +30,6 @@ local function minibatch_adam(model, criterion,
|
||||||
end
|
end
|
||||||
inputs:copy(inputs_tmp)
|
inputs:copy(inputs_tmp)
|
||||||
targets:copy(targets_tmp)
|
targets:copy(targets_tmp)
|
||||||
|
|
||||||
local feval = function(x)
|
local feval = function(x)
|
||||||
if x ~= parameters then
|
if x ~= parameters then
|
||||||
parameters:copy(x)
|
parameters:copy(x)
|
||||||
|
@ -53,7 +51,7 @@ local function minibatch_adam(model, criterion,
|
||||||
end
|
end
|
||||||
xlua.progress(#train_x, #train_x)
|
xlua.progress(#train_x, #train_x)
|
||||||
|
|
||||||
return { mse = sum_loss / count_loss}
|
return { loss = sum_loss / count_loss}
|
||||||
end
|
end
|
||||||
|
|
||||||
return minibatch_adam
|
return minibatch_adam
|
||||||
|
|
32
train.lua
32
train.lua
|
@ -35,18 +35,19 @@ local function split_data(x, test_size)
|
||||||
end
|
end
|
||||||
return train_x, valid_x
|
return train_x, valid_x
|
||||||
end
|
end
|
||||||
local function make_validation_set(x, transformer, n)
|
local function make_validation_set(x, transformer, n, batch_size)
|
||||||
n = n or 4
|
n = n or 4
|
||||||
local data = {}
|
local data = {}
|
||||||
for i = 1, #x do
|
for i = 1, #x do
|
||||||
for k = 1, math.max(n / 8, 1) do
|
for k = 1, math.max(n / batch_size, 1) do
|
||||||
local xy = transformer(x[i], true, 8)
|
local xy = transformer(x[i], true, batch_size)
|
||||||
|
local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
|
||||||
|
local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
|
||||||
for j = 1, #xy do
|
for j = 1, #xy do
|
||||||
local x = xy[j][1]
|
tx[j]:copy(xy[j][1])
|
||||||
local y = xy[j][2]
|
ty[j]:copy(xy[j][2])
|
||||||
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
|
|
||||||
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
|
|
||||||
end
|
end
|
||||||
|
table.insert(data, {x = tx, y = ty})
|
||||||
end
|
end
|
||||||
xlua.progress(i, #x)
|
xlua.progress(i, #x)
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
|
@ -58,11 +59,12 @@ local function validate(model, criterion, data)
|
||||||
for i = 1, #data do
|
for i = 1, #data do
|
||||||
local z = model:forward(data[i].x:cuda())
|
local z = model:forward(data[i].x:cuda())
|
||||||
loss = loss + criterion:forward(z, data[i].y:cuda())
|
loss = loss + criterion:forward(z, data[i].y:cuda())
|
||||||
xlua.progress(i, #data)
|
if i % 100 == 0 then
|
||||||
if i % 10 == 0 then
|
xlua.progress(i, #data)
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
xlua.progress(#data, #data)
|
||||||
return loss / #data
|
return loss / #data
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -71,10 +73,10 @@ local function create_criterion(model)
|
||||||
local offset = reconstruct.offset_size(model)
|
local offset = reconstruct.offset_size(model)
|
||||||
local output_w = settings.crop_size - offset * 2
|
local output_w = settings.crop_size - offset * 2
|
||||||
local weight = torch.Tensor(3, output_w * output_w)
|
local weight = torch.Tensor(3, output_w * output_w)
|
||||||
weight[1]:fill(0.299 * 3) -- R
|
weight[1]:fill(0.29891 * 3) -- R
|
||||||
weight[2]:fill(0.587 * 3) -- G
|
weight[2]:fill(0.58661 * 3) -- G
|
||||||
weight[3]:fill(0.114 * 3) -- B
|
weight[3]:fill(0.11448 * 3) -- B
|
||||||
return w2nn.WeightedMSECriterion(weight):cuda()
|
return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
|
||||||
else
|
else
|
||||||
return nn.MSECriterion():cuda()
|
return nn.MSECriterion():cuda()
|
||||||
end
|
end
|
||||||
|
@ -151,7 +153,9 @@ local function train()
|
||||||
end
|
end
|
||||||
local best_score = 100000.0
|
local best_score = 100000.0
|
||||||
print("# make validation-set")
|
print("# make validation-set")
|
||||||
local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
|
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
||||||
|
settings.validation_crops,
|
||||||
|
settings.batch_size)
|
||||||
valid_x = nil
|
valid_x = nil
|
||||||
|
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
|
|
Loading…
Reference in a new issue