1
0
Fork 0
mirror of synced 2024-06-26 18:20:26 +12:00
waifu2x/lib/minibatch_adam.lua

59 lines
1.8 KiB
Lua
Raw Normal View History

2015-05-16 17:48:05 +12:00
require 'optim'
require 'cutorch'
require 'xlua'
2016-03-12 10:53:42 +13:00
local function minibatch_adam(model, criterion, eval_metric,
2015-11-30 21:18:52 +13:00
train_x, train_y,
config)
2015-05-16 17:48:05 +12:00
local parameters, gradParameters = model:getParameters()
config = config or {}
local sum_loss = 0
2016-03-12 10:53:42 +13:00
local sum_eval = 0
2015-05-16 17:48:05 +12:00
local count_loss = 0
local batch_size = config.xBatchSize or 32
2015-11-30 21:18:52 +13:00
local shuffle = torch.randperm(train_x:size(1))
2015-05-16 17:48:05 +12:00
local c = 1
local inputs_tmp = torch.Tensor(batch_size,
2015-11-30 21:18:52 +13:00
train_x:size(2), train_x:size(3), train_x:size(4)):zero()
2015-05-16 17:48:05 +12:00
local targets_tmp = torch.Tensor(batch_size,
2015-11-30 21:18:52 +13:00
train_y:size(2)):zero()
local inputs = inputs_tmp:clone():cuda()
local targets = targets_tmp:clone():cuda()
print("## update")
for t = 1, train_x:size(1), batch_size do
if t + batch_size -1 > train_x:size(1) then
break
end
for i = 1, batch_size do
inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]])
targets_tmp[i]:copy(train_y[shuffle[t + i - 1]])
2015-05-16 17:48:05 +12:00
end
inputs:copy(inputs_tmp)
targets:copy(targets_tmp)
local feval = function(x)
if x ~= parameters then
parameters:copy(x)
end
gradParameters:zero()
local output = model:forward(inputs)
local f = criterion:forward(output, targets)
2016-03-12 10:53:42 +13:00
sum_eval = sum_eval + eval_metric:forward(output, targets)
2015-05-16 17:48:05 +12:00
sum_loss = sum_loss + f
count_loss = count_loss + 1
model:backward(inputs, criterion:backward(output, targets))
return f, gradParameters
end
optim.adam(feval, parameters, config)
c = c + 1
2015-11-30 21:18:52 +13:00
if c % 50 == 0 then
2015-05-16 17:48:05 +12:00
collectgarbage()
2016-04-11 02:30:23 +12:00
xlua.progress(t, train_x:size(1))
2015-05-16 17:48:05 +12:00
end
end
2015-11-30 21:18:52 +13:00
xlua.progress(train_x:size(1), train_x:size(1))
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}
2015-05-16 17:48:05 +12:00
end
2015-05-17 17:42:53 +12:00
return minibatch_adam