From 81df729a8a5f1db8dbde5f2515ef910817d6d230 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 9 Jul 2016 15:05:11 +0900 Subject: [PATCH] Remove -loss option --- lib/settings.lua | 1 - train.lua | 14 +++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/lib/settings.lua b/lib/settings.lua index 0aa98ef..1c66cb3 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -57,7 +57,6 @@ cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage') cmd:option("-oracle_rate", 0.1, '') cmd:option("-oracle_drop_rate", 0.5, '') cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))') -cmd:option("-loss", "y", 'loss (rgb|y)') cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') diff --git a/train.lua b/train.lua index bd89d58..6974cca 100644 --- a/train.lua +++ b/train.lua @@ -101,18 +101,14 @@ local function validate(model, criterion, eval_metric, data, batch_size) return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))} end -local function create_criterion(model, loss) +local function create_criterion(model) if reconstruct.is_rgb(model) then local offset = reconstruct.offset_size(model) local output_w = settings.crop_size - offset * 2 local weight = torch.Tensor(3, output_w * output_w) - if loss == "y" then - weight[1]:fill(0.29891 * 3) -- R - weight[2]:fill(0.58661 * 3) -- G - weight[3]:fill(0.11448 * 3) -- B - else - weight:fill(1) - end + weight[1]:fill(0.29891 * 3) -- R + weight[2]:fill(0.58661 * 3) -- G + weight[3]:fill(0.11448 * 3) -- B return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() else local offset = reconstruct.offset_size(model) @@ -309,7 +305,7 @@ local function train() local pairwise_func = function(x, is_validation, n) return transformer(model, x, is_validation, n, offset) end - local criterion = create_criterion(model, settings.loss) + local criterion = create_criterion(model) local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda() local x = remove_small_image(torch.load(settings.images)) local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))