1
0
Fork 0
mirror of synced 2024-06-02 02:54:31 +12:00
waifu2x/lib/minibatch_adam.lua
nagadomi 490eb33a6b Minimize the weighted huber loss instead of the weighted mean square error
Huber loss is less sensitive to outliers(i.e. noise) in data than the squared error loss.
2015-10-31 22:05:59 +09:00

58 lines
1.7 KiB
Lua

require 'optim'
require 'cutorch'
require 'xlua'
local function minibatch_adam(model, criterion,
train_x,
config, transformer,
input_size, target_size)
local parameters, gradParameters = model:getParameters()
config = config or {}
local sum_loss = 0
local count_loss = 0
local batch_size = config.xBatchSize or 32
local shuffle = torch.randperm(#train_x)
local c = 1
local inputs = torch.Tensor(batch_size,
input_size[1], input_size[2], input_size[3]):cuda()
local targets = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3]):cuda()
local inputs_tmp = torch.Tensor(batch_size,
input_size[1], input_size[2], input_size[3])
local targets_tmp = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3])
for t = 1, #train_x do
xlua.progress(t, #train_x)
local xy = transformer(train_x[shuffle[t]], false, batch_size)
for i = 1, #xy do
inputs_tmp[i]:copy(xy[i][1])
targets_tmp[i]:copy(xy[i][2])
end
inputs:copy(inputs_tmp)
targets:copy(targets_tmp)
local feval = function(x)
if x ~= parameters then
parameters:copy(x)
end
gradParameters:zero()
local output = model:forward(inputs)
local f = criterion:forward(output, targets)
sum_loss = sum_loss + f
count_loss = count_loss + 1
model:backward(inputs, criterion:backward(output, targets))
return f, gradParameters
end
optim.adam(feval, parameters, config)
c = c + 1
if c % 10 == 0 then
collectgarbage()
end
end
xlua.progress(#train_x, #train_x)
return { loss = sum_loss / count_loss}
end
return minibatch_adam