1
0
Fork 0
mirror of synced 2024-06-01 10:39:30 +12:00

Add -random_unsharp_mask_rate option for photo

This commit is contained in:
nagadomi 2015-11-27 18:36:36 +09:00
parent 56296c25b3
commit c72ec3112b
4 changed files with 29 additions and 1 deletions

View file

@ -1,5 +1,6 @@
require 'image'
local iproc = require 'iproc'
local gm = require 'graphicsmagick'
local data_augmentation = {}
@ -50,6 +51,25 @@ function data_augmentation.overlay(src, p)
return src
end
end
function data_augmentation.unsharp_mask(src, p)
if torch.uniform() < p then
local radius = 0 -- auto
local sigma = torch.uniform(0.7, 1.4)
local amount = torch.uniform(0.5, 1.0)
local threshold = torch.uniform(0.0, 0.05)
local unsharp = gm.Image(src, "RGB", "DHW"):
unsharpMask(radius, sigma, amount, threshold):
toTensor("float", "RGB", "DHW")
if src:type() == "torch.ByteTensor" then
return iproc.float2byte(unsharp)
else
return unsharp
end
else
return src
end
end
function data_augmentation.shift_1px(src)
-- reducing the even/odd issue in nearest neighbor scaler.
local direction = torch.random(1, 4)

View file

@ -7,7 +7,7 @@ local pairwise_transform = {}
local function random_half(src, p)
if torch.uniform() < p then
local filter = ({"Box","Box","Blackman","Sinc","Lanczos"})[torch.random(1, 5)]
local filter = ({"Box","Box","Blackman","Sinc","Lanczos", "Catrom"})[torch.random(1, 6)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
@ -38,6 +38,7 @@ local function preprocess(src, crop_size, options)
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = data_augmentation.shift_1px(dest)
return dest
@ -81,6 +82,7 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
--"Hermite", -- 0.013850225205266
"Sinc", -- 0.014095824314306
"Lanczos", -- 0.014244299255442
"Catrom"
}
local unstable_region_offset = 8
local downscale_filter = filters[torch.random(1, #filters)]
@ -211,6 +213,8 @@ function pairwise_transform.test_jpeg(src)
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
jpeg_chroma_subsampling_rate = 0.5,
nr_rate = 1.0,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
@ -233,6 +237,7 @@ function pairwise_transform.test_scale(src)
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,

View file

@ -30,6 +30,7 @@ cmd:option("-color", 'rgb', '(y|rgb)')
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-crop_size", 46, 'crop size')

View file

@ -110,6 +110,7 @@ local function transformer(x, is_validation, n, offset)
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
@ -125,6 +126,7 @@ local function transformer(x, is_validation, n, offset)
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
max_size = settings.max_size,
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
active_cropping_rate = active_cropping_rate,