remove noise_scale training
This commit is contained in:
parent
3abc5a03e3
commit
da786e15ba
3 changed files with 3 additions and 165 deletions
|
@ -276,125 +276,6 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
||||||
error("unknown category: " .. category)
|
error("unknown category: " .. category)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options)
|
|
||||||
if options.random_half then
|
|
||||||
src = random_half(src)
|
|
||||||
end
|
|
||||||
src = crop_if_large(src, math.max(size * 4, 512))
|
|
||||||
local down_scale = 1.0 / scale
|
|
||||||
local filters = {
|
|
||||||
"Box", -- 0.012756949974688
|
|
||||||
"Blackman", -- 0.013191924552285
|
|
||||||
--"Cartom", -- 0.013753536746706
|
|
||||||
--"Hanning", -- 0.013761314529647
|
|
||||||
--"Hermite", -- 0.013850225205266
|
|
||||||
"SincFast", -- 0.014095824314306
|
|
||||||
"Jinc", -- 0.014244299255442
|
|
||||||
}
|
|
||||||
local downscale_filter = filters[torch.random(1, #filters)]
|
|
||||||
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
|
|
||||||
local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
|
|
||||||
local y = src
|
|
||||||
local x
|
|
||||||
|
|
||||||
if options.color_noise then
|
|
||||||
y = color_noise(y)
|
|
||||||
end
|
|
||||||
if options.overlay then
|
|
||||||
y = overlay_augment(y)
|
|
||||||
end
|
|
||||||
|
|
||||||
x = y
|
|
||||||
x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
|
|
||||||
for i = 1, #quality do
|
|
||||||
x = gm.Image(x, "RGB", "DHW")
|
|
||||||
x:format("jpeg")
|
|
||||||
if options.jpeg_sampling_factors == 444 then
|
|
||||||
x:samplingFactors({1.0, 1.0, 1.0})
|
|
||||||
else -- 422
|
|
||||||
x:samplingFactors({2.0, 1.0, 1.0})
|
|
||||||
end
|
|
||||||
local blob, len = x:toBlob(quality[i])
|
|
||||||
x:fromBlob(blob, len)
|
|
||||||
x = x:toTensor("byte", "RGB", "DHW")
|
|
||||||
end
|
|
||||||
x = iproc.scale(x, y:size(3), y:size(2))
|
|
||||||
y = image.crop(y,
|
|
||||||
xi, yi,
|
|
||||||
xi + size, yi + size)
|
|
||||||
x = image.crop(x,
|
|
||||||
xi, yi,
|
|
||||||
xi + size, yi + size)
|
|
||||||
x = x:float():div(255)
|
|
||||||
y = y:float():div(255)
|
|
||||||
x, y = flip_augment(x, y)
|
|
||||||
|
|
||||||
if options.rgb then
|
|
||||||
else
|
|
||||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
|
||||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
|
||||||
end
|
|
||||||
|
|
||||||
return x, image.crop(y, offset, offset, size - offset, size - offset)
|
|
||||||
end
|
|
||||||
function pairwise_transform.jpeg_scale(src, scale, category, level, size, offset, options)
|
|
||||||
options = options or {color_noise = false, random_half = true}
|
|
||||||
if category == "anime_style_art" then
|
|
||||||
if level == 1 then
|
|
||||||
if torch.uniform() > 0.7 then
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale, {},
|
|
||||||
size, offset, options)
|
|
||||||
else
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)},
|
|
||||||
size, offset, options)
|
|
||||||
end
|
|
||||||
elseif level == 2 then
|
|
||||||
if torch.uniform() > 0.7 then
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale, {},
|
|
||||||
size, offset, options)
|
|
||||||
else
|
|
||||||
local r = torch.uniform()
|
|
||||||
if r > 0.6 then
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)},
|
|
||||||
size, offset, options)
|
|
||||||
elseif r > 0.3 then
|
|
||||||
local quality1 = torch.random(37, 70)
|
|
||||||
local quality2 = quality1 - torch.random(5, 10)
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2},
|
|
||||||
size, offset, options)
|
|
||||||
else
|
|
||||||
local quality1 = torch.random(52, 70)
|
|
||||||
local quality2 = quality1 - torch.random(5, 15)
|
|
||||||
local quality3 = quality1 - torch.random(15, 25)
|
|
||||||
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale,
|
|
||||||
{quality1, quality2, quality3 },
|
|
||||||
size, offset, options)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
else
|
|
||||||
error("unknown noise level: " .. level)
|
|
||||||
end
|
|
||||||
elseif category == "photo" then
|
|
||||||
if level == 1 then
|
|
||||||
if torch.uniform() > 0.7 then
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale, {},
|
|
||||||
size, offset, options)
|
|
||||||
else
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(80, 95)},
|
|
||||||
size, offset, options)
|
|
||||||
end
|
|
||||||
elseif level == 2 then
|
|
||||||
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(70, 85)},
|
|
||||||
size, offset, options)
|
|
||||||
else
|
|
||||||
error("unknown noise level: " .. level)
|
|
||||||
end
|
|
||||||
else
|
|
||||||
error("unknown category: " .. category)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
local function test_jpeg()
|
local function test_jpeg()
|
||||||
local loader = require './image_loader'
|
local loader = require './image_loader'
|
||||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||||
|
@ -428,31 +309,6 @@ local function test_scale()
|
||||||
--print(x:mean(), y:mean())
|
--print(x:mean(), y:mean())
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local function test_jpeg_scale()
|
|
||||||
torch.setdefaulttensortype('torch.FloatTensor')
|
|
||||||
local loader = require './image_loader'
|
|
||||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
|
||||||
local options = {color_noise = true,
|
|
||||||
random_half = true,
|
|
||||||
overlay = true,
|
|
||||||
active_cropping_ratio = 0.5,
|
|
||||||
active_cropping_times = 10
|
|
||||||
}
|
|
||||||
for i = 1, 9 do
|
|
||||||
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, options)
|
|
||||||
image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
|
|
||||||
image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
|
|
||||||
print(y:size(), x:size())
|
|
||||||
--print(x:mean(), y:mean())
|
|
||||||
end
|
|
||||||
for i = 1, 9 do
|
|
||||||
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, options)
|
|
||||||
image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
|
|
||||||
image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
|
|
||||||
print(y:size(), x:size())
|
|
||||||
--print(x:mean(), y:mean())
|
|
||||||
end
|
|
||||||
end
|
|
||||||
local function test_color_noise()
|
local function test_color_noise()
|
||||||
torch.setdefaulttensortype('torch.FloatTensor')
|
torch.setdefaulttensortype('torch.FloatTensor')
|
||||||
local loader = require './image_loader'
|
local loader = require './image_loader'
|
||||||
|
|
|
@ -15,14 +15,14 @@ local settings = {}
|
||||||
|
|
||||||
local cmd = torch.CmdLine()
|
local cmd = torch.CmdLine()
|
||||||
cmd:text()
|
cmd:text()
|
||||||
cmd:text("waifu2x")
|
cmd:text("waifu2x-training")
|
||||||
cmd:text("Options:")
|
cmd:text("Options:")
|
||||||
cmd:option("-seed", 11, 'fixed input seed')
|
cmd:option("-seed", 11, 'fixed input seed')
|
||||||
cmd:option("-data_dir", "./data", 'data directory')
|
cmd:option("-data_dir", "./data", 'data directory')
|
||||||
-- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slow than cunn
|
-- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slower than cunn
|
||||||
cmd:option("-test", "images/miku_small.png", 'test image file')
|
cmd:option("-test", "images/miku_small.png", 'test image file')
|
||||||
cmd:option("-model_dir", "./models", 'model directory')
|
cmd:option("-model_dir", "./models", 'model directory')
|
||||||
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
|
cmd:option("-method", "scale", '(noise|scale)')
|
||||||
cmd:option("-noise_level", 1, '(1|2)')
|
cmd:option("-noise_level", 1, '(1|2)')
|
||||||
cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
|
cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
|
||||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||||
|
|
18
train.lua
18
train.lua
|
@ -126,19 +126,6 @@ local function transformer(x, is_validation, n, offset)
|
||||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||||
rgb = (settings.color == "rgb")
|
rgb = (settings.color == "rgb")
|
||||||
})
|
})
|
||||||
elseif settings.method == "noise_scale" then
|
|
||||||
return pairwise_transform.jpeg_scale(x,
|
|
||||||
settings.scale,
|
|
||||||
settings.category,
|
|
||||||
settings.noise_level,
|
|
||||||
settings.crop_size, offset,
|
|
||||||
n,
|
|
||||||
{ color_noise = color_noise,
|
|
||||||
overlay = overlay,
|
|
||||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
|
||||||
random_half = settings.random_half,
|
|
||||||
rgb = (settings.color == "rgb")
|
|
||||||
})
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -195,11 +182,6 @@ local function train()
|
||||||
local log = path.join(settings.model_dir,
|
local log = path.join(settings.model_dir,
|
||||||
("scale%.1f_best.png"):format(settings.scale))
|
("scale%.1f_best.png"):format(settings.scale))
|
||||||
save_test_scale(model, test_image, log)
|
save_test_scale(model, test_image, log)
|
||||||
elseif settings.method == "noise_scale" then
|
|
||||||
local log = path.join(settings.model_dir,
|
|
||||||
("noise%d_scale%.1f_best.png"):format(settings.noise_level,
|
|
||||||
settings.scale))
|
|
||||||
save_test_scale(model, test_image, log)
|
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
lrd_count = lrd_count + 1
|
lrd_count = lrd_count + 1
|
||||||
|
|
Loading…
Reference in a new issue