Add -save_history option
This commit is contained in:
parent
eea4c31d7b
commit
9f935835dd
|
@ -46,22 +46,37 @@ cmd:option("-validation_crops", 80, 'number of cropping region per image in vali
|
|||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
||||
cmd:option("-save_history", 0, 'save all model (0|1)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
for k, v in pairs(opt) do
|
||||
settings[k] = v
|
||||
end
|
||||
if settings.method == "noise" then
|
||||
settings.model_file = string.format("%s/noise%d_model.t7",
|
||||
settings.model_dir, settings.noise_level)
|
||||
elseif settings.method == "scale" then
|
||||
settings.model_file = string.format("%s/scale%.1fx_model.t7",
|
||||
settings.model_dir, settings.scale)
|
||||
elseif settings.method == "noise_scale" then
|
||||
settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
|
||||
settings.model_dir, settings.noise_level, settings.scale)
|
||||
if settings.save_history == 1 then
|
||||
settings.save_history = true
|
||||
else
|
||||
error("unknown method: " .. settings.method)
|
||||
settings.save_history = false
|
||||
end
|
||||
if settings.save_history then
|
||||
if settings.method == "noise" then
|
||||
settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
|
||||
settings.model_dir, settings.noise_level)
|
||||
elseif settings.method == "scale" then
|
||||
settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
|
||||
settings.model_dir, settings.scale)
|
||||
else
|
||||
error("unknown method: " .. settings.method)
|
||||
end
|
||||
else
|
||||
if settings.method == "noise" then
|
||||
settings.model_file = string.format("%s/noise%d_model.t7",
|
||||
settings.model_dir, settings.noise_level)
|
||||
elseif settings.method == "scale" then
|
||||
settings.model_file = string.format("%s/scale%.1fx_model.t7",
|
||||
settings.model_dir, settings.scale)
|
||||
else
|
||||
error("unknown method: " .. settings.method)
|
||||
end
|
||||
end
|
||||
if not (settings.color == "rgb" or settings.color == "y") then
|
||||
error("color must be y or rgb")
|
||||
|
|
35
train.lua
35
train.lua
|
@ -205,15 +205,32 @@ local function train()
|
|||
lrd_count = 0
|
||||
best_score = score
|
||||
print("* update best model")
|
||||
torch.save(settings.model_file, model)
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.png"):format(settings.noise_level))
|
||||
save_test_jpeg(model, test_image, log)
|
||||
elseif settings.method == "scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("scale%.1f_best.png"):format(settings.scale))
|
||||
save_test_scale(model, test_image, log)
|
||||
if settings.save_history then
|
||||
local model_clone = model:clone()
|
||||
w2nn.cleanup_model(model_clone)
|
||||
torch.save(string.format(settings.model_file, epoch, i), model_clone)
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.%d-%d.png"):format(settings.noise_level,
|
||||
epoch, i))
|
||||
save_test_jpeg(model, test_image, log)
|
||||
elseif settings.method == "scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("scale%.1f_best.%d-%d.png"):format(settings.scale,
|
||||
epoch, i))
|
||||
save_test_scale(model, test_image, log)
|
||||
end
|
||||
else
|
||||
torch.save(settings.model_file, model)
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.png"):format(settings.noise_level))
|
||||
save_test_jpeg(model, test_image, log)
|
||||
elseif settings.method == "scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("scale%.1f_best.png"):format(settings.scale))
|
||||
save_test_scale(model, test_image, log)
|
||||
end
|
||||
end
|
||||
else
|
||||
lrd_count = lrd_count + 1
|
||||
|
|
Loading…
Reference in a new issue