diff --git a/lib/settings.lua b/lib/settings.lua index 5d5cbde..c1f829b 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -46,22 +46,37 @@ cmd:option("-validation_crops", 80, 'number of cropping region per image in vali cmd:option("-active_cropping_rate", 0.5, 'active cropping rate') cmd:option("-active_cropping_tries", 10, 'active cropping tries') cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)') +cmd:option("-save_history", 0, 'save all model (0|1)') local opt = cmd:parse(arg) for k, v in pairs(opt) do settings[k] = v end -if settings.method == "noise" then - settings.model_file = string.format("%s/noise%d_model.t7", - settings.model_dir, settings.noise_level) -elseif settings.method == "scale" then - settings.model_file = string.format("%s/scale%.1fx_model.t7", - settings.model_dir, settings.scale) -elseif settings.method == "noise_scale" then - settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7", - settings.model_dir, settings.noise_level, settings.scale) +if settings.save_history == 1 then + settings.save_history = true else - error("unknown method: " .. settings.method) + settings.save_history = false +end +if settings.save_history then + if settings.method == "noise" then + settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7", + settings.model_dir, settings.noise_level) + elseif settings.method == "scale" then + settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7", + settings.model_dir, settings.scale) + else + error("unknown method: " .. settings.method) + end +else + if settings.method == "noise" then + settings.model_file = string.format("%s/noise%d_model.t7", + settings.model_dir, settings.noise_level) + elseif settings.method == "scale" then + settings.model_file = string.format("%s/scale%.1fx_model.t7", + settings.model_dir, settings.scale) + else + error("unknown method: " .. settings.method) + end end if not (settings.color == "rgb" or settings.color == "y") then error("color must be y or rgb") diff --git a/train.lua b/train.lua index 16a9dee..82531ef 100644 --- a/train.lua +++ b/train.lua @@ -205,15 +205,32 @@ local function train() lrd_count = 0 best_score = score print("* update best model") - torch.save(settings.model_file, model) - if settings.method == "noise" then - local log = path.join(settings.model_dir, - ("noise%d_best.png"):format(settings.noise_level)) - save_test_jpeg(model, test_image, log) - elseif settings.method == "scale" then - local log = path.join(settings.model_dir, - ("scale%.1f_best.png"):format(settings.scale)) - save_test_scale(model, test_image, log) + if settings.save_history then + local model_clone = model:clone() + w2nn.cleanup_model(model_clone) + torch.save(string.format(settings.model_file, epoch, i), model_clone) + if settings.method == "noise" then + local log = path.join(settings.model_dir, + ("noise%d_best.%d-%d.png"):format(settings.noise_level, + epoch, i)) + save_test_jpeg(model, test_image, log) + elseif settings.method == "scale" then + local log = path.join(settings.model_dir, + ("scale%.1f_best.%d-%d.png"):format(settings.scale, + epoch, i)) + save_test_scale(model, test_image, log) + end + else + torch.save(settings.model_file, model) + if settings.method == "noise" then + local log = path.join(settings.model_dir, + ("noise%d_best.png"):format(settings.noise_level)) + save_test_jpeg(model, test_image, log) + elseif settings.method == "scale" then + local log = path.join(settings.model_dir, + ("scale%.1f_best.png"):format(settings.scale)) + save_test_scale(model, test_image, log) + end end else lrd_count = lrd_count + 1