2015-05-16 17:48:05 +12:00
|
|
|
require 'optim'
|
|
|
|
require 'cutorch'
|
|
|
|
require 'xlua'
|
|
|
|
|
2015-05-17 17:42:53 +12:00
|
|
|
local function minibatch_adam(model, criterion,
|
|
|
|
train_x,
|
|
|
|
config, transformer,
|
|
|
|
input_size, target_size)
|
2015-05-16 17:48:05 +12:00
|
|
|
local parameters, gradParameters = model:getParameters()
|
|
|
|
config = config or {}
|
|
|
|
local sum_loss = 0
|
|
|
|
local count_loss = 0
|
|
|
|
local batch_size = config.xBatchSize or 32
|
|
|
|
local shuffle = torch.randperm(#train_x)
|
|
|
|
local c = 1
|
|
|
|
local inputs = torch.Tensor(batch_size,
|
|
|
|
input_size[1], input_size[2], input_size[3]):cuda()
|
|
|
|
local targets = torch.Tensor(batch_size,
|
|
|
|
target_size[1] * target_size[2] * target_size[3]):cuda()
|
|
|
|
local inputs_tmp = torch.Tensor(batch_size,
|
|
|
|
input_size[1], input_size[2], input_size[3])
|
|
|
|
local targets_tmp = torch.Tensor(batch_size,
|
|
|
|
target_size[1] * target_size[2] * target_size[3])
|
2015-10-26 13:23:52 +13:00
|
|
|
for t = 1, #train_x do
|
2015-05-16 17:48:05 +12:00
|
|
|
xlua.progress(t, #train_x)
|
2015-10-26 13:23:52 +13:00
|
|
|
local xy = transformer(train_x[shuffle[t]], false, batch_size)
|
|
|
|
for i = 1, #xy do
|
|
|
|
inputs_tmp[i]:copy(xy[i][1])
|
|
|
|
targets_tmp[i]:copy(xy[i][2])
|
2015-05-16 17:48:05 +12:00
|
|
|
end
|
|
|
|
inputs:copy(inputs_tmp)
|
|
|
|
targets:copy(targets_tmp)
|
|
|
|
local feval = function(x)
|
|
|
|
if x ~= parameters then
|
|
|
|
parameters:copy(x)
|
|
|
|
end
|
|
|
|
gradParameters:zero()
|
|
|
|
local output = model:forward(inputs)
|
|
|
|
local f = criterion:forward(output, targets)
|
|
|
|
sum_loss = sum_loss + f
|
|
|
|
count_loss = count_loss + 1
|
|
|
|
model:backward(inputs, criterion:backward(output, targets))
|
|
|
|
return f, gradParameters
|
|
|
|
end
|
|
|
|
optim.adam(feval, parameters, config)
|
|
|
|
|
|
|
|
c = c + 1
|
2015-11-03 10:10:44 +13:00
|
|
|
if c % 20 == 0 then
|
2015-05-16 17:48:05 +12:00
|
|
|
collectgarbage()
|
|
|
|
end
|
|
|
|
end
|
|
|
|
xlua.progress(#train_x, #train_x)
|
|
|
|
|
2015-11-01 01:56:20 +13:00
|
|
|
return { loss = sum_loss / count_loss}
|
2015-05-16 17:48:05 +12:00
|
|
|
end
|
|
|
|
|
2015-05-17 17:42:53 +12:00
|
|
|
return minibatch_adam
|