1
0
Fork 0
mirror of synced 2024-06-16 18:04:31 +12:00

Add support for plotting loss chart

This commit is contained in:
nagadomi 2016-03-14 05:06:14 +09:00
parent 459c7c5e18
commit 4d115e4bdb
2 changed files with 22 additions and 2 deletions

View file

@ -47,11 +47,18 @@ cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
cmd:option("-active_cropping_tries", 10, 'active cropping tries') cmd:option("-active_cropping_tries", 10, 'active cropping tries')
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)') cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
cmd:option("-save_history", 0, 'save all model (0|1)') cmd:option("-save_history", 0, 'save all model (0|1)')
cmd:option("-plot", 0, 'plot loss chart(0|1)')
local opt = cmd:parse(arg) local opt = cmd:parse(arg)
for k, v in pairs(opt) do for k, v in pairs(opt) do
settings[k] = v settings[k] = v
end end
if settings.plot == 1 then
settings.plot = true
require 'gnuplot'
else
settings.plot = false
end
if settings.save_history == 1 then if settings.save_history == 1 then
settings.save_history = true settings.save_history = true
else else

View file

@ -157,8 +157,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
end end
end end
end end
local function plot(train, valid)
gnuplot.plot({
{'training', torch.Tensor(train), '-'},
{'validation', torch.Tensor(valid), '-'}})
end
local function train() local function train()
local hist_train = {}
local hist_valid = {}
local LR_MIN = 1.0e-5 local LR_MIN = 1.0e-5
local model = srcnn.create(settings.method, settings.backend, settings.color) local model = srcnn.create(settings.method, settings.backend, settings.color)
local offset = reconstruct.offset_size(model) local offset = reconstruct.offset_size(model)
@ -201,10 +207,17 @@ local function train()
print("# " .. epoch) print("# " .. epoch)
resampling(x, y, train_x, pairwise_func) resampling(x, y, train_x, pairwise_func)
for i = 1, settings.inner_epoch do for i = 1, settings.inner_epoch do
print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config)) local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
print(train_score)
model:evaluate() model:evaluate()
print("# validation") print("# validation")
local score = validate(model, eval_metric, valid_xy) local score = validate(model, eval_metric, valid_xy)
table.insert(hist_train, train_score.PSNR)
table.insert(hist_valid, score)
if settings.plot then
plot(hist_train, hist_valid)
end
if score > best_score then if score > best_score then
local test_image = image_loader.load_float(settings.test) -- reload local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0 lrd_count = 0