1
0
Fork 0
mirror of synced 2024-09-28 15:31:25 +12:00

Use correct criterion

This commit is contained in:
nagadomi 2016-05-28 10:34:59 +09:00
parent 99e6dd1a57
commit b96bc5d453

View file

@ -58,8 +58,9 @@ local function make_validation_set(x, transformer, n, patches)
data = new_data data = new_data
return data return data
end end
local function validate(model, criterion, data, batch_size) local function validate(model, criterion, eval_metric, data, batch_size)
local loss = 0 local loss = 0
local mse = 0
local loss_count = 0 local loss_count = 0
local inputs_tmp = torch.Tensor(batch_size, local inputs_tmp = torch.Tensor(batch_size,
data[1].x:size(1), data[1].x:size(1),
@ -83,6 +84,7 @@ local function validate(model, criterion, data, batch_size)
targets:copy(targets_tmp) targets:copy(targets_tmp)
local z = model:forward(inputs) local z = model:forward(inputs)
loss = loss + criterion:forward(z, targets) loss = loss + criterion:forward(z, targets)
mse = mse + eval_metric:forward(z, targets)
loss_count = loss_count + 1 loss_count = loss_count + 1
if loss_count % 10 == 0 then if loss_count % 10 == 0 then
xlua.progress(t, #data) xlua.progress(t, #data)
@ -90,7 +92,7 @@ local function validate(model, criterion, data, batch_size)
end end
end end
xlua.progress(#data, #data) xlua.progress(#data, #data)
return loss / loss_count return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
end end
local function create_criterion(model) local function create_criterion(model)
@ -247,7 +249,7 @@ local function train()
return transformer(model, x, is_validation, n, offset) return transformer(model, x, is_validation, n, offset)
end end
local criterion = create_criterion(model) local criterion = create_criterion(model)
local eval_metric = nn.MSECriterion():cuda() local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
local x = remove_small_image(torch.load(settings.images)) local x = remove_small_image(torch.load(settings.images))
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local adam_config = { local adam_config = {
@ -312,16 +314,16 @@ local function train()
print(train_score) print(train_score)
model:evaluate() model:evaluate()
print("# validation") print("# validation")
local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize) local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
table.insert(hist_train, train_score.MSE) table.insert(hist_train, train_score.loss)
table.insert(hist_valid, score) table.insert(hist_valid, score.loss)
if settings.plot then if settings.plot then
plot(hist_train, hist_valid) plot(hist_train, hist_valid)
end end
if score < best_score then if score.loss < best_score then
local test_image = image_loader.load_float(settings.test) -- reload local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0 lrd_count = 0
best_score = score best_score = score.loss
print("* update best model") print("* update best model")
if settings.save_history then if settings.save_history then
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii") torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
@ -356,7 +358,7 @@ local function train()
lrd_count = 0 lrd_count = 0
end end
end end
print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score) print("PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score)
collectgarbage() collectgarbage()
end end
end end