diff --git a/.gitignore b/.gitignore index 7d15096..60dcc5d 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,8 @@ models/* !models/ukbench !models/photo !models/upconv_7 +!models/upconv_7l +!models/srresnet_12l !models/vgg_7 models/*/*.png models/*/*/*.png diff --git a/README.md b/README.md index 10b4adb..ae9abc1 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,7 @@ See: [Getting started with Torch](http://torch.ch/docs/getting-started.html) And install luarocks packages. ``` luarocks install graphicsmagick # upgrade +luarocks install threads # upgrade luarocks install lua-csnappy luarocks install md5 luarocks install uuid diff --git a/appendix/benchmark.md b/appendix/benchmark.md index 05b2592..4330731 100644 --- a/appendix/benchmark.md +++ b/appendix/benchmark.md @@ -1,45 +1,79 @@ -# Benchmark results +# Benchmarks -Warning: This benchmark results is outdated. I will update soon. +## Photo -## Usage +Note: waifu2x's photo models was trained on the blending dataset of [kou's photo collection](http://photosku.com/photo/category/%E6%92%AE%E5%BD%B1%E8%80%85/kou/) and [ukbench](http://vis.uky.edu/~stewe/ukbench/). -``` -th tools/benchmark.lua -dir path/to/dataset_dir -method scale -color y -model1_dir path/to/model_dir -``` +Note: PSNR in this benchmark uses a [MATLAB's rgb2ycbcr](https://jp.mathworks.com/help/images/ref/rgb2ycbcr.html?lang=en) compatible function (dynamic range [16 235], not [0 255]) for converting grayscale image. I think it's not correct PSNR. But many paper used this metric. -## Dataset +command: +`th tools/benchmark.lua -dir -model1_dir -method scale -filter Catrom -color y -range_bug 1 -tta <0|1> -force_cudnn 1` - photo_test: 300 various photos. - art_test : 90 artworks (PNG only). +### Datasets -## 2x upscaling model +BSD100: https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/ (100 test images in BSDS300) +Urban100: https://github.com/jbhuang0604/SelfExSR -| Dataset/Model | anime\_style\_art(Y) | anime\_style\_art\_rgb | photo | ukbench| -|---------------|----------------------|------------------------|---------|--------| -| photo\_test | 29.83 | 29.81 |**29.89**| 29.86 | -| art\_test | 36.02 | **36.24**| 34.92 | 34.85 | +### 2x - PSNR -The evaluation metric is PSNR(Y only), higher is better. +| Dataset/Model | Bicubic | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo | +|---------------|---------------|---------------|------------------|------------------|--------------------| +| BSD100 | 29.558 | 31.427 | 31.640 | 31.749 | 31.847 | +| Urban100 | 26.852 | 30.057 | 30.477 | 30.759 | 31.016 | -## Denosing level 1 model +### 2x with TTA - PSNR -| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo | -|--------------------------|-------------------|------------------------|---------| -| photo\_test Quality 80 | 36.07 | **36.20**| 36.01 | -| photo\_test Quality 50,45| 31.72 | 32.01 |**32.31**| -| art\_test Quality 80 | 40.39 | **42.48**| 40.35 | -| art\_test Quality 50,45 | 35.45 | **36.70**| 36.27 | +Note: TTA is an ensemble technique that is supported by waifu2x. TTA method is 8x slower than non TTA method but it improves PSNR (~+0.1 on photo, ~+0.4 on art). -The evaluation metric is PSNR(RGB), higher is better. +| Dataset/Model | Bicubic | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo | +|---------------|---------------|---------------|------------------|------------------|--------------------| +| BSD100 | 29.558 | 31.474 | 31.705 | 31.812 | 31.915 | +| Urban100 | 26.852 | 30.140 | 30.599 | 30.868 | 31.162 | -## Denosing level 2 model +### 2x - benchmark elapsed time (sec) -| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo | -|--------------------------|-------------------|------------------------|---------| -| photo\_test Quality 80 | 34.03 | 34.42 |**36.06**| -| photo\_test Quality 50,45| 31.95 | 32.31 |**32.42**| -| art\_test Quality 80 | 39.20 | **41.12**| 40.48 | -| art\_test Quality 50,45 | 36.14 | **37.78**| 36.55 | +| Dataset/Model | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo | +|---------------|---------------|------------------|------------------|--------------------| +| BSD100 | 4.057 | 2.509 | 4.947 | 6.86 | +| Urban100 | 16.349 | 7.083 | 14.178 | 27.87 | + +### 2x with TTA - benchmark elapsed time (sec) + +| Dataset/Model | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo | +|---------------|---------------|------------------|------------------|--------------------| +| BSD100 | 36.611 | 20.219 | 42.486 | 60.38 | +| Urban100 | 132.416 | 65.125 | 129.916 | 255.20 | + +## Art + +command: +`th tools/benchmark.lua -dir -model1_dir -method scale -filter Lanczos -color y -range_bug 1 -tta <0|1> -force_cudnn 1` + +### Dataset + +art_test: This dataset contains 85 various fan-arts. Sorry, This dataset is private. + +### 2x - PSNR + +| Dataset/Model | Bicubic | vgg\_7/art | upconv\_7/art | upconv\_7l/art | +|---------------|---------------|-------------|----------------|----------------| +| art_test | 31.022 | 37.495 | 38.330 | 39.140 | + +### 2x with TTA - PSNR + +| Dataset/Model | Bicubic | vgg\_7/art | upconv\_7/art | upconv\_7l/art | +|---------------|---------------|-------------|----------------|----------------| +| art_test | 31.022 | 37.777 | 38.677 | 39.510 | + +### 2x - benchmark elapsed time (sec) + +| Dataset/Model | vgg\_7/art | upconv\_7/art | upconv\_7l/art | +|---------------|-------------|----------------|----------------| +| art_test | 20.681 | 7.683 | 17.667 | + +### 2x with TTA - benchmark elapsed time (sec) + +| Dataset/Model | vgg\_7/art | upconv\_7/art | upconv\_7l/art | +|---------------|-------------|----------------|----------------| +| art_test | 174.674 | 77.716 | 163.932 | -The evaluation metric is PSNR(RGB), higher is better. diff --git a/appendix/benchmark.sh b/appendix/benchmark.sh new file mode 100755 index 0000000..8fdbad5 --- /dev/null +++ b/appendix/benchmark.sh @@ -0,0 +1,34 @@ +#!/bin/sh +set -x + +benchmark_photo() { + dir=./benchmarks/${1}/${2}/${3} + mkdir -p ${dir} + th tools/benchmark.lua -dir data/${1} -model1\_dir models/${2}/photo -method scale -filter Catrom -color y -range\_bug 1 -tta ${3} -force_cudnn 1 -output_dir ${dir} -save_info 1 -show_progress 0 +} +run_benchmark_photo() { + for tta in 0 1 + do + for dataset in bsd100 urban100 + do + benchmark_photo ${dataset} vgg_7 ${tta} + benchmark_photo ${dataset} upconv_7 ${tta} + benchmark_photo ${dataset} upconv_7l ${tta} + done + done +} +benchmark_art() { + dir=./benchmarks/${1}/${2}/${3} + mkdir -p ${dir} + th tools/benchmark.lua -dir data/${1} -model1\_dir models/${2}/art -method scale -filter Lanczos -color y -range\_bug 1 -tta ${3} -force_cudnn 1 -output_dir ${dir} -save_info 1 -show_progress 0 +} +run_benchmark_art() { + for tta in 0 1 + do + benchmark_art art_test vgg_7 ${tta} + benchmark_art art_test upconv_7 ${tta} + benchmark_art art_test upconv_7l ${tta} + done +} +#run_benchmark_photo +run_benchmark_art diff --git a/appendix/caffe_prototxt/resnet_14l.prototxt b/appendix/caffe_prototxt/resnet_14l.prototxt new file mode 100644 index 0000000..4c0221b --- /dev/null +++ b/appendix/caffe_prototxt/resnet_14l.prototxt @@ -0,0 +1,524 @@ +name: "resnet_14l" +layer { + name: "input" + type: "Input" + top: "input" + input_param { shape: { dim: 1 dim: 3 dim: 156 dim: 156 } } +} +layer { + name: "Convolution1" + type: "Convolution" + bottom: "input" + top: "Convolution1" + convolution_param { + num_output: 32 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU1" + type: "ReLU" + bottom: "Convolution1" + top: "Convolution1" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution2" + type: "Convolution" + bottom: "Convolution1" + top: "Convolution2" + convolution_param { + num_output: 64 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU2" + type: "ReLU" + bottom: "Convolution2" + top: "Convolution2" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution3" + type: "Convolution" + bottom: "Convolution2" + top: "Convolution3" + convolution_param { + num_output: 64 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU3" + type: "ReLU" + bottom: "Convolution3" + top: "Convolution3" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution4" + type: "Convolution" + bottom: "Convolution1" + top: "Convolution4" + convolution_param { + num_output: 64 + bias_term: true + pad: 0 + kernel_size: 1 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "Crop1" + type: "Crop" + bottom: "Convolution4" + bottom: "Convolution3" + top: "Crop1" + crop_param { + axis: 2 + offset: 2 + offset: 2 + } +} +layer { + name: "Eltwise1" + type: "Eltwise" + bottom: "Convolution3" + bottom: "Crop1" + top: "Eltwise1" + eltwise_param { + operation: SUM + } +} +layer { + name: "Convolution5" + type: "Convolution" + bottom: "Eltwise1" + top: "Convolution5" + convolution_param { + num_output: 64 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU4" + type: "ReLU" + bottom: "Convolution5" + top: "Convolution5" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution6" + type: "Convolution" + bottom: "Convolution5" + top: "Convolution6" + convolution_param { + num_output: 64 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU5" + type: "ReLU" + bottom: "Convolution6" + top: "Convolution6" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Crop2" + type: "Crop" + bottom: "Eltwise1" + bottom: "Convolution6" + top: "Crop2" + crop_param { + axis: 2 + offset: 2 + offset: 2 + } +} +layer { + name: "Eltwise2" + type: "Eltwise" + bottom: "Convolution6" + bottom: "Crop2" + top: "Eltwise2" + eltwise_param { + operation: SUM + } +} +layer { + name: "Convolution7" + type: "Convolution" + bottom: "Eltwise2" + top: "Convolution7" + convolution_param { + num_output: 128 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU6" + type: "ReLU" + bottom: "Convolution7" + top: "Convolution7" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution8" + type: "Convolution" + bottom: "Convolution7" + top: "Convolution8" + convolution_param { + num_output: 128 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU7" + type: "ReLU" + bottom: "Convolution8" + top: "Convolution8" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution9" + type: "Convolution" + bottom: "Eltwise2" + top: "Convolution9" + convolution_param { + num_output: 128 + bias_term: true + pad: 0 + kernel_size: 1 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "Crop3" + type: "Crop" + bottom: "Convolution9" + bottom: "Convolution8" + top: "Crop3" + crop_param { + axis: 2 + offset: 2 + offset: 2 + } +} +layer { + name: "Eltwise3" + type: "Eltwise" + bottom: "Convolution8" + bottom: "Crop3" + top: "Eltwise3" + eltwise_param { + operation: SUM + } +} +layer { + name: "Convolution10" + type: "Convolution" + bottom: "Eltwise3" + top: "Convolution10" + convolution_param { + num_output: 128 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU8" + type: "ReLU" + bottom: "Convolution10" + top: "Convolution10" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution11" + type: "Convolution" + bottom: "Convolution10" + top: "Convolution11" + convolution_param { + num_output: 128 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU9" + type: "ReLU" + bottom: "Convolution11" + top: "Convolution11" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Crop4" + type: "Crop" + bottom: "Eltwise3" + bottom: "Convolution11" + top: "Crop4" + crop_param { + axis: 2 + offset: 2 + offset: 2 + } +} +layer { + name: "Eltwise4" + type: "Eltwise" + bottom: "Convolution11" + bottom: "Crop4" + top: "Eltwise4" + eltwise_param { + operation: SUM + } +} +layer { + name: "Convolution12" + type: "Convolution" + bottom: "Eltwise4" + top: "Convolution12" + convolution_param { + num_output: 256 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU10" + type: "ReLU" + bottom: "Convolution12" + top: "Convolution12" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution13" + type: "Convolution" + bottom: "Convolution12" + top: "Convolution13" + convolution_param { + num_output: 256 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU11" + type: "ReLU" + bottom: "Convolution13" + top: "Convolution13" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution14" + type: "Convolution" + bottom: "Eltwise4" + top: "Convolution14" + convolution_param { + num_output: 256 + bias_term: true + pad: 0 + kernel_size: 1 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "Crop5" + type: "Crop" + bottom: "Convolution14" + bottom: "Convolution13" + top: "Crop5" + crop_param { + axis: 2 + offset: 2 + offset: 2 + } +} +layer { + name: "Eltwise5" + type: "Eltwise" + bottom: "Convolution13" + bottom: "Crop5" + top: "Eltwise5" + eltwise_param { + operation: SUM + } +} +layer { + name: "Convolution15" + type: "Convolution" + bottom: "Eltwise5" + top: "Convolution15" + convolution_param { + num_output: 256 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU12" + type: "ReLU" + bottom: "Convolution15" + top: "Convolution15" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Convolution16" + type: "Convolution" + bottom: "Convolution15" + top: "Convolution16" + convolution_param { + num_output: 256 + bias_term: true + pad: 0 + kernel_size: 3 + stride: 1 + weight_filler { + type: "msra" + } + } +} +layer { + name: "ReLU13" + type: "ReLU" + bottom: "Convolution16" + top: "Convolution16" + relu_param { + negative_slope: 0.1 + } +} +layer { + name: "Crop6" + type: "Crop" + bottom: "Eltwise5" + bottom: "Convolution16" + top: "Crop6" + crop_param { + axis: 2 + offset: 2 + offset: 2 + } +} +layer { + name: "Eltwise6" + type: "Eltwise" + bottom: "Convolution16" + bottom: "Crop6" + top: "Eltwise6" + eltwise_param { + operation: SUM + } +} +layer { + name: "Deconvolution1" + type: "Deconvolution" + bottom: "Eltwise6" + top: "Deconvolution1" + convolution_param { + num_output: 3 + pad: 3 + kernel_size: 4 + stride: 2 + } +} diff --git a/convert_data.lua b/convert_data.lua index f055741..ef620e8 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -82,46 +82,50 @@ local function load_images(list) local skip = false local alpha_color = torch.random(0, 1) - if meta and meta.alpha then - if settings.use_transparent_png then - im = alpha_util.fill(im, meta.alpha, alpha_color) - else - skip = true - end - end - if skip then - if not skip_notice then - io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n") - skip_notice = true - end - else - if csv_meta and csv_meta.x then - -- method == user - local yy = im - local xx, meta2 = image_loader.load_byte(csv_meta.x) - if meta2 and meta2.alpha then - xx = alpha_util.fill(xx, meta2.alpha, alpha_color) + if im then + if meta and meta.alpha then + if settings.use_transparent_png then + im = alpha_util.fill(im, meta.alpha, alpha_color) + else + skip = true end - xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size) - table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)}, - {data = {filters = filters, has_x = true}}}) - else - im = crop_if_large(im, settings.max_training_image_size) - im = iproc.crop_mod4(im) - local scale = 1.0 - if settings.random_half_rate > 0.0 then - scale = 2.0 + end + if skip then + if not skip_notice then + io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n") + skip_notice = true end - if im then + else + if csv_meta and csv_meta.x then + -- method == user + local yy = im + local xx, meta2 = image_loader.load_byte(csv_meta.x) + if xx then + if meta2 and meta2.alpha then + xx = alpha_util.fill(xx, meta2.alpha, alpha_color) + end + xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size) + table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)}, + {data = {filters = filters, has_x = true}}}) + else + io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x)) + end + else + im = crop_if_large(im, settings.max_training_image_size) + im = iproc.crop_mod4(im) + local scale = 1.0 + if settings.random_half_rate > 0.0 then + scale = 2.0 + end if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then table.insert(x, {compression.compress(im), {data = {filters = filters}}}) else io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN)) end - else - io.stderr:write(string.format("\n%s: skip: load error.\n", filename)) end end + else + io.stderr:write(string.format("\n%s: skip: load error.\n", filename)) end xlua.progress(i, #csv) if i % 10 == 0 then diff --git a/lib/ClippedMSECriterion.lua b/lib/ClippedMSECriterion.lua index 19336ee..204d156 100644 --- a/lib/ClippedMSECriterion.lua +++ b/lib/ClippedMSECriterion.lua @@ -5,12 +5,14 @@ function ClippedMSECriterion:__init(min, max) self.min = min self.max = max self.diff = torch.Tensor() + self.diff_pow2 = torch.Tensor() end function ClippedMSECriterion:updateOutput(input, target) self.diff:resizeAs(input):copy(input) self.diff:clamp(self.min, self.max) self.diff:add(-1, target) - self.output = self.diff:pow(2):sum() / input:nElement() + self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2) + self.output = self.diff_pow2:sum() / input:nElement() return self.output end function ClippedMSECriterion:updateGradInput(input, target) diff --git a/lib/InplaceClip01.lua b/lib/InplaceClip01.lua new file mode 100644 index 0000000..93095d0 --- /dev/null +++ b/lib/InplaceClip01.lua @@ -0,0 +1,13 @@ +local Clip01, parent = torch.class("w2nn.InplaceClip01", "nn.Module") + +function Clip01:__init() + parent.__init(self) +end +function Clip01:updateOutput(input) + self.output:set(input:clamp(0, 1)) + return self.output +end +function Clip01:updateGradInput(input, gradOutput) + self.gradInput:set(gradOutput) + return self.gradInput +end diff --git a/lib/L1Criterion.lua b/lib/L1Criterion.lua new file mode 100644 index 0000000..c706ece --- /dev/null +++ b/lib/L1Criterion.lua @@ -0,0 +1,27 @@ +-- ref: https://en.wikipedia.org/wiki/L1_loss +local L1Criterion, parent = torch.class('w2nn.L1Criterion','nn.Criterion') + +function L1Criterion:__init() + parent.__init(self) + self.diff = torch.Tensor() + self.linear_loss_buff = torch.Tensor() +end +function L1Criterion:updateOutput(input, target) + self.diff:resizeAs(input):copy(input) + if input:dim() == 1 then + self.diff[1] = input[1] - target + else + for i = 1, input:size(1) do + self.diff[i]:add(-1, target[i]) + end + end + local linear_targets = self.diff + local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():sum() + self.output = (linear_loss) / input:nElement() + return self.output +end +function L1Criterion:updateGradInput(input, target) + local norm = 1.0 / input:nElement() + self.gradInput:resizeAs(self.diff):copy(self.diff):sign():mul(norm) + return self.gradInput +end diff --git a/lib/SSIMCriterion.lua b/lib/SSIMCriterion.lua new file mode 100644 index 0000000..9137920 --- /dev/null +++ b/lib/SSIMCriterion.lua @@ -0,0 +1,67 @@ +-- SSIM Index, ref: http://www.cns.nyu.edu/~lcv/ssim/ssim_index.m +local SSIMCriterion, parent = torch.class('w2nn.SSIMCriterion','nn.Criterion') +function SSIMCriterion:__init(ch, kernel_size, sigma) + parent.__init(self) + local function gaussian2d(kernel_size, sigma) + sigma = sigma or 1 + local kernel = torch.Tensor(kernel_size, kernel_size) + local u = math.floor(kernel_size / 2) + 1 + local amp = (1 / math.sqrt(2 * math.pi * sigma^2)) + for x = 1, kernel_size do + for y = 1, kernel_size do + kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2)) + end + end + kernel:div(kernel:sum()) + return kernel + end + ch = ch or 1 + kernel_size = kernel_size or 11 + sigma = sigma or 1.5 + local kernel = gaussian2d(kernel_size, sigma) + if ch > 1 then + local kernel_nd = torch.Tensor(ch, ch, kernel_size, kernel_size) + for i = 1, ch do + for j = 1, ch do + kernel_nd[i][j]:copy(kernel) + if i ~= j then + kernel_nd[i][j]:zero() + end + end + end + kernel = kernel_nd + end + self.c1 = 0.01^2 + self.c2 = 0.03^2 + self.ch = ch + self.conv = nn.SpatialConvolution(ch, ch, kernel_size, kernel_size, 1, 1, 0, 0):noBias() + self.conv.weight:copy(kernel) + self.mu1 = torch.Tensor() + self.mu2 = torch.Tensor() + self.mu1_sq = torch.Tensor() + self.mu2_sq = torch.Tensor() + self.mu1_mu2 = torch.Tensor() + self.sigma1_sq = torch.Tensor() + self.sigma2_sq = torch.Tensor() + self.sigma12 = torch.Tensor() + self.ssim_map = torch.Tensor() +end +function SSIMCriterion:updateOutput(input, target)-- dynamic range: 0-1 + assert(input:nElement() == target:nElement()) + local valid = self.conv:forward(input) + self.mu1:resizeAs(valid):copy(valid) + self.mu2:resizeAs(valid):copy(self.conv:forward(target)) + self.mu1_sq:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu1) + self.mu2_sq:resizeAs(self.mu2):copy(self.mu2):cmul(self.mu2) + self.mu1_mu2:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu2) + self.sigma1_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, input)):add(-1, self.mu1_sq)) + self.sigma2_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(target, target)):add(-1, self.mu2_sq)) + self.sigma12:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, target)):add(-1, self.mu1_mu2)) + + local ssim = self.mu1_mu2:mul(2):add(self.c1):cmul(self.sigma12:mul(2):add(self.c2)): + cdiv(self.mu1_sq:add(self.mu2_sq):add(self.c1):cmul(self.sigma1_sq:add(self.sigma2_sq):add(self.c2))):mean() + return ssim +end +function SSIMCriterion:updateGradInput(input, target) + error("not implemented") +end diff --git a/lib/alpha_util.lua b/lib/alpha_util.lua index ad8266f..c80c06d 100644 --- a/lib/alpha_util.lua +++ b/lib/alpha_util.lua @@ -40,8 +40,7 @@ function alpha_util.make_border(rgb, alpha, offset) collectgarbage() end end - rgb[torch.gt(rgb, 1.0)] = 1.0 - rgb[torch.lt(rgb, 0.0)] = 0.0 + rgb:clamp(0.0, 1.0) return rgb end diff --git a/lib/data_augmentation.lua b/lib/data_augmentation.lua index 7bde0e1..a8f9b60 100644 --- a/lib/data_augmentation.lua +++ b/lib/data_augmentation.lua @@ -1,7 +1,8 @@ -require 'image' +require 'pl' +require 'cunn' local iproc = require 'iproc' -local gm = require 'graphicsmagick' - +local gm = {} +gm.Image = require 'graphicsmagick.Image' local data_augmentation = {} local function pcacov(x) @@ -25,8 +26,7 @@ function data_augmentation.color_noise(src, p, factor) pca_space[i]:mul(color_scale[i]) end local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src) - dest[torch.lt(dest, 0.0)] = 0.0 - dest[torch.gt(dest, 1.0)] = 1.0 + dest:clamp(0.0, 1.0) if conversion then dest = iproc.float2byte(dest) @@ -70,6 +70,75 @@ function data_augmentation.unsharp_mask(src, p) return src end end +function data_augmentation.blur(src, p, size, sigma_min, sigma_max) + size = size or "3" + filters = utils.split(size, ",") + for i = 1, #filters do + local s = tonumber(filters[i]) + filters[i] = s + end + if torch.uniform() < p then + local src, conversion = iproc.byte2float(src) + local kernel_size = filters[torch.random(1, #filters)] + local sigma + if sigma_min == sigma_max then + sigma = sigma_min + else + sigma = torch.uniform(sigma_min, sigma_max) + end + local kernel = iproc.gaussian2d(kernel_size, sigma) + local dest = image.convolve(src, kernel, 'same') + if conversion then + dest = iproc.float2byte(dest) + end + return dest + else + return src + end +end +function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max) + if torch.uniform() < p then + assert(x:size(2) == y:size(2) and x:size(3) == y:size(3)) + local scale = torch.uniform(scale_min, scale_max) + local h = math.floor(x:size(2) * scale) + local w = math.floor(x:size(3) * scale) + x = iproc.scale(x, w, h, "Triangle") + y = iproc.scale(y, w, h, "Triangle") + return x, y + else + return x, y + end +end +function data_augmentation.pairwise_rotate(x, y, p, r_min, r_max) + if torch.uniform() < p then + assert(x:size(2) == y:size(2) and x:size(3) == y:size(3)) + local r = torch.uniform(r_min, r_max) / 360.0 * math.pi + x = iproc.rotate(x, r) + y = iproc.rotate(y, r) + return x, y + else + return x, y + end +end +function data_augmentation.pairwise_negate(x, y, p) + if torch.uniform() < p then + assert(x:size(2) == y:size(2) and x:size(3) == y:size(3)) + x = iproc.negate(x) + y = iproc.negate(y) + return x, y + else + return x, y + end +end +function data_augmentation.pairwise_negate_x(x, y, p) + if torch.uniform() < p then + assert(x:size(2) == y:size(2) and x:size(3) == y:size(3)) + x = iproc.negate(x) + return x, y + else + return x, y + end +end function data_augmentation.shift_1px(src) -- reducing the even/odd issue in nearest neighbor scaler. local direction = torch.random(1, 4) @@ -107,11 +176,11 @@ function data_augmentation.flip(src) src = src:transpose(2, 3):contiguous() end if flip == 1 then - dest = image.hflip(src) + dest = iproc.hflip(src) elseif flip == 2 then - dest = image.vflip(src) + dest = iproc.vflip(src) elseif flip == 3 then - dest = image.hflip(image.vflip(src)) + dest = iproc.hflip(iproc.vflip(src)) elseif flip == 4 then dest = src end @@ -120,4 +189,20 @@ function data_augmentation.flip(src) end return dest end + +local function test_blur() + torch.setdefaulttensortype("torch.FloatTensor") + local image =require 'image' + local src = image.lena() + + image.display({image = src, min=0, max=1}) + local dest = data_augmentation.blur(src, 1.0, "3,5", 0.5, 0.6) + image.display({image = dest, min=0, max=1}) + dest = data_augmentation.blur(src, 1.0, "3", 1.0, 1.0) + image.display({image = dest, min=0, max=1}) + dest = data_augmentation.blur(src, 1.0, "5", 0.75, 0.75) + image.display({image = dest, min=0, max=1}) +end +--test_blur() + return data_augmentation diff --git a/lib/image_loader.lua b/lib/image_loader.lua index 76fec1b..56a994f 100644 --- a/lib/image_loader.lua +++ b/lib/image_loader.lua @@ -22,8 +22,7 @@ function image_loader.encode_png(rgb, options) else rgb = rgb:clone():add(clip_eps8) end - rgb[torch.lt(rgb, 0.0)] = 0.0 - rgb[torch.gt(rgb, 1.0)] = 1.0 + rgb:clamp(0.0, 1.0) rgb = rgb:mul(255):floor():div(255) else if options.inplace then @@ -31,8 +30,7 @@ function image_loader.encode_png(rgb, options) else rgb = rgb:clone():add(clip_eps16) end - rgb[torch.lt(rgb, 0.0)] = 0.0 - rgb[torch.gt(rgb, 1.0)] = 1.0 + rgb:clamp(0.0, 1.0) rgb = rgb:mul(65535):floor():div(65535) end local im diff --git a/lib/iproc.lua b/lib/iproc.lua index 240bb9e..b4e6e17 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -1,6 +1,7 @@ -local gm = require 'graphicsmagick' +local gm = {} +gm.Image = require 'graphicsmagick.Image' +require 'dok' local image = require 'image' - local iproc = {} local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5) @@ -42,8 +43,7 @@ function iproc.float2byte(src) if src:type() == "torch.FloatTensor" then conversion = true dest = (src + clip_eps8):mul(255.0) - dest[torch.lt(dest, 0.0)] = 0 - dest[torch.gt(dest, 255.0)] = 255.0 + dest:clamp(0, 255.0) dest = dest:byte() end return dest, conversion @@ -80,6 +80,7 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur) return dest end function iproc.padding(img, w1, w2, h1, h2) + image = image or require 'image' local dst_height = img:size(2) + h1 + h2 local dst_width = img:size(3) + w1 + w2 local flow = torch.Tensor(2, dst_height, dst_width) @@ -90,6 +91,7 @@ function iproc.padding(img, w1, w2, h1, h2) return image.warp(img, flow, "simple", false, "clamp") end function iproc.zero_padding(img, w1, w2, h1, h2) + image = image or require 'image' local dst_height = img:size(2) + h1 + h2 local dst_width = img:size(3) + w1 + w2 local flow = torch.Tensor(2, dst_height, dst_width) @@ -115,8 +117,7 @@ function iproc.white_noise(src, std, rgb_weights, gamma) local dest if gamma ~= 0 then dest = src:clone():pow(gamma):add(noise) - dest[torch.lt(dest, 0.0)] = 0.0 - dest[torch.gt(dest, 1.0)] = 1.0 + dest:clamp(0.0, 1.0) dest:pow(1.0 / gamma) else dest = src + noise @@ -126,6 +127,101 @@ function iproc.white_noise(src, std, rgb_weights, gamma) end return dest end +function iproc.hflip(src) + local t + if src:type() == "torch.ByteTensor" then + t = "byte" + else + t = "float" + end + if src:size(1) == 3 then + color = "RGB" + else + color = "I" + end + local im = gm.Image(src, color, "DHW") + return im:flop():toTensor(t, color, "DHW") +end +function iproc.vflip(src) + local t + if src:type() == "torch.ByteTensor" then + t = "byte" + else + t = "float" + end + if src:size(1) == 3 then + color = "RGB" + else + color = "I" + end + local im = gm.Image(src, color, "DHW") + return im:flip():toTensor(t, color, "DHW") +end +local function rotate_with_warp(src, dst, theta, mode) + local height + local width + if src:dim() == 2 then + height = src:size(1) + width = src:size(2) + elseif src:dim() == 3 then + height = src:size(2) + width = src:size(3) + else + dok.error('src image must be 2D or 3D', 'image.rotate') + end + local flow = torch.Tensor(2, height, width) + local kernel = torch.Tensor({{math.cos(-theta), -math.sin(-theta)}, + {math.sin(-theta), math.cos(-theta)}}) + flow[1] = torch.ger(torch.linspace(0, 1, height), torch.ones(width)) + flow[1]:mul(-(height -1)):add(math.floor(height / 2 + 0.5)) + flow[2] = torch.ger(torch.ones(height), torch.linspace(0, 1, width)) + flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5)) + flow:add(-1, torch.mm(kernel, flow:view(2, height * width))) + dst:resizeAs(src) + return image.warp(dst, src, flow, mode, true, 'clamp') +end +function iproc.rotate(src, theta) + local conversion + src, conversion = iproc.byte2float(src) + local dest = torch.Tensor():typeAs(src):resizeAs(src) + rotate_with_warp(src, dest, theta, 'bilinear') + dest:clamp(0, 1) + if conversion then + dest = iproc.float2byte(dest) + end + return dest +end +function iproc.negate(src) + if src:type() == "torch.ByteTensor" then + return -src + 255 + else + return -src + 1 + end +end + +function iproc.gaussian2d(kernel_size, sigma) + sigma = sigma or 1 + local kernel = torch.Tensor(kernel_size, kernel_size) + local u = math.floor(kernel_size / 2) + 1 + local amp = (1 / math.sqrt(2 * math.pi * sigma^2)) + for x = 1, kernel_size do + for y = 1, kernel_size do + kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2)) + end + end + kernel:div(kernel:sum()) + return kernel +end +function iproc.rgb2y(src) + local conversion + src, conversion = iproc.byte2float(src) + local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero() + dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3]) + if conversion then + dest = iproc.float2byte(dest) + end + return dest +end local function test_conversion() local a = torch.linspace(0, 255, 256):float():div(255.0) @@ -144,6 +240,46 @@ local function test_conversion() print(b) assert(b:float():sum() == 254.0 * 3) end +local function test_flip() + require 'sys' + require 'torch' + torch.setdefaulttensortype("torch.FloatTensor") + image = require 'image' + local src = image.lena() + local src_byte = src:clone():mul(255):byte() + + print(src:size()) + print((image.hflip(src) - iproc.hflip(src)):sum()) + print((image.hflip(src_byte) - iproc.hflip(src_byte)):sum()) + print((image.vflip(src) - iproc.vflip(src)):sum()) + print((image.vflip(src_byte) - iproc.vflip(src_byte)):sum()) +end +local function test_gaussian2d() + local t = {3, 5, 7} + for i = 1, #t do + local kp = iproc.gaussian2d(t[i], 0.5) + print(kp) + end +end +local function test_conv() + local image = require 'image' + local src = image.lena() + local kernel = torch.Tensor(3, 3):fill(1) + kernel:div(kernel:sum()) + --local blur = image.convolve(iproc.padding(src, 1, 1, 1, 1), kernel, 'valid') + local blur = image.convolve(src, kernel, 'same') + print(src:size(), blur:size()) + local diff = (blur - src):abs() + image.save("diff.png", diff) + image.display({image = blur, min=0, max=1}) + image.display({image = diff, min=0, max=1}) +end + --test_conversion() +--test_flip() +--test_gaussian2d() +--test_conv() return iproc + + diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index dbe70f4..96f189f 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric, local output = model:forward(inputs) local f = criterion:forward(output, targets) local se = 0 - for i = 1, batch_size do - local el = eval_metric:forward(output[i], targets[i]) - se = se + el - instance_loss[shuffle[t + i - 1]] = el + if config.xInstanceLoss then + for i = 1, batch_size do + local el = eval_metric:forward(output[i], targets[i]) + se = se + el + instance_loss[shuffle[t + i - 1]] = el + end + se = (se / batch_size) + else + se = eval_metric:forward(output, targets) end - sum_eval = sum_eval + (se / batch_size) + sum_eval = sum_eval + se sum_loss = sum_loss + f count_loss = count_loss + 1 model:backward(inputs, criterion:backward(output, targets)) diff --git a/lib/pairwise_transform_jpeg.lua b/lib/pairwise_transform_jpeg.lua index e9cecc1..23cde45 100644 --- a/lib/pairwise_transform_jpeg.lua +++ b/lib/pairwise_transform_jpeg.lua @@ -1,5 +1,6 @@ local pairwise_utils = require 'pairwise_transform_utils' -local gm = require 'graphicsmagick' +local gm = {} +gm.Image = require 'graphicsmagick.Image' local iproc = require 'iproc' local pairwise_transform = {} @@ -42,8 +43,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options) yc = iproc.byte2float(yc) if options.rgb then else - yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) - xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + yc = iproc.rgb2y(yc) + xc = iproc.rgb2y(xc) end if torch.uniform() < options.nr_rate then -- reducing noise diff --git a/lib/pairwise_transform_jpeg_scale.lua b/lib/pairwise_transform_jpeg_scale.lua index 0b3000f..3740bb6 100644 --- a/lib/pairwise_transform_jpeg_scale.lua +++ b/lib/pairwise_transform_jpeg_scale.lua @@ -1,6 +1,7 @@ local pairwise_utils = require 'pairwise_transform_utils' local iproc = require 'iproc' -local gm = require 'graphicsmagick' +local gm = {} +gm.Image = require 'graphicsmagick.Image' local pairwise_transform = {} local function add_jpeg_noise_(x, quality, options) @@ -117,8 +118,8 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off yc = iproc.byte2float(yc) if options.rgb then else - yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) - xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + yc = iproc.rgb2y(yc) + xc = iproc.rgb2y(xc) end table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) end diff --git a/lib/pairwise_transform_scale.lua b/lib/pairwise_transform_scale.lua index e2d7674..7dd18f1 100644 --- a/lib/pairwise_transform_scale.lua +++ b/lib/pairwise_transform_scale.lua @@ -1,6 +1,7 @@ local pairwise_utils = require 'pairwise_transform_utils' local iproc = require 'iproc' -local gm = require 'graphicsmagick' +local gm = {} +gm.Image = require 'graphicsmagick.Image' local pairwise_transform = {} function pairwise_transform.scale(src, scale, size, offset, n, options) @@ -50,8 +51,8 @@ function pairwise_transform.scale(src, scale, size, offset, n, options) yc = iproc.byte2float(yc) if options.rgb then else - yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) - xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + yc = iproc.rgb2y(yc) + xc = iproc.rgb2y(xc) end table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) end diff --git a/lib/pairwise_transform_user.lua b/lib/pairwise_transform_user.lua index 6d00993..279b343 100644 --- a/lib/pairwise_transform_user.lua +++ b/lib/pairwise_transform_user.lua @@ -1,43 +1,30 @@ local pairwise_utils = require 'pairwise_transform_utils' local iproc = require 'iproc' -local gm = require 'graphicsmagick' +local gm = {} +gm.Image = require 'graphicsmagick.Image' local pairwise_transform = {} -local function crop_if_large(x, y, scale_y, max_size, mod) - local tries = 4 - if y:size(2) > max_size and y:size(3) > max_size then - assert(max_size % 4 == 0) - local rect_x, rect_y - for i = 1, tries do - local yi = torch.random(0, y:size(2) - max_size) - local xi = torch.random(0, y:size(3) - max_size) - if mod then - yi = yi - (yi % mod) - xi = xi - (xi % mod) - end - rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size) - rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y) - -- ignore simple background - if rect_y:float():std() >= 0 then - break - end - end - return rect_x, rect_y - else - return x, y - end -end function pairwise_transform.user(x, y, size, offset, n, options) assert(x:size(1) == y:size(1)) local scale_y = y:size(2) / x:size(2) assert(x:size(3) == y:size(3) / scale_y) - x, y = crop_if_large(x, y, scale_y, options.max_size, scale_y) + x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options) assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y) local batch = {} - local lowres_y = pairwise_utils.low_resolution(y) - local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y) + local lowres_y = nil + local xs ={x} + local ys = {y} + local ls = {} + + if options.active_cropping_rate > 0 then + lowres_y = pairwise_utils.low_resolution(y) + end + if options.pairwise_flip then + xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y) + end + assert(#xs == #ys) for i = 1, n do local t = (i % #xs) + 1 local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y, @@ -47,8 +34,17 @@ function pairwise_transform.user(x, y, size, offset, n, options) yc = iproc.byte2float(yc) if options.rgb then else - yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) - xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + yc = iproc.rgb2y(yc) + xc = iproc.rgb2y(xc) + end + if options.gcn then + local mean = xc:mean() + local stdv = xc:std() + if stdv > 0 then + xc:add(-mean):div(stdv) + else + xc:add(-mean) + end end table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) end diff --git a/lib/pairwise_transform_utils.lua b/lib/pairwise_transform_utils.lua index 0bc2bdd..5938a86 100644 --- a/lib/pairwise_transform_utils.lua +++ b/lib/pairwise_transform_utils.lua @@ -1,7 +1,7 @@ -require 'image' require 'cunn' local iproc = require 'iproc' -local gm = require 'graphicsmagick' +local gm = {} +gm.Image = require 'graphicsmagick.Image' local data_augmentation = require 'data_augmentation' local pairwise_transform_utils = {} @@ -36,6 +36,30 @@ function pairwise_transform_utils.crop_if_large(src, max_size, mod) return src end end +function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod) + local tries = 4 + if y:size(2) > max_size and y:size(3) > max_size then + assert(max_size % 4 == 0) + local rect_x, rect_y + for i = 1, tries do + local yi = torch.random(0, y:size(2) - max_size) + local xi = torch.random(0, y:size(3) - max_size) + if mod then + yi = yi - (yi % mod) + xi = xi - (xi % mod) + end + rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size) + rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y) + -- ignore simple background + if rect_y:float():std() >= 0 then + break + end + end + return rect_x, rect_y + else + return x, y + end +end function pairwise_transform_utils.preprocess(src, crop_size, options) local dest = src local box_only = false @@ -47,7 +71,6 @@ function pairwise_transform_utils.preprocess(src, crop_size, options) if box_only then local mod = 2 -- assert pos % 2 == 0 dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod) - dest = data_augmentation.flip(dest) dest = data_augmentation.color_noise(dest, options.random_color_noise_rate) dest = data_augmentation.overlay(dest, options.random_overlay_rate) dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate) @@ -55,7 +78,10 @@ function pairwise_transform_utils.preprocess(src, crop_size, options) else dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters) dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size)) - dest = data_augmentation.flip(dest) + dest = data_augmentation.blur(dest, options.random_blur_rate, + options.random_blur_size, + options.random_blur_sigma_min, + options.random_blur_sigma_max) dest = data_augmentation.color_noise(dest, options.random_color_noise_rate) dest = data_augmentation.overlay(dest, options.random_overlay_rate) dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate) @@ -63,6 +89,33 @@ function pairwise_transform_utils.preprocess(src, crop_size, options) end return dest end +function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options) + + x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y) + x, y = data_augmentation.pairwise_rotate(x, y, + options.random_pairwise_rotate_rate, + options.random_pairwise_rotate_min, + options.random_pairwise_rotate_max) + + local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3)))) + local scale_max = math.max(scale_min, options.random_pairwise_scale_max) + x, y = data_augmentation.pairwise_scale(x, y, + options.random_pairwise_scale_rate, + scale_min, + scale_max) + x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate) + x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate) + + x = iproc.crop_mod4(x) + y = iproc.crop_mod4(y) + + if options.pairwise_y_binary then + y[torch.lt(y, 128)] = 0 + y[torch.gt(y, 0)] = 255 + end + + return x, y +end function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries) assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3)) assert("crop_size % scale == 0", size % scale == 0) @@ -111,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise) for j = 1, 2 do -- TTA - local xi, yi, ri + local xi, yi, ri, ni if j == 1 then xi = x ni = x_noise @@ -123,42 +176,55 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise) ni = x_noise:transpose(2, 3):contiguous() end yi = y:transpose(2, 3):contiguous() - ri = lowres_y:transpose(2, 3):contiguous() + if lowres_y then + ri = lowres_y:transpose(2, 3):contiguous() + end end - local xv = image.vflip(xi) + local xv = iproc.vflip(xi) local nv if x_noise then - nv = image.vflip(ni) + nv = iproc.vflip(ni) + end + local yv = iproc.vflip(yi) + local rv + if ri then + rv = iproc.vflip(ri) end - local yv = image.vflip(yi) - local rv = image.vflip(ri) table.insert(xs, xi) if ni then table.insert(ns, ni) end table.insert(ys, yi) - table.insert(ls, ri) + if ri then + table.insert(ls, ri) + end table.insert(xs, xv) if nv then table.insert(ns, nv) end table.insert(ys, yv) - table.insert(ls, rv) + if rv then + table.insert(ls, rv) + end - table.insert(xs, image.hflip(xi)) + table.insert(xs, iproc.hflip(xi)) if ni then - table.insert(ns, image.hflip(ni)) + table.insert(ns, iproc.hflip(ni)) + end + table.insert(ys, iproc.hflip(yi)) + if ri then + table.insert(ls, iproc.hflip(ri)) end - table.insert(ys, image.hflip(yi)) - table.insert(ls, image.hflip(ri)) - table.insert(xs, image.hflip(xv)) + table.insert(xs, iproc.hflip(xv)) if nv then - table.insert(ns, image.hflip(nv)) + table.insert(ns, iproc.hflip(nv)) + end + table.insert(ys, iproc.hflip(yv)) + if rv then + table.insert(ls, iproc.hflip(rv)) end - table.insert(ys, image.hflip(yv)) - table.insert(ls, image.hflip(rv)) end return xs, ys, ls, ns end @@ -171,6 +237,9 @@ end local g_lowres_model = nil local g_lowres_gpu = nil function pairwise_transform_utils.low_resolution(src) +--[[ + -- I am not sure that the following process is thraed-safe + g_lowres_model = g_lowres_model or lowres_model() if g_lowres_gpu == nil then --benchmark @@ -203,6 +272,11 @@ function pairwise_transform_utils.low_resolution(src) size(src:size(3), src:size(2), "Box"): toTensor("byte", "RGB", "DHW") end +--]] + return gm.Image(src, "RGB", "DHW"): + size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"): + size(src:size(3), src:size(2), "Box"): + toTensor("byte", "RGB", "DHW") end return pairwise_transform_utils diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 138bb6f..6c37f15 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -40,6 +40,15 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s break end input[j+1]:copy(x[input_indexes[i + j]]) + if model.w2nn_gcn then + local mean = input[j + 1]:mean() + local stdv = input[j + 1]:std() + if stdv > 0 then + input[j + 1]:add(-mean):div(stdv) + else + input[j + 1]:add(-mean) + end + end c = c + 1 end input_cuda:copy(input) @@ -80,7 +89,12 @@ local function padding_params(x, model, block_size) p.x_w = x:size(3) p.x_h = x:size(2) p.inner_scale = reconstruct.inner_scale(model) - local input_offset = math.ceil(offset / p.inner_scale) + local input_offset + if model.w2nn_input_offset then + input_offset = model.w2nn_input_offset + else + input_offset = math.ceil(offset / p.inner_scale) + end local input_block_size = block_size local process_size = input_block_size - input_offset * 2 local h_blocks = math.floor(p.x_h / process_size) + @@ -172,6 +186,9 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size) return output end function reconstruct.image(model, x, block_size) + if model.w2nn_input_size then + block_size = model.w2nn_input_size + end local i2rgb = false if x:size(1) == 1 then local new_x = torch.Tensor(3, x:size(2), x:size(3)) @@ -194,6 +211,9 @@ function reconstruct.image(model, x, block_size) return x end function reconstruct.scale(model, scale, x, block_size) + if model.w2nn_input_size then + block_size = model.w2nn_input_size + end local i2rgb = false if x:size(1) == 1 then local new_x = torch.Tensor(3, x:size(2), x:size(3)) @@ -287,6 +307,9 @@ local function tta(f, n, model, x, block_size) return average:div(#augments) end function reconstruct.image_tta(model, n, x, block_size) + if model.w2nn_input_size then + block_size = model.w2nn_input_size + end if reconstruct.is_rgb(model) then return tta(reconstruct.image_rgb, n, model, x, block_size) else @@ -294,6 +317,9 @@ function reconstruct.image_tta(model, n, x, block_size) end end function reconstruct.scale_tta(model, n, scale, x, block_size) + if model.w2nn_input_size then + block_size = model.w2nn_input_size + end if reconstruct.is_rgb(model) then local f = function (model, x, offset, block_size) return reconstruct.scale_rgb(model, scale, x, offset, block_size) diff --git a/lib/settings.lua b/lib/settings.lua index 1c66cb3..4507711 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -1,6 +1,7 @@ require 'xlua' require 'pl' require 'trepl' +require 'cutorch' -- global settings @@ -18,7 +19,7 @@ cmd:text() cmd:text("waifu2x-training") cmd:text("Options:") cmd:option("-gpu", -1, 'GPU Device ID') -cmd:option("-seed", 11, 'RNG seed') +cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)') cmd:option("-data_dir", "./data", 'path to data directory') cmd:option("-backend", "cunn", '(cunn|cudnn)') cmd:option("-test", "images/miku_small.png", 'path to test image') @@ -32,6 +33,20 @@ cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)') cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)') cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)') +cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0-1.0)') +cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)') +cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur') +cmd:option("-random_blur_sigma_max", 1.0, 'max sigma for random gaussian blur') +cmd:option("-random_pairwise_scale_rate", 0.0, 'data augmentation using pairwise resize for user method') +cmd:option("-random_pairwise_scale_min", 0.85, 'min scale factor for random pairwise scale') +cmd:option("-random_pairwise_scale_max", 1.176, 'max scale factor for random pairwise scale') +cmd:option("-random_pairwise_rotate_rate", 0.0, 'data augmentation using pairwise resize for user method') +cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairwise rotate') +cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate') +cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method') +cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method') +cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)') +cmd:option("-pairwise_flip", 1, 'use flip(0|1)') cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') cmd:option("-crop_size", 48, 'crop size') @@ -59,6 +74,8 @@ cmd:option("-oracle_drop_rate", 0.5, '') cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))') cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') +cmd:option("-gpu", 1, 'Device ID') +cmd:option("-loss", "huber", 'loss function (huber|l1|mse)') local function to_bool(settings, name) if settings[name] == 1 then @@ -75,6 +92,8 @@ end to_bool(settings, "plot") to_bool(settings, "save_history") to_bool(settings, "use_transparent_png") +to_bool(settings, "pairwise_y_binary") +to_bool(settings, "pairwise_flip") if settings.plot then require 'gnuplot' @@ -148,4 +167,6 @@ end settings.images = string.format("%s/images.t7", settings.data_dir) settings.image_list = string.format("%s/image_list.txt", settings.data_dir) +cutorch.setDevice(opt.gpu) + return settings diff --git a/lib/srcnn.lua b/lib/srcnn.lua index af06def..00c493f 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -136,6 +136,24 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH error("unsupported backend:" .. backend) end end +local function ReLU(backend) + if backend == "cunn" then + return nn.ReLU(true) + elseif backend == "cudnn" then + return cudnn.ReLU(true) + else + error("unsupported backend:" .. backend) + end +end +local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH) + if backend == "cunn" then + return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH) + elseif backend == "cudnn" then + return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH) + else + error("unsupported backend:" .. backend) + end +end -- VGG style net(7 layers) function srcnn.vgg_7(backend, ch) @@ -153,6 +171,7 @@ function srcnn.vgg_7(backend, ch) model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1, true)) model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.InplaceClip01()) model:add(nn.View(-1):setNumInputDims(3)) model.w2nn_arch_name = "vgg_7" @@ -190,6 +209,7 @@ function srcnn.vgg_12(backend, ch) model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1, true)) model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.InplaceClip01()) model:add(nn.View(-1):setNumInputDims(3)) model.w2nn_arch_name = "vgg_12" @@ -219,6 +239,7 @@ function srcnn.dilated_7(backend, ch) model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1, true)) model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.InplaceClip01()) model:add(nn.View(-1):setNumInputDims(3)) model.w2nn_arch_name = "dilated_7" @@ -249,6 +270,7 @@ function srcnn.upconv_7(backend, ch) model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1, true)) model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias()) + model:add(w2nn.InplaceClip01()) model:add(nn.View(-1):setNumInputDims(3)) model.w2nn_arch_name = "upconv_7" @@ -257,11 +279,255 @@ function srcnn.upconv_7(backend, ch) model.w2nn_resize = true model.w2nn_channels = ch + return model +end + +-- large version of upconv_7 +-- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7. +function srcnn.upconv_7l(backend, ch) + local model = nn.Sequential() + model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias()) + model:add(w2nn.InplaceClip01()) + model:add(nn.View(-1):setNumInputDims(3)) + + model.w2nn_arch_name = "upconv_7l" + model.w2nn_offset = 14 + model.w2nn_scale_factor = 2 + model.w2nn_resize = true + model.w2nn_channels = ch + --model:cuda() --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) return model end + +-- layerwise linear blending with skip connections +-- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l +function srcnn.skiplb_7(backend, ch) + local function skip(backend, i, o) + local con = nn.Concat(2) + local conv = nn.Sequential() + conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1)) + conv:add(nn.LeakyReLU(0.1, true)) + + -- depth concat + con:add(conv) + con:add(nn.Identity()) -- skip + return con + end + local model = nn.Sequential() + model:add(skip(backend, ch, 16)) + model:add(skip(backend, 16+ch, 32)) + model:add(skip(backend, 32+16+ch, 64)) + model:add(skip(backend, 64+32+16+ch, 128)) + model:add(skip(backend, 128+64+32+16+ch, 128)) + model:add(skip(backend, 128+128+64+32+16+ch, 256)) + -- input of last layer = [all layerwise output(contains input layer)].flatten + model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend + model:add(w2nn.InplaceClip01()) + model:add(nn.View(-1):setNumInputDims(3)) + model.w2nn_arch_name = "skiplb_7" + model.w2nn_offset = 14 + model.w2nn_scale_factor = 2 + model.w2nn_resize = true + model.w2nn_channels = ch + + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) + + return model +end + +-- dilated convolution + deconvolution +-- Note: This model is not better than upconv_7. Maybe becuase of under-fitting. +function srcnn.dilated_upconv_7(backend, ch) + local model = nn.Sequential() + model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias()) + model:add(w2nn.InplaceClip01()) + model:add(nn.View(-1):setNumInputDims(3)) + + model.w2nn_arch_name = "dilated_upconv_7" + model.w2nn_offset = 20 + model.w2nn_scale_factor = 2 + model.w2nn_resize = true + model.w2nn_channels = ch + + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) + + return model +end + +-- ref: https://arxiv.org/abs/1609.04802 +-- note: no batch-norm, no zero-paading +function srcnn.srresnet_2x(backend, ch) + local function resblock(backend) + local seq = nn.Sequential() + local con = nn.ConcatTable() + local conv = nn.Sequential() + conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + conv:add(ReLU(backend)) + conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + conv:add(ReLU(backend)) + con:add(conv) + con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding + seq:add(con) + seq:add(nn.CAddTable()) + return seq + end + local model = nn.Sequential() + --model:add(skip(backend, ch, 64 - ch)) + model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(resblock(backend)) + model:add(resblock(backend)) + model:add(resblock(backend)) + model:add(resblock(backend)) + model:add(resblock(backend)) + model:add(resblock(backend)) + model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2)) + model:add(ReLU(backend)) + model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0)) + + model:add(w2nn.InplaceClip01()) + --model:add(nn.View(-1):setNumInputDims(3)) + model.w2nn_arch_name = "srresnet_2x" + model.w2nn_offset = 28 + model.w2nn_scale_factor = 2 + model.w2nn_resize = true + model.w2nn_channels = ch + + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) + + return model +end + +-- large version of srresnet_2x. It's current best model but slow. +function srcnn.resnet_14l(backend, ch) + local function resblock(backend, i, o) + local seq = nn.Sequential() + local con = nn.ConcatTable() + local conv = nn.Sequential() + conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0)) + conv:add(nn.LeakyReLU(0.1, true)) + conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0)) + conv:add(nn.LeakyReLU(0.1, true)) + con:add(conv) + if i == o then + con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding + else + local seq = nn.Sequential() + seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0)) + seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) + con:add(seq) + end + seq:add(con) + seq:add(nn.CAddTable()) + return seq + end + local model = nn.Sequential() + model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(resblock(backend, 32, 64)) + model:add(resblock(backend, 64, 64)) + model:add(resblock(backend, 64, 128)) + model:add(resblock(backend, 128, 128)) + model:add(resblock(backend, 128, 256)) + model:add(resblock(backend, 256, 256)) + model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias()) + model:add(w2nn.InplaceClip01()) + model:add(nn.View(-1):setNumInputDims(3)) + model.w2nn_arch_name = "resnet_14l" + model.w2nn_offset = 28 + model.w2nn_scale_factor = 2 + model.w2nn_resize = true + model.w2nn_channels = ch + + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) + + return model +end + +-- for segmentation +function srcnn.fcn_v1(backend, ch) + -- input_size = 120 + local model = nn.Sequential() + --i = 120 + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size()) + + model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialMaxPooling(backend, 2, 2, 2, 2)) + + model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialMaxPooling(backend, 2, 2, 2, 2)) + + model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialMaxPooling(backend, 2, 2, 2, 2)) + + model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(nn.Dropout(0.5, false, true)) + + model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3)) + + model:add(w2nn.InplaceClip01()) + model:add(nn.View(-1):setNumInputDims(3)) + model.w2nn_arch_name = "fcn_v1" + model.w2nn_offset = 36 + model.w2nn_scale_factor = 1 + model.w2nn_channels = ch + model.w2nn_input_size = 120 + --model.w2nn_gcn = true + + return model +end function srcnn.create(model_name, backend, color) model_name = model_name or "vgg_7" backend = backend or "cunn" @@ -282,8 +548,10 @@ function srcnn.create(model_name, backend, color) error("unsupported model_name: " .. model_name) end end - ---local model = srcnn.upconv_6("cunn", 3):cuda() ---print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size()) +--[[ +local model = srcnn.fcn_v1("cunn", 3):cuda() +print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size()) +print(model) +--]] return srcnn diff --git a/lib/w2nn.lua b/lib/w2nn.lua index c59513d..a7f93e4 100644 --- a/lib/w2nn.lua +++ b/lib/w2nn.lua @@ -30,5 +30,8 @@ else require 'LeakyReLU' require 'ClippedWeightedHuberCriterion' require 'ClippedMSECriterion' + require 'SSIMCriterion' + require 'InplaceClip01' + require 'L1Criterion' return w2nn end diff --git a/models/resnet_14l/README.md b/models/resnet_14l/README.md new file mode 100644 index 0000000..ed5c9a2 --- /dev/null +++ b/models/resnet_14l/README.md @@ -0,0 +1 @@ +Currently, this models are for the benchmark. diff --git a/models/resnet_14l/photo/scale2.0x_model.t7 b/models/resnet_14l/photo/scale2.0x_model.t7 new file mode 100644 index 0000000..4f69d53 Binary files /dev/null and b/models/resnet_14l/photo/scale2.0x_model.t7 differ diff --git a/models/upconv_7l/README.md b/models/upconv_7l/README.md new file mode 100644 index 0000000..ed5c9a2 --- /dev/null +++ b/models/upconv_7l/README.md @@ -0,0 +1 @@ +Currently, this models are for the benchmark. diff --git a/models/upconv_7l/art/scale2.0x_model.json b/models/upconv_7l/art/scale2.0x_model.json new file mode 100644 index 0000000..b82ff0f Binary files /dev/null and b/models/upconv_7l/art/scale2.0x_model.json differ diff --git a/models/upconv_7l/art/scale2.0x_model.t7 b/models/upconv_7l/art/scale2.0x_model.t7 new file mode 100644 index 0000000..ddb3eb5 Binary files /dev/null and b/models/upconv_7l/art/scale2.0x_model.t7 differ diff --git a/models/upconv_7l/photo/scale2.0x_model.json b/models/upconv_7l/photo/scale2.0x_model.json new file mode 100644 index 0000000..bbc536a Binary files /dev/null and b/models/upconv_7l/photo/scale2.0x_model.json differ diff --git a/models/upconv_7l/photo/scale2.0x_model.t7 b/models/upconv_7l/photo/scale2.0x_model.t7 new file mode 100644 index 0000000..4ac54a3 Binary files /dev/null and b/models/upconv_7l/photo/scale2.0x_model.t7 differ diff --git a/tools/benchmark.lua b/tools/benchmark.lua index a9aa70f..7e93ac6 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -18,7 +18,7 @@ cmd:option("-dir", "./data/test", 'test image directory') cmd:option("-file", "", 'test image file list') cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory') cmd:option("-model2_dir", "", 'model2 directory (optional)') -cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff)') +cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff|scale4)') cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))") cmd:option("-resize_blur", 1.0, 'blur parameter for resize') cmd:option("-color", "y", '(rgb|y|r|g|b)') @@ -46,6 +46,7 @@ cmd:option("-x_dir", "", 'input image for user method') cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be the same as x_dir') cmd:option("-x_file", "", 'input image for user method') cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file') +cmd:option("-border", 0, 'border px that will removed') local function to_bool(settings, name) if settings[name] == 1 then @@ -153,12 +154,24 @@ local function baseline_scale(x, filter) x:size(2) * 2.0, filter) end +local function baseline_scale4(x, filter) + return iproc.scale(x, + x:size(3) * 4.0, + x:size(2) * 4.0, + filter) +end local function transform_scale(x, opt) return iproc.scale(x, x:size(3) * 0.5, x:size(2) * 0.5, opt.filter, opt.resize_blur) end +local function transform_scale4(x, opt) + return iproc.scale(x, + x:size(3) * 0.25, + x:size(2) * 0.25, + opt.filter, opt.resize_blur) +end local function transform_scale_jpeg(x, opt) x = iproc.scale(x, @@ -179,9 +192,15 @@ local function transform_scale_jpeg(x, opt) end return iproc.byte2float(x) end - +local function remove_border(x, border) + return iproc.crop(x, + border, border, + x:size(3) - border, + x:size(2) - border) +end local function benchmark(opt, x, model1, model2) - local mse + local mse1, mse2 + local won = {0, 0} local model1_mse = 0 local model2_mse = 0 local baseline_mse = 0 @@ -192,6 +211,10 @@ local function benchmark(opt, x, model1, model2) local model2_time = 0 local scale_f = reconstruct.scale local image_f = reconstruct.image + local detail_fp = nil + if opt.save_info then + detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w") + end if opt.tta then scale_f = function(model, scale, x, block_size, batch_size) return reconstruct.scale_tta(model, opt.tta_level, @@ -204,12 +227,15 @@ local function benchmark(opt, x, model1, model2) end for i = 1, #x do + if i % 10 == 0 then + collectgarbage() + end local basename = x[i].basename local input, model1_output, model2_output, baseline_output, ground_truth if opt.method == "scale" then - input = transform_scale(x[i].y, opt) - ground_truth = x[i].y + input = transform_scale(iproc.byte2float(x[i].y), opt) + ground_truth = iproc.byte2float(x[i].y) if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size) @@ -226,9 +252,29 @@ local function benchmark(opt, x, model1, model2) model2_time = model2_time + (sys.clock() - t) end baseline_output = baseline_scale(input, opt.baseline_filter) + elseif opt.method == "scale4" then + input = transform_scale4(iproc.byte2float(x[i].y), opt) + ground_truth = iproc.byte2float(x[i].y) + if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first + model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size) + if model2 then + model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size) + end + end + t = sys.clock() + model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size) + model1_output = scale_f(model1, 2.0, model1_output, opt.crop_size, opt.batch_size) + model1_time = model1_time + (sys.clock() - t) + if model2 then + t = sys.clock() + model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size) + model2_output = scale_f(model2, 2.0, model2_output, opt.crop_size, opt.batch_size) + model2_time = model2_time + (sys.clock() - t) + end + baseline_output = baseline_scale4(input, opt.baseline_filter) elseif opt.method == "noise" then - input = transform_jpeg(x[i].y, opt) - ground_truth = x[i].y + input = transform_jpeg(iproc.byte2float(x[i].y), opt) + ground_truth = iproc.byte2float(x[i].y) if opt.force_cudnn and i == 1 then model1_output = image_f(model1, input, opt.crop_size, opt.batch_size) @@ -246,8 +292,8 @@ local function benchmark(opt, x, model1, model2) end baseline_output = input elseif opt.method == "noise_scale" then - input = transform_scale_jpeg(x[i].y, opt) - ground_truth = x[i].y + input = transform_scale_jpeg(iproc.byte2float(x[i].y), opt) + ground_truth = iproc.byte2float(x[i].y) if opt.force_cudnn and i == 1 then if model1.noise_scale_model then @@ -312,8 +358,8 @@ local function benchmark(opt, x, model1, model2) end baseline_output = baseline_scale(input, opt.baseline_filter) elseif opt.method == "user" then - input = x[i].x - ground_truth = x[i].y + input = iproc.byte2float(x[i].x) + ground_truth = iproc.byte2float(x[i].y) local y_scale = ground_truth:size(2) / input:size(2) if y_scale > 1 then if opt.force_cudnn and i == 1 then @@ -347,19 +393,44 @@ local function benchmark(opt, x, model1, model2) end end elseif opt.method == "diff" then - input = x[i].x - ground_truth = x[i].y + input = iproc.byte2float(x[i].x) + ground_truth = iproc.byte2float(x[i].y) model1_output = input end - mse = MSE(ground_truth, model1_output, opt.color) - model1_mse = model1_mse + mse - model1_psnr = model1_psnr + MSE2PSNR(mse) + if opt.border > 0 then + ground_truth = remove_border(ground_truth, opt.border) + model1_output = remove_border(model1_output, opt.border) + end + mse1 = MSE(ground_truth, model1_output, opt.color) + model1_mse = model1_mse + mse1 + model1_psnr = model1_psnr + MSE2PSNR(mse1) + + local won_model = 1 if model2 then - mse = MSE(ground_truth, model2_output, opt.color) - model2_mse = model2_mse + mse - model2_psnr = model2_psnr + MSE2PSNR(mse) + if opt.border > 0 then + model2_output = remove_border(model2_output, opt.border) + end + mse2 = MSE(ground_truth, model2_output, opt.color) + model2_mse = model2_mse + mse2 + model2_psnr = model2_psnr + MSE2PSNR(mse2) + + if mse1 < mse2 then + won[1] = won[1] + 1 + elseif mse1 > mse2 then + won[2] = won[2] + 1 + won_model = 2 + end + if detail_fp then + detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename, + MSE2PSNR(mse1), MSE2PSNR(mse2), won_model)) + end + else + if detail_fp then + detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1))) + end end if baseline_output then + baseline_output = remove_border(baseline_output, opt.border) mse = MSE(ground_truth, baseline_output, opt.color) baseline_mse = baseline_mse + mse baseline_psnr = baseline_psnr + MSE2PSNR(mse) @@ -382,29 +453,31 @@ local function benchmark(opt, x, model1, model2) if model2 then if baseline_output then io.stdout:write( - string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r", + string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r", i, #x, model1_time, model2_time, math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), baseline_psnr / i, - model1_psnr / i, model2_psnr / i + model1_psnr / i, model2_psnr / i, + won[1], won[2] )) else io.stdout:write( - string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r", + string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r", i, #x, model1_time, model2_time, math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), - model1_psnr / i, model2_psnr / i + model1_psnr / i, model2_psnr / i, + won[1], won[2] )) end else if baseline_output then io.stdout:write( - string.format("%d/%d; model1_time=%.2f, baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r", + string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r", i, #x, model1_time, math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i), @@ -412,7 +485,7 @@ local function benchmark(opt, x, model1, model2) )) else io.stdout:write( - string.format("%d/%d; model1_time=%.2f, model1_rmse=%f, model1_psnr=%f \r", + string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r", i, #x, model1_time, math.sqrt(model1_mse / i), model1_psnr / i @@ -438,6 +511,9 @@ local function benchmark(opt, x, model1, model2) math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time)) end fp:close() + if detail_fp then + detail_fp:close() + end end io.stdout:write("\n") end @@ -448,7 +524,7 @@ local function load_data_from_dir(test_dir) local name = path.basename(files[i]) local e = path.extension(name) local base = name:sub(0, name:len() - e:len()) - local img = image_loader.load_float(files[i]) + local img = image_loader.load_byte(files[i]) if img then table.insert(test_x, {y = iproc.crop_mod4(img), basename = base}) @@ -456,6 +532,9 @@ local function load_data_from_dir(test_dir) if opt.show_progress then xlua.progress(i, #files) end + if i % 10 == 0 then + collectgarbage() + end end return test_x end @@ -466,7 +545,7 @@ local function load_data_from_file(test_file) local name = path.basename(files[i]) local e = path.extension(name) local base = name:sub(0, name:len() - e:len()) - local img = image_loader.load_float(files[i]) + local img = image_loader.load_byte(files[i]) if img then table.insert(test_x, {y = iproc.crop_mod4(img), basename = base}) @@ -474,6 +553,9 @@ local function load_data_from_file(test_file) if opt.show_progress then xlua.progress(i, #files) end + if i % 10 == 0 then + collectgarbage() + end end return test_x end @@ -519,16 +601,19 @@ local function load_user_data(y_dir, y_file, x_dir, x_file) end for i = 1, #y_files do local key = get_basename(y_files[i]) - local x = image_loader.load_float(basename_db[key].x) - local y = image_loader.load_float(basename_db[key].y) + local x = image_loader.load_byte(basename_db[key].x) + local y = image_loader.load_byte(basename_db[key].y) if x and y then table.insert(test, {y = y, x = x, - basename = base}) + basename = key}) end if opt.show_progress then xlua.progress(i, #y_files) end + if i % 10 == 0 then + collectgarbage() + end end return test end @@ -563,7 +648,7 @@ if opt.show_progress then print(opt) end -if opt.method == "scale" then +if opt.method == "scale" or opt.method == "scale4" then local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7") local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7") local s1, model1 = pcall(w2nn.load_model, f1, opt.force_cudnn) diff --git a/tools/export_all.sh b/tools/export_all.sh index d3e44b3..198f747 100755 --- a/tools/export_all.sh +++ b/tools/export_all.sh @@ -33,5 +33,7 @@ export_model() { } export_model vgg_7/art export_model upconv_7/art +export_model upconv_7l/art export_model vgg_7/photo export_model upconv_7/photo +export_model upconv_7l/photo diff --git a/tools/export_model.lua b/tools/export_model.lua index f3c90f9..e8d5819 100644 --- a/tools/export_model.lua +++ b/tools/export_model.lua @@ -22,7 +22,6 @@ local function includes(s, a) end return false end - local function get_bias(mod) if mod.bias then return mod.bias:float() @@ -31,20 +30,18 @@ local function get_bias(mod) return torch.FloatTensor(mod.nOutputPlane):zero() end end -local function export(model, output) +local function export_weight(jmodules, seq) local targets = {"nn.SpatialConvolutionMM", "cudnn.SpatialConvolution", "nn.SpatialFullConvolution", "cudnn.SpatialFullConvolution" } - local jmodules = {} - local model_config = meta_data(model) - local first_layer = true - - for k = 1, #model.modules do - local mod = model.modules[k] + for k = 1, #seq.modules do + local mod = seq.modules[k] local name = torch.typename(mod) - if includes(name, targets) then + if name == "nn.Sequential" or name == "nn.ConcatTable" then + export_weight(jmodules, mod) + elseif includes(name, targets) then local weight = mod.weight:float() if name:match("FullConvolution") then weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW)) @@ -71,6 +68,14 @@ local function export(model, output) table.insert(jmodules, jmod) end end +end +local function export(model, output) + local jmodules = {} + local model_config = meta_data(model) + local first_layer = true + + export_weight(jmodules, model) + local fp = io.open(output, "w") if not fp then error("IO Error: " .. output) diff --git a/train.lua b/train.lua index 2e04a50..77536f5 100644 --- a/train.lua +++ b/train.lua @@ -3,15 +3,14 @@ local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^ package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path require 'optim' require 'xlua' - +require 'image' require 'w2nn' +local threads = require 'threads' local settings = require 'settings' local srcnn = require 'srcnn' local minibatch_adam = require 'minibatch_adam' local iproc = require 'iproc' local reconstruct = require 'reconstruct' -local compression = require 'compression' -local pairwise_transform = require 'pairwise_transform' local image_loader = require 'image_loader' local function save_test_scale(model, rgb, file) @@ -42,20 +41,218 @@ local function split_data(x, test_size) end return train_x, valid_x end -local function make_validation_set(x, transformer, n, patches) + +local g_transform_pool = nil +local g_mutex = nil +local g_mutex_id = nil +local function transform_pool_init(has_resize, offset) + local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end + g_mutex = threads.Mutex() + g_mutex_id = g_mutex:id() + g_transform_pool = threads.Threads( + nthread, + threads.safe( + function(threadid) + require 'pl' + local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() + package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path + require 'torch' + require 'nn' + require 'cunn' + + torch.setnumthreads(1) + torch.setdefaulttensortype("torch.FloatTensor") + + local threads = require 'threads' + local compression = require 'compression' + local pairwise_transform = require 'pairwise_transform' + + function transformer(x, is_validation, n) + local mutex = threads.Mutex(g_mutex_id) + local meta = {data = {}} + local y = nil + if type(x) == "table" and type(x[2]) == "table" then + meta = x[2] + if x[1].x and x[1].y then + y = compression.decompress(x[1].y) + x = compression.decompress(x[1].x) + else + x = compression.decompress(x[1]) + end + else + x = compression.decompress(x) + end + n = n or settings.patches + if is_validation == nil then is_validation = false end + local random_color_noise_rate = nil + local random_overlay_rate = nil + local active_cropping_rate = nil + local active_cropping_tries = nil + if is_validation then + active_cropping_rate = settings.active_cropping_rate + active_cropping_tries = settings.active_cropping_tries + random_color_noise_rate = 0.0 + random_overlay_rate = 0.0 + else + active_cropping_rate = settings.active_cropping_rate + active_cropping_tries = settings.active_cropping_tries + random_color_noise_rate = settings.random_color_noise_rate + random_overlay_rate = settings.random_overlay_rate + end + if settings.method == "scale" then + local conf = tablex.update({ + mutex = mutex, + downsampling_filters = settings.downsampling_filters, + random_half_rate = settings.random_half_rate, + random_color_noise_rate = random_color_noise_rate, + random_overlay_rate = random_overlay_rate, + random_unsharp_mask_rate = settings.random_unsharp_mask_rate, + random_blur_rate = settings.random_blur_rate, + random_blur_size = settings.random_blur_size, + random_blur_sigma_min = settings.random_blur_sigma_min, + random_blur_sigma_max = settings.random_blur_sigma_max, + max_size = settings.max_size, + active_cropping_rate = active_cropping_rate, + active_cropping_tries = active_cropping_tries, + rgb = (settings.color == "rgb"), + x_upsampling = not has_resize, + resize_blur_min = settings.resize_blur_min, + resize_blur_max = settings.resize_blur_max}, meta) + return pairwise_transform.scale(x, + settings.scale, + settings.crop_size, offset, + n, conf) + elseif settings.method == "noise" then + local conf = tablex.update({ + mutex = mutex, + random_half_rate = settings.random_half_rate, + random_color_noise_rate = random_color_noise_rate, + random_overlay_rate = random_overlay_rate, + random_unsharp_mask_rate = settings.random_unsharp_mask_rate, + random_blur_rate = settings.random_blur_rate, + random_blur_size = settings.random_blur_size, + random_blur_sigma_min = settings.random_blur_sigma_min, + random_blur_sigma_max = settings.random_blur_sigma_max, + max_size = settings.max_size, + jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate, + active_cropping_rate = active_cropping_rate, + active_cropping_tries = active_cropping_tries, + nr_rate = settings.nr_rate, + rgb = (settings.color == "rgb")}, meta) + return pairwise_transform.jpeg(x, + settings.style, + settings.noise_level, + settings.crop_size, offset, + n, conf) + elseif settings.method == "noise_scale" then + local conf = tablex.update({ + mutex = mutex, + downsampling_filters = settings.downsampling_filters, + random_half_rate = settings.random_half_rate, + random_color_noise_rate = random_color_noise_rate, + random_overlay_rate = random_overlay_rate, + random_unsharp_mask_rate = settings.random_unsharp_mask_rate, + random_blur_rate = settings.random_blur_rate, + random_blur_size = settings.random_blur_size, + random_blur_sigma_min = settings.random_blur_sigma_min, + random_blur_sigma_max = settings.random_blur_sigma_max, + max_size = settings.max_size, + jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate, + nr_rate = settings.nr_rate, + active_cropping_rate = active_cropping_rate, + active_cropping_tries = active_cropping_tries, + rgb = (settings.color == "rgb"), + x_upsampling = not has_resize, + resize_blur_min = settings.resize_blur_min, + resize_blur_max = settings.resize_blur_max}, meta) + return pairwise_transform.jpeg_scale(x, + settings.scale, + settings.style, + settings.noise_level, + settings.crop_size, offset, + n, conf) + elseif settings.method == "user" then + if is_validation == nil then is_validation = false end + local rotate_rate = nil + local scale_rate = nil + local negate_rate = nil + local negate_x_rate = nil + if is_validation then + rotate_rate = 0 + scale_rate = 0 + negate_rate = 0 + negate_x_rate = 0 + else + rotate_rate = settings.random_pairwise_rotate_rate + scale_rate = settings.random_pairwise_scale_rate + negate_rate = settings.random_pairwise_negate_rate + negate_x_rate = settings.random_pairwise_negate_x_rate + end + local conf = tablex.update({ + gcn = settings.gcn, + max_size = settings.max_size, + active_cropping_rate = active_cropping_rate, + active_cropping_tries = active_cropping_tries, + random_pairwise_rotate_rate = rotate_rate, + random_pairwise_rotate_min = settings.random_pairwise_rotate_min, + random_pairwise_rotate_max = settings.random_pairwise_rotate_max, + random_pairwise_scale_rate = scale_rate, + random_pairwise_scale_min = settings.random_pairwise_scale_min, + random_pairwise_scale_max = settings.random_pairwise_scale_max, + random_pairwise_negate_rate = negate_rate, + random_pairwise_negate_x_rate = negate_x_rate, + pairwise_y_binary = settings.pairwise_y_binary, + pairwise_flip = settings.pairwise_flip, + rgb = (settings.color == "rgb")}, meta) + return pairwise_transform.user(x, y, + settings.crop_size, offset, + n, conf) + end + end + end) + ) + g_transform_pool:synchronize() +end + +local function make_validation_set(x, n, patches) + local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end n = n or 4 local validation_patches = math.min(16, patches or 16) local data = {} + + g_transform_pool:synchronize() + torch.setnumthreads(1) -- 1 + for i = 1, #x do for k = 1, math.max(n / validation_patches, 1) do - local xy = transformer(x[i], true, validation_patches) - for j = 1, #xy do - table.insert(data, {x = xy[j][1], y = xy[j][2]}) - end + local input = x[i] + g_transform_pool:addjob( + function() + local xy = transformer(input, true, validation_patches) + return xy + end, + function(xy) + for j = 1, #xy do + table.insert(data, {x = xy[j][1], y = xy[j][2]}) + end + end + ) + end + if i % 20 == 0 then + collectgarbage() + g_transform_pool:synchronize() + xlua.progress(i, #x) end - xlua.progress(i, #x) - collectgarbage() end + g_transform_pool:synchronize() + torch.setnumthreads(nthread) -- revert + local new_data = {} local perm = torch.randperm(#data) for i = 1, perm:size(1) do @@ -102,144 +299,71 @@ local function validate(model, criterion, eval_metric, data, batch_size) end local function create_criterion(model) - if reconstruct.is_rgb(model) then - local offset = reconstruct.offset_size(model) - local output_w = settings.crop_size - offset * 2 - local weight = torch.Tensor(3, output_w * output_w) - weight[1]:fill(0.29891 * 3) -- R - weight[2]:fill(0.58661 * 3) -- G - weight[3]:fill(0.11448 * 3) -- B - return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() - else - local offset = reconstruct.offset_size(model) - local output_w = settings.crop_size - offset * 2 - local weight = torch.Tensor(1, output_w * output_w) - weight[1]:fill(1.0) - return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() - end -end -local function transformer(model, x, is_validation, n, offset) - local meta = {data = {}} - local y = nil - if type(x) == "table" and type(x[2]) == "table" then - meta = x[2] - if x[1].x and x[1].y then - y = compression.decompress(x[1].y) - x = compression.decompress(x[1].x) + if settings.loss == "huber" then + if reconstruct.is_rgb(model) then + local offset = reconstruct.offset_size(model) + local output_w = settings.crop_size - offset * 2 + local weight = torch.Tensor(3, output_w * output_w) + weight[1]:fill(0.29891 * 3) -- R + weight[2]:fill(0.58661 * 3) -- G + weight[3]:fill(0.11448 * 3) -- B + return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() else - x = compression.decompress(x[1]) + local offset = reconstruct.offset_size(model) + local output_w = settings.crop_size - offset * 2 + local weight = torch.Tensor(1, output_w * output_w) + weight[1]:fill(1.0) + return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() end + elseif settings.loss == "l1" then + return w2nn.L1Criterion():cuda() + elseif settings.loss == "mse" then + return w2nn.ClippedMSECriterion(0, 1.0):cuda() else - x = compression.decompress(x) - end - n = n or settings.patches - if is_validation == nil then is_validation = false end - local random_color_noise_rate = nil - local random_overlay_rate = nil - local active_cropping_rate = nil - local active_cropping_tries = nil - if is_validation then - active_cropping_rate = settings.active_cropping_rate - active_cropping_tries = settings.active_cropping_tries - random_color_noise_rate = 0.0 - random_overlay_rate = 0.0 - else - active_cropping_rate = settings.active_cropping_rate - active_cropping_tries = settings.active_cropping_tries - random_color_noise_rate = settings.random_color_noise_rate - random_overlay_rate = settings.random_overlay_rate - end - if settings.method == "scale" then - local conf = tablex.update({ - downsampling_filters = settings.downsampling_filters, - random_half_rate = settings.random_half_rate, - random_color_noise_rate = random_color_noise_rate, - random_overlay_rate = random_overlay_rate, - random_unsharp_mask_rate = settings.random_unsharp_mask_rate, - max_size = settings.max_size, - active_cropping_rate = active_cropping_rate, - active_cropping_tries = active_cropping_tries, - rgb = (settings.color == "rgb"), - x_upsampling = not reconstruct.has_resize(model), - resize_blur_min = settings.resize_blur_min, - resize_blur_max = settings.resize_blur_max}, meta) - return pairwise_transform.scale(x, - settings.scale, - settings.crop_size, offset, - n, conf) - elseif settings.method == "noise" then - local conf = tablex.update({ - random_half_rate = settings.random_half_rate, - random_color_noise_rate = random_color_noise_rate, - random_overlay_rate = random_overlay_rate, - random_unsharp_mask_rate = settings.random_unsharp_mask_rate, - max_size = settings.max_size, - jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate, - active_cropping_rate = active_cropping_rate, - active_cropping_tries = active_cropping_tries, - nr_rate = settings.nr_rate, - rgb = (settings.color == "rgb")}, meta) - return pairwise_transform.jpeg(x, - settings.style, - settings.noise_level, - settings.crop_size, offset, - n, conf) - elseif settings.method == "noise_scale" then - local conf = tablex.update({ - downsampling_filters = settings.downsampling_filters, - random_half_rate = settings.random_half_rate, - random_color_noise_rate = random_color_noise_rate, - random_overlay_rate = random_overlay_rate, - random_unsharp_mask_rate = settings.random_unsharp_mask_rate, - max_size = settings.max_size, - jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate, - nr_rate = settings.nr_rate, - active_cropping_rate = active_cropping_rate, - active_cropping_tries = active_cropping_tries, - rgb = (settings.color == "rgb"), - x_upsampling = not reconstruct.has_resize(model), - resize_blur_min = settings.resize_blur_min, - resize_blur_max = settings.resize_blur_max}, meta) - return pairwise_transform.jpeg_scale(x, - settings.scale, - settings.style, - settings.noise_level, - settings.crop_size, offset, - n, conf) - elseif settings.method == "user" then - local conf = tablex.update({ - max_size = settings.max_size, - active_cropping_rate = active_cropping_rate, - active_cropping_tries = active_cropping_tries, - rgb = (settings.color == "rgb")}, meta) - return pairwise_transform.user(x, y, - settings.crop_size, offset, - n, conf) + error("unsupported loss .." .. settings.loss) end end -local function resampling(x, y, train_x, transformer, input_size, target_size) +local function resampling(x, y, train_x) local c = 1 local shuffle = torch.randperm(#train_x) + local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end + torch.setnumthreads(1) -- 1 + for t = 1, #train_x do - xlua.progress(t, #train_x) - local xy = transformer(train_x[shuffle[t]], false, settings.patches) - for i = 1, #xy do - x[c]:copy(xy[i][1]) - y[c]:copy(xy[i][2]) - c = c + 1 - if c > x:size(1) then - break + local input = train_x[shuffle[t]] + g_transform_pool:addjob( + function() + local xy = transformer(input, false, settings.patches) + return xy + end, + function(xy) + for i = 1, #xy do + if c <= x:size(1) then + x[c]:copy(xy[i][1]) + y[c]:copy(xy[i][2]) + c = c + 1 + else + break + end + end end + ) + if t % 50 == 0 then + collectgarbage() + g_transform_pool:synchronize() + xlua.progress(t, #train_x) end if c > x:size(1) then break end - if t % 50 == 0 then - collectgarbage() - end end + g_transform_pool:synchronize() xlua.progress(#train_x, #train_x) + torch.setnumthreads(nthread) -- revert end local function get_oracle_data(x, y, instance_loss, k, samples) local index = torch.LongTensor(instance_loss:size(1)) @@ -262,6 +386,7 @@ local function get_oracle_data(x, y, instance_loss, k, samples) end local function remove_small_image(x) + local compression = require 'compression' local new_x = {} for i = 1, #x do local xe, meta, x_s @@ -293,6 +418,8 @@ local function plot(train, valid) {'validation', torch.Tensor(valid), '-'}}) end local function train() + local x = remove_small_image(torch.load(settings.images)) + local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local hist_train = {} local hist_valid = {} local model @@ -301,20 +428,30 @@ local function train() else model = srcnn.create(settings.model, settings.backend, settings.color) end + if model.w2nn_input_size then + if settings.crop_size ~= model.w2nn_input_size then + io.stderr:write(string.format("warning: crop_size is replaced with %d\n", + model.w2nn_input_size)) + settings.crop_size = model.w2nn_input_size + end + end + if model.w2nn_gcn then + settings.gcn = true + else + settings.gcn = false + end dir.makepath(settings.model_dir) local offset = reconstruct.offset_size(model) - local pairwise_func = function(x, is_validation, n) - return transformer(model, x, is_validation, n, offset) - end + transform_pool_init(reconstruct.has_resize(model), offset) + local criterion = create_criterion(model) local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda() - local x = remove_small_image(torch.load(settings.images)) - local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local adam_config = { xLearningRate = settings.learning_rate, xBatchSize = settings.batch_size, - xLearningRateDecay = settings.learning_rate_decay + xLearningRateDecay = settings.learning_rate_decay, + xInstanceLoss = (settings.oracle_rate > 0) } local ch = nil if settings.color == "y" then @@ -324,7 +461,7 @@ local function train() end local best_score = 1000.0 print("# make validation-set") - local valid_xy = make_validation_set(valid_x, pairwise_func, + local valid_xy = make_validation_set(valid_x, settings.validation_crops, settings.patches) valid_x = nil @@ -358,7 +495,7 @@ local function train() if oracle_n > 0 then local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n) resampling(x:narrow(1, oracle_x:size(1) + 1, x:size(1)-oracle_x:size(1)), - y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x, pairwise_func) + y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x) x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x) y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y) @@ -374,7 +511,7 @@ local function train() min = 0, max = 1})) else - resampling(x, y, train_x, pairwise_func) + resampling(x, y, train_x) end else resampling(x, y, train_x, pairwise_func) @@ -395,9 +532,9 @@ local function train() if settings.plot then plot(hist_train, hist_valid) end - if score.MSE < best_score then + if score.loss < best_score then local test_image = image_loader.load_float(settings.test) -- reload - best_score = score.MSE + best_score = score.loss print("* model has updated") if settings.save_history then torch.save(settings.model_file_best, model:clearState(), "ascii") @@ -446,7 +583,7 @@ local function train() end end end - print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score) + print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE) collectgarbage() end end diff --git a/waifu2x.lua b/waifu2x.lua index 6ffa98d..8e48937 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -267,6 +267,7 @@ local function waifu2x() cmd:option("-tta_level", 8, 'TTA level (2|4|8). A higher value makes better quality output but slow') cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)') cmd:option("-q", 0, 'quiet (0|1)') + cmd:option("-gpu", 1, 'Device ID') local opt = cmd:parse(arg) if opt.method:len() > 0 then @@ -292,5 +293,6 @@ local function waifu2x() else convert_frames(opt) end + cutorch.setDevice(opt.gpu) end waifu2x()