1
0
Fork 0
mirror of synced 2024-05-23 14:19:38 +12:00

Add support for plotting loss chart

This commit is contained in:
nagadomi 2016-03-14 05:06:14 +09:00
parent 459c7c5e18
commit 4d115e4bdb
2 changed files with 22 additions and 2 deletions

View file

@ -47,11 +47,18 @@ cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
cmd:option("-save_history", 0, 'save all model (0|1)')
cmd:option("-plot", 0, 'plot loss chart(0|1)')
local opt = cmd:parse(arg)
for k, v in pairs(opt) do
settings[k] = v
end
if settings.plot == 1 then
settings.plot = true
require 'gnuplot'
else
settings.plot = false
end
if settings.save_history == 1 then
settings.save_history = true
else

View file

@ -157,8 +157,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
end
end
end
local function plot(train, valid)
gnuplot.plot({
{'training', torch.Tensor(train), '-'},
{'validation', torch.Tensor(valid), '-'}})
end
local function train()
local hist_train = {}
local hist_valid = {}
local LR_MIN = 1.0e-5
local model = srcnn.create(settings.method, settings.backend, settings.color)
local offset = reconstruct.offset_size(model)
@ -201,10 +207,17 @@ local function train()
print("# " .. epoch)
resampling(x, y, train_x, pairwise_func)
for i = 1, settings.inner_epoch do
print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
print(train_score)
model:evaluate()
print("# validation")
local score = validate(model, eval_metric, valid_xy)
table.insert(hist_train, train_score.PSNR)
table.insert(hist_valid, score)
if settings.plot then
plot(hist_train, hist_valid)
end
if score > best_score then
local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0