individual filters and box-only support
This commit is contained in:
parent
5e222a3981
commit
d0630d3a20
|
@ -35,7 +35,13 @@ local function load_images(list)
|
|||
local skip_notice = false
|
||||
for i = 1, #lines do
|
||||
local line = lines[i]
|
||||
local im, meta = image_loader.load_byte(line)
|
||||
local v = utils.split(line, ",")
|
||||
local filename = v[1]
|
||||
local filters = v[2]
|
||||
if filters then
|
||||
filters = utils.split(filters, ":")
|
||||
end
|
||||
local im, meta = image_loader.load_byte(filename)
|
||||
local skip = false
|
||||
if meta and meta.alpha then
|
||||
if settings.use_transparent_png then
|
||||
|
@ -60,12 +66,12 @@ local function load_images(list)
|
|||
end
|
||||
if im then
|
||||
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
|
||||
table.insert(x, compression.compress(im))
|
||||
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
|
||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
|
||||
end
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", line))
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", filename))
|
||||
end
|
||||
end
|
||||
xlua.progress(i, #lines)
|
||||
|
|
103
image_generators/dots/gen.lua
Normal file
103
image_generators/dots/gen.lua
Normal file
|
@ -0,0 +1,103 @@
|
|||
require 'pl'
|
||||
require 'image'
|
||||
require 'trepl'
|
||||
|
||||
local gm = require 'graphicsmagick'
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
|
||||
local function color(black)
|
||||
local r, g, b
|
||||
if torch.uniform() > 0.8 then
|
||||
if black then
|
||||
return {0, 0, 0}
|
||||
else
|
||||
return {1, 1, 1}
|
||||
end
|
||||
else
|
||||
if torch.uniform() > 0.7 then
|
||||
r = torch.random(0, 1)
|
||||
else
|
||||
r = torch.uniform(0, 1)
|
||||
end
|
||||
if torch.uniform() > 0.7 then
|
||||
g = torch.random(0, 1)
|
||||
else
|
||||
g = torch.uniform(0, 1)
|
||||
end
|
||||
if torch.uniform() > 0.7 then
|
||||
b = torch.random(0, 1)
|
||||
else
|
||||
b = torch.uniform(0, 1)
|
||||
end
|
||||
end
|
||||
return {r,g,b}
|
||||
end
|
||||
|
||||
local function gen_mod()
|
||||
local f = function()
|
||||
local xm = torch.random(2, 4)
|
||||
local ym = torch.random(2, 4)
|
||||
return function(x, y) return x % xm == 0 and y % ym == 0 end
|
||||
end
|
||||
return f()
|
||||
end
|
||||
local function dot()
|
||||
local sp = 1
|
||||
local blocks = {}
|
||||
local n = 64
|
||||
local s = 24
|
||||
for i = 1, n do
|
||||
local block = torch.Tensor(3, s, s)
|
||||
local margin = torch.random(1, 3)
|
||||
local size = torch.random(1, 4)
|
||||
local mod = gen_mod()
|
||||
local fg = color(true)
|
||||
local bg = color()
|
||||
for j = 1, 3 do
|
||||
block[j]:fill(bg[j])
|
||||
end
|
||||
for y = margin, s - margin do
|
||||
for x = margin, s - margin do
|
||||
local yc = math.floor(y / size)
|
||||
local xc = math.floor(x / size)
|
||||
if mod(yc, xc) then
|
||||
block[1][y][x] = fg[1]
|
||||
block[2][y][x] = fg[2]
|
||||
block[3][y][x] = fg[3]
|
||||
end
|
||||
end
|
||||
end
|
||||
block = image.scale(block, s * 2, s * 2, "simple")
|
||||
if size >= 3 and torch.uniform() > 0.5 then
|
||||
block = image.rotate(block, math.pi / 4, "bilinear")
|
||||
end
|
||||
blocks[i] = block
|
||||
end
|
||||
local img = torch.Tensor(#blocks, 3, s * 2, s * 2)
|
||||
for i = 1, #blocks do
|
||||
img[i]:copy(blocks[i])
|
||||
end
|
||||
img = image.toDisplayTensor({input = img, padding = 0, nrow = math.pow(n, 0.5), min = 0, max = 1})
|
||||
return img
|
||||
end
|
||||
local function gen()
|
||||
return dot()
|
||||
end
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("dot image generator")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-o", "", 'output directory')
|
||||
cmd:option("-n", 64, 'number of images')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if opt.o:len() == 0 then
|
||||
cmd:help()
|
||||
os.exit(1)
|
||||
end
|
||||
|
||||
for i = 1, opt.n do
|
||||
local img = gen()
|
||||
image.save(path.join(opt.o, i .. ".png"), img)
|
||||
end
|
|
@ -100,7 +100,6 @@ function data_augmentation.flip(src)
|
|||
local tr = torch.random(1, 2)
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
local dest
|
||||
|
||||
src = src:contiguous()
|
||||
if tr == 1 then
|
||||
-- pass
|
||||
|
|
|
@ -5,6 +5,9 @@ local pairwise_transform = {}
|
|||
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
local filters = options.downsampling_filters
|
||||
if options.data.filters then
|
||||
filters = options.data.filters
|
||||
end
|
||||
local unstable_region_offset = 8
|
||||
local downsampling_filter = filters[torch.random(1, #filters)]
|
||||
local blur = torch.uniform(options.resize_blur_min, options.resize_blur_max)
|
||||
|
@ -48,8 +51,41 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|||
size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
|
||||
size(y:size(3), y:size(2), "Box"):
|
||||
toTensor(t, "RGB", "DHW")
|
||||
local xs = {}
|
||||
local ys = {}
|
||||
local lowreses = {}
|
||||
|
||||
for j = 1, 2 do
|
||||
-- TTA
|
||||
local xi, yi, ri
|
||||
if j == 1 then
|
||||
xi = x
|
||||
yi = y
|
||||
ri = lowres_y
|
||||
else
|
||||
xi = x:transpose(2, 3):contiguous()
|
||||
yi = y:transpose(2, 3):contiguous()
|
||||
ri = lowres_y:transpose(2, 3):contiguous()
|
||||
end
|
||||
local xv = image.vflip(xi)
|
||||
local yv = image.vflip(yi)
|
||||
local rv = image.vflip(ri)
|
||||
table.insert(xs, xi)
|
||||
table.insert(ys, yi)
|
||||
table.insert(lowreses, ri)
|
||||
table.insert(xs, xv)
|
||||
table.insert(ys, yv)
|
||||
table.insert(lowreses, rv)
|
||||
table.insert(xs, image.hflip(xi))
|
||||
table.insert(ys, image.hflip(yi))
|
||||
table.insert(lowreses, image.hflip(ri))
|
||||
table.insert(xs, image.hflip(xv))
|
||||
table.insert(ys, image.hflip(yv))
|
||||
table.insert(lowreses, image.hflip(rv))
|
||||
end
|
||||
for i = 1, n do
|
||||
local xc, yc = pairwise_utils.active_cropping(x, y, lowres_y,
|
||||
local t = (i % #xs) + 1
|
||||
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], lowreses[t],
|
||||
size,
|
||||
scale_inner,
|
||||
options.active_cropping_rate,
|
||||
|
|
|
@ -11,13 +11,18 @@ function pairwise_transform_utils.random_half(src, p, filters)
|
|||
return src
|
||||
end
|
||||
end
|
||||
function pairwise_transform_utils.crop_if_large(src, max_size)
|
||||
function pairwise_transform_utils.crop_if_large(src, max_size, mod)
|
||||
local tries = 4
|
||||
if src:size(2) > max_size and src:size(3) > max_size then
|
||||
assert(max_size % 4 == 0)
|
||||
local rect
|
||||
for i = 1, tries do
|
||||
local yi = torch.random(0, src:size(2) - max_size)
|
||||
local xi = torch.random(0, src:size(3) - max_size)
|
||||
if mod then
|
||||
yi = yi - (yi % mod)
|
||||
xi = xi - (xi % mod)
|
||||
end
|
||||
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
||||
-- ignore simple background
|
||||
if rect:float():std() >= 0 then
|
||||
|
@ -31,14 +36,28 @@ function pairwise_transform_utils.crop_if_large(src, max_size)
|
|||
end
|
||||
function pairwise_transform_utils.preprocess(src, crop_size, options)
|
||||
local dest = src
|
||||
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
|
||||
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
dest = data_augmentation.shift_1px(dest)
|
||||
|
||||
local box_only = false
|
||||
if options.data.filters then
|
||||
if #options.data.filters == 1 and options.data.filters[1] == "Box" then
|
||||
box_only = true
|
||||
end
|
||||
end
|
||||
if box_only then
|
||||
local mod = 2 -- assert pos % 2 == 0
|
||||
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
else
|
||||
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
|
||||
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
dest = data_augmentation.shift_1px(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
||||
|
|
|
@ -59,6 +59,7 @@ cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage')
|
|||
cmd:option("-oracle_rate", 0.0, '')
|
||||
cmd:option("-oracle_drop_rate", 0.5, '')
|
||||
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
||||
cmd:option("-loss", "rgb", 'loss (rgb|y)')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
@ -84,9 +85,13 @@ if settings.save_history then
|
|||
if settings.method == "noise" then
|
||||
settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
|
||||
settings.model_dir, settings.noise_level)
|
||||
settings.model_file_best = string.format("%s/noise%d_model.t7",
|
||||
settings.model_dir, settings.noise_level)
|
||||
elseif settings.method == "scale" then
|
||||
settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
|
||||
settings.model_dir, settings.scale)
|
||||
settings.model_file_best = string.format("%s/scale%.1fx_model.t7",
|
||||
settings.model_dir, settings.scale)
|
||||
else
|
||||
error("unknown method: " .. settings.method)
|
||||
end
|
||||
|
|
112
train.lua
112
train.lua
|
@ -95,14 +95,18 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|||
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
|
||||
end
|
||||
|
||||
local function create_criterion(model)
|
||||
local function create_criterion(model, loss)
|
||||
if reconstruct.is_rgb(model) then
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(3, output_w * output_w)
|
||||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
if loss == "y" then
|
||||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
else
|
||||
weight:fill(1)
|
||||
end
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
else
|
||||
local offset = reconstruct.offset_size(model)
|
||||
|
@ -113,9 +117,14 @@ local function create_criterion(model)
|
|||
end
|
||||
end
|
||||
local function transformer(model, x, is_validation, n, offset)
|
||||
x = compression.decompress(x)
|
||||
local meta = {data = {}}
|
||||
if type(x) == "table" and type(x[2]) == "table" then
|
||||
meta = x[2]
|
||||
x = compression.decompress(x[1])
|
||||
else
|
||||
x = compression.decompress(x)
|
||||
end
|
||||
n = n or settings.patches
|
||||
|
||||
if is_validation == nil then is_validation = false end
|
||||
local random_color_noise_rate = nil
|
||||
local random_overlay_rate = nil
|
||||
|
@ -132,46 +141,43 @@ local function transformer(model, x, is_validation, n, offset)
|
|||
random_color_noise_rate = settings.random_color_noise_rate
|
||||
random_overlay_rate = settings.random_overlay_rate
|
||||
end
|
||||
|
||||
if settings.method == "scale" then
|
||||
local conf = tablex.update({
|
||||
downsampling_filters = settings.downsampling_filters,
|
||||
upsampling_filter = settings.upsampling_filter,
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb"),
|
||||
gamma_correction = settings.gamma_correction,
|
||||
x_upsampling = not reconstruct.has_resize(model),
|
||||
resize_blur_min = settings.resize_blur_min,
|
||||
resize_blur_max = settings.resize_blur_max}, meta)
|
||||
return pairwise_transform.scale(x,
|
||||
settings.scale,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{
|
||||
downsampling_filters = settings.downsampling_filters,
|
||||
upsampling_filter = settings.upsampling_filter,
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb"),
|
||||
gamma_correction = settings.gamma_correction,
|
||||
x_upsampling = not reconstruct.has_resize(model),
|
||||
resize_blur_min = settings.resize_blur_min,
|
||||
resize_blur_max = settings.resize_blur_max,
|
||||
})
|
||||
n, conf)
|
||||
elseif settings.method == "noise" then
|
||||
local conf = tablex.update({
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
nr_rate = settings.nr_rate,
|
||||
rgb = (settings.color == "rgb")}, meta)
|
||||
return pairwise_transform.jpeg(x,
|
||||
settings.style,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
nr_rate = settings.nr_rate,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
n, conf)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -221,7 +227,13 @@ end
|
|||
local function remove_small_image(x)
|
||||
local new_x = {}
|
||||
for i = 1, #x do
|
||||
local x_s = compression.size(x[i])
|
||||
local xe, meta, x_s
|
||||
xe = x[i]
|
||||
if type(xe) == "table" and type(xe[2]) == "table" then
|
||||
x_s = compression.size(xe[1])
|
||||
else
|
||||
x_s = compression.size(xe)
|
||||
end
|
||||
if x_s[2] / settings.scale > settings.crop_size + 32 and
|
||||
x_s[3] / settings.scale > settings.crop_size + 32 then
|
||||
table.insert(new_x, x[i])
|
||||
|
@ -247,7 +259,7 @@ local function train()
|
|||
local pairwise_func = function(x, is_validation, n)
|
||||
return transformer(model, x, is_validation, n, offset)
|
||||
end
|
||||
local criterion = create_criterion(model)
|
||||
local criterion = create_criterion(model, settings.loss)
|
||||
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
|
||||
local x = remove_small_image(torch.load(settings.images))
|
||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||
|
@ -297,9 +309,22 @@ local function train()
|
|||
local oracle_n = math.min(x:size(1) * settings.oracle_rate, x:size(1))
|
||||
if oracle_n > 0 then
|
||||
local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n)
|
||||
resampling(x, y, train_x, pairwise_func)
|
||||
resampling(x:narrow(1, oracle_x:size(1) + 1, x:size(1)-oracle_x:size(1)),
|
||||
y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x, pairwise_func)
|
||||
x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x)
|
||||
y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y)
|
||||
|
||||
local draw_n = math.floor(math.sqrt(oracle_x:size(1), 0.5))
|
||||
if draw_n > 100 then
|
||||
draw_n = 100
|
||||
end
|
||||
image.save(path.join(settings.model_dir, "oracle_x.png"),
|
||||
image.toDisplayTensor({
|
||||
input = oracle_x:narrow(1, 1, draw_n * draw_n),
|
||||
padding = 2,
|
||||
nrow = draw_n,
|
||||
min = 0,
|
||||
max = 1}))
|
||||
else
|
||||
resampling(x, y, train_x, pairwise_func)
|
||||
end
|
||||
|
@ -322,11 +347,12 @@ local function train()
|
|||
if settings.plot then
|
||||
plot(hist_train, hist_valid)
|
||||
end
|
||||
if score.loss < best_score then
|
||||
if score.MSE < best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
best_score = score.loss
|
||||
best_score = score.MSE
|
||||
print("* update best model")
|
||||
if settings.save_history then
|
||||
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
||||
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
|
@ -352,7 +378,7 @@ local function train()
|
|||
end
|
||||
end
|
||||
end
|
||||
print("PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score)
|
||||
print("PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue