diff --git a/lib/settings.lua b/lib/settings.lua index 4620fb6..32af3a1 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -47,11 +47,18 @@ cmd:option("-active_cropping_rate", 0.5, 'active cropping rate') cmd:option("-active_cropping_tries", 10, 'active cropping tries') cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)') cmd:option("-save_history", 0, 'save all model (0|1)') +cmd:option("-plot", 0, 'plot loss chart(0|1)') local opt = cmd:parse(arg) for k, v in pairs(opt) do settings[k] = v end +if settings.plot == 1 then + settings.plot = true + require 'gnuplot' +else + settings.plot = false +end if settings.save_history == 1 then settings.save_history = true else diff --git a/train.lua b/train.lua index faea13b..91aa1e3 100644 --- a/train.lua +++ b/train.lua @@ -157,8 +157,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size) end end end - +local function plot(train, valid) + gnuplot.plot({ + {'training', torch.Tensor(train), '-'}, + {'validation', torch.Tensor(valid), '-'}}) +end local function train() + local hist_train = {} + local hist_valid = {} local LR_MIN = 1.0e-5 local model = srcnn.create(settings.method, settings.backend, settings.color) local offset = reconstruct.offset_size(model) @@ -201,10 +207,17 @@ local function train() print("# " .. epoch) resampling(x, y, train_x, pairwise_func) for i = 1, settings.inner_epoch do - print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config)) + local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config) + print(train_score) model:evaluate() print("# validation") local score = validate(model, eval_metric, valid_xy) + + table.insert(hist_train, train_score.PSNR) + table.insert(hist_valid, score) + if settings.plot then + plot(hist_train, hist_valid) + end if score > best_score then local test_image = image_loader.load_float(settings.test) -- reload lrd_count = 0