1
0
Fork 0
mirror of synced 2024-05-16 10:52:20 +12:00
waifu2x/lib/minibatch_adam.lua
nagadomi 8dea362bed sync from internal repo
- Memory compression by snappy (lua-csnappy)
- Use RGB-wise Weighted MSE(R*0.299, G*0.587, B*0.114) instead of MSE
- Aggressive cropping for edge region
and some change.
2015-10-26 09:23:52 +09:00

60 lines
1.7 KiB
Lua

require 'optim'
require 'cutorch'
require 'xlua'
local function minibatch_adam(model, criterion,
train_x,
config, transformer,
input_size, target_size)
local parameters, gradParameters = model:getParameters()
config = config or {}
local sum_loss = 0
local count_loss = 0
local batch_size = config.xBatchSize or 32
local shuffle = torch.randperm(#train_x)
local c = 1
local inputs = torch.Tensor(batch_size,
input_size[1], input_size[2], input_size[3]):cuda()
local targets = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3]):cuda()
local inputs_tmp = torch.Tensor(batch_size,
input_size[1], input_size[2], input_size[3])
local targets_tmp = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3])
for t = 1, #train_x do
xlua.progress(t, #train_x)
local xy = transformer(train_x[shuffle[t]], false, batch_size)
for i = 1, #xy do
inputs_tmp[i]:copy(xy[i][1])
targets_tmp[i]:copy(xy[i][2])
end
inputs:copy(inputs_tmp)
targets:copy(targets_tmp)
local feval = function(x)
if x ~= parameters then
parameters:copy(x)
end
gradParameters:zero()
local output = model:forward(inputs)
local f = criterion:forward(output, targets)
sum_loss = sum_loss + f
count_loss = count_loss + 1
model:backward(inputs, criterion:backward(output, targets))
return f, gradParameters
end
optim.adam(feval, parameters, config)
c = c + 1
if c % 10 == 0 then
collectgarbage()
end
end
xlua.progress(#train_x, #train_x)
return { mse = sum_loss / count_loss}
end
return minibatch_adam