1
0
Fork 0
mirror of synced 2024-05-16 10:52:20 +12:00

Merge branch 'dev'

This commit is contained in:
nagadomi 2017-02-11 01:17:49 +09:00
commit d779a9d47a
36 changed files with 1891 additions and 336 deletions

2
.gitignore vendored
View file

@ -11,6 +11,8 @@ models/*
!models/ukbench
!models/photo
!models/upconv_7
!models/upconv_7l
!models/srresnet_12l
!models/vgg_7
models/*/*.png
models/*/*/*.png

View file

@ -89,6 +89,7 @@ See: [Getting started with Torch](http://torch.ch/docs/getting-started.html)
And install luarocks packages.
```
luarocks install graphicsmagick # upgrade
luarocks install threads # upgrade
luarocks install lua-csnappy
luarocks install md5
luarocks install uuid

View file

@ -1,45 +1,79 @@
# Benchmark results
# Benchmarks
Warning: This benchmark results is outdated. I will update soon.
## Photo
## Usage
Note: waifu2x's photo models was trained on the blending dataset of [kou's photo collection](http://photosku.com/photo/category/%E6%92%AE%E5%BD%B1%E8%80%85/kou/) and [ukbench](http://vis.uky.edu/~stewe/ukbench/).
```
th tools/benchmark.lua -dir path/to/dataset_dir -method scale -color y -model1_dir path/to/model_dir
```
Note: PSNR in this benchmark uses a [MATLAB's rgb2ycbcr](https://jp.mathworks.com/help/images/ref/rgb2ycbcr.html?lang=en) compatible function (dynamic range [16 235], not [0 255]) for converting grayscale image. I think it's not correct PSNR. But many paper used this metric.
## Dataset
command:
`th tools/benchmark.lua -dir <dataset_dir> -model1_dir <model_dir> -method scale -filter Catrom -color y -range_bug 1 -tta <0|1> -force_cudnn 1`
photo_test: 300 various photos.
art_test : 90 artworks (PNG only).
### Datasets
## 2x upscaling model
BSD100: https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/ (100 test images in BSDS300)
Urban100: https://github.com/jbhuang0604/SelfExSR
| Dataset/Model | anime\_style\_art(Y) | anime\_style\_art\_rgb | photo | ukbench|
|---------------|----------------------|------------------------|---------|--------|
| photo\_test | 29.83 | 29.81 |**29.89**| 29.86 |
| art\_test | 36.02 | **36.24**| 34.92 | 34.85 |
### 2x - PSNR
The evaluation metric is PSNR(Y only), higher is better.
| Dataset/Model | Bicubic | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo |
|---------------|---------------|---------------|------------------|------------------|--------------------|
| BSD100 | 29.558 | 31.427 | 31.640 | 31.749 | 31.847 |
| Urban100 | 26.852 | 30.057 | 30.477 | 30.759 | 31.016 |
## Denosing level 1 model
### 2x with TTA - PSNR
| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo |
|--------------------------|-------------------|------------------------|---------|
| photo\_test Quality 80 | 36.07 | **36.20**| 36.01 |
| photo\_test Quality 50,45| 31.72 | 32.01 |**32.31**|
| art\_test Quality 80 | 40.39 | **42.48**| 40.35 |
| art\_test Quality 50,45 | 35.45 | **36.70**| 36.27 |
Note: TTA is an ensemble technique that is supported by waifu2x. TTA method is 8x slower than non TTA method but it improves PSNR (~+0.1 on photo, ~+0.4 on art).
The evaluation metric is PSNR(RGB), higher is better.
| Dataset/Model | Bicubic | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo |
|---------------|---------------|---------------|------------------|------------------|--------------------|
| BSD100 | 29.558 | 31.474 | 31.705 | 31.812 | 31.915 |
| Urban100 | 26.852 | 30.140 | 30.599 | 30.868 | 31.162 |
## Denosing level 2 model
### 2x - benchmark elapsed time (sec)
| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo |
|--------------------------|-------------------|------------------------|---------|
| photo\_test Quality 80 | 34.03 | 34.42 |**36.06**|
| photo\_test Quality 50,45| 31.95 | 32.31 |**32.42**|
| art\_test Quality 80 | 39.20 | **41.12**| 40.48 |
| art\_test Quality 50,45 | 36.14 | **37.78**| 36.55 |
| Dataset/Model | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo |
|---------------|---------------|------------------|------------------|--------------------|
| BSD100 | 4.057 | 2.509 | 4.947 | 6.86 |
| Urban100 | 16.349 | 7.083 | 14.178 | 27.87 |
### 2x with TTA - benchmark elapsed time (sec)
| Dataset/Model | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo |
|---------------|---------------|------------------|------------------|--------------------|
| BSD100 | 36.611 | 20.219 | 42.486 | 60.38 |
| Urban100 | 132.416 | 65.125 | 129.916 | 255.20 |
## Art
command:
`th tools/benchmark.lua -dir <dataset_dir> -model1_dir <model_dir> -method scale -filter Lanczos -color y -range_bug 1 -tta <0|1> -force_cudnn 1`
### Dataset
art_test: This dataset contains 85 various fan-arts. Sorry, This dataset is private.
### 2x - PSNR
| Dataset/Model | Bicubic | vgg\_7/art | upconv\_7/art | upconv\_7l/art |
|---------------|---------------|-------------|----------------|----------------|
| art_test | 31.022 | 37.495 | 38.330 | 39.140 |
### 2x with TTA - PSNR
| Dataset/Model | Bicubic | vgg\_7/art | upconv\_7/art | upconv\_7l/art |
|---------------|---------------|-------------|----------------|----------------|
| art_test | 31.022 | 37.777 | 38.677 | 39.510 |
### 2x - benchmark elapsed time (sec)
| Dataset/Model | vgg\_7/art | upconv\_7/art | upconv\_7l/art |
|---------------|-------------|----------------|----------------|
| art_test | 20.681 | 7.683 | 17.667 |
### 2x with TTA - benchmark elapsed time (sec)
| Dataset/Model | vgg\_7/art | upconv\_7/art | upconv\_7l/art |
|---------------|-------------|----------------|----------------|
| art_test | 174.674 | 77.716 | 163.932 |
The evaluation metric is PSNR(RGB), higher is better.

34
appendix/benchmark.sh Executable file
View file

@ -0,0 +1,34 @@
#!/bin/sh
set -x
benchmark_photo() {
dir=./benchmarks/${1}/${2}/${3}
mkdir -p ${dir}
th tools/benchmark.lua -dir data/${1} -model1\_dir models/${2}/photo -method scale -filter Catrom -color y -range\_bug 1 -tta ${3} -force_cudnn 1 -output_dir ${dir} -save_info 1 -show_progress 0
}
run_benchmark_photo() {
for tta in 0 1
do
for dataset in bsd100 urban100
do
benchmark_photo ${dataset} vgg_7 ${tta}
benchmark_photo ${dataset} upconv_7 ${tta}
benchmark_photo ${dataset} upconv_7l ${tta}
done
done
}
benchmark_art() {
dir=./benchmarks/${1}/${2}/${3}
mkdir -p ${dir}
th tools/benchmark.lua -dir data/${1} -model1\_dir models/${2}/art -method scale -filter Lanczos -color y -range\_bug 1 -tta ${3} -force_cudnn 1 -output_dir ${dir} -save_info 1 -show_progress 0
}
run_benchmark_art() {
for tta in 0 1
do
benchmark_art art_test vgg_7 ${tta}
benchmark_art art_test upconv_7 ${tta}
benchmark_art art_test upconv_7l ${tta}
done
}
#run_benchmark_photo
run_benchmark_art

View file

@ -0,0 +1,524 @@
name: "resnet_14l"
layer {
name: "input"
type: "Input"
top: "input"
input_param { shape: { dim: 1 dim: 3 dim: 156 dim: 156 } }
}
layer {
name: "Convolution1"
type: "Convolution"
bottom: "input"
top: "Convolution1"
convolution_param {
num_output: 32
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU1"
type: "ReLU"
bottom: "Convolution1"
top: "Convolution1"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution2"
type: "Convolution"
bottom: "Convolution1"
top: "Convolution2"
convolution_param {
num_output: 64
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU2"
type: "ReLU"
bottom: "Convolution2"
top: "Convolution2"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution3"
type: "Convolution"
bottom: "Convolution2"
top: "Convolution3"
convolution_param {
num_output: 64
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU3"
type: "ReLU"
bottom: "Convolution3"
top: "Convolution3"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution4"
type: "Convolution"
bottom: "Convolution1"
top: "Convolution4"
convolution_param {
num_output: 64
bias_term: true
pad: 0
kernel_size: 1
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "Crop1"
type: "Crop"
bottom: "Convolution4"
bottom: "Convolution3"
top: "Crop1"
crop_param {
axis: 2
offset: 2
offset: 2
}
}
layer {
name: "Eltwise1"
type: "Eltwise"
bottom: "Convolution3"
bottom: "Crop1"
top: "Eltwise1"
eltwise_param {
operation: SUM
}
}
layer {
name: "Convolution5"
type: "Convolution"
bottom: "Eltwise1"
top: "Convolution5"
convolution_param {
num_output: 64
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU4"
type: "ReLU"
bottom: "Convolution5"
top: "Convolution5"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution6"
type: "Convolution"
bottom: "Convolution5"
top: "Convolution6"
convolution_param {
num_output: 64
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU5"
type: "ReLU"
bottom: "Convolution6"
top: "Convolution6"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Crop2"
type: "Crop"
bottom: "Eltwise1"
bottom: "Convolution6"
top: "Crop2"
crop_param {
axis: 2
offset: 2
offset: 2
}
}
layer {
name: "Eltwise2"
type: "Eltwise"
bottom: "Convolution6"
bottom: "Crop2"
top: "Eltwise2"
eltwise_param {
operation: SUM
}
}
layer {
name: "Convolution7"
type: "Convolution"
bottom: "Eltwise2"
top: "Convolution7"
convolution_param {
num_output: 128
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU6"
type: "ReLU"
bottom: "Convolution7"
top: "Convolution7"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution8"
type: "Convolution"
bottom: "Convolution7"
top: "Convolution8"
convolution_param {
num_output: 128
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU7"
type: "ReLU"
bottom: "Convolution8"
top: "Convolution8"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution9"
type: "Convolution"
bottom: "Eltwise2"
top: "Convolution9"
convolution_param {
num_output: 128
bias_term: true
pad: 0
kernel_size: 1
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "Crop3"
type: "Crop"
bottom: "Convolution9"
bottom: "Convolution8"
top: "Crop3"
crop_param {
axis: 2
offset: 2
offset: 2
}
}
layer {
name: "Eltwise3"
type: "Eltwise"
bottom: "Convolution8"
bottom: "Crop3"
top: "Eltwise3"
eltwise_param {
operation: SUM
}
}
layer {
name: "Convolution10"
type: "Convolution"
bottom: "Eltwise3"
top: "Convolution10"
convolution_param {
num_output: 128
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU8"
type: "ReLU"
bottom: "Convolution10"
top: "Convolution10"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution11"
type: "Convolution"
bottom: "Convolution10"
top: "Convolution11"
convolution_param {
num_output: 128
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU9"
type: "ReLU"
bottom: "Convolution11"
top: "Convolution11"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Crop4"
type: "Crop"
bottom: "Eltwise3"
bottom: "Convolution11"
top: "Crop4"
crop_param {
axis: 2
offset: 2
offset: 2
}
}
layer {
name: "Eltwise4"
type: "Eltwise"
bottom: "Convolution11"
bottom: "Crop4"
top: "Eltwise4"
eltwise_param {
operation: SUM
}
}
layer {
name: "Convolution12"
type: "Convolution"
bottom: "Eltwise4"
top: "Convolution12"
convolution_param {
num_output: 256
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU10"
type: "ReLU"
bottom: "Convolution12"
top: "Convolution12"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution13"
type: "Convolution"
bottom: "Convolution12"
top: "Convolution13"
convolution_param {
num_output: 256
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU11"
type: "ReLU"
bottom: "Convolution13"
top: "Convolution13"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution14"
type: "Convolution"
bottom: "Eltwise4"
top: "Convolution14"
convolution_param {
num_output: 256
bias_term: true
pad: 0
kernel_size: 1
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "Crop5"
type: "Crop"
bottom: "Convolution14"
bottom: "Convolution13"
top: "Crop5"
crop_param {
axis: 2
offset: 2
offset: 2
}
}
layer {
name: "Eltwise5"
type: "Eltwise"
bottom: "Convolution13"
bottom: "Crop5"
top: "Eltwise5"
eltwise_param {
operation: SUM
}
}
layer {
name: "Convolution15"
type: "Convolution"
bottom: "Eltwise5"
top: "Convolution15"
convolution_param {
num_output: 256
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU12"
type: "ReLU"
bottom: "Convolution15"
top: "Convolution15"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Convolution16"
type: "Convolution"
bottom: "Convolution15"
top: "Convolution16"
convolution_param {
num_output: 256
bias_term: true
pad: 0
kernel_size: 3
stride: 1
weight_filler {
type: "msra"
}
}
}
layer {
name: "ReLU13"
type: "ReLU"
bottom: "Convolution16"
top: "Convolution16"
relu_param {
negative_slope: 0.1
}
}
layer {
name: "Crop6"
type: "Crop"
bottom: "Eltwise5"
bottom: "Convolution16"
top: "Crop6"
crop_param {
axis: 2
offset: 2
offset: 2
}
}
layer {
name: "Eltwise6"
type: "Eltwise"
bottom: "Convolution16"
bottom: "Crop6"
top: "Eltwise6"
eltwise_param {
operation: SUM
}
}
layer {
name: "Deconvolution1"
type: "Deconvolution"
bottom: "Eltwise6"
top: "Deconvolution1"
convolution_param {
num_output: 3
pad: 3
kernel_size: 4
stride: 2
}
}

View file

@ -82,46 +82,50 @@ local function load_images(list)
local skip = false
local alpha_color = torch.random(0, 1)
if meta and meta.alpha then
if settings.use_transparent_png then
im = alpha_util.fill(im, meta.alpha, alpha_color)
else
skip = true
end
end
if skip then
if not skip_notice then
io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n")
skip_notice = true
end
else
if csv_meta and csv_meta.x then
-- method == user
local yy = im
local xx, meta2 = image_loader.load_byte(csv_meta.x)
if meta2 and meta2.alpha then
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
if im then
if meta and meta.alpha then
if settings.use_transparent_png then
im = alpha_util.fill(im, meta.alpha, alpha_color)
else
skip = true
end
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
{data = {filters = filters, has_x = true}}})
else
im = crop_if_large(im, settings.max_training_image_size)
im = iproc.crop_mod4(im)
local scale = 1.0
if settings.random_half_rate > 0.0 then
scale = 2.0
end
if skip then
if not skip_notice then
io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n")
skip_notice = true
end
if im then
else
if csv_meta and csv_meta.x then
-- method == user
local yy = im
local xx, meta2 = image_loader.load_byte(csv_meta.x)
if xx then
if meta2 and meta2.alpha then
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
end
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
{data = {filters = filters, has_x = true}}})
else
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
end
else
im = crop_if_large(im, settings.max_training_image_size)
im = iproc.crop_mod4(im)
local scale = 1.0
if settings.random_half_rate > 0.0 then
scale = 2.0
end
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
else
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
end
else
io.stderr:write(string.format("\n%s: skip: load error.\n", filename))
end
end
else
io.stderr:write(string.format("\n%s: skip: load error.\n", filename))
end
xlua.progress(i, #csv)
if i % 10 == 0 then

View file

@ -5,12 +5,14 @@ function ClippedMSECriterion:__init(min, max)
self.min = min
self.max = max
self.diff = torch.Tensor()
self.diff_pow2 = torch.Tensor()
end
function ClippedMSECriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
self.diff:clamp(self.min, self.max)
self.diff:add(-1, target)
self.output = self.diff:pow(2):sum() / input:nElement()
self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
self.output = self.diff_pow2:sum() / input:nElement()
return self.output
end
function ClippedMSECriterion:updateGradInput(input, target)

13
lib/InplaceClip01.lua Normal file
View file

@ -0,0 +1,13 @@
local Clip01, parent = torch.class("w2nn.InplaceClip01", "nn.Module")
function Clip01:__init()
parent.__init(self)
end
function Clip01:updateOutput(input)
self.output:set(input:clamp(0, 1))
return self.output
end
function Clip01:updateGradInput(input, gradOutput)
self.gradInput:set(gradOutput)
return self.gradInput
end

27
lib/L1Criterion.lua Normal file
View file

@ -0,0 +1,27 @@
-- ref: https://en.wikipedia.org/wiki/L1_loss
local L1Criterion, parent = torch.class('w2nn.L1Criterion','nn.Criterion')
function L1Criterion:__init()
parent.__init(self)
self.diff = torch.Tensor()
self.linear_loss_buff = torch.Tensor()
end
function L1Criterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
if input:dim() == 1 then
self.diff[1] = input[1] - target
else
for i = 1, input:size(1) do
self.diff[i]:add(-1, target[i])
end
end
local linear_targets = self.diff
local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():sum()
self.output = (linear_loss) / input:nElement()
return self.output
end
function L1Criterion:updateGradInput(input, target)
local norm = 1.0 / input:nElement()
self.gradInput:resizeAs(self.diff):copy(self.diff):sign():mul(norm)
return self.gradInput
end

67
lib/SSIMCriterion.lua Normal file
View file

@ -0,0 +1,67 @@
-- SSIM Index, ref: http://www.cns.nyu.edu/~lcv/ssim/ssim_index.m
local SSIMCriterion, parent = torch.class('w2nn.SSIMCriterion','nn.Criterion')
function SSIMCriterion:__init(ch, kernel_size, sigma)
parent.__init(self)
local function gaussian2d(kernel_size, sigma)
sigma = sigma or 1
local kernel = torch.Tensor(kernel_size, kernel_size)
local u = math.floor(kernel_size / 2) + 1
local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
for x = 1, kernel_size do
for y = 1, kernel_size do
kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
end
end
kernel:div(kernel:sum())
return kernel
end
ch = ch or 1
kernel_size = kernel_size or 11
sigma = sigma or 1.5
local kernel = gaussian2d(kernel_size, sigma)
if ch > 1 then
local kernel_nd = torch.Tensor(ch, ch, kernel_size, kernel_size)
for i = 1, ch do
for j = 1, ch do
kernel_nd[i][j]:copy(kernel)
if i ~= j then
kernel_nd[i][j]:zero()
end
end
end
kernel = kernel_nd
end
self.c1 = 0.01^2
self.c2 = 0.03^2
self.ch = ch
self.conv = nn.SpatialConvolution(ch, ch, kernel_size, kernel_size, 1, 1, 0, 0):noBias()
self.conv.weight:copy(kernel)
self.mu1 = torch.Tensor()
self.mu2 = torch.Tensor()
self.mu1_sq = torch.Tensor()
self.mu2_sq = torch.Tensor()
self.mu1_mu2 = torch.Tensor()
self.sigma1_sq = torch.Tensor()
self.sigma2_sq = torch.Tensor()
self.sigma12 = torch.Tensor()
self.ssim_map = torch.Tensor()
end
function SSIMCriterion:updateOutput(input, target)-- dynamic range: 0-1
assert(input:nElement() == target:nElement())
local valid = self.conv:forward(input)
self.mu1:resizeAs(valid):copy(valid)
self.mu2:resizeAs(valid):copy(self.conv:forward(target))
self.mu1_sq:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu1)
self.mu2_sq:resizeAs(self.mu2):copy(self.mu2):cmul(self.mu2)
self.mu1_mu2:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu2)
self.sigma1_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, input)):add(-1, self.mu1_sq))
self.sigma2_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(target, target)):add(-1, self.mu2_sq))
self.sigma12:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, target)):add(-1, self.mu1_mu2))
local ssim = self.mu1_mu2:mul(2):add(self.c1):cmul(self.sigma12:mul(2):add(self.c2)):
cdiv(self.mu1_sq:add(self.mu2_sq):add(self.c1):cmul(self.sigma1_sq:add(self.sigma2_sq):add(self.c2))):mean()
return ssim
end
function SSIMCriterion:updateGradInput(input, target)
error("not implemented")
end

View file

@ -40,8 +40,7 @@ function alpha_util.make_border(rgb, alpha, offset)
collectgarbage()
end
end
rgb[torch.gt(rgb, 1.0)] = 1.0
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb:clamp(0.0, 1.0)
return rgb
end

View file

@ -1,7 +1,8 @@
require 'image'
require 'pl'
require 'cunn'
local iproc = require 'iproc'
local gm = require 'graphicsmagick'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
local data_augmentation = {}
local function pcacov(x)
@ -25,8 +26,7 @@ function data_augmentation.color_noise(src, p, factor)
pca_space[i]:mul(color_scale[i])
end
local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
dest[torch.lt(dest, 0.0)] = 0.0
dest[torch.gt(dest, 1.0)] = 1.0
dest:clamp(0.0, 1.0)
if conversion then
dest = iproc.float2byte(dest)
@ -70,6 +70,75 @@ function data_augmentation.unsharp_mask(src, p)
return src
end
end
function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
size = size or "3"
filters = utils.split(size, ",")
for i = 1, #filters do
local s = tonumber(filters[i])
filters[i] = s
end
if torch.uniform() < p then
local src, conversion = iproc.byte2float(src)
local kernel_size = filters[torch.random(1, #filters)]
local sigma
if sigma_min == sigma_max then
sigma = sigma_min
else
sigma = torch.uniform(sigma_min, sigma_max)
end
local kernel = iproc.gaussian2d(kernel_size, sigma)
local dest = image.convolve(src, kernel, 'same')
if conversion then
dest = iproc.float2byte(dest)
end
return dest
else
return src
end
end
function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
if torch.uniform() < p then
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
local scale = torch.uniform(scale_min, scale_max)
local h = math.floor(x:size(2) * scale)
local w = math.floor(x:size(3) * scale)
x = iproc.scale(x, w, h, "Triangle")
y = iproc.scale(y, w, h, "Triangle")
return x, y
else
return x, y
end
end
function data_augmentation.pairwise_rotate(x, y, p, r_min, r_max)
if torch.uniform() < p then
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
local r = torch.uniform(r_min, r_max) / 360.0 * math.pi
x = iproc.rotate(x, r)
y = iproc.rotate(y, r)
return x, y
else
return x, y
end
end
function data_augmentation.pairwise_negate(x, y, p)
if torch.uniform() < p then
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
x = iproc.negate(x)
y = iproc.negate(y)
return x, y
else
return x, y
end
end
function data_augmentation.pairwise_negate_x(x, y, p)
if torch.uniform() < p then
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
x = iproc.negate(x)
return x, y
else
return x, y
end
end
function data_augmentation.shift_1px(src)
-- reducing the even/odd issue in nearest neighbor scaler.
local direction = torch.random(1, 4)
@ -107,11 +176,11 @@ function data_augmentation.flip(src)
src = src:transpose(2, 3):contiguous()
end
if flip == 1 then
dest = image.hflip(src)
dest = iproc.hflip(src)
elseif flip == 2 then
dest = image.vflip(src)
dest = iproc.vflip(src)
elseif flip == 3 then
dest = image.hflip(image.vflip(src))
dest = iproc.hflip(iproc.vflip(src))
elseif flip == 4 then
dest = src
end
@ -120,4 +189,20 @@ function data_augmentation.flip(src)
end
return dest
end
local function test_blur()
torch.setdefaulttensortype("torch.FloatTensor")
local image =require 'image'
local src = image.lena()
image.display({image = src, min=0, max=1})
local dest = data_augmentation.blur(src, 1.0, "3,5", 0.5, 0.6)
image.display({image = dest, min=0, max=1})
dest = data_augmentation.blur(src, 1.0, "3", 1.0, 1.0)
image.display({image = dest, min=0, max=1})
dest = data_augmentation.blur(src, 1.0, "5", 0.75, 0.75)
image.display({image = dest, min=0, max=1})
end
--test_blur()
return data_augmentation

View file

@ -22,8 +22,7 @@ function image_loader.encode_png(rgb, options)
else
rgb = rgb:clone():add(clip_eps8)
end
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
rgb:clamp(0.0, 1.0)
rgb = rgb:mul(255):floor():div(255)
else
if options.inplace then
@ -31,8 +30,7 @@ function image_loader.encode_png(rgb, options)
else
rgb = rgb:clone():add(clip_eps16)
end
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
rgb:clamp(0.0, 1.0)
rgb = rgb:mul(65535):floor():div(65535)
end
local im

View file

@ -1,6 +1,7 @@
local gm = require 'graphicsmagick'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
require 'dok'
local image = require 'image'
local iproc = {}
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
@ -42,8 +43,7 @@ function iproc.float2byte(src)
if src:type() == "torch.FloatTensor" then
conversion = true
dest = (src + clip_eps8):mul(255.0)
dest[torch.lt(dest, 0.0)] = 0
dest[torch.gt(dest, 255.0)] = 255.0
dest:clamp(0, 255.0)
dest = dest:byte()
end
return dest, conversion
@ -80,6 +80,7 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
return dest
end
function iproc.padding(img, w1, w2, h1, h2)
image = image or require 'image'
local dst_height = img:size(2) + h1 + h2
local dst_width = img:size(3) + w1 + w2
local flow = torch.Tensor(2, dst_height, dst_width)
@ -90,6 +91,7 @@ function iproc.padding(img, w1, w2, h1, h2)
return image.warp(img, flow, "simple", false, "clamp")
end
function iproc.zero_padding(img, w1, w2, h1, h2)
image = image or require 'image'
local dst_height = img:size(2) + h1 + h2
local dst_width = img:size(3) + w1 + w2
local flow = torch.Tensor(2, dst_height, dst_width)
@ -115,8 +117,7 @@ function iproc.white_noise(src, std, rgb_weights, gamma)
local dest
if gamma ~= 0 then
dest = src:clone():pow(gamma):add(noise)
dest[torch.lt(dest, 0.0)] = 0.0
dest[torch.gt(dest, 1.0)] = 1.0
dest:clamp(0.0, 1.0)
dest:pow(1.0 / gamma)
else
dest = src + noise
@ -126,6 +127,101 @@ function iproc.white_noise(src, std, rgb_weights, gamma)
end
return dest
end
function iproc.hflip(src)
local t
if src:type() == "torch.ByteTensor" then
t = "byte"
else
t = "float"
end
if src:size(1) == 3 then
color = "RGB"
else
color = "I"
end
local im = gm.Image(src, color, "DHW")
return im:flop():toTensor(t, color, "DHW")
end
function iproc.vflip(src)
local t
if src:type() == "torch.ByteTensor" then
t = "byte"
else
t = "float"
end
if src:size(1) == 3 then
color = "RGB"
else
color = "I"
end
local im = gm.Image(src, color, "DHW")
return im:flip():toTensor(t, color, "DHW")
end
local function rotate_with_warp(src, dst, theta, mode)
local height
local width
if src:dim() == 2 then
height = src:size(1)
width = src:size(2)
elseif src:dim() == 3 then
height = src:size(2)
width = src:size(3)
else
dok.error('src image must be 2D or 3D', 'image.rotate')
end
local flow = torch.Tensor(2, height, width)
local kernel = torch.Tensor({{math.cos(-theta), -math.sin(-theta)},
{math.sin(-theta), math.cos(-theta)}})
flow[1] = torch.ger(torch.linspace(0, 1, height), torch.ones(width))
flow[1]:mul(-(height -1)):add(math.floor(height / 2 + 0.5))
flow[2] = torch.ger(torch.ones(height), torch.linspace(0, 1, width))
flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
dst:resizeAs(src)
return image.warp(dst, src, flow, mode, true, 'clamp')
end
function iproc.rotate(src, theta)
local conversion
src, conversion = iproc.byte2float(src)
local dest = torch.Tensor():typeAs(src):resizeAs(src)
rotate_with_warp(src, dest, theta, 'bilinear')
dest:clamp(0, 1)
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
function iproc.negate(src)
if src:type() == "torch.ByteTensor" then
return -src + 255
else
return -src + 1
end
end
function iproc.gaussian2d(kernel_size, sigma)
sigma = sigma or 1
local kernel = torch.Tensor(kernel_size, kernel_size)
local u = math.floor(kernel_size / 2) + 1
local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
for x = 1, kernel_size do
for y = 1, kernel_size do
kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
end
end
kernel:div(kernel:sum())
return kernel
end
function iproc.rgb2y(src)
local conversion
src, conversion = iproc.byte2float(src)
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
local function test_conversion()
local a = torch.linspace(0, 255, 256):float():div(255.0)
@ -144,6 +240,46 @@ local function test_conversion()
print(b)
assert(b:float():sum() == 254.0 * 3)
end
local function test_flip()
require 'sys'
require 'torch'
torch.setdefaulttensortype("torch.FloatTensor")
image = require 'image'
local src = image.lena()
local src_byte = src:clone():mul(255):byte()
print(src:size())
print((image.hflip(src) - iproc.hflip(src)):sum())
print((image.hflip(src_byte) - iproc.hflip(src_byte)):sum())
print((image.vflip(src) - iproc.vflip(src)):sum())
print((image.vflip(src_byte) - iproc.vflip(src_byte)):sum())
end
local function test_gaussian2d()
local t = {3, 5, 7}
for i = 1, #t do
local kp = iproc.gaussian2d(t[i], 0.5)
print(kp)
end
end
local function test_conv()
local image = require 'image'
local src = image.lena()
local kernel = torch.Tensor(3, 3):fill(1)
kernel:div(kernel:sum())
--local blur = image.convolve(iproc.padding(src, 1, 1, 1, 1), kernel, 'valid')
local blur = image.convolve(src, kernel, 'same')
print(src:size(), blur:size())
local diff = (blur - src):abs()
image.save("diff.png", diff)
image.display({image = blur, min=0, max=1})
image.display({image = diff, min=0, max=1})
end
--test_conversion()
--test_flip()
--test_gaussian2d()
--test_conv()
return iproc

View file

@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric,
local output = model:forward(inputs)
local f = criterion:forward(output, targets)
local se = 0
for i = 1, batch_size do
local el = eval_metric:forward(output[i], targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
if config.xInstanceLoss then
for i = 1, batch_size do
local el = eval_metric:forward(output[i], targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
end
se = (se / batch_size)
else
se = eval_metric:forward(output, targets)
end
sum_eval = sum_eval + (se / batch_size)
sum_eval = sum_eval + se
sum_loss = sum_loss + f
count_loss = count_loss + 1
model:backward(inputs, criterion:backward(output, targets))

View file

@ -1,5 +1,6 @@
local pairwise_utils = require 'pairwise_transform_utils'
local gm = require 'graphicsmagick'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
local iproc = require 'iproc'
local pairwise_transform = {}
@ -42,8 +43,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
if torch.uniform() < options.nr_rate then
-- reducing noise

View file

@ -1,6 +1,7 @@
local pairwise_utils = require 'pairwise_transform_utils'
local iproc = require 'iproc'
local gm = require 'graphicsmagick'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
local pairwise_transform = {}
local function add_jpeg_noise_(x, quality, options)
@ -117,8 +118,8 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end

View file

@ -1,6 +1,7 @@
local pairwise_utils = require 'pairwise_transform_utils'
local iproc = require 'iproc'
local gm = require 'graphicsmagick'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
local pairwise_transform = {}
function pairwise_transform.scale(src, scale, size, offset, n, options)
@ -50,8 +51,8 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end

View file

@ -1,43 +1,30 @@
local pairwise_utils = require 'pairwise_transform_utils'
local iproc = require 'iproc'
local gm = require 'graphicsmagick'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
local pairwise_transform = {}
local function crop_if_large(x, y, scale_y, max_size, mod)
local tries = 4
if y:size(2) > max_size and y:size(3) > max_size then
assert(max_size % 4 == 0)
local rect_x, rect_y
for i = 1, tries do
local yi = torch.random(0, y:size(2) - max_size)
local xi = torch.random(0, y:size(3) - max_size)
if mod then
yi = yi - (yi % mod)
xi = xi - (xi % mod)
end
rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
-- ignore simple background
if rect_y:float():std() >= 0 then
break
end
end
return rect_x, rect_y
else
return x, y
end
end
function pairwise_transform.user(x, y, size, offset, n, options)
assert(x:size(1) == y:size(1))
local scale_y = y:size(2) / x:size(2)
assert(x:size(3) == y:size(3) / scale_y)
x, y = crop_if_large(x, y, scale_y, options.max_size, scale_y)
x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
local batch = {}
local lowres_y = pairwise_utils.low_resolution(y)
local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
local lowres_y = nil
local xs ={x}
local ys = {y}
local ls = {}
if options.active_cropping_rate > 0 then
lowres_y = pairwise_utils.low_resolution(y)
end
if options.pairwise_flip then
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
end
assert(#xs == #ys)
for i = 1, n do
local t = (i % #xs) + 1
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
@ -47,8 +34,17 @@ function pairwise_transform.user(x, y, size, offset, n, options)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
if options.gcn then
local mean = xc:mean()
local stdv = xc:std()
if stdv > 0 then
xc:add(-mean):div(stdv)
else
xc:add(-mean)
end
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end

View file

@ -1,7 +1,7 @@
require 'image'
require 'cunn'
local iproc = require 'iproc'
local gm = require 'graphicsmagick'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
local data_augmentation = require 'data_augmentation'
local pairwise_transform_utils = {}
@ -36,6 +36,30 @@ function pairwise_transform_utils.crop_if_large(src, max_size, mod)
return src
end
end
function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
local tries = 4
if y:size(2) > max_size and y:size(3) > max_size then
assert(max_size % 4 == 0)
local rect_x, rect_y
for i = 1, tries do
local yi = torch.random(0, y:size(2) - max_size)
local xi = torch.random(0, y:size(3) - max_size)
if mod then
yi = yi - (yi % mod)
xi = xi - (xi % mod)
end
rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
-- ignore simple background
if rect_y:float():std() >= 0 then
break
end
end
return rect_x, rect_y
else
return x, y
end
end
function pairwise_transform_utils.preprocess(src, crop_size, options)
local dest = src
local box_only = false
@ -47,7 +71,6 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
if box_only then
local mod = 2 -- assert pos % 2 == 0
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
@ -55,7 +78,10 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
else
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest)
dest = data_augmentation.blur(dest, options.random_blur_rate,
options.random_blur_size,
options.random_blur_sigma_min,
options.random_blur_sigma_max)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
@ -63,6 +89,33 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
end
return dest
end
function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
x, y = data_augmentation.pairwise_rotate(x, y,
options.random_pairwise_rotate_rate,
options.random_pairwise_rotate_min,
options.random_pairwise_rotate_max)
local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
x, y = data_augmentation.pairwise_scale(x, y,
options.random_pairwise_scale_rate,
scale_min,
scale_max)
x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
x = iproc.crop_mod4(x)
y = iproc.crop_mod4(y)
if options.pairwise_y_binary then
y[torch.lt(y, 128)] = 0
y[torch.gt(y, 0)] = 255
end
return x, y
end
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
assert("crop_size % scale == 0", size % scale == 0)
@ -111,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
for j = 1, 2 do
-- TTA
local xi, yi, ri
local xi, yi, ri, ni
if j == 1 then
xi = x
ni = x_noise
@ -123,42 +176,55 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
ni = x_noise:transpose(2, 3):contiguous()
end
yi = y:transpose(2, 3):contiguous()
ri = lowres_y:transpose(2, 3):contiguous()
if lowres_y then
ri = lowres_y:transpose(2, 3):contiguous()
end
end
local xv = image.vflip(xi)
local xv = iproc.vflip(xi)
local nv
if x_noise then
nv = image.vflip(ni)
nv = iproc.vflip(ni)
end
local yv = iproc.vflip(yi)
local rv
if ri then
rv = iproc.vflip(ri)
end
local yv = image.vflip(yi)
local rv = image.vflip(ri)
table.insert(xs, xi)
if ni then
table.insert(ns, ni)
end
table.insert(ys, yi)
table.insert(ls, ri)
if ri then
table.insert(ls, ri)
end
table.insert(xs, xv)
if nv then
table.insert(ns, nv)
end
table.insert(ys, yv)
table.insert(ls, rv)
if rv then
table.insert(ls, rv)
end
table.insert(xs, image.hflip(xi))
table.insert(xs, iproc.hflip(xi))
if ni then
table.insert(ns, image.hflip(ni))
table.insert(ns, iproc.hflip(ni))
end
table.insert(ys, iproc.hflip(yi))
if ri then
table.insert(ls, iproc.hflip(ri))
end
table.insert(ys, image.hflip(yi))
table.insert(ls, image.hflip(ri))
table.insert(xs, image.hflip(xv))
table.insert(xs, iproc.hflip(xv))
if nv then
table.insert(ns, image.hflip(nv))
table.insert(ns, iproc.hflip(nv))
end
table.insert(ys, iproc.hflip(yv))
if rv then
table.insert(ls, iproc.hflip(rv))
end
table.insert(ys, image.hflip(yv))
table.insert(ls, image.hflip(rv))
end
return xs, ys, ls, ns
end
@ -171,6 +237,9 @@ end
local g_lowres_model = nil
local g_lowres_gpu = nil
function pairwise_transform_utils.low_resolution(src)
--[[
-- I am not sure that the following process is thraed-safe
g_lowres_model = g_lowres_model or lowres_model()
if g_lowres_gpu == nil then
--benchmark
@ -203,6 +272,11 @@ function pairwise_transform_utils.low_resolution(src)
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "RGB", "DHW")
end
--]]
return gm.Image(src, "RGB", "DHW"):
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "RGB", "DHW")
end
return pairwise_transform_utils

View file

@ -40,6 +40,15 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
break
end
input[j+1]:copy(x[input_indexes[i + j]])
if model.w2nn_gcn then
local mean = input[j + 1]:mean()
local stdv = input[j + 1]:std()
if stdv > 0 then
input[j + 1]:add(-mean):div(stdv)
else
input[j + 1]:add(-mean)
end
end
c = c + 1
end
input_cuda:copy(input)
@ -80,7 +89,12 @@ local function padding_params(x, model, block_size)
p.x_w = x:size(3)
p.x_h = x:size(2)
p.inner_scale = reconstruct.inner_scale(model)
local input_offset = math.ceil(offset / p.inner_scale)
local input_offset
if model.w2nn_input_offset then
input_offset = model.w2nn_input_offset
else
input_offset = math.ceil(offset / p.inner_scale)
end
local input_block_size = block_size
local process_size = input_block_size - input_offset * 2
local h_blocks = math.floor(p.x_h / process_size) +
@ -172,6 +186,9 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
return output
end
function reconstruct.image(model, x, block_size)
if model.w2nn_input_size then
block_size = model.w2nn_input_size
end
local i2rgb = false
if x:size(1) == 1 then
local new_x = torch.Tensor(3, x:size(2), x:size(3))
@ -194,6 +211,9 @@ function reconstruct.image(model, x, block_size)
return x
end
function reconstruct.scale(model, scale, x, block_size)
if model.w2nn_input_size then
block_size = model.w2nn_input_size
end
local i2rgb = false
if x:size(1) == 1 then
local new_x = torch.Tensor(3, x:size(2), x:size(3))
@ -287,6 +307,9 @@ local function tta(f, n, model, x, block_size)
return average:div(#augments)
end
function reconstruct.image_tta(model, n, x, block_size)
if model.w2nn_input_size then
block_size = model.w2nn_input_size
end
if reconstruct.is_rgb(model) then
return tta(reconstruct.image_rgb, n, model, x, block_size)
else
@ -294,6 +317,9 @@ function reconstruct.image_tta(model, n, x, block_size)
end
end
function reconstruct.scale_tta(model, n, scale, x, block_size)
if model.w2nn_input_size then
block_size = model.w2nn_input_size
end
if reconstruct.is_rgb(model) then
local f = function (model, x, offset, block_size)
return reconstruct.scale_rgb(model, scale, x, offset, block_size)

View file

@ -1,6 +1,7 @@
require 'xlua'
require 'pl'
require 'trepl'
require 'cutorch'
-- global settings
@ -18,7 +19,7 @@ cmd:text()
cmd:text("waifu2x-training")
cmd:text("Options:")
cmd:option("-gpu", -1, 'GPU Device ID')
cmd:option("-seed", 11, 'RNG seed')
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
cmd:option("-data_dir", "./data", 'path to data directory')
cmd:option("-backend", "cunn", '(cunn|cudnn)')
cmd:option("-test", "images/miku_small.png", 'path to test image')
@ -32,6 +33,20 @@ cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0-1.0)')
cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
cmd:option("-random_blur_sigma_max", 1.0, 'max sigma for random gaussian blur')
cmd:option("-random_pairwise_scale_rate", 0.0, 'data augmentation using pairwise resize for user method')
cmd:option("-random_pairwise_scale_min", 0.85, 'min scale factor for random pairwise scale')
cmd:option("-random_pairwise_scale_max", 1.176, 'max scale factor for random pairwise scale')
cmd:option("-random_pairwise_rotate_rate", 0.0, 'data augmentation using pairwise resize for user method')
cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairwise rotate')
cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-crop_size", 48, 'crop size')
@ -59,6 +74,8 @@ cmd:option("-oracle_drop_rate", 0.5, '')
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
cmd:option("-resume", "", 'resume model file')
cmd:option("-name", "user", 'model name for user method')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
local function to_bool(settings, name)
if settings[name] == 1 then
@ -75,6 +92,8 @@ end
to_bool(settings, "plot")
to_bool(settings, "save_history")
to_bool(settings, "use_transparent_png")
to_bool(settings, "pairwise_y_binary")
to_bool(settings, "pairwise_flip")
if settings.plot then
require 'gnuplot'
@ -148,4 +167,6 @@ end
settings.images = string.format("%s/images.t7", settings.data_dir)
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
cutorch.setDevice(opt.gpu)
return settings

View file

@ -136,6 +136,24 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
error("unsupported backend:" .. backend)
end
end
local function ReLU(backend)
if backend == "cunn" then
return nn.ReLU(true)
elseif backend == "cudnn" then
return cudnn.ReLU(true)
else
error("unsupported backend:" .. backend)
end
end
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
elseif backend == "cudnn" then
return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
else
error("unsupported backend:" .. backend)
end
end
-- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch)
@ -153,6 +171,7 @@ function srcnn.vgg_7(backend, ch)
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "vgg_7"
@ -190,6 +209,7 @@ function srcnn.vgg_12(backend, ch)
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "vgg_12"
@ -219,6 +239,7 @@ function srcnn.dilated_7(backend, ch)
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "dilated_7"
@ -249,6 +270,7 @@ function srcnn.upconv_7(backend, ch)
model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "upconv_7"
@ -257,11 +279,255 @@ function srcnn.upconv_7(backend, ch)
model.w2nn_resize = true
model.w2nn_channels = ch
return model
end
-- large version of upconv_7
-- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
function srcnn.upconv_7l(backend, ch)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "upconv_7l"
model.w2nn_offset = 14
model.w2nn_scale_factor = 2
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
-- layerwise linear blending with skip connections
-- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l
function srcnn.skiplb_7(backend, ch)
local function skip(backend, i, o)
local con = nn.Concat(2)
local conv = nn.Sequential()
conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1))
conv:add(nn.LeakyReLU(0.1, true))
-- depth concat
con:add(conv)
con:add(nn.Identity()) -- skip
return con
end
local model = nn.Sequential()
model:add(skip(backend, ch, 16))
model:add(skip(backend, 16+ch, 32))
model:add(skip(backend, 32+16+ch, 64))
model:add(skip(backend, 64+32+16+ch, 128))
model:add(skip(backend, 128+64+32+16+ch, 128))
model:add(skip(backend, 128+128+64+32+16+ch, 256))
-- input of last layer = [all layerwise output(contains input layer)].flatten
model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "skiplb_7"
model.w2nn_offset = 14
model.w2nn_scale_factor = 2
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
-- dilated convolution + deconvolution
-- Note: This model is not better than upconv_7. Maybe becuase of under-fitting.
function srcnn.dilated_upconv_7(backend, ch)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
model:add(nn.LeakyReLU(0.1, true))
model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2))
model:add(nn.LeakyReLU(0.1, true))
model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "dilated_upconv_7"
model.w2nn_offset = 20
model.w2nn_scale_factor = 2
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
-- ref: https://arxiv.org/abs/1609.04802
-- note: no batch-norm, no zero-paading
function srcnn.srresnet_2x(backend, ch)
local function resblock(backend)
local seq = nn.Sequential()
local con = nn.ConcatTable()
local conv = nn.Sequential()
conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
conv:add(ReLU(backend))
conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
conv:add(ReLU(backend))
con:add(conv)
con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
seq:add(con)
seq:add(nn.CAddTable())
return seq
end
local model = nn.Sequential()
--model:add(skip(backend, ch, 64 - ch))
model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(resblock(backend))
model:add(resblock(backend))
model:add(resblock(backend))
model:add(resblock(backend))
model:add(resblock(backend))
model:add(resblock(backend))
model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2))
model:add(ReLU(backend))
model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
model:add(w2nn.InplaceClip01())
--model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "srresnet_2x"
model.w2nn_offset = 28
model.w2nn_scale_factor = 2
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
-- large version of srresnet_2x. It's current best model but slow.
function srcnn.resnet_14l(backend, ch)
local function resblock(backend, i, o)
local seq = nn.Sequential()
local con = nn.ConcatTable()
local conv = nn.Sequential()
conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
conv:add(nn.LeakyReLU(0.1, true))
conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
conv:add(nn.LeakyReLU(0.1, true))
con:add(conv)
if i == o then
con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
else
local seq = nn.Sequential()
seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
con:add(seq)
end
seq:add(con)
seq:add(nn.CAddTable())
return seq
end
local model = nn.Sequential()
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(resblock(backend, 32, 64))
model:add(resblock(backend, 64, 64))
model:add(resblock(backend, 64, 128))
model:add(resblock(backend, 128, 128))
model:add(resblock(backend, 128, 256))
model:add(resblock(backend, 256, 256))
model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "resnet_14l"
model.w2nn_offset = 28
model.w2nn_scale_factor = 2
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
-- for segmentation
function srcnn.fcn_v1(backend, ch)
-- input_size = 120
local model = nn.Sequential()
--i = 120
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(nn.Dropout(0.5, false, true))
model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "fcn_v1"
model.w2nn_offset = 36
model.w2nn_scale_factor = 1
model.w2nn_channels = ch
model.w2nn_input_size = 120
--model.w2nn_gcn = true
return model
end
function srcnn.create(model_name, backend, color)
model_name = model_name or "vgg_7"
backend = backend or "cunn"
@ -282,8 +548,10 @@ function srcnn.create(model_name, backend, color)
error("unsupported model_name: " .. model_name)
end
end
--local model = srcnn.upconv_6("cunn", 3):cuda()
--print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
--[[
local model = srcnn.fcn_v1("cunn", 3):cuda()
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
print(model)
--]]
return srcnn

View file

@ -30,5 +30,8 @@ else
require 'LeakyReLU'
require 'ClippedWeightedHuberCriterion'
require 'ClippedMSECriterion'
require 'SSIMCriterion'
require 'InplaceClip01'
require 'L1Criterion'
return w2nn
end

View file

@ -0,0 +1 @@
Currently, this models are for the benchmark.

Binary file not shown.

View file

@ -0,0 +1 @@
Currently, this models are for the benchmark.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -18,7 +18,7 @@ cmd:option("-dir", "./data/test", 'test image directory')
cmd:option("-file", "", 'test image file list')
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
cmd:option("-model2_dir", "", 'model2 directory (optional)')
cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff)')
cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff|scale4)')
cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
cmd:option("-resize_blur", 1.0, 'blur parameter for resize')
cmd:option("-color", "y", '(rgb|y|r|g|b)')
@ -46,6 +46,7 @@ cmd:option("-x_dir", "", 'input image for user method')
cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be the same as x_dir')
cmd:option("-x_file", "", 'input image for user method')
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
cmd:option("-border", 0, 'border px that will removed')
local function to_bool(settings, name)
if settings[name] == 1 then
@ -153,12 +154,24 @@ local function baseline_scale(x, filter)
x:size(2) * 2.0,
filter)
end
local function baseline_scale4(x, filter)
return iproc.scale(x,
x:size(3) * 4.0,
x:size(2) * 4.0,
filter)
end
local function transform_scale(x, opt)
return iproc.scale(x,
x:size(3) * 0.5,
x:size(2) * 0.5,
opt.filter, opt.resize_blur)
end
local function transform_scale4(x, opt)
return iproc.scale(x,
x:size(3) * 0.25,
x:size(2) * 0.25,
opt.filter, opt.resize_blur)
end
local function transform_scale_jpeg(x, opt)
x = iproc.scale(x,
@ -179,9 +192,15 @@ local function transform_scale_jpeg(x, opt)
end
return iproc.byte2float(x)
end
local function remove_border(x, border)
return iproc.crop(x,
border, border,
x:size(3) - border,
x:size(2) - border)
end
local function benchmark(opt, x, model1, model2)
local mse
local mse1, mse2
local won = {0, 0}
local model1_mse = 0
local model2_mse = 0
local baseline_mse = 0
@ -192,6 +211,10 @@ local function benchmark(opt, x, model1, model2)
local model2_time = 0
local scale_f = reconstruct.scale
local image_f = reconstruct.image
local detail_fp = nil
if opt.save_info then
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
end
if opt.tta then
scale_f = function(model, scale, x, block_size, batch_size)
return reconstruct.scale_tta(model, opt.tta_level,
@ -204,12 +227,15 @@ local function benchmark(opt, x, model1, model2)
end
for i = 1, #x do
if i % 10 == 0 then
collectgarbage()
end
local basename = x[i].basename
local input, model1_output, model2_output, baseline_output, ground_truth
if opt.method == "scale" then
input = transform_scale(x[i].y, opt)
ground_truth = x[i].y
input = transform_scale(iproc.byte2float(x[i].y), opt)
ground_truth = iproc.byte2float(x[i].y)
if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
@ -226,9 +252,29 @@ local function benchmark(opt, x, model1, model2)
model2_time = model2_time + (sys.clock() - t)
end
baseline_output = baseline_scale(input, opt.baseline_filter)
elseif opt.method == "scale4" then
input = transform_scale4(iproc.byte2float(x[i].y), opt)
ground_truth = iproc.byte2float(x[i].y)
if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
if model2 then
model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
end
end
t = sys.clock()
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
model1_output = scale_f(model1, 2.0, model1_output, opt.crop_size, opt.batch_size)
model1_time = model1_time + (sys.clock() - t)
if model2 then
t = sys.clock()
model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
model2_output = scale_f(model2, 2.0, model2_output, opt.crop_size, opt.batch_size)
model2_time = model2_time + (sys.clock() - t)
end
baseline_output = baseline_scale4(input, opt.baseline_filter)
elseif opt.method == "noise" then
input = transform_jpeg(x[i].y, opt)
ground_truth = x[i].y
input = transform_jpeg(iproc.byte2float(x[i].y), opt)
ground_truth = iproc.byte2float(x[i].y)
if opt.force_cudnn and i == 1 then
model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
@ -246,8 +292,8 @@ local function benchmark(opt, x, model1, model2)
end
baseline_output = input
elseif opt.method == "noise_scale" then
input = transform_scale_jpeg(x[i].y, opt)
ground_truth = x[i].y
input = transform_scale_jpeg(iproc.byte2float(x[i].y), opt)
ground_truth = iproc.byte2float(x[i].y)
if opt.force_cudnn and i == 1 then
if model1.noise_scale_model then
@ -312,8 +358,8 @@ local function benchmark(opt, x, model1, model2)
end
baseline_output = baseline_scale(input, opt.baseline_filter)
elseif opt.method == "user" then
input = x[i].x
ground_truth = x[i].y
input = iproc.byte2float(x[i].x)
ground_truth = iproc.byte2float(x[i].y)
local y_scale = ground_truth:size(2) / input:size(2)
if y_scale > 1 then
if opt.force_cudnn and i == 1 then
@ -347,19 +393,44 @@ local function benchmark(opt, x, model1, model2)
end
end
elseif opt.method == "diff" then
input = x[i].x
ground_truth = x[i].y
input = iproc.byte2float(x[i].x)
ground_truth = iproc.byte2float(x[i].y)
model1_output = input
end
mse = MSE(ground_truth, model1_output, opt.color)
model1_mse = model1_mse + mse
model1_psnr = model1_psnr + MSE2PSNR(mse)
if opt.border > 0 then
ground_truth = remove_border(ground_truth, opt.border)
model1_output = remove_border(model1_output, opt.border)
end
mse1 = MSE(ground_truth, model1_output, opt.color)
model1_mse = model1_mse + mse1
model1_psnr = model1_psnr + MSE2PSNR(mse1)
local won_model = 1
if model2 then
mse = MSE(ground_truth, model2_output, opt.color)
model2_mse = model2_mse + mse
model2_psnr = model2_psnr + MSE2PSNR(mse)
if opt.border > 0 then
model2_output = remove_border(model2_output, opt.border)
end
mse2 = MSE(ground_truth, model2_output, opt.color)
model2_mse = model2_mse + mse2
model2_psnr = model2_psnr + MSE2PSNR(mse2)
if mse1 < mse2 then
won[1] = won[1] + 1
elseif mse1 > mse2 then
won[2] = won[2] + 1
won_model = 2
end
if detail_fp then
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
end
else
if detail_fp then
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
end
end
if baseline_output then
baseline_output = remove_border(baseline_output, opt.border)
mse = MSE(ground_truth, baseline_output, opt.color)
baseline_mse = baseline_mse + mse
baseline_psnr = baseline_psnr + MSE2PSNR(mse)
@ -382,29 +453,31 @@ local function benchmark(opt, x, model1, model2)
if model2 then
if baseline_output then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
i, #x,
model1_time,
model2_time,
math.sqrt(baseline_mse / i),
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
baseline_psnr / i,
model1_psnr / i, model2_psnr / i
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
i, #x,
model1_time,
model2_time,
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
model1_psnr / i, model2_psnr / i
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
end
else
if baseline_output then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r",
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
@ -412,7 +485,7 @@ local function benchmark(opt, x, model1, model2)
))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model1_rmse=%f, model1_psnr=%f \r",
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(model1_mse / i), model1_psnr / i
@ -438,6 +511,9 @@ local function benchmark(opt, x, model1, model2)
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
end
fp:close()
if detail_fp then
detail_fp:close()
end
end
io.stdout:write("\n")
end
@ -448,7 +524,7 @@ local function load_data_from_dir(test_dir)
local name = path.basename(files[i])
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
local img = image_loader.load_float(files[i])
local img = image_loader.load_byte(files[i])
if img then
table.insert(test_x, {y = iproc.crop_mod4(img),
basename = base})
@ -456,6 +532,9 @@ local function load_data_from_dir(test_dir)
if opt.show_progress then
xlua.progress(i, #files)
end
if i % 10 == 0 then
collectgarbage()
end
end
return test_x
end
@ -466,7 +545,7 @@ local function load_data_from_file(test_file)
local name = path.basename(files[i])
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
local img = image_loader.load_float(files[i])
local img = image_loader.load_byte(files[i])
if img then
table.insert(test_x, {y = iproc.crop_mod4(img),
basename = base})
@ -474,6 +553,9 @@ local function load_data_from_file(test_file)
if opt.show_progress then
xlua.progress(i, #files)
end
if i % 10 == 0 then
collectgarbage()
end
end
return test_x
end
@ -519,16 +601,19 @@ local function load_user_data(y_dir, y_file, x_dir, x_file)
end
for i = 1, #y_files do
local key = get_basename(y_files[i])
local x = image_loader.load_float(basename_db[key].x)
local y = image_loader.load_float(basename_db[key].y)
local x = image_loader.load_byte(basename_db[key].x)
local y = image_loader.load_byte(basename_db[key].y)
if x and y then
table.insert(test, {y = y,
x = x,
basename = base})
basename = key})
end
if opt.show_progress then
xlua.progress(i, #y_files)
end
if i % 10 == 0 then
collectgarbage()
end
end
return test
end
@ -563,7 +648,7 @@ if opt.show_progress then
print(opt)
end
if opt.method == "scale" then
if opt.method == "scale" or opt.method == "scale4" then
local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
local s1, model1 = pcall(w2nn.load_model, f1, opt.force_cudnn)

View file

@ -33,5 +33,7 @@ export_model() {
}
export_model vgg_7/art
export_model upconv_7/art
export_model upconv_7l/art
export_model vgg_7/photo
export_model upconv_7/photo
export_model upconv_7l/photo

View file

@ -22,7 +22,6 @@ local function includes(s, a)
end
return false
end
local function get_bias(mod)
if mod.bias then
return mod.bias:float()
@ -31,20 +30,18 @@ local function get_bias(mod)
return torch.FloatTensor(mod.nOutputPlane):zero()
end
end
local function export(model, output)
local function export_weight(jmodules, seq)
local targets = {"nn.SpatialConvolutionMM",
"cudnn.SpatialConvolution",
"nn.SpatialFullConvolution",
"cudnn.SpatialFullConvolution"
}
local jmodules = {}
local model_config = meta_data(model)
local first_layer = true
for k = 1, #model.modules do
local mod = model.modules[k]
for k = 1, #seq.modules do
local mod = seq.modules[k]
local name = torch.typename(mod)
if includes(name, targets) then
if name == "nn.Sequential" or name == "nn.ConcatTable" then
export_weight(jmodules, mod)
elseif includes(name, targets) then
local weight = mod.weight:float()
if name:match("FullConvolution") then
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
@ -71,6 +68,14 @@ local function export(model, output)
table.insert(jmodules, jmod)
end
end
end
local function export(model, output)
local jmodules = {}
local model_config = meta_data(model)
local first_layer = true
export_weight(jmodules, model)
local fp = io.open(output, "w")
if not fp then
error("IO Error: " .. output)

425
train.lua
View file

@ -3,15 +3,14 @@ local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
require 'optim'
require 'xlua'
require 'image'
require 'w2nn'
local threads = require 'threads'
local settings = require 'settings'
local srcnn = require 'srcnn'
local minibatch_adam = require 'minibatch_adam'
local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local compression = require 'compression'
local pairwise_transform = require 'pairwise_transform'
local image_loader = require 'image_loader'
local function save_test_scale(model, rgb, file)
@ -42,20 +41,218 @@ local function split_data(x, test_size)
end
return train_x, valid_x
end
local function make_validation_set(x, transformer, n, patches)
local g_transform_pool = nil
local g_mutex = nil
local g_mutex_id = nil
local function transform_pool_init(has_resize, offset)
local nthread = torch.getnumthreads()
if (settings.thread > 0) then
nthread = settings.thread
end
g_mutex = threads.Mutex()
g_mutex_id = g_mutex:id()
g_transform_pool = threads.Threads(
nthread,
threads.safe(
function(threadid)
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
require 'torch'
require 'nn'
require 'cunn'
torch.setnumthreads(1)
torch.setdefaulttensortype("torch.FloatTensor")
local threads = require 'threads'
local compression = require 'compression'
local pairwise_transform = require 'pairwise_transform'
function transformer(x, is_validation, n)
local mutex = threads.Mutex(g_mutex_id)
local meta = {data = {}}
local y = nil
if type(x) == "table" and type(x[2]) == "table" then
meta = x[2]
if x[1].x and x[1].y then
y = compression.decompress(x[1].y)
x = compression.decompress(x[1].x)
else
x = compression.decompress(x[1])
end
else
x = compression.decompress(x)
end
n = n or settings.patches
if is_validation == nil then is_validation = false end
local random_color_noise_rate = nil
local random_overlay_rate = nil
local active_cropping_rate = nil
local active_cropping_tries = nil
if is_validation then
active_cropping_rate = settings.active_cropping_rate
active_cropping_tries = settings.active_cropping_tries
random_color_noise_rate = 0.0
random_overlay_rate = 0.0
else
active_cropping_rate = settings.active_cropping_rate
active_cropping_tries = settings.active_cropping_tries
random_color_noise_rate = settings.random_color_noise_rate
random_overlay_rate = settings.random_overlay_rate
end
if settings.method == "scale" then
local conf = tablex.update({
mutex = mutex,
downsampling_filters = settings.downsampling_filters,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
random_blur_rate = settings.random_blur_rate,
random_blur_size = settings.random_blur_size,
random_blur_sigma_min = settings.random_blur_sigma_min,
random_blur_sigma_max = settings.random_blur_sigma_max,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb"),
x_upsampling = not has_resize,
resize_blur_min = settings.resize_blur_min,
resize_blur_max = settings.resize_blur_max}, meta)
return pairwise_transform.scale(x,
settings.scale,
settings.crop_size, offset,
n, conf)
elseif settings.method == "noise" then
local conf = tablex.update({
mutex = mutex,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
random_blur_rate = settings.random_blur_rate,
random_blur_size = settings.random_blur_size,
random_blur_sigma_min = settings.random_blur_sigma_min,
random_blur_sigma_max = settings.random_blur_sigma_max,
max_size = settings.max_size,
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
nr_rate = settings.nr_rate,
rgb = (settings.color == "rgb")}, meta)
return pairwise_transform.jpeg(x,
settings.style,
settings.noise_level,
settings.crop_size, offset,
n, conf)
elseif settings.method == "noise_scale" then
local conf = tablex.update({
mutex = mutex,
downsampling_filters = settings.downsampling_filters,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
random_blur_rate = settings.random_blur_rate,
random_blur_size = settings.random_blur_size,
random_blur_sigma_min = settings.random_blur_sigma_min,
random_blur_sigma_max = settings.random_blur_sigma_max,
max_size = settings.max_size,
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
nr_rate = settings.nr_rate,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb"),
x_upsampling = not has_resize,
resize_blur_min = settings.resize_blur_min,
resize_blur_max = settings.resize_blur_max}, meta)
return pairwise_transform.jpeg_scale(x,
settings.scale,
settings.style,
settings.noise_level,
settings.crop_size, offset,
n, conf)
elseif settings.method == "user" then
if is_validation == nil then is_validation = false end
local rotate_rate = nil
local scale_rate = nil
local negate_rate = nil
local negate_x_rate = nil
if is_validation then
rotate_rate = 0
scale_rate = 0
negate_rate = 0
negate_x_rate = 0
else
rotate_rate = settings.random_pairwise_rotate_rate
scale_rate = settings.random_pairwise_scale_rate
negate_rate = settings.random_pairwise_negate_rate
negate_x_rate = settings.random_pairwise_negate_x_rate
end
local conf = tablex.update({
gcn = settings.gcn,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
random_pairwise_rotate_rate = rotate_rate,
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
random_pairwise_scale_rate = scale_rate,
random_pairwise_scale_min = settings.random_pairwise_scale_min,
random_pairwise_scale_max = settings.random_pairwise_scale_max,
random_pairwise_negate_rate = negate_rate,
random_pairwise_negate_x_rate = negate_x_rate,
pairwise_y_binary = settings.pairwise_y_binary,
pairwise_flip = settings.pairwise_flip,
rgb = (settings.color == "rgb")}, meta)
return pairwise_transform.user(x, y,
settings.crop_size, offset,
n, conf)
end
end
end)
)
g_transform_pool:synchronize()
end
local function make_validation_set(x, n, patches)
local nthread = torch.getnumthreads()
if (settings.thread > 0) then
nthread = settings.thread
end
n = n or 4
local validation_patches = math.min(16, patches or 16)
local data = {}
g_transform_pool:synchronize()
torch.setnumthreads(1) -- 1
for i = 1, #x do
for k = 1, math.max(n / validation_patches, 1) do
local xy = transformer(x[i], true, validation_patches)
for j = 1, #xy do
table.insert(data, {x = xy[j][1], y = xy[j][2]})
end
local input = x[i]
g_transform_pool:addjob(
function()
local xy = transformer(input, true, validation_patches)
return xy
end,
function(xy)
for j = 1, #xy do
table.insert(data, {x = xy[j][1], y = xy[j][2]})
end
end
)
end
if i % 20 == 0 then
collectgarbage()
g_transform_pool:synchronize()
xlua.progress(i, #x)
end
xlua.progress(i, #x)
collectgarbage()
end
g_transform_pool:synchronize()
torch.setnumthreads(nthread) -- revert
local new_data = {}
local perm = torch.randperm(#data)
for i = 1, perm:size(1) do
@ -102,144 +299,71 @@ local function validate(model, criterion, eval_metric, data, batch_size)
end
local function create_criterion(model)
if reconstruct.is_rgb(model) then
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
else
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(1, output_w * output_w)
weight[1]:fill(1.0)
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
end
end
local function transformer(model, x, is_validation, n, offset)
local meta = {data = {}}
local y = nil
if type(x) == "table" and type(x[2]) == "table" then
meta = x[2]
if x[1].x and x[1].y then
y = compression.decompress(x[1].y)
x = compression.decompress(x[1].x)
if settings.loss == "huber" then
if reconstruct.is_rgb(model) then
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
else
x = compression.decompress(x[1])
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(1, output_w * output_w)
weight[1]:fill(1.0)
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
end
elseif settings.loss == "l1" then
return w2nn.L1Criterion():cuda()
elseif settings.loss == "mse" then
return w2nn.ClippedMSECriterion(0, 1.0):cuda()
else
x = compression.decompress(x)
end
n = n or settings.patches
if is_validation == nil then is_validation = false end
local random_color_noise_rate = nil
local random_overlay_rate = nil
local active_cropping_rate = nil
local active_cropping_tries = nil
if is_validation then
active_cropping_rate = settings.active_cropping_rate
active_cropping_tries = settings.active_cropping_tries
random_color_noise_rate = 0.0
random_overlay_rate = 0.0
else
active_cropping_rate = settings.active_cropping_rate
active_cropping_tries = settings.active_cropping_tries
random_color_noise_rate = settings.random_color_noise_rate
random_overlay_rate = settings.random_overlay_rate
end
if settings.method == "scale" then
local conf = tablex.update({
downsampling_filters = settings.downsampling_filters,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb"),
x_upsampling = not reconstruct.has_resize(model),
resize_blur_min = settings.resize_blur_min,
resize_blur_max = settings.resize_blur_max}, meta)
return pairwise_transform.scale(x,
settings.scale,
settings.crop_size, offset,
n, conf)
elseif settings.method == "noise" then
local conf = tablex.update({
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
max_size = settings.max_size,
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
nr_rate = settings.nr_rate,
rgb = (settings.color == "rgb")}, meta)
return pairwise_transform.jpeg(x,
settings.style,
settings.noise_level,
settings.crop_size, offset,
n, conf)
elseif settings.method == "noise_scale" then
local conf = tablex.update({
downsampling_filters = settings.downsampling_filters,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
max_size = settings.max_size,
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
nr_rate = settings.nr_rate,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb"),
x_upsampling = not reconstruct.has_resize(model),
resize_blur_min = settings.resize_blur_min,
resize_blur_max = settings.resize_blur_max}, meta)
return pairwise_transform.jpeg_scale(x,
settings.scale,
settings.style,
settings.noise_level,
settings.crop_size, offset,
n, conf)
elseif settings.method == "user" then
local conf = tablex.update({
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb")}, meta)
return pairwise_transform.user(x, y,
settings.crop_size, offset,
n, conf)
error("unsupported loss .." .. settings.loss)
end
end
local function resampling(x, y, train_x, transformer, input_size, target_size)
local function resampling(x, y, train_x)
local c = 1
local shuffle = torch.randperm(#train_x)
local nthread = torch.getnumthreads()
if (settings.thread > 0) then
nthread = settings.thread
end
torch.setnumthreads(1) -- 1
for t = 1, #train_x do
xlua.progress(t, #train_x)
local xy = transformer(train_x[shuffle[t]], false, settings.patches)
for i = 1, #xy do
x[c]:copy(xy[i][1])
y[c]:copy(xy[i][2])
c = c + 1
if c > x:size(1) then
break
local input = train_x[shuffle[t]]
g_transform_pool:addjob(
function()
local xy = transformer(input, false, settings.patches)
return xy
end,
function(xy)
for i = 1, #xy do
if c <= x:size(1) then
x[c]:copy(xy[i][1])
y[c]:copy(xy[i][2])
c = c + 1
else
break
end
end
end
)
if t % 50 == 0 then
collectgarbage()
g_transform_pool:synchronize()
xlua.progress(t, #train_x)
end
if c > x:size(1) then
break
end
if t % 50 == 0 then
collectgarbage()
end
end
g_transform_pool:synchronize()
xlua.progress(#train_x, #train_x)
torch.setnumthreads(nthread) -- revert
end
local function get_oracle_data(x, y, instance_loss, k, samples)
local index = torch.LongTensor(instance_loss:size(1))
@ -262,6 +386,7 @@ local function get_oracle_data(x, y, instance_loss, k, samples)
end
local function remove_small_image(x)
local compression = require 'compression'
local new_x = {}
for i = 1, #x do
local xe, meta, x_s
@ -293,6 +418,8 @@ local function plot(train, valid)
{'validation', torch.Tensor(valid), '-'}})
end
local function train()
local x = remove_small_image(torch.load(settings.images))
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local hist_train = {}
local hist_valid = {}
local model
@ -301,20 +428,30 @@ local function train()
else
model = srcnn.create(settings.model, settings.backend, settings.color)
end
if model.w2nn_input_size then
if settings.crop_size ~= model.w2nn_input_size then
io.stderr:write(string.format("warning: crop_size is replaced with %d\n",
model.w2nn_input_size))
settings.crop_size = model.w2nn_input_size
end
end
if model.w2nn_gcn then
settings.gcn = true
else
settings.gcn = false
end
dir.makepath(settings.model_dir)
local offset = reconstruct.offset_size(model)
local pairwise_func = function(x, is_validation, n)
return transformer(model, x, is_validation, n, offset)
end
transform_pool_init(reconstruct.has_resize(model), offset)
local criterion = create_criterion(model)
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
local x = remove_small_image(torch.load(settings.images))
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local adam_config = {
xLearningRate = settings.learning_rate,
xBatchSize = settings.batch_size,
xLearningRateDecay = settings.learning_rate_decay
xLearningRateDecay = settings.learning_rate_decay,
xInstanceLoss = (settings.oracle_rate > 0)
}
local ch = nil
if settings.color == "y" then
@ -324,7 +461,7 @@ local function train()
end
local best_score = 1000.0
print("# make validation-set")
local valid_xy = make_validation_set(valid_x, pairwise_func,
local valid_xy = make_validation_set(valid_x,
settings.validation_crops,
settings.patches)
valid_x = nil
@ -358,7 +495,7 @@ local function train()
if oracle_n > 0 then
local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n)
resampling(x:narrow(1, oracle_x:size(1) + 1, x:size(1)-oracle_x:size(1)),
y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x, pairwise_func)
y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x)
x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x)
y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y)
@ -374,7 +511,7 @@ local function train()
min = 0,
max = 1}))
else
resampling(x, y, train_x, pairwise_func)
resampling(x, y, train_x)
end
else
resampling(x, y, train_x, pairwise_func)
@ -395,9 +532,9 @@ local function train()
if settings.plot then
plot(hist_train, hist_valid)
end
if score.MSE < best_score then
if score.loss < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
best_score = score.MSE
best_score = score.loss
print("* model has updated")
if settings.save_history then
torch.save(settings.model_file_best, model:clearState(), "ascii")
@ -446,7 +583,7 @@ local function train()
end
end
end
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
collectgarbage()
end
end

View file

@ -267,6 +267,7 @@ local function waifu2x()
cmd:option("-tta_level", 8, 'TTA level (2|4|8). A higher value makes better quality output but slow')
cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
cmd:option("-q", 0, 'quiet (0|1)')
cmd:option("-gpu", 1, 'Device ID')
local opt = cmd:parse(arg)
if opt.method:len() > 0 then
@ -292,5 +293,6 @@ local function waifu2x()
else
convert_frames(opt)
end
cutorch.setDevice(opt.gpu)
end
waifu2x()