1
0
Fork 0
mirror of synced 2024-06-14 00:44:32 +12:00

remove noise_scale training

This commit is contained in:
nagadomi 2015-10-28 16:25:41 +09:00
parent 3abc5a03e3
commit da786e15ba
3 changed files with 3 additions and 165 deletions

View file

@ -276,125 +276,6 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
error("unknown category: " .. category)
end
end
function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options)
if options.random_half then
src = random_half(src)
end
src = crop_if_large(src, math.max(size * 4, 512))
local down_scale = 1.0 / scale
local filters = {
"Box", -- 0.012756949974688
"Blackman", -- 0.013191924552285
--"Cartom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266
"SincFast", -- 0.014095824314306
"Jinc", -- 0.014244299255442
}
local downscale_filter = filters[torch.random(1, #filters)]
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
local y = src
local x
if options.color_noise then
y = color_noise(y)
end
if options.overlay then
y = overlay_augment(y)
end
x = y
x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg")
if options.jpeg_sampling_factors == 444 then
x:samplingFactors({1.0, 1.0, 1.0})
else -- 422
x:samplingFactors({2.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
x = iproc.scale(x, y:size(3), y:size(2))
y = image.crop(y,
xi, yi,
xi + size, yi + size)
x = image.crop(x,
xi, yi,
xi + size, yi + size)
x = x:float():div(255)
y = y:float():div(255)
x, y = flip_augment(x, y)
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
return x, image.crop(y, offset, offset, size - offset, size - offset)
end
function pairwise_transform.jpeg_scale(src, scale, category, level, size, offset, options)
options = options or {color_noise = false, random_half = true}
if category == "anime_style_art" then
if level == 1 then
if torch.uniform() > 0.7 then
return pairwise_transform.jpeg_scale_(src, scale, {},
size, offset, options)
else
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)},
size, offset, options)
end
elseif level == 2 then
if torch.uniform() > 0.7 then
return pairwise_transform.jpeg_scale_(src, scale, {},
size, offset, options)
else
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)},
size, offset, options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2},
size, offset, options)
else
local quality1 = torch.random(52, 70)
local quality2 = quality1 - torch.random(5, 15)
local quality3 = quality1 - torch.random(15, 25)
return pairwise_transform.jpeg_scale_(src, scale,
{quality1, quality2, quality3 },
size, offset, options)
end
end
else
error("unknown noise level: " .. level)
end
elseif category == "photo" then
if level == 1 then
if torch.uniform() > 0.7 then
return pairwise_transform.jpeg_scale_(src, scale, {},
size, offset, options)
else
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(80, 95)},
size, offset, options)
end
elseif level == 2 then
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(70, 85)},
size, offset, options)
else
error("unknown noise level: " .. level)
end
else
error("unknown category: " .. category)
end
end
local function test_jpeg()
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
@ -428,31 +309,6 @@ local function test_scale()
--print(x:mean(), y:mean())
end
end
local function test_jpeg_scale()
torch.setdefaulttensortype('torch.FloatTensor')
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
local options = {color_noise = true,
random_half = true,
overlay = true,
active_cropping_ratio = 0.5,
active_cropping_times = 10
}
for i = 1, 9 do
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, options)
image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())
--print(x:mean(), y:mean())
end
for i = 1, 9 do
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, options)
image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())
--print(x:mean(), y:mean())
end
end
local function test_color_noise()
torch.setdefaulttensortype('torch.FloatTensor')
local loader = require './image_loader'

View file

@ -15,14 +15,14 @@ local settings = {}
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x")
cmd:text("waifu2x-training")
cmd:text("Options:")
cmd:option("-seed", 11, 'fixed input seed')
cmd:option("-data_dir", "./data", 'data directory')
-- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slow than cunn
-- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slower than cunn
cmd:option("-test", "images/miku_small.png", 'test image file')
cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
cmd:option("-method", "scale", '(noise|scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
cmd:option("-color", 'rgb', '(y|rgb)')

View file

@ -126,19 +126,6 @@ local function transformer(x, is_validation, n, offset)
jpeg_sampling_factors = settings.jpeg_sampling_factors,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise_scale" then
return pairwise_transform.jpeg_scale(x,
settings.scale,
settings.category,
settings.noise_level,
settings.crop_size, offset,
n,
{ color_noise = color_noise,
overlay = overlay,
jpeg_sampling_factors = settings.jpeg_sampling_factors,
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
end
end
@ -195,11 +182,6 @@ local function train()
local log = path.join(settings.model_dir,
("scale%.1f_best.png"):format(settings.scale))
save_test_scale(model, test_image, log)
elseif settings.method == "noise_scale" then
local log = path.join(settings.model_dir,
("noise%d_scale%.1f_best.png"):format(settings.noise_level,
settings.scale))
save_test_scale(model, test_image, log)
end
else
lrd_count = lrd_count + 1