Add support for plotting loss chart
This commit is contained in:
parent
459c7c5e18
commit
4d115e4bdb
|
@ -47,11 +47,18 @@ cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||||
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
||||||
cmd:option("-save_history", 0, 'save all model (0|1)')
|
cmd:option("-save_history", 0, 'save all model (0|1)')
|
||||||
|
cmd:option("-plot", 0, 'plot loss chart(0|1)')
|
||||||
|
|
||||||
local opt = cmd:parse(arg)
|
local opt = cmd:parse(arg)
|
||||||
for k, v in pairs(opt) do
|
for k, v in pairs(opt) do
|
||||||
settings[k] = v
|
settings[k] = v
|
||||||
end
|
end
|
||||||
|
if settings.plot == 1 then
|
||||||
|
settings.plot = true
|
||||||
|
require 'gnuplot'
|
||||||
|
else
|
||||||
|
settings.plot = false
|
||||||
|
end
|
||||||
if settings.save_history == 1 then
|
if settings.save_history == 1 then
|
||||||
settings.save_history = true
|
settings.save_history = true
|
||||||
else
|
else
|
||||||
|
|
17
train.lua
17
train.lua
|
@ -157,8 +157,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
local function plot(train, valid)
|
||||||
|
gnuplot.plot({
|
||||||
|
{'training', torch.Tensor(train), '-'},
|
||||||
|
{'validation', torch.Tensor(valid), '-'}})
|
||||||
|
end
|
||||||
local function train()
|
local function train()
|
||||||
|
local hist_train = {}
|
||||||
|
local hist_valid = {}
|
||||||
local LR_MIN = 1.0e-5
|
local LR_MIN = 1.0e-5
|
||||||
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
||||||
local offset = reconstruct.offset_size(model)
|
local offset = reconstruct.offset_size(model)
|
||||||
|
@ -201,10 +207,17 @@ local function train()
|
||||||
print("# " .. epoch)
|
print("# " .. epoch)
|
||||||
resampling(x, y, train_x, pairwise_func)
|
resampling(x, y, train_x, pairwise_func)
|
||||||
for i = 1, settings.inner_epoch do
|
for i = 1, settings.inner_epoch do
|
||||||
print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
|
local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
||||||
|
print(train_score)
|
||||||
model:evaluate()
|
model:evaluate()
|
||||||
print("# validation")
|
print("# validation")
|
||||||
local score = validate(model, eval_metric, valid_xy)
|
local score = validate(model, eval_metric, valid_xy)
|
||||||
|
|
||||||
|
table.insert(hist_train, train_score.PSNR)
|
||||||
|
table.insert(hist_valid, score)
|
||||||
|
if settings.plot then
|
||||||
|
plot(hist_train, hist_valid)
|
||||||
|
end
|
||||||
if score > best_score then
|
if score > best_score then
|
||||||
local test_image = image_loader.load_float(settings.test) -- reload
|
local test_image = image_loader.load_float(settings.test) -- reload
|
||||||
lrd_count = 0
|
lrd_count = 0
|
||||||
|
|
Loading…
Reference in a new issue