1
0
Fork 0
mirror of synced 2024-05-29 09:09:34 +12:00
waifu2x/lib/minibatch_adam.lua
2016-06-02 10:11:15 +09:00

74 lines
2.3 KiB
Lua

require 'optim'
require 'cutorch'
require 'xlua'
local function minibatch_adam(model, criterion, eval_metric,
train_x, train_y,
config)
local parameters, gradParameters = model:getParameters()
config = config or {}
if config.xEvalCount == nil then
config.xEvalCount = 0
config.learningRate = config.xLearningRate
end
local sum_loss = 0
local sum_eval = 0
local count_loss = 0
local batch_size = config.xBatchSize or 32
local shuffle = torch.randperm(train_x:size(1))
local c = 1
local inputs_tmp = torch.Tensor(batch_size,
train_x:size(2), train_x:size(3), train_x:size(4)):zero()
local targets_tmp = torch.Tensor(batch_size,
train_y:size(2)):zero()
local inputs = inputs_tmp:clone():cuda()
local targets = targets_tmp:clone():cuda()
local instance_loss = torch.Tensor(train_x:size(1)):zero()
print("## update")
for t = 1, train_x:size(1), batch_size do
if t + batch_size -1 > train_x:size(1) then
break
end
for i = 1, batch_size do
inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]])
targets_tmp[i]:copy(train_y[shuffle[t + i - 1]])
end
inputs:copy(inputs_tmp)
targets:copy(targets_tmp)
local feval = function(x)
if x ~= parameters then
parameters:copy(x)
end
gradParameters:zero()
local output = model:forward(inputs)
local f = criterion:forward(output, targets)
local se = 0
for i = 1, batch_size do
local el = eval_metric:forward(output[i], targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
end
sum_eval = sum_eval + (se / batch_size)
sum_loss = sum_loss + f
count_loss = count_loss + 1
model:backward(inputs, criterion:backward(output, targets))
return f, gradParameters
end
optim.adam(feval, parameters, config)
config.xEvalCount = config.xEvalCount + batch_size
config.learningRate = config.xLearningRate / (1 + config.xEvalCount * config.xLearningRateDecay)
c = c + 1
if c % 50 == 0 then
collectgarbage()
xlua.progress(t, train_x:size(1))
end
end
xlua.progress(train_x:size(1), train_x:size(1))
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss
end
return minibatch_adam