diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index 87b33c5..fdabae9 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -65,9 +65,28 @@ local function flip_augment(x, y) return x end end +local function overlay_augment(src, p) + p = p or 0.25 + if torch.uniform() > (1.0 - p) then + local r = torch.uniform(0.2, 0.8) + local t = "float" + if src:type() == "torch.ByteTensor" then + src = src:float():div(255) + t = "byte" + end + local flip = flip_augment(src) + flip:mul(r):add(src * (1.0 - r)) + if t == "byte" then + flip = flip:mul(255):byte() + end + return flip + else + return src + end +end local INTERPOLATION_PADDING = 16 function pairwise_transform.scale(src, scale, size, offset, options) - options = options or {color_noise = false, random_half = true, rgb = true} + options = options or {color_noise = false, overlay = false, random_half = true, rgb = true} if options.random_half then src = random_half(src) end @@ -92,6 +111,9 @@ function pairwise_transform.scale(src, scale, size, offset, options) if options.color_noise then y = color_noise(y) end + if options.overlay then + y = overlay_augment(y) + end local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter) x = iproc.scale(x, y:size(3), y:size(2)) y = y:float():div(255) @@ -109,7 +131,7 @@ function pairwise_transform.scale(src, scale, size, offset, options) return x, y end function pairwise_transform.jpeg_(src, quality, size, offset, options) - options = options or {color_noise = false, random_half = true, rgb = true} + options = options or {color_noise = false, overlay = false, random_half = true, rgb = true} if options.random_half then src = random_half(src) end @@ -121,6 +143,9 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options) if options.color_noise then y = color_noise(y) end + if options.overlay then + y = overlay_augment(y) + end x = y for i = 1, #quality do x = gm.Image(x, "RGB", "DHW") @@ -236,6 +261,10 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio if options.color_noise then y = color_noise(y) end + if options.overlay then + y = overlay_augment(y) + end + x = y x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter) for i = 1, #quality do @@ -325,12 +354,12 @@ end local function test_jpeg() local loader = require './image_loader' local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") - local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, false) + local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, {}) image.display({image = y, legend = "y:0"}) image.display({image = x, legend = "x:0"}) for i = 2, 9 do - local y, x = pairwise_transform.jpeg_(pairwise_transform.random_half(src), - {i * 10}, 128, 0, {color_noise = false, random_half = true}) + local y, x = pairwise_transform.jpeg_(random_half(src), + {i * 10}, 128, 0, {color_noise = false, random_half = true, overlay = true, rgb = true}) image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0}) image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0}) --print(x:mean(), y:mean()) @@ -342,7 +371,7 @@ local function test_scale() local loader = require './image_loader' local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") for i = 1, 9 do - local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_noise = true, random_half = true, rgb = true}) + local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_noise = true, random_half = true, rgb = true, overlay = true}) image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1}) image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1}) print(y:size(), x:size()) @@ -354,14 +383,14 @@ local function test_jpeg_scale() local loader = require './image_loader' local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") for i = 1, 9 do - local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_noise = true, random_half = true}) + local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_noise = true, random_half = true, overlay = true}) image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1}) image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1}) print(y:size(), x:size()) --print(x:mean(), y:mean()) end for i = 1, 9 do - local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_noise = true, random_half = true}) + local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_noise = true, random_half = true, overlay = true}) image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1}) image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1}) print(y:size(), x:size()) @@ -376,9 +405,19 @@ local function test_color_noise() image.display(color_noise(src)) end end +local function test_overlay() + torch.setdefaulttensortype('torch.FloatTensor') + local loader = require './image_loader' + local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") + for i = 1, 10 do + image.display(overlay_augment(src, 1.0)) + end +end + --test_scale() --test_jpeg() --test_jpeg_scale() --test_color_noise() +--test_overlay() return pairwise_transform diff --git a/lib/settings.lua b/lib/settings.lua index cf09fb1..4807870 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -25,6 +25,7 @@ cmd:option("-noise_level", 1, '(1|2)') cmd:option("-category", "anime_style_art", '(anime_style_art|photo)') cmd:option("-color", 'rgb', '(y|rgb)') cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)') +cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)') cmd:option("-scale", 2.0, 'scale') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)') @@ -69,6 +70,12 @@ if settings.color_noise == 1 then else settings.color_noise = false end +if settings.overlay == 1 then + settings.overlay = true +else + settings.overlay = false +end + torch.setnumthreads(settings.core) settings.images = string.format("%s/images.t7", settings.data_dir) diff --git a/train.lua b/train.lua index 6b4d378..a9c7d5e 100644 --- a/train.lua +++ b/train.lua @@ -80,11 +80,13 @@ local function train() local transformer = function(x, is_validation) if is_validation == nil then is_validation = false end local color_noise = (not is_validation) and settings.color_noise + local overlay = (not is_validation) and settings.overlay if settings.method == "scale" then return pairwise_transform.scale(x, settings.scale, settings.crop_size, offset, { color_noise = color_noise, + overlay = overlay, random_half = settings.random_half, rgb = (settings.color == "rgb") }) @@ -94,6 +96,7 @@ local function train() settings.noise_level, settings.crop_size, offset, { color_noise = color_noise, + overlay = overlay, random_half = settings.random_half, rgb = (settings.color == "rgb") }) @@ -104,6 +107,7 @@ local function train() settings.noise_level, settings.crop_size, offset, { color_noise = color_noise, + overlay = overlay, random_half = settings.random_half, rgb = (settings.color == "rgb") })