1
0
Fork 0
mirror of synced 2024-06-22 04:40:15 +12:00

Remove -loss option

This commit is contained in:
nagadomi 2016-07-09 15:05:11 +09:00
parent 83e5d885f5
commit 81df729a8a
2 changed files with 5 additions and 10 deletions

View file

@ -57,7 +57,6 @@ cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage')
cmd:option("-oracle_rate", 0.1, '')
cmd:option("-oracle_drop_rate", 0.5, '')
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
cmd:option("-loss", "y", 'loss (rgb|y)')
cmd:option("-resume", "", 'resume model file')
cmd:option("-name", "user", 'model name for user method')

View file

@ -101,18 +101,14 @@ local function validate(model, criterion, eval_metric, data, batch_size)
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
end
local function create_criterion(model, loss)
local function create_criterion(model)
if reconstruct.is_rgb(model) then
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
if loss == "y" then
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
else
weight:fill(1)
end
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
else
local offset = reconstruct.offset_size(model)
@ -309,7 +305,7 @@ local function train()
local pairwise_func = function(x, is_validation, n)
return transformer(model, x, is_validation, n, offset)
end
local criterion = create_criterion(model, settings.loss)
local criterion = create_criterion(model)
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
local x = remove_small_image(torch.load(settings.images))
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))