Add support for plotting loss chart
This commit is contained in:
parent
459c7c5e18
commit
4d115e4bdb
|
@ -47,11 +47,18 @@ cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
|||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
||||
cmd:option("-save_history", 0, 'save all model (0|1)')
|
||||
cmd:option("-plot", 0, 'plot loss chart(0|1)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
for k, v in pairs(opt) do
|
||||
settings[k] = v
|
||||
end
|
||||
if settings.plot == 1 then
|
||||
settings.plot = true
|
||||
require 'gnuplot'
|
||||
else
|
||||
settings.plot = false
|
||||
end
|
||||
if settings.save_history == 1 then
|
||||
settings.save_history = true
|
||||
else
|
||||
|
|
17
train.lua
17
train.lua
|
@ -157,8 +157,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function plot(train, valid)
|
||||
gnuplot.plot({
|
||||
{'training', torch.Tensor(train), '-'},
|
||||
{'validation', torch.Tensor(valid), '-'}})
|
||||
end
|
||||
local function train()
|
||||
local hist_train = {}
|
||||
local hist_valid = {}
|
||||
local LR_MIN = 1.0e-5
|
||||
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
||||
local offset = reconstruct.offset_size(model)
|
||||
|
@ -201,10 +207,17 @@ local function train()
|
|||
print("# " .. epoch)
|
||||
resampling(x, y, train_x, pairwise_func)
|
||||
for i = 1, settings.inner_epoch do
|
||||
print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
|
||||
local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
||||
print(train_score)
|
||||
model:evaluate()
|
||||
print("# validation")
|
||||
local score = validate(model, eval_metric, valid_xy)
|
||||
|
||||
table.insert(hist_train, train_score.PSNR)
|
||||
table.insert(hist_valid, score)
|
||||
if settings.plot then
|
||||
plot(hist_train, hist_valid)
|
||||
end
|
||||
if score > best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
lrd_count = 0
|
||||
|
|
Loading…
Reference in a new issue