diff --git a/lib/settings.lua b/lib/settings.lua index f8db26c..f7a37ba 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -56,6 +56,8 @@ cmd:option("-max_training_image_size", -1, 'if training image is larger than N, cmd:option("-use_transparent_png", 0, 'use transparent png (0|1)') cmd:option("-resize_blur_min", 0.85, 'min blur parameter for ResizeImage') cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage') +cmd:option("-oracle_rate", 0.0, '') +cmd:option("-oracle_drop_rate", 0.5, '') local function to_bool(settings, name) if settings[name] == 1 then diff --git a/train.lua b/train.lua index c6694b5..0a7416b 100644 --- a/train.lua +++ b/train.lua @@ -175,20 +175,48 @@ local function transformer(model, x, is_validation, n, offset) end local function resampling(x, y, train_x, transformer, input_size, target_size) - print("## resampling") + local c = 1 + local shuffle = torch.randperm(#train_x) for t = 1, #train_x do xlua.progress(t, #train_x) - local xy = transformer(train_x[t], false, settings.patches) + local xy = transformer(train_x[shuffle[t]], false, settings.patches) for i = 1, #xy do - local index = (t - 1) * settings.patches + i - x[index]:copy(xy[i][1]) - y[index]:copy(xy[i][2]) + x[c]:copy(xy[i][1]) + y[c]:copy(xy[i][2]) + c = c + 1 + if c > x:size(1) then + break + end + end + if c > x:size(1) then + break end if t % 50 == 0 then collectgarbage() end end + xlua.progress(#train_x, #train_x) end +local function get_oracle_data(x, y, instance_loss, k, samples) + local index = torch.LongTensor(instance_loss:size(1)) + local dummy = torch.Tensor(instance_loss:size(1)) + torch.topk(dummy, index, instance_loss, k, 1, true) + print("average loss: " ..instance_loss:mean() .. ", average oracle loss: " .. dummy:mean()) + local shuffle = torch.randperm(k) + local x_s = x:size() + local y_s = y:size() + x_s[1] = samples + y_s[1] = samples + local oracle_x = torch.Tensor(table.unpack(torch.totable(x_s))) + local oracle_y = torch.Tensor(table.unpack(torch.totable(y_s))) + + for i = 1, samples do + oracle_x[i]:copy(x[index[shuffle[i]]]) + oracle_y[i]:copy(y[index[shuffle[i]]]) + end + return oracle_x, oracle_y +end + local function remove_small_image(x) local new_x = {} for i = 1, #x do @@ -254,12 +282,33 @@ local function train() x = torch.Tensor(settings.patches * #train_x, ch, settings.crop_size, settings.crop_size) end + local instance_loss = nil + for epoch = 1, settings.epoch do model:training() print("# " .. epoch) - resampling(x, y, train_x, pairwise_func) + print("## resampling") + if instance_loss then + -- active learning + local oracle_k = math.min(x:size(1) * (settings.oracle_rate * (1 / (1 - settings.oracle_drop_rate))), x:size(1)) + local oracle_n = math.min(x:size(1) * settings.oracle_rate, x:size(1)) + if oracle_n > 0 then + local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n) + resampling(x, y, train_x, pairwise_func) + x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x) + y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y) + else + resampling(x, y, train_x, pairwise_func) + end + else + resampling(x, y, train_x, pairwise_func) + end + collectgarbage() + instance_loss = torch.Tensor(x:size(1)):zero() + for i = 1, settings.inner_epoch do - local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config) + local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config) + instance_loss:copy(il) print(train_score) model:evaluate() print("# validation")