From 71a34393b847abe9d7f7c4eb6cf53f6372be8a58 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 8 Apr 2017 22:01:13 +0900 Subject: [PATCH 1/3] Add support for padding in convert_data.lua --- convert_data.lua | 21 ++++++++++++++++++++- lib/iproc.lua | 16 ++++++++++++++-- lib/settings.lua | 3 +++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/convert_data.lua b/convert_data.lua index ef620e8..8afad03 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -63,7 +63,24 @@ local function crop_if_large_pair(x, y, max_size) return x, y end end - +local function padding_x(x, pad) + if pad > 0 then + x = iproc.padding(x, pad, pad, pad, pad) + end + return x +end +local function padding_xy(x, y, pad, y_zero) + local scale = y:size(2) / x:size(2) + if pad > 0 then + x = iproc.padding(x, pad, pad, pad, pad) + if y_zero then + y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale) + else + y = iproc.padding(y, pad * scale, pad * scale, pad * scale, pad * scale) + end + end + return x, y +end local function load_images(list) local MARGIN = 32 local csv = csvigo.load({path = list, verbose = false, mode = "raw"}) @@ -105,6 +122,7 @@ local function load_images(list) xx = alpha_util.fill(xx, meta2.alpha, alpha_color) end xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size) + xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero) table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)}, {data = {filters = filters, has_x = true}}}) else @@ -113,6 +131,7 @@ local function load_images(list) else im = crop_if_large(im, settings.max_training_image_size) im = iproc.crop_mod4(im) + im = padding_x(im, settings.padding) local scale = 1.0 if settings.random_half_rate > 0.0 then scale = 2.0 diff --git a/lib/iproc.lua b/lib/iproc.lua index b4e6e17..a1fb8cb 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -80,6 +80,8 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur) return dest end function iproc.padding(img, w1, w2, h1, h2) + local conversion + img, conversion = iproc.byte2float(img) image = image or require 'image' local dst_height = img:size(2) + h1 + h2 local dst_width = img:size(3) + w1 + w2 @@ -88,9 +90,15 @@ function iproc.padding(img, w1, w2, h1, h2) flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width)) flow[1]:add(-h1) flow[2]:add(-w1) - return image.warp(img, flow, "simple", false, "clamp") + local dest = image.warp(img, flow, "simple", false, "clamp") + if conversion then + dest = iproc.float2byte(dest) + end + return dest end function iproc.zero_padding(img, w1, w2, h1, h2) + local conversion + img, conversion = iproc.byte2float(img) image = image or require 'image' local dst_height = img:size(2) + h1 + h2 local dst_width = img:size(3) + w1 + w2 @@ -99,7 +107,11 @@ function iproc.zero_padding(img, w1, w2, h1, h2) flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width)) flow[1]:add(-h1) flow[2]:add(-w1) - return image.warp(img, flow, "simple", false, "pad", 0) + local dest = image.warp(img, flow, "simple", false, "pad", 0) + if conversion then + dest = iproc.float2byte(dest) + end + return dest end function iproc.white_noise(src, std, rgb_weights, gamma) gamma = gamma or 0.454545 diff --git a/lib/settings.lua b/lib/settings.lua index fb0b890..dc6f24f 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -77,6 +77,8 @@ cmd:option("-name", "user", 'model name for user method') cmd:option("-gpu", 1, 'Device ID') cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)') cmd:option("-update_criterion", "mse", 'mse|loss') +cmd:option("-padding", 0, 'replication padding size') +cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)') local function to_bool(settings, name) if settings[name] == 1 then @@ -95,6 +97,7 @@ to_bool(settings, "save_history") to_bool(settings, "use_transparent_png") to_bool(settings, "pairwise_y_binary") to_bool(settings, "pairwise_flip") +to_bool(settings, "padding_y_zero") if settings.plot then require 'gnuplot' From f0fc2c89d106ecc7153d4ed98f06cb5d9672948a Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 9 Apr 2017 20:53:53 +0900 Subject: [PATCH 2/3] Add support for grayscale data --- convert_data.lua | 7 +++++++ lib/iproc.lua | 1 + lib/pairwise_transform_jpeg.lua | 6 ++++-- lib/pairwise_transform_scale.lua | 6 ++++-- lib/pairwise_transform_user.lua | 6 ++++-- lib/pairwise_transform_utils.lua | 15 +++++++++++---- lib/settings.lua | 2 ++ 7 files changed, 33 insertions(+), 10 deletions(-) diff --git a/convert_data.lua b/convert_data.lua index 8afad03..11d2f62 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -123,6 +123,10 @@ local function load_images(list) end xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size) xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero) + if settings.grayscale then + xx = iproc.rgb2y(xx) + yy = iproc.rgb2y(yy) + end table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)}, {data = {filters = filters, has_x = true}}}) else @@ -137,6 +141,9 @@ local function load_images(list) scale = 2.0 end if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then + if settings.grayscale then + im = iproc.rgb2y(im) + end table.insert(x, {compression.compress(im), {data = {filters = filters}}}) else io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN)) diff --git a/lib/iproc.lua b/lib/iproc.lua index a1fb8cb..afeb22a 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -229,6 +229,7 @@ function iproc.rgb2y(src) src, conversion = iproc.byte2float(src) local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero() dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3]) + dest:clamp(0, 1) if conversion then dest = iproc.float2byte(dest) end diff --git a/lib/pairwise_transform_jpeg.lua b/lib/pairwise_transform_jpeg.lua index 23cde45..c7fc0b7 100644 --- a/lib/pairwise_transform_jpeg.lua +++ b/lib/pairwise_transform_jpeg.lua @@ -43,8 +43,10 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options) yc = iproc.byte2float(yc) if options.rgb then else - yc = iproc.rgb2y(yc) - xc = iproc.rgb2y(xc) + if xc:size(1) > 1 then + yc = iproc.rgb2y(yc) + xc = iproc.rgb2y(xc) + end end if torch.uniform() < options.nr_rate then -- reducing noise diff --git a/lib/pairwise_transform_scale.lua b/lib/pairwise_transform_scale.lua index 7dd18f1..6a2761f 100644 --- a/lib/pairwise_transform_scale.lua +++ b/lib/pairwise_transform_scale.lua @@ -51,8 +51,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options) yc = iproc.byte2float(yc) if options.rgb then else - yc = iproc.rgb2y(yc) - xc = iproc.rgb2y(xc) + if xc:size(1) > 1 then + yc = iproc.rgb2y(yc) + xc = iproc.rgb2y(xc) + end end table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) end diff --git a/lib/pairwise_transform_user.lua b/lib/pairwise_transform_user.lua index 924d20f..51f3de2 100644 --- a/lib/pairwise_transform_user.lua +++ b/lib/pairwise_transform_user.lua @@ -38,8 +38,10 @@ function pairwise_transform.user(x, y, size, offset, n, options) yc = iproc.byte2float(yc) if options.rgb then else - yc = iproc.rgb2y(yc) - xc = iproc.rgb2y(xc) + if xc:size(1) > 1 then + yc = iproc.rgb2y(yc) + xc = iproc.rgb2y(xc) + end end if options.gcn then local mean = xc:mean() diff --git a/lib/pairwise_transform_utils.lua b/lib/pairwise_transform_utils.lua index bb6edf9..75bce15 100644 --- a/lib/pairwise_transform_utils.lua +++ b/lib/pairwise_transform_utils.lua @@ -279,10 +279,17 @@ function pairwise_transform_utils.low_resolution(src) toTensor("byte", "RGB", "DHW") end --]] - return gm.Image(src, "RGB", "DHW"): - size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"): - size(src:size(3), src:size(2), "Box"): - toTensor("byte", "RGB", "DHW") + if src:size(1) == 1 then + return gm.Image(src, "I", "DHW"): + size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"): + size(src:size(3), src:size(2), "Box"): + toTensor("byte", "I", "DHW") + else + return gm.Image(src, "RGB", "DHW"): + size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"): + size(src:size(3), src:size(2), "Box"): + toTensor("byte", "RGB", "DHW") + end end return pairwise_transform_utils diff --git a/lib/settings.lua b/lib/settings.lua index dc6f24f..b2331fa 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -79,6 +79,7 @@ cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)') cmd:option("-update_criterion", "mse", 'mse|loss') cmd:option("-padding", 0, 'replication padding size') cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)') +cmd:option("-grayscale", 0, 'grayscale x&y (0|1)') local function to_bool(settings, name) if settings[name] == 1 then @@ -98,6 +99,7 @@ to_bool(settings, "use_transparent_png") to_bool(settings, "pairwise_y_binary") to_bool(settings, "pairwise_flip") to_bool(settings, "padding_y_zero") +to_bool(settings, "grayscale") if settings.plot then require 'gnuplot' From 88e3322296863f12b2f98397e8d48503262ebec9 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 13 Apr 2017 17:35:32 +0900 Subject: [PATCH 3/3] Add ShakeShakeTable --- lib/ShakeShakeTable.lua | 42 +++++++++++++++++++++++++++++++++++++++++ lib/w2nn.lua | 1 + 2 files changed, 43 insertions(+) create mode 100644 lib/ShakeShakeTable.lua diff --git a/lib/ShakeShakeTable.lua b/lib/ShakeShakeTable.lua new file mode 100644 index 0000000..0920daf --- /dev/null +++ b/lib/ShakeShakeTable.lua @@ -0,0 +1,42 @@ +local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module') + +function ShakeShakeTable:__init() + parent.__init(self) + self.alpha = torch.Tensor() + self.beta = torch.Tensor() + self.first = torch.Tensor() + self.second = torch.Tensor() + self.train = true +end +function ShakeShakeTable:updateOutput(input) + local batch_size = input[1]:size(1) + if self.train then + self.alpha:resize(batch_size):uniform() + self.beta:resize(batch_size):uniform() + self.second:resizeAs(input[1]):copy(input[2]) + for i = 1, batch_size do + self.second[i]:mul(self.alpha[i]) + end + self.output:resizeAs(input[1]):copy(input[1]) + for i = 1, batch_size do + self.output[i]:mul(1.0 - self.alpha[i]) + end + self.output:add(self.second):mul(2) + else + self.output:resizeAs(input[1]):copy(input[1]):add(input[2]) + end + return self.output +end +function ShakeShakeTable:updateGradInput(input, gradOutput) + local batch_size = input[1]:size(1) + self.first:resizeAs(gradOutput):copy(gradOutput) + for i = 1, batch_size do + self.first[i]:mul(self.beta[i]) + end + self.second:resizeAs(gradOutput):copy(gradOutput) + for i = 1, batch_size do + self.second[i]:mul(1.0 - self.beta[i]) + end + self.gradOutput = {self.first, self.second} + return self.gradOutput +end diff --git a/lib/w2nn.lua b/lib/w2nn.lua index 5a9c727..91453fb 100644 --- a/lib/w2nn.lua +++ b/lib/w2nn.lua @@ -74,5 +74,6 @@ else require 'SSIMCriterion' require 'InplaceClip01' require 'L1Criterion' + require 'ShakeShakeTable' return w2nn end