Use correct criterion
This commit is contained in:
parent
99e6dd1a57
commit
b96bc5d453
1 changed files with 11 additions and 9 deletions
20
train.lua
20
train.lua
|
@ -58,8 +58,9 @@ local function make_validation_set(x, transformer, n, patches)
|
||||||
data = new_data
|
data = new_data
|
||||||
return data
|
return data
|
||||||
end
|
end
|
||||||
local function validate(model, criterion, data, batch_size)
|
local function validate(model, criterion, eval_metric, data, batch_size)
|
||||||
local loss = 0
|
local loss = 0
|
||||||
|
local mse = 0
|
||||||
local loss_count = 0
|
local loss_count = 0
|
||||||
local inputs_tmp = torch.Tensor(batch_size,
|
local inputs_tmp = torch.Tensor(batch_size,
|
||||||
data[1].x:size(1),
|
data[1].x:size(1),
|
||||||
|
@ -83,6 +84,7 @@ local function validate(model, criterion, data, batch_size)
|
||||||
targets:copy(targets_tmp)
|
targets:copy(targets_tmp)
|
||||||
local z = model:forward(inputs)
|
local z = model:forward(inputs)
|
||||||
loss = loss + criterion:forward(z, targets)
|
loss = loss + criterion:forward(z, targets)
|
||||||
|
mse = mse + eval_metric:forward(z, targets)
|
||||||
loss_count = loss_count + 1
|
loss_count = loss_count + 1
|
||||||
if loss_count % 10 == 0 then
|
if loss_count % 10 == 0 then
|
||||||
xlua.progress(t, #data)
|
xlua.progress(t, #data)
|
||||||
|
@ -90,7 +92,7 @@ local function validate(model, criterion, data, batch_size)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
xlua.progress(#data, #data)
|
xlua.progress(#data, #data)
|
||||||
return loss / loss_count
|
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
|
||||||
end
|
end
|
||||||
|
|
||||||
local function create_criterion(model)
|
local function create_criterion(model)
|
||||||
|
@ -247,7 +249,7 @@ local function train()
|
||||||
return transformer(model, x, is_validation, n, offset)
|
return transformer(model, x, is_validation, n, offset)
|
||||||
end
|
end
|
||||||
local criterion = create_criterion(model)
|
local criterion = create_criterion(model)
|
||||||
local eval_metric = nn.MSECriterion():cuda()
|
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
|
||||||
local x = remove_small_image(torch.load(settings.images))
|
local x = remove_small_image(torch.load(settings.images))
|
||||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||||
local adam_config = {
|
local adam_config = {
|
||||||
|
@ -312,16 +314,16 @@ local function train()
|
||||||
print(train_score)
|
print(train_score)
|
||||||
model:evaluate()
|
model:evaluate()
|
||||||
print("# validation")
|
print("# validation")
|
||||||
local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
|
local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
|
||||||
table.insert(hist_train, train_score.MSE)
|
table.insert(hist_train, train_score.loss)
|
||||||
table.insert(hist_valid, score)
|
table.insert(hist_valid, score.loss)
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
plot(hist_train, hist_valid)
|
plot(hist_train, hist_valid)
|
||||||
end
|
end
|
||||||
if score < best_score then
|
if score.loss < best_score then
|
||||||
local test_image = image_loader.load_float(settings.test) -- reload
|
local test_image = image_loader.load_float(settings.test) -- reload
|
||||||
lrd_count = 0
|
lrd_count = 0
|
||||||
best_score = score
|
best_score = score.loss
|
||||||
print("* update best model")
|
print("* update best model")
|
||||||
if settings.save_history then
|
if settings.save_history then
|
||||||
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
|
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
|
||||||
|
@ -356,7 +358,7 @@ local function train()
|
||||||
lrd_count = 0
|
lrd_count = 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score)
|
print("PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score)
|
||||||
collectgarbage()
|
collectgarbage()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue