diff --git a/train.lua b/train.lua index 0a7416b..801e3e1 100644 --- a/train.lua +++ b/train.lua @@ -58,8 +58,9 @@ local function make_validation_set(x, transformer, n, patches) data = new_data return data end -local function validate(model, criterion, data, batch_size) +local function validate(model, criterion, eval_metric, data, batch_size) local loss = 0 + local mse = 0 local loss_count = 0 local inputs_tmp = torch.Tensor(batch_size, data[1].x:size(1), @@ -83,6 +84,7 @@ local function validate(model, criterion, data, batch_size) targets:copy(targets_tmp) local z = model:forward(inputs) loss = loss + criterion:forward(z, targets) + mse = mse + eval_metric:forward(z, targets) loss_count = loss_count + 1 if loss_count % 10 == 0 then xlua.progress(t, #data) @@ -90,7 +92,7 @@ local function validate(model, criterion, data, batch_size) end end xlua.progress(#data, #data) - return loss / loss_count + return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))} end local function create_criterion(model) @@ -247,7 +249,7 @@ local function train() return transformer(model, x, is_validation, n, offset) end local criterion = create_criterion(model) - local eval_metric = nn.MSECriterion():cuda() + local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda() local x = remove_small_image(torch.load(settings.images)) local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local adam_config = { @@ -312,16 +314,16 @@ local function train() print(train_score) model:evaluate() print("# validation") - local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize) - table.insert(hist_train, train_score.MSE) - table.insert(hist_valid, score) + local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize) + table.insert(hist_train, train_score.loss) + table.insert(hist_valid, score.loss) if settings.plot then plot(hist_train, hist_valid) end - if score < best_score then + if score.loss < best_score then local test_image = image_loader.load_float(settings.test) -- reload lrd_count = 0 - best_score = score + best_score = score.loss print("* update best model") if settings.save_history then torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii") @@ -356,7 +358,7 @@ local function train() lrd_count = 0 end end - print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score) + print("PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score) collectgarbage() end end