diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index f8130a4..79e8100 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -1,255 +1,9 @@ -require 'image' -local gm = require 'graphicsmagick' -local iproc = require 'iproc' -local data_augmentation = require 'data_augmentation' - +require 'pl' local pairwise_transform = {} -local function random_half(src, p, filters) - if torch.uniform() < p then - local filter = filters[torch.random(1, #filters)] - return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter) - else - return src - end -end -local function crop_if_large(src, max_size) - local tries = 4 - if src:size(2) > max_size and src:size(3) > max_size then - local rect - for i = 1, tries do - local yi = torch.random(0, src:size(2) - max_size) - local xi = torch.random(0, src:size(3) - max_size) - rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size) - -- ignore simple background - if rect:float():std() >= 0 then - break - end - end - return rect - else - return src - end -end -local function preprocess(src, crop_size, options) - local dest = src - dest = random_half(dest, options.random_half_rate, options.downsampling_filters) - dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size)) - dest = data_augmentation.flip(dest) - dest = data_augmentation.color_noise(dest, options.random_color_noise_rate) - dest = data_augmentation.overlay(dest, options.random_overlay_rate) - dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate) - dest = data_augmentation.shift_1px(dest) - - return dest -end -local function active_cropping(x, y, size, p, tries) - assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3)) - local r = torch.uniform() - local t = "float" - if x:type() == "torch.ByteTensor" then - t = "byte" - end - if p < r then - local xi = torch.random(0, y:size(3) - (size + 1)) - local yi = torch.random(0, y:size(2) - (size + 1)) - local xc = iproc.crop(x, xi, yi, xi + size, yi + size) - local yc = iproc.crop(y, xi, yi, xi + size, yi + size) - return xc, yc - else - local lowres = gm.Image(x, "RGB", "DHW"): - size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"): - size(x:size(3), x:size(2), "Box"): - toTensor(t, "RGB", "DHW") - local best_se = 0.0 - local best_xc, best_yc - local m = torch.FloatTensor(x:size(1), size, size) - for i = 1, tries do - local xi = torch.random(0, y:size(3) - (size + 1)) - local yi = torch.random(0, y:size(2) - (size + 1)) - local xc = iproc.crop(x, xi, yi, xi + size, yi + size) - local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size) - local xcf = iproc.byte2float(xc) - local lcf = iproc.byte2float(lc) - local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum() - if se >= best_se then - best_xc = xcf - best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size)) - best_se = se - end - end - return best_xc, best_yc - end -end -function pairwise_transform.scale(src, scale, size, offset, n, options) - local filters = options.downsampling_filters - local unstable_region_offset = 8 - local downsampling_filter = filters[torch.random(1, #filters)] - local y = preprocess(src, size, options) - assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0) - local down_scale = 1.0 / scale - local x - if options.gamma_correction then - x = iproc.scale(iproc.scale_with_gamma22(y, y:size(3) * down_scale, - y:size(2) * down_scale, downsampling_filter), - y:size(3), y:size(2), options.upsampling_filter) - else - x = iproc.scale(iproc.scale(y, y:size(3) * down_scale, - y:size(2) * down_scale, downsampling_filter), - y:size(3), y:size(2), options.upsampling_filter) - end - x = iproc.crop(x, unstable_region_offset, unstable_region_offset, - x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset) - y = iproc.crop(y, unstable_region_offset, unstable_region_offset, - y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset) - assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0) - assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3)) - - local batch = {} - for i = 1, n do - local xc, yc = active_cropping(x, y, - size, - options.active_cropping_rate, - options.active_cropping_tries) - xc = iproc.byte2float(xc) - yc = iproc.byte2float(yc) - if options.rgb then - else - yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) - xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) - end - table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) - end - return batch -end -function pairwise_transform.jpeg_(src, quality, size, offset, n, options) - local unstable_region_offset = 8 - local y = preprocess(src, size, options) - local x = y +pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_scale')) +pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_jpeg')) - for i = 1, #quality do - x = gm.Image(x, "RGB", "DHW") - x:format("jpeg"):depth(8) - if torch.uniform() < options.jpeg_chroma_subsampling_rate then - -- YUV 420 - x:samplingFactors({2.0, 1.0, 1.0}) - else - -- YUV 444 - x:samplingFactors({1.0, 1.0, 1.0}) - end - local blob, len = x:toBlob(quality[i]) - x:fromBlob(blob, len) - x = x:toTensor("byte", "RGB", "DHW") - end - x = iproc.crop(x, unstable_region_offset, unstable_region_offset, - x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset) - y = iproc.crop(y, unstable_region_offset, unstable_region_offset, - y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset) - assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0) - assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3)) - - local batch = {} - for i = 1, n do - local xc, yc = active_cropping(x, y, size, - options.active_cropping_rate, - options.active_cropping_tries) - xc = iproc.byte2float(xc) - yc = iproc.byte2float(yc) - if options.rgb then - else - yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) - xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) - end - if torch.uniform() < options.nr_rate then - -- reducing noise - table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) - else - -- ratain useful details - table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) - end - end - return batch -end -function pairwise_transform.jpeg(src, style, level, size, offset, n, options) - if style == "art" then - if level == 1 then - return pairwise_transform.jpeg_(src, {torch.random(65, 85)}, - size, offset, n, options) - elseif level == 2 or level == 3 then - -- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1 - local r = torch.uniform() - if r > 0.6 then - return pairwise_transform.jpeg_(src, {torch.random(27, 70)}, - size, offset, n, options) - elseif r > 0.3 then - local quality1 = torch.random(37, 70) - local quality2 = quality1 - torch.random(5, 10) - return pairwise_transform.jpeg_(src, {quality1, quality2}, - size, offset, n, options) - else - local quality1 = torch.random(52, 70) - local quality2 = quality1 - torch.random(5, 15) - local quality3 = quality1 - torch.random(15, 25) - - return pairwise_transform.jpeg_(src, - {quality1, quality2, quality3}, - size, offset, n, options) - end - else - error("unknown noise level: " .. level) - end - elseif style == "photo" then - -- level adjusting by -nr_rate - return pairwise_transform.jpeg_(src, {torch.random(30, 70)}, - size, offset, n, - options) - else - error("unknown style: " .. style) - end -end +print(pairwise_transform) -function pairwise_transform.test_jpeg(src) - torch.setdefaulttensortype("torch.FloatTensor") - local options = {random_color_noise_rate = 0.5, - random_half_rate = 0.5, - random_overlay_rate = 0.5, - random_unsharp_mask_rate = 0.5, - jpeg_chroma_subsampling_rate = 0.5, - nr_rate = 1.0, - active_cropping_rate = 0.5, - active_cropping_tries = 10, - max_size = 256, - rgb = true - } - local image = require 'image' - local src = image.lena() - for i = 1, 9 do - local xy = pairwise_transform.jpeg(src, - "art", - torch.random(1, 2), - 128, 7, 1, options) - image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1}) - image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1}) - end -end -function pairwise_transform.test_scale(src) - torch.setdefaulttensortype("torch.FloatTensor") - local options = {random_color_noise_rate = 0.5, - random_half_rate = 0.5, - random_overlay_rate = 0.5, - random_unsharp_mask_rate = 0.5, - active_cropping_rate = 0.5, - active_cropping_tries = 10, - max_size = 256, - rgb = true - } - local image = require 'image' - local src = image.lena() - - for i = 1, 10 do - local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options) - image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1}) - image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1}) - end -end return pairwise_transform diff --git a/lib/pairwise_transform_jpeg.lua b/lib/pairwise_transform_jpeg.lua new file mode 100644 index 0000000..170a6b6 --- /dev/null +++ b/lib/pairwise_transform_jpeg.lua @@ -0,0 +1,117 @@ +local pairwise_utils = require 'pairwise_transform_utils' +local gm = require 'graphicsmagick' +local iproc = require 'iproc' +local pairwise_transform = {} + +function pairwise_transform.jpeg_(src, quality, size, offset, n, options) + local unstable_region_offset = 8 + local y = pairwise_utils.preprocess(src, size, options) + local x = y + + for i = 1, #quality do + x = gm.Image(x, "RGB", "DHW") + x:format("jpeg"):depth(8) + if torch.uniform() < options.jpeg_chroma_subsampling_rate then + -- YUV 420 + x:samplingFactors({2.0, 1.0, 1.0}) + else + -- YUV 444 + x:samplingFactors({1.0, 1.0, 1.0}) + end + local blob, len = x:toBlob(quality[i]) + x:fromBlob(blob, len) + x = x:toTensor("byte", "RGB", "DHW") + end + x = iproc.crop(x, unstable_region_offset, unstable_region_offset, + x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset) + y = iproc.crop(y, unstable_region_offset, unstable_region_offset, + y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset) + assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0) + assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3)) + + local batch = {} + for i = 1, n do + local xc, yc = pairwise_utils.active_cropping(x, y, size, 1, + options.active_cropping_rate, + options.active_cropping_tries) + xc = iproc.byte2float(xc) + yc = iproc.byte2float(yc) + if options.rgb then + else + yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) + xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + end + if torch.uniform() < options.nr_rate then + -- reducing noise + table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) + else + -- ratain useful details + table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) + end + end + return batch +end +function pairwise_transform.jpeg(src, style, level, size, offset, n, options) + if style == "art" then + if level == 1 then + return pairwise_transform.jpeg_(src, {torch.random(65, 85)}, + size, offset, n, options) + elseif level == 2 or level == 3 then + -- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1 + local r = torch.uniform() + if r > 0.6 then + return pairwise_transform.jpeg_(src, {torch.random(27, 70)}, + size, offset, n, options) + elseif r > 0.3 then + local quality1 = torch.random(37, 70) + local quality2 = quality1 - torch.random(5, 10) + return pairwise_transform.jpeg_(src, {quality1, quality2}, + size, offset, n, options) + else + local quality1 = torch.random(52, 70) + local quality2 = quality1 - torch.random(5, 15) + local quality3 = quality1 - torch.random(15, 25) + + return pairwise_transform.jpeg_(src, + {quality1, quality2, quality3}, + size, offset, n, options) + end + else + error("unknown noise level: " .. level) + end + elseif style == "photo" then + -- level adjusting by -nr_rate + return pairwise_transform.jpeg_(src, {torch.random(30, 70)}, + size, offset, n, + options) + else + error("unknown style: " .. style) + end +end + +function pairwise_transform.test_jpeg(src) + torch.setdefaulttensortype("torch.FloatTensor") + local options = {random_color_noise_rate = 0.5, + random_half_rate = 0.5, + random_overlay_rate = 0.5, + random_unsharp_mask_rate = 0.5, + jpeg_chroma_subsampling_rate = 0.5, + nr_rate = 1.0, + active_cropping_rate = 0.5, + active_cropping_tries = 10, + max_size = 256, + rgb = true + } + local image = require 'image' + local src = image.lena() + for i = 1, 9 do + local xy = pairwise_transform.jpeg(src, + "art", + torch.random(1, 2), + 128, 7, 1, options) + image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1}) + image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1}) + end +end +return pairwise_transform + diff --git a/lib/pairwise_transform_scale.lua b/lib/pairwise_transform_scale.lua new file mode 100644 index 0000000..dbef019 --- /dev/null +++ b/lib/pairwise_transform_scale.lua @@ -0,0 +1,86 @@ +local pairwise_utils = require 'pairwise_transform_utils' +local iproc = require 'iproc' +local pairwise_transform = {} + +function pairwise_transform.scale(src, scale, size, offset, n, options) + local filters = options.downsampling_filters + local unstable_region_offset = 8 + local downsampling_filter = filters[torch.random(1, #filters)] + local y = pairwise_utils.preprocess(src, size, options) + assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0) + local down_scale = 1.0 / scale + local x + if options.gamma_correction then + local small = iproc.scale_with_gamma22(y, y:size(3) * down_scale, + y:size(2) * down_scale, downsampling_filter) + if options.x_upsampling then + x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter) + else + x = small + end + else + local small = iproc.scale(y, y:size(3) * down_scale, + y:size(2) * down_scale, downsampling_filter) + if options.x_upsampling then + x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter) + else + x = small + end + end + + if options.x_upsampling then + x = iproc.crop(x, unstable_region_offset, unstable_region_offset, + x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset) + y = iproc.crop(y, unstable_region_offset, unstable_region_offset, + y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset) + assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0) + assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3)) + else + assert(x:size(1) == y:size(1) and x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3)) + end + local scale_inner = scale + if options.x_upsampling then + scale_inner = 1 + end + local batch = {} + + for i = 1, n do + local xc, yc = pairwise_utils.active_cropping(x, y, + size, + scale_inner, + options.active_cropping_rate, + options.active_cropping_tries) + xc = iproc.byte2float(xc) + yc = iproc.byte2float(yc) + if options.rgb then + else + yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) + xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + end + table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) + end + return batch +end +function pairwise_transform.test_scale(src) + torch.setdefaulttensortype("torch.FloatTensor") + local options = {random_color_noise_rate = 0.5, + random_half_rate = 0.5, + random_overlay_rate = 0.5, + random_unsharp_mask_rate = 0.5, + active_cropping_rate = 0.5, + active_cropping_tries = 10, + max_size = 256, + x_upsampling = false, + downsampling_filters = "Box", + rgb = true + } + local image = require 'image' + local src = image.lena() + + for i = 1, 10 do + local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options) + image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1}) + image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1}) + end +end +return pairwise_transform diff --git a/lib/pairwise_transform_utils.lua b/lib/pairwise_transform_utils.lua new file mode 100644 index 0000000..3ff7f1f --- /dev/null +++ b/lib/pairwise_transform_utils.lua @@ -0,0 +1,91 @@ +require 'image' +local gm = require 'graphicsmagick' +local iproc = require 'iproc' +local data_augmentation = require 'data_augmentation' +local pairwise_transform_utils = {} + +function pairwise_transform_utils.random_half(src, p, filters) + if torch.uniform() < p then + local filter = filters[torch.random(1, #filters)] + return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter) + else + return src + end +end +function pairwise_transform_utils.crop_if_large(src, max_size) + local tries = 4 + if src:size(2) > max_size and src:size(3) > max_size then + local rect + for i = 1, tries do + local yi = torch.random(0, src:size(2) - max_size) + local xi = torch.random(0, src:size(3) - max_size) + rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size) + -- ignore simple background + if rect:float():std() >= 0 then + break + end + end + return rect + else + return src + end +end +function pairwise_transform_utils.preprocess(src, crop_size, options) + local dest = src + dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters) + dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size)) + dest = data_augmentation.flip(dest) + dest = data_augmentation.color_noise(dest, options.random_color_noise_rate) + dest = data_augmentation.overlay(dest, options.random_overlay_rate) + dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate) + dest = data_augmentation.shift_1px(dest) + + return dest +end +function pairwise_transform_utils.active_cropping(x, y, size, scale, p, tries) + assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3)) + assert("crop_size % scale == 0", size % scale == 0) + local r = torch.uniform() + local t = "float" + if x:type() == "torch.ByteTensor" then + t = "byte" + end + if p < r then + local xi = torch.random(0, x:size(3) - (size + 1)) + local yi = torch.random(0, x:size(2) - (size + 1)) + local yc = iproc.crop(y, xi * scale, yi * scale, xi * scale + size, yi * scale + size) + local xc = iproc.crop(x, xi, yi, xi + size / scale, yi + size / scale) + return xc, yc + else + local test_scale = 2 + if test_scale < scale then + test_scale = scale + end + local lowres = gm.Image(y, "RGB", "DHW"): + size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"): + size(y:size(3), y:size(2), "Box"): + toTensor(t, "RGB", "DHW") + local best_se = 0.0 + local best_xi, best_yi + local m = torch.FloatTensor(y:size(1), size, size) + for i = 1, tries do + local xi = torch.random(0, x:size(3) - (size + 1)) * scale + local yi = torch.random(0, x:size(2) - (size + 1)) * scale + local xc = iproc.crop(y, xi, yi, xi + size, yi + size) + local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size) + local xcf = iproc.byte2float(xc) + local lcf = iproc.byte2float(lc) + local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum() + if se >= best_se then + best_xi = xi + best_yi = yi + best_se = se + end + end + local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size) + local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale) + return xc, yc + end +end + +return pairwise_transform_utils diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 6a78926..51ee807 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -49,6 +49,32 @@ local function reconstruct_rgb(model, x, offset, block_size) end return new_x end +local function reconstruct_rgb_with_scale(model, x, scale, offset, block_size) + local new_x = torch.Tensor(x:size(1), x:size(2) * scale, x:size(3) * scale):zero() + local input_block_size = block_size / scale + local output_block_size = block_size + local output_size = output_block_size - offset * 2 + local output_size_in_input = input_block_size - offset + local input = torch.CudaTensor(1, 3, input_block_size, input_block_size) + + for i = 1, x:size(2), output_size_in_input do + for j = 1, new_x:size(3), output_size_in_input do + if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then + local index = {{}, + {i, i + input_block_size - 1}, + {j, j + input_block_size - 1}} + input:copy(x[index]) + local output = model:forward(input):view(3, output_size, output_size) + local ii = (i - 1) * scale + 1 + local jj = (j - 1) * scale + 1 + local output_index = {{}, { ii , ii + output_size - 1 }, + { jj, jj + output_size - 1}} + new_x[output_index]:copy(output) + end + end + end + return new_x +end local reconstruct = {} function reconstruct.is_rgb(model) if srcnn.channels(model) == 3 then @@ -62,6 +88,9 @@ end function reconstruct.offset_size(model) return srcnn.offset_size(model) end +function reconstruct.no_resize(model) + return srcnn.has_resize(model) +end function reconstruct.image_y(model, x, offset, block_size) block_size = block_size or 128 local output_size = block_size - offset * 2 @@ -95,8 +124,14 @@ end function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter) upsampling_filter = upsampling_filter or "Box" block_size = block_size or 128 - local x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos") - x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter) + + local x_lanczos + if reconstruct.no_resize(model) then + x_lanczos = x:clone() + else + x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos") + x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter) + end if x:size(2) * x:size(3) > 2048*2048 then collectgarbage() end @@ -162,39 +197,77 @@ function reconstruct.image_rgb(model, x, offset, block_size) return output end function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter) - upsampling_filter = upsampling_filter or "Box" - block_size = block_size or 128 - x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter) - if x:size(2) * x:size(3) > 2048*2048 then - collectgarbage() - end - local output_size = block_size - offset * 2 - local h_blocks = math.floor(x:size(2) / output_size) + - ((x:size(2) % output_size == 0 and 0) or 1) - local w_blocks = math.floor(x:size(3) / output_size) + - ((x:size(3) % output_size == 0 and 0) or 1) - - local h = offset + h_blocks * output_size + offset - local w = offset + w_blocks * output_size + offset - local pad_h1 = offset - local pad_w1 = offset - local pad_h2 = (h - offset) - x:size(2) - local pad_w2 = (w - offset) - x:size(3) - x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2) - if x:size(2) * x:size(3) > 2048*2048 then - collectgarbage() - end - local y = reconstruct_rgb(model, x, offset, block_size) - local output = iproc.crop(y, - pad_w1, pad_h1, - y:size(3) - pad_w2, y:size(2) - pad_h2) - output[torch.lt(output, 0)] = 0 - output[torch.gt(output, 1)] = 1 - x = nil - y = nil - collectgarbage() + if reconstruct.no_resize(model) then + block_size = block_size or 128 + local input_block_size = block_size / scale + local x_w = x:size(3) + local x_h = x:size(2) + local process_size = input_block_size - offset * 2 + -- TODO: under construction!! bug in 4x + local h_blocks = math.floor(x_h / process_size) + 2 +-- ((x_h % process_size == 0 and 0) or 1) + local w_blocks = math.floor(x_w / process_size) + 2 +-- ((x_w % process_size == 0 and 0) or 1) + local h = offset + (h_blocks * process_size) + offset + local w = offset + (w_blocks * process_size) + offset + local pad_h1 = offset + local pad_w1 = offset - return output + local pad_h2 = (h - offset) - x:size(2) + local pad_w2 = (w - offset) - x:size(3) + + x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2) + if x:size(2) * x:size(3) > 2048*2048 then + collectgarbage() + end + local y + y = reconstruct_rgb_with_scale(model, x, scale, offset, block_size) + local output = iproc.crop(y, + pad_w1, pad_h1, + pad_w1 + x_w * scale, pad_h1 + x_h * scale) + output[torch.lt(output, 0)] = 0 + output[torch.gt(output, 1)] = 1 + x = nil + y = nil + collectgarbage() + + return output + else + upsampling_filter = upsampling_filter or "Box" + block_size = block_size or 128 + x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter) + if x:size(2) * x:size(3) > 2048*2048 then + collectgarbage() + end + local output_size = block_size - offset * 2 + local h_blocks = math.floor(x:size(2) / output_size) + + ((x:size(2) % output_size == 0 and 0) or 1) + local w_blocks = math.floor(x:size(3) / output_size) + + ((x:size(3) % output_size == 0 and 0) or 1) + + local h = offset + h_blocks * output_size + offset + local w = offset + w_blocks * output_size + offset + local pad_h1 = offset + local pad_w1 = offset + local pad_h2 = (h - offset) - x:size(2) + local pad_w2 = (w - offset) - x:size(3) + x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2) + if x:size(2) * x:size(3) > 2048*2048 then + collectgarbage() + end + local y + y = reconstruct_rgb(model, x, offset, block_size) + local output = iproc.crop(y, + pad_w1, pad_h1, + y:size(3) - pad_w2, y:size(2) - pad_h2) + output[torch.lt(output, 0)] = 0 + output[torch.gt(output, 1)] = 1 + x = nil + y = nil + collectgarbage() + + return output + end end function reconstruct.image(model, x, block_size) diff --git a/lib/settings.lua b/lib/settings.lua index 15c9f26..3ddcacd 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)') cmd:option("-test", "images/miku_small.png", 'path to test image') cmd:option("-model_dir", "./models", 'model directory') cmd:option("-method", "scale", 'method to training (noise|scale)') -cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)') +cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)') cmd:option("-noise_level", 1, '(1|2|3)') cmd:option("-style", "art", '(art|photo)') cmd:option("-color", 'rgb', '(y|rgb)') @@ -34,7 +34,7 @@ cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution im cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)') cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-learning_rate", 0.0005, 'learning rate for adam') -cmd:option("-crop_size", 46, 'crop size') +cmd:option("-crop_size", 48, 'crop size') cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly') cmd:option("-batch_size", 8, 'mini batch size') cmd:option("-patches", 16, 'number of patch samples') diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 1c4ba9f..cf98b5c 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -9,14 +9,23 @@ function nn.SpatialConvolutionMM:reset(stdv) self.weight:normal(0, stdv) self.bias:zero() end +function nn.SpatialFullConvolution:reset(stdv) + stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane)) + self.weight:normal(0, stdv) + self.bias:zero() +end if cudnn and cudnn.SpatialConvolution then function cudnn.SpatialConvolution:reset(stdv) stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane)) self.weight:normal(0, stdv) self.bias:zero() end + function cudnn.SpatialFullConvolution:reset(stdv) + stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane)) + self.weight:normal(0, stdv) + self.bias:zero() + end end - function nn.SpatialConvolutionMM:clearState() if self.gradWeight then self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero() @@ -26,9 +35,12 @@ function nn.SpatialConvolutionMM:clearState() end return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput') end - function srcnn.channels(model) - return model:get(model:size() - 1).weight:size(1) + if model.w2nn_channels ~= nil then + return model.w2nn_channels + else + return model:get(model:size() - 1).weight:size(1) + end end function srcnn.backend(model) local conv = model:findModules("cudnn.SpatialConvolution") @@ -47,32 +59,54 @@ function srcnn.color(model) end end function srcnn.name(model) - local backend_cudnn = false - local conv = model:findModules("nn.SpatialConvolutionMM") - if #conv == 0 then - backend_cudnn = true - conv = model:findModules("cudnn.SpatialConvolution") - end - if #conv == 7 then - return "vgg_7" - elseif #conv == 12 then - return "vgg_12" + if model.w2nn_arch_name then + return model.w2nn_arch_name else - return nil + local conv = model:findModules("nn.SpatialConvolutionMM") + if #conv == 0 then + conv = model:findModules("cudnn.SpatialConvolution") + end + if #conv == 7 then + return "vgg_7" + elseif #conv == 12 then + return "vgg_12" + else + error("unsupported model name") + end end end function srcnn.offset_size(model) - local conv = model:findModules("nn.SpatialConvolutionMM") - if #conv == 0 then - conv = model:findModules("cudnn.SpatialConvolution") + if model.w2nn_offset ~= nil then + return model.w2nn_offset + else + local name = srcnn.name(model) + if name:match("vgg_") then + local conv = model:findModules("nn.SpatialConvolutionMM") + if #conv == 0 then + conv = model:findModules("cudnn.SpatialConvolution") + end + local offset = 0 + for i = 1, #conv do + offset = offset + (conv[i].kW - 1) / 2 + end + return math.floor(offset) + else + error("unsupported model name") + end + end +end +function srcnn.has_resize(model) + if model.w2nn_resize ~= nil then + return model.w2nn_resize + else + local name = srcnn.name(model) + if name:match("upconv") ~= nil then + return true + else + return false + end end - local offset = 0 - for i = 1, #conv do - offset = offset + (conv[i].kW - 1) / 2 - end - return math.floor(offset) end - local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) if backend == "cunn" then return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) @@ -82,6 +116,15 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW error("unsupported backend:" .. backend) end end +local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) + if backend == "cunn" then + return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) + elseif backend == "cudnn" then + return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) + else + error("unsupported backend:" .. backend) + end +end -- VGG style net(7 layers) function srcnn.vgg_7(backend, ch) @@ -100,6 +143,11 @@ function srcnn.vgg_7(backend, ch) model:add(w2nn.LeakyReLU(0.1)) model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) model:add(nn.View(-1):setNumInputDims(3)) + + model.w2nn_arch_name = "vgg_7" + model.w2nn_offset = 7 + model.w2nn_resize = false + model.w2nn_channels = ch --model:cuda() --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) @@ -132,12 +180,103 @@ function srcnn.vgg_12(backend, ch) model:add(w2nn.LeakyReLU(0.1)) model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) model:add(nn.View(-1):setNumInputDims(3)) + + model.w2nn_arch_name = "vgg_12" + model.w2nn_offset = 12 + model.w2nn_resize = false + model.w2nn_channels = ch --model:cuda() --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) return model end +-- Dilated Convolution (7 layers) +function srcnn.dilated_7(backend, ch) + local model = nn.Sequential() + model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) + model:add(nn.View(-1):setNumInputDims(3)) + + model.w2nn_arch_name = "dilated_7" + model.w2nn_offset = 12 + model.w2nn_resize = false + model.w2nn_channels = ch + + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) + + return model +end + +-- Up Convolution +function srcnn.upconv_7(backend, ch) + local model = nn.Sequential() + + model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1)) + + model.w2nn_arch_name = "upconv_7" + model.w2nn_offset = 12 + model.w2nn_resize = true + model.w2nn_channels = ch + + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) + + return model +end +function srcnn.upconv_8_4x(backend, ch) + local model = nn.Sequential() + + model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1)) + + model.w2nn_arch_name = "upconv_8_4x" + model.w2nn_offset = 12 + model.w2nn_resize = true + model.w2nn_channels = ch + + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) + + return model +end function srcnn.create(model_name, backend, color) model_name = model_name or "vgg_7" backend = backend or "cunn" @@ -150,12 +289,14 @@ function srcnn.create(model_name, backend, color) else error("unsupported color: " .. color) end - if model_name == "vgg_7" then - return srcnn.vgg_7(backend, ch) - elseif model_name == "vgg_12" then - return srcnn.vgg_12(backend, ch) + if srcnn[model_name] then + return srcnn[model_name](backend, ch) else error("unsupported model_name: " .. model_name) end end + +--local model = srcnn.upconv_8_4x("cunn", 3):cuda() +--print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size()) + return srcnn diff --git a/train.lua b/train.lua index 381b94a..68dd0b1 100644 --- a/train.lua +++ b/train.lua @@ -15,7 +15,9 @@ local pairwise_transform = require 'pairwise_transform' local image_loader = require 'image_loader' local function save_test_scale(model, rgb, file) - local up = reconstruct.scale(model, settings.scale, rgb, 128, settings.upsampling_filter) + local up = reconstruct.scale(model, settings.scale, rgb, + settings.scale * settings.crop_size, + settings.upsampling_filter) image.save(file, up) end local function save_test_jpeg(model, rgb, file) @@ -96,6 +98,7 @@ local function create_criterion(model) local offset = reconstruct.offset_size(model) local output_w = settings.crop_size - offset * 2 local weight = torch.Tensor(3, output_w * output_w) + weight[1]:fill(0.29891 * 3) -- R weight[2]:fill(0.58661 * 3) -- G weight[3]:fill(0.11448 * 3) -- B @@ -108,7 +111,7 @@ local function create_criterion(model) return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() end end -local function transformer(x, is_validation, n, offset) +local function transformer(model, x, is_validation, n, offset) x = compression.decompress(x) n = n or settings.patches @@ -145,7 +148,8 @@ local function transformer(x, is_validation, n, offset) active_cropping_rate = active_cropping_rate, active_cropping_tries = active_cropping_tries, rgb = (settings.color == "rgb"), - gamma_correction = settings.gamma_correction + gamma_correction = settings.gamma_correction, + x_upsampling = not srcnn.has_resize(model) }) elseif settings.method == "noise" then return pairwise_transform.jpeg(x, @@ -183,6 +187,22 @@ local function resampling(x, y, train_x, transformer, input_size, target_size) end end end +local function remove_small_image(x) + local new_x = {} + for i = 1, #x do + local x_s = compression.size(x[i]) + if x_s[2] / settings.scale > settings.crop_size + 16 and + x_s[3] / settings.scale > settings.crop_size + 16 then + table.insert(new_x, x[i]) + end + if i % 100 == 0 then + collectgarbage() + end + end + print(string.format("removed %d small images", #x - #new_x)) + + return new_x +end local function plot(train, valid) gnuplot.plot({ {'training', torch.Tensor(train), '-'}, @@ -194,11 +214,11 @@ local function train() local model = srcnn.create(settings.model, settings.backend, settings.color) local offset = reconstruct.offset_size(model) local pairwise_func = function(x, is_validation, n) - return transformer(x, is_validation, n, offset) + return transformer(model, x, is_validation, n, offset) end local criterion = create_criterion(model) local eval_metric = nn.MSECriterion():cuda() - local x = torch.load(settings.images) + local x = remove_small_image(torch.load(settings.images)) local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local adam_config = { learningRate = settings.learning_rate, @@ -222,10 +242,16 @@ local function train() model:cuda() print("load .. " .. #train_x) - local x = torch.Tensor(settings.patches * #train_x, - ch, settings.crop_size, settings.crop_size) + local x = nil local y = torch.Tensor(settings.patches * #train_x, ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero() + if srcnn.has_resize(model) then + x = torch.Tensor(settings.patches * #train_x, + ch, settings.crop_size / settings.scale, settings.crop_size / settings.scale) + else + x = torch.Tensor(settings.patches * #train_x, + ch, settings.crop_size, settings.crop_size) + end for epoch = 1, settings.epoch do model:training() print("# " .. epoch)