Add oracle_rate option
This commit is contained in:
parent
8fec6f1b5a
commit
8088460a20
|
@ -56,6 +56,8 @@ cmd:option("-max_training_image_size", -1, 'if training image is larger than N,
|
||||||
cmd:option("-use_transparent_png", 0, 'use transparent png (0|1)')
|
cmd:option("-use_transparent_png", 0, 'use transparent png (0|1)')
|
||||||
cmd:option("-resize_blur_min", 0.85, 'min blur parameter for ResizeImage')
|
cmd:option("-resize_blur_min", 0.85, 'min blur parameter for ResizeImage')
|
||||||
cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage')
|
cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage')
|
||||||
|
cmd:option("-oracle_rate", 0.0, '')
|
||||||
|
cmd:option("-oracle_drop_rate", 0.5, '')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
if settings[name] == 1 then
|
if settings[name] == 1 then
|
||||||
|
|
63
train.lua
63
train.lua
|
@ -175,20 +175,48 @@ local function transformer(model, x, is_validation, n, offset)
|
||||||
end
|
end
|
||||||
|
|
||||||
local function resampling(x, y, train_x, transformer, input_size, target_size)
|
local function resampling(x, y, train_x, transformer, input_size, target_size)
|
||||||
print("## resampling")
|
local c = 1
|
||||||
|
local shuffle = torch.randperm(#train_x)
|
||||||
for t = 1, #train_x do
|
for t = 1, #train_x do
|
||||||
xlua.progress(t, #train_x)
|
xlua.progress(t, #train_x)
|
||||||
local xy = transformer(train_x[t], false, settings.patches)
|
local xy = transformer(train_x[shuffle[t]], false, settings.patches)
|
||||||
for i = 1, #xy do
|
for i = 1, #xy do
|
||||||
local index = (t - 1) * settings.patches + i
|
x[c]:copy(xy[i][1])
|
||||||
x[index]:copy(xy[i][1])
|
y[c]:copy(xy[i][2])
|
||||||
y[index]:copy(xy[i][2])
|
c = c + 1
|
||||||
|
if c > x:size(1) then
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if c > x:size(1) then
|
||||||
|
break
|
||||||
end
|
end
|
||||||
if t % 50 == 0 then
|
if t % 50 == 0 then
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
xlua.progress(#train_x, #train_x)
|
||||||
end
|
end
|
||||||
|
local function get_oracle_data(x, y, instance_loss, k, samples)
|
||||||
|
local index = torch.LongTensor(instance_loss:size(1))
|
||||||
|
local dummy = torch.Tensor(instance_loss:size(1))
|
||||||
|
torch.topk(dummy, index, instance_loss, k, 1, true)
|
||||||
|
print("average loss: " ..instance_loss:mean() .. ", average oracle loss: " .. dummy:mean())
|
||||||
|
local shuffle = torch.randperm(k)
|
||||||
|
local x_s = x:size()
|
||||||
|
local y_s = y:size()
|
||||||
|
x_s[1] = samples
|
||||||
|
y_s[1] = samples
|
||||||
|
local oracle_x = torch.Tensor(table.unpack(torch.totable(x_s)))
|
||||||
|
local oracle_y = torch.Tensor(table.unpack(torch.totable(y_s)))
|
||||||
|
|
||||||
|
for i = 1, samples do
|
||||||
|
oracle_x[i]:copy(x[index[shuffle[i]]])
|
||||||
|
oracle_y[i]:copy(y[index[shuffle[i]]])
|
||||||
|
end
|
||||||
|
return oracle_x, oracle_y
|
||||||
|
end
|
||||||
|
|
||||||
local function remove_small_image(x)
|
local function remove_small_image(x)
|
||||||
local new_x = {}
|
local new_x = {}
|
||||||
for i = 1, #x do
|
for i = 1, #x do
|
||||||
|
@ -254,12 +282,33 @@ local function train()
|
||||||
x = torch.Tensor(settings.patches * #train_x,
|
x = torch.Tensor(settings.patches * #train_x,
|
||||||
ch, settings.crop_size, settings.crop_size)
|
ch, settings.crop_size, settings.crop_size)
|
||||||
end
|
end
|
||||||
|
local instance_loss = nil
|
||||||
|
|
||||||
for epoch = 1, settings.epoch do
|
for epoch = 1, settings.epoch do
|
||||||
model:training()
|
model:training()
|
||||||
print("# " .. epoch)
|
print("# " .. epoch)
|
||||||
resampling(x, y, train_x, pairwise_func)
|
print("## resampling")
|
||||||
|
if instance_loss then
|
||||||
|
-- active learning
|
||||||
|
local oracle_k = math.min(x:size(1) * (settings.oracle_rate * (1 / (1 - settings.oracle_drop_rate))), x:size(1))
|
||||||
|
local oracle_n = math.min(x:size(1) * settings.oracle_rate, x:size(1))
|
||||||
|
if oracle_n > 0 then
|
||||||
|
local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n)
|
||||||
|
resampling(x, y, train_x, pairwise_func)
|
||||||
|
x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x)
|
||||||
|
y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y)
|
||||||
|
else
|
||||||
|
resampling(x, y, train_x, pairwise_func)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
resampling(x, y, train_x, pairwise_func)
|
||||||
|
end
|
||||||
|
collectgarbage()
|
||||||
|
instance_loss = torch.Tensor(x:size(1)):zero()
|
||||||
|
|
||||||
for i = 1, settings.inner_epoch do
|
for i = 1, settings.inner_epoch do
|
||||||
local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
||||||
|
instance_loss:copy(il)
|
||||||
print(train_score)
|
print(train_score)
|
||||||
model:evaluate()
|
model:evaluate()
|
||||||
print("# validation")
|
print("# validation")
|
||||||
|
|
Loading…
Reference in a new issue