Merge branch 'dev'
This commit is contained in:
commit
d779a9d47a
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -11,6 +11,8 @@ models/*
|
|||
!models/ukbench
|
||||
!models/photo
|
||||
!models/upconv_7
|
||||
!models/upconv_7l
|
||||
!models/srresnet_12l
|
||||
!models/vgg_7
|
||||
models/*/*.png
|
||||
models/*/*/*.png
|
||||
|
|
|
@ -89,6 +89,7 @@ See: [Getting started with Torch](http://torch.ch/docs/getting-started.html)
|
|||
And install luarocks packages.
|
||||
```
|
||||
luarocks install graphicsmagick # upgrade
|
||||
luarocks install threads # upgrade
|
||||
luarocks install lua-csnappy
|
||||
luarocks install md5
|
||||
luarocks install uuid
|
||||
|
|
|
@ -1,45 +1,79 @@
|
|||
# Benchmark results
|
||||
# Benchmarks
|
||||
|
||||
Warning: This benchmark results is outdated. I will update soon.
|
||||
## Photo
|
||||
|
||||
## Usage
|
||||
Note: waifu2x's photo models was trained on the blending dataset of [kou's photo collection](http://photosku.com/photo/category/%E6%92%AE%E5%BD%B1%E8%80%85/kou/) and [ukbench](http://vis.uky.edu/~stewe/ukbench/).
|
||||
|
||||
```
|
||||
th tools/benchmark.lua -dir path/to/dataset_dir -method scale -color y -model1_dir path/to/model_dir
|
||||
```
|
||||
Note: PSNR in this benchmark uses a [MATLAB's rgb2ycbcr](https://jp.mathworks.com/help/images/ref/rgb2ycbcr.html?lang=en) compatible function (dynamic range [16 235], not [0 255]) for converting grayscale image. I think it's not correct PSNR. But many paper used this metric.
|
||||
|
||||
## Dataset
|
||||
command:
|
||||
`th tools/benchmark.lua -dir <dataset_dir> -model1_dir <model_dir> -method scale -filter Catrom -color y -range_bug 1 -tta <0|1> -force_cudnn 1`
|
||||
|
||||
photo_test: 300 various photos.
|
||||
art_test : 90 artworks (PNG only).
|
||||
### Datasets
|
||||
|
||||
## 2x upscaling model
|
||||
BSD100: https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/ (100 test images in BSDS300)
|
||||
Urban100: https://github.com/jbhuang0604/SelfExSR
|
||||
|
||||
| Dataset/Model | anime\_style\_art(Y) | anime\_style\_art\_rgb | photo | ukbench|
|
||||
|---------------|----------------------|------------------------|---------|--------|
|
||||
| photo\_test | 29.83 | 29.81 |**29.89**| 29.86 |
|
||||
| art\_test | 36.02 | **36.24**| 34.92 | 34.85 |
|
||||
### 2x - PSNR
|
||||
|
||||
The evaluation metric is PSNR(Y only), higher is better.
|
||||
| Dataset/Model | Bicubic | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo |
|
||||
|---------------|---------------|---------------|------------------|------------------|--------------------|
|
||||
| BSD100 | 29.558 | 31.427 | 31.640 | 31.749 | 31.847 |
|
||||
| Urban100 | 26.852 | 30.057 | 30.477 | 30.759 | 31.016 |
|
||||
|
||||
## Denosing level 1 model
|
||||
### 2x with TTA - PSNR
|
||||
|
||||
| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo |
|
||||
|--------------------------|-------------------|------------------------|---------|
|
||||
| photo\_test Quality 80 | 36.07 | **36.20**| 36.01 |
|
||||
| photo\_test Quality 50,45| 31.72 | 32.01 |**32.31**|
|
||||
| art\_test Quality 80 | 40.39 | **42.48**| 40.35 |
|
||||
| art\_test Quality 50,45 | 35.45 | **36.70**| 36.27 |
|
||||
Note: TTA is an ensemble technique that is supported by waifu2x. TTA method is 8x slower than non TTA method but it improves PSNR (~+0.1 on photo, ~+0.4 on art).
|
||||
|
||||
The evaluation metric is PSNR(RGB), higher is better.
|
||||
| Dataset/Model | Bicubic | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo |
|
||||
|---------------|---------------|---------------|------------------|------------------|--------------------|
|
||||
| BSD100 | 29.558 | 31.474 | 31.705 | 31.812 | 31.915 |
|
||||
| Urban100 | 26.852 | 30.140 | 30.599 | 30.868 | 31.162 |
|
||||
|
||||
## Denosing level 2 model
|
||||
### 2x - benchmark elapsed time (sec)
|
||||
|
||||
| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo |
|
||||
|--------------------------|-------------------|------------------------|---------|
|
||||
| photo\_test Quality 80 | 34.03 | 34.42 |**36.06**|
|
||||
| photo\_test Quality 50,45| 31.95 | 32.31 |**32.42**|
|
||||
| art\_test Quality 80 | 39.20 | **41.12**| 40.48 |
|
||||
| art\_test Quality 50,45 | 36.14 | **37.78**| 36.55 |
|
||||
| Dataset/Model | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo |
|
||||
|---------------|---------------|------------------|------------------|--------------------|
|
||||
| BSD100 | 4.057 | 2.509 | 4.947 | 6.86 |
|
||||
| Urban100 | 16.349 | 7.083 | 14.178 | 27.87 |
|
||||
|
||||
### 2x with TTA - benchmark elapsed time (sec)
|
||||
|
||||
| Dataset/Model | vgg\_7/photo | upconv\_7/photo | upconv\_7l/photo | resnet_14l/photo |
|
||||
|---------------|---------------|------------------|------------------|--------------------|
|
||||
| BSD100 | 36.611 | 20.219 | 42.486 | 60.38 |
|
||||
| Urban100 | 132.416 | 65.125 | 129.916 | 255.20 |
|
||||
|
||||
## Art
|
||||
|
||||
command:
|
||||
`th tools/benchmark.lua -dir <dataset_dir> -model1_dir <model_dir> -method scale -filter Lanczos -color y -range_bug 1 -tta <0|1> -force_cudnn 1`
|
||||
|
||||
### Dataset
|
||||
|
||||
art_test: This dataset contains 85 various fan-arts. Sorry, This dataset is private.
|
||||
|
||||
### 2x - PSNR
|
||||
|
||||
| Dataset/Model | Bicubic | vgg\_7/art | upconv\_7/art | upconv\_7l/art |
|
||||
|---------------|---------------|-------------|----------------|----------------|
|
||||
| art_test | 31.022 | 37.495 | 38.330 | 39.140 |
|
||||
|
||||
### 2x with TTA - PSNR
|
||||
|
||||
| Dataset/Model | Bicubic | vgg\_7/art | upconv\_7/art | upconv\_7l/art |
|
||||
|---------------|---------------|-------------|----------------|----------------|
|
||||
| art_test | 31.022 | 37.777 | 38.677 | 39.510 |
|
||||
|
||||
### 2x - benchmark elapsed time (sec)
|
||||
|
||||
| Dataset/Model | vgg\_7/art | upconv\_7/art | upconv\_7l/art |
|
||||
|---------------|-------------|----------------|----------------|
|
||||
| art_test | 20.681 | 7.683 | 17.667 |
|
||||
|
||||
### 2x with TTA - benchmark elapsed time (sec)
|
||||
|
||||
| Dataset/Model | vgg\_7/art | upconv\_7/art | upconv\_7l/art |
|
||||
|---------------|-------------|----------------|----------------|
|
||||
| art_test | 174.674 | 77.716 | 163.932 |
|
||||
|
||||
The evaluation metric is PSNR(RGB), higher is better.
|
||||
|
|
34
appendix/benchmark.sh
Executable file
34
appendix/benchmark.sh
Executable file
|
@ -0,0 +1,34 @@
|
|||
#!/bin/sh
|
||||
set -x
|
||||
|
||||
benchmark_photo() {
|
||||
dir=./benchmarks/${1}/${2}/${3}
|
||||
mkdir -p ${dir}
|
||||
th tools/benchmark.lua -dir data/${1} -model1\_dir models/${2}/photo -method scale -filter Catrom -color y -range\_bug 1 -tta ${3} -force_cudnn 1 -output_dir ${dir} -save_info 1 -show_progress 0
|
||||
}
|
||||
run_benchmark_photo() {
|
||||
for tta in 0 1
|
||||
do
|
||||
for dataset in bsd100 urban100
|
||||
do
|
||||
benchmark_photo ${dataset} vgg_7 ${tta}
|
||||
benchmark_photo ${dataset} upconv_7 ${tta}
|
||||
benchmark_photo ${dataset} upconv_7l ${tta}
|
||||
done
|
||||
done
|
||||
}
|
||||
benchmark_art() {
|
||||
dir=./benchmarks/${1}/${2}/${3}
|
||||
mkdir -p ${dir}
|
||||
th tools/benchmark.lua -dir data/${1} -model1\_dir models/${2}/art -method scale -filter Lanczos -color y -range\_bug 1 -tta ${3} -force_cudnn 1 -output_dir ${dir} -save_info 1 -show_progress 0
|
||||
}
|
||||
run_benchmark_art() {
|
||||
for tta in 0 1
|
||||
do
|
||||
benchmark_art art_test vgg_7 ${tta}
|
||||
benchmark_art art_test upconv_7 ${tta}
|
||||
benchmark_art art_test upconv_7l ${tta}
|
||||
done
|
||||
}
|
||||
#run_benchmark_photo
|
||||
run_benchmark_art
|
524
appendix/caffe_prototxt/resnet_14l.prototxt
Normal file
524
appendix/caffe_prototxt/resnet_14l.prototxt
Normal file
|
@ -0,0 +1,524 @@
|
|||
name: "resnet_14l"
|
||||
layer {
|
||||
name: "input"
|
||||
type: "Input"
|
||||
top: "input"
|
||||
input_param { shape: { dim: 1 dim: 3 dim: 156 dim: 156 } }
|
||||
}
|
||||
layer {
|
||||
name: "Convolution1"
|
||||
type: "Convolution"
|
||||
bottom: "input"
|
||||
top: "Convolution1"
|
||||
convolution_param {
|
||||
num_output: 32
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU1"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution1"
|
||||
top: "Convolution1"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution2"
|
||||
type: "Convolution"
|
||||
bottom: "Convolution1"
|
||||
top: "Convolution2"
|
||||
convolution_param {
|
||||
num_output: 64
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU2"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution2"
|
||||
top: "Convolution2"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution3"
|
||||
type: "Convolution"
|
||||
bottom: "Convolution2"
|
||||
top: "Convolution3"
|
||||
convolution_param {
|
||||
num_output: 64
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU3"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution3"
|
||||
top: "Convolution3"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution4"
|
||||
type: "Convolution"
|
||||
bottom: "Convolution1"
|
||||
top: "Convolution4"
|
||||
convolution_param {
|
||||
num_output: 64
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 1
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Crop1"
|
||||
type: "Crop"
|
||||
bottom: "Convolution4"
|
||||
bottom: "Convolution3"
|
||||
top: "Crop1"
|
||||
crop_param {
|
||||
axis: 2
|
||||
offset: 2
|
||||
offset: 2
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Eltwise1"
|
||||
type: "Eltwise"
|
||||
bottom: "Convolution3"
|
||||
bottom: "Crop1"
|
||||
top: "Eltwise1"
|
||||
eltwise_param {
|
||||
operation: SUM
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution5"
|
||||
type: "Convolution"
|
||||
bottom: "Eltwise1"
|
||||
top: "Convolution5"
|
||||
convolution_param {
|
||||
num_output: 64
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU4"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution5"
|
||||
top: "Convolution5"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution6"
|
||||
type: "Convolution"
|
||||
bottom: "Convolution5"
|
||||
top: "Convolution6"
|
||||
convolution_param {
|
||||
num_output: 64
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU5"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution6"
|
||||
top: "Convolution6"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Crop2"
|
||||
type: "Crop"
|
||||
bottom: "Eltwise1"
|
||||
bottom: "Convolution6"
|
||||
top: "Crop2"
|
||||
crop_param {
|
||||
axis: 2
|
||||
offset: 2
|
||||
offset: 2
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Eltwise2"
|
||||
type: "Eltwise"
|
||||
bottom: "Convolution6"
|
||||
bottom: "Crop2"
|
||||
top: "Eltwise2"
|
||||
eltwise_param {
|
||||
operation: SUM
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution7"
|
||||
type: "Convolution"
|
||||
bottom: "Eltwise2"
|
||||
top: "Convolution7"
|
||||
convolution_param {
|
||||
num_output: 128
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU6"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution7"
|
||||
top: "Convolution7"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution8"
|
||||
type: "Convolution"
|
||||
bottom: "Convolution7"
|
||||
top: "Convolution8"
|
||||
convolution_param {
|
||||
num_output: 128
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU7"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution8"
|
||||
top: "Convolution8"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution9"
|
||||
type: "Convolution"
|
||||
bottom: "Eltwise2"
|
||||
top: "Convolution9"
|
||||
convolution_param {
|
||||
num_output: 128
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 1
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Crop3"
|
||||
type: "Crop"
|
||||
bottom: "Convolution9"
|
||||
bottom: "Convolution8"
|
||||
top: "Crop3"
|
||||
crop_param {
|
||||
axis: 2
|
||||
offset: 2
|
||||
offset: 2
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Eltwise3"
|
||||
type: "Eltwise"
|
||||
bottom: "Convolution8"
|
||||
bottom: "Crop3"
|
||||
top: "Eltwise3"
|
||||
eltwise_param {
|
||||
operation: SUM
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution10"
|
||||
type: "Convolution"
|
||||
bottom: "Eltwise3"
|
||||
top: "Convolution10"
|
||||
convolution_param {
|
||||
num_output: 128
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU8"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution10"
|
||||
top: "Convolution10"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution11"
|
||||
type: "Convolution"
|
||||
bottom: "Convolution10"
|
||||
top: "Convolution11"
|
||||
convolution_param {
|
||||
num_output: 128
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU9"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution11"
|
||||
top: "Convolution11"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Crop4"
|
||||
type: "Crop"
|
||||
bottom: "Eltwise3"
|
||||
bottom: "Convolution11"
|
||||
top: "Crop4"
|
||||
crop_param {
|
||||
axis: 2
|
||||
offset: 2
|
||||
offset: 2
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Eltwise4"
|
||||
type: "Eltwise"
|
||||
bottom: "Convolution11"
|
||||
bottom: "Crop4"
|
||||
top: "Eltwise4"
|
||||
eltwise_param {
|
||||
operation: SUM
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution12"
|
||||
type: "Convolution"
|
||||
bottom: "Eltwise4"
|
||||
top: "Convolution12"
|
||||
convolution_param {
|
||||
num_output: 256
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU10"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution12"
|
||||
top: "Convolution12"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution13"
|
||||
type: "Convolution"
|
||||
bottom: "Convolution12"
|
||||
top: "Convolution13"
|
||||
convolution_param {
|
||||
num_output: 256
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU11"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution13"
|
||||
top: "Convolution13"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution14"
|
||||
type: "Convolution"
|
||||
bottom: "Eltwise4"
|
||||
top: "Convolution14"
|
||||
convolution_param {
|
||||
num_output: 256
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 1
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Crop5"
|
||||
type: "Crop"
|
||||
bottom: "Convolution14"
|
||||
bottom: "Convolution13"
|
||||
top: "Crop5"
|
||||
crop_param {
|
||||
axis: 2
|
||||
offset: 2
|
||||
offset: 2
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Eltwise5"
|
||||
type: "Eltwise"
|
||||
bottom: "Convolution13"
|
||||
bottom: "Crop5"
|
||||
top: "Eltwise5"
|
||||
eltwise_param {
|
||||
operation: SUM
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution15"
|
||||
type: "Convolution"
|
||||
bottom: "Eltwise5"
|
||||
top: "Convolution15"
|
||||
convolution_param {
|
||||
num_output: 256
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU12"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution15"
|
||||
top: "Convolution15"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Convolution16"
|
||||
type: "Convolution"
|
||||
bottom: "Convolution15"
|
||||
top: "Convolution16"
|
||||
convolution_param {
|
||||
num_output: 256
|
||||
bias_term: true
|
||||
pad: 0
|
||||
kernel_size: 3
|
||||
stride: 1
|
||||
weight_filler {
|
||||
type: "msra"
|
||||
}
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "ReLU13"
|
||||
type: "ReLU"
|
||||
bottom: "Convolution16"
|
||||
top: "Convolution16"
|
||||
relu_param {
|
||||
negative_slope: 0.1
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Crop6"
|
||||
type: "Crop"
|
||||
bottom: "Eltwise5"
|
||||
bottom: "Convolution16"
|
||||
top: "Crop6"
|
||||
crop_param {
|
||||
axis: 2
|
||||
offset: 2
|
||||
offset: 2
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Eltwise6"
|
||||
type: "Eltwise"
|
||||
bottom: "Convolution16"
|
||||
bottom: "Crop6"
|
||||
top: "Eltwise6"
|
||||
eltwise_param {
|
||||
operation: SUM
|
||||
}
|
||||
}
|
||||
layer {
|
||||
name: "Deconvolution1"
|
||||
type: "Deconvolution"
|
||||
bottom: "Eltwise6"
|
||||
top: "Deconvolution1"
|
||||
convolution_param {
|
||||
num_output: 3
|
||||
pad: 3
|
||||
kernel_size: 4
|
||||
stride: 2
|
||||
}
|
||||
}
|
|
@ -82,46 +82,50 @@ local function load_images(list)
|
|||
local skip = false
|
||||
local alpha_color = torch.random(0, 1)
|
||||
|
||||
if meta and meta.alpha then
|
||||
if settings.use_transparent_png then
|
||||
im = alpha_util.fill(im, meta.alpha, alpha_color)
|
||||
else
|
||||
skip = true
|
||||
end
|
||||
end
|
||||
if skip then
|
||||
if not skip_notice then
|
||||
io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n")
|
||||
skip_notice = true
|
||||
end
|
||||
else
|
||||
if csv_meta and csv_meta.x then
|
||||
-- method == user
|
||||
local yy = im
|
||||
local xx, meta2 = image_loader.load_byte(csv_meta.x)
|
||||
if meta2 and meta2.alpha then
|
||||
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
|
||||
if im then
|
||||
if meta and meta.alpha then
|
||||
if settings.use_transparent_png then
|
||||
im = alpha_util.fill(im, meta.alpha, alpha_color)
|
||||
else
|
||||
skip = true
|
||||
end
|
||||
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
|
||||
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
||||
{data = {filters = filters, has_x = true}}})
|
||||
else
|
||||
im = crop_if_large(im, settings.max_training_image_size)
|
||||
im = iproc.crop_mod4(im)
|
||||
local scale = 1.0
|
||||
if settings.random_half_rate > 0.0 then
|
||||
scale = 2.0
|
||||
end
|
||||
if skip then
|
||||
if not skip_notice then
|
||||
io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n")
|
||||
skip_notice = true
|
||||
end
|
||||
if im then
|
||||
else
|
||||
if csv_meta and csv_meta.x then
|
||||
-- method == user
|
||||
local yy = im
|
||||
local xx, meta2 = image_loader.load_byte(csv_meta.x)
|
||||
if xx then
|
||||
if meta2 and meta2.alpha then
|
||||
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
|
||||
end
|
||||
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
|
||||
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
||||
{data = {filters = filters, has_x = true}}})
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
|
||||
end
|
||||
else
|
||||
im = crop_if_large(im, settings.max_training_image_size)
|
||||
im = iproc.crop_mod4(im)
|
||||
local scale = 1.0
|
||||
if settings.random_half_rate > 0.0 then
|
||||
scale = 2.0
|
||||
end
|
||||
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
|
||||
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
|
||||
end
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", filename))
|
||||
end
|
||||
end
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", filename))
|
||||
end
|
||||
xlua.progress(i, #csv)
|
||||
if i % 10 == 0 then
|
||||
|
|
|
@ -5,12 +5,14 @@ function ClippedMSECriterion:__init(min, max)
|
|||
self.min = min
|
||||
self.max = max
|
||||
self.diff = torch.Tensor()
|
||||
self.diff_pow2 = torch.Tensor()
|
||||
end
|
||||
function ClippedMSECriterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
self.diff:clamp(self.min, self.max)
|
||||
self.diff:add(-1, target)
|
||||
self.output = self.diff:pow(2):sum() / input:nElement()
|
||||
self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
|
||||
self.output = self.diff_pow2:sum() / input:nElement()
|
||||
return self.output
|
||||
end
|
||||
function ClippedMSECriterion:updateGradInput(input, target)
|
||||
|
|
13
lib/InplaceClip01.lua
Normal file
13
lib/InplaceClip01.lua
Normal file
|
@ -0,0 +1,13 @@
|
|||
local Clip01, parent = torch.class("w2nn.InplaceClip01", "nn.Module")
|
||||
|
||||
function Clip01:__init()
|
||||
parent.__init(self)
|
||||
end
|
||||
function Clip01:updateOutput(input)
|
||||
self.output:set(input:clamp(0, 1))
|
||||
return self.output
|
||||
end
|
||||
function Clip01:updateGradInput(input, gradOutput)
|
||||
self.gradInput:set(gradOutput)
|
||||
return self.gradInput
|
||||
end
|
27
lib/L1Criterion.lua
Normal file
27
lib/L1Criterion.lua
Normal file
|
@ -0,0 +1,27 @@
|
|||
-- ref: https://en.wikipedia.org/wiki/L1_loss
|
||||
local L1Criterion, parent = torch.class('w2nn.L1Criterion','nn.Criterion')
|
||||
|
||||
function L1Criterion:__init()
|
||||
parent.__init(self)
|
||||
self.diff = torch.Tensor()
|
||||
self.linear_loss_buff = torch.Tensor()
|
||||
end
|
||||
function L1Criterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
if input:dim() == 1 then
|
||||
self.diff[1] = input[1] - target
|
||||
else
|
||||
for i = 1, input:size(1) do
|
||||
self.diff[i]:add(-1, target[i])
|
||||
end
|
||||
end
|
||||
local linear_targets = self.diff
|
||||
local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():sum()
|
||||
self.output = (linear_loss) / input:nElement()
|
||||
return self.output
|
||||
end
|
||||
function L1Criterion:updateGradInput(input, target)
|
||||
local norm = 1.0 / input:nElement()
|
||||
self.gradInput:resizeAs(self.diff):copy(self.diff):sign():mul(norm)
|
||||
return self.gradInput
|
||||
end
|
67
lib/SSIMCriterion.lua
Normal file
67
lib/SSIMCriterion.lua
Normal file
|
@ -0,0 +1,67 @@
|
|||
-- SSIM Index, ref: http://www.cns.nyu.edu/~lcv/ssim/ssim_index.m
|
||||
local SSIMCriterion, parent = torch.class('w2nn.SSIMCriterion','nn.Criterion')
|
||||
function SSIMCriterion:__init(ch, kernel_size, sigma)
|
||||
parent.__init(self)
|
||||
local function gaussian2d(kernel_size, sigma)
|
||||
sigma = sigma or 1
|
||||
local kernel = torch.Tensor(kernel_size, kernel_size)
|
||||
local u = math.floor(kernel_size / 2) + 1
|
||||
local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
|
||||
for x = 1, kernel_size do
|
||||
for y = 1, kernel_size do
|
||||
kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
|
||||
end
|
||||
end
|
||||
kernel:div(kernel:sum())
|
||||
return kernel
|
||||
end
|
||||
ch = ch or 1
|
||||
kernel_size = kernel_size or 11
|
||||
sigma = sigma or 1.5
|
||||
local kernel = gaussian2d(kernel_size, sigma)
|
||||
if ch > 1 then
|
||||
local kernel_nd = torch.Tensor(ch, ch, kernel_size, kernel_size)
|
||||
for i = 1, ch do
|
||||
for j = 1, ch do
|
||||
kernel_nd[i][j]:copy(kernel)
|
||||
if i ~= j then
|
||||
kernel_nd[i][j]:zero()
|
||||
end
|
||||
end
|
||||
end
|
||||
kernel = kernel_nd
|
||||
end
|
||||
self.c1 = 0.01^2
|
||||
self.c2 = 0.03^2
|
||||
self.ch = ch
|
||||
self.conv = nn.SpatialConvolution(ch, ch, kernel_size, kernel_size, 1, 1, 0, 0):noBias()
|
||||
self.conv.weight:copy(kernel)
|
||||
self.mu1 = torch.Tensor()
|
||||
self.mu2 = torch.Tensor()
|
||||
self.mu1_sq = torch.Tensor()
|
||||
self.mu2_sq = torch.Tensor()
|
||||
self.mu1_mu2 = torch.Tensor()
|
||||
self.sigma1_sq = torch.Tensor()
|
||||
self.sigma2_sq = torch.Tensor()
|
||||
self.sigma12 = torch.Tensor()
|
||||
self.ssim_map = torch.Tensor()
|
||||
end
|
||||
function SSIMCriterion:updateOutput(input, target)-- dynamic range: 0-1
|
||||
assert(input:nElement() == target:nElement())
|
||||
local valid = self.conv:forward(input)
|
||||
self.mu1:resizeAs(valid):copy(valid)
|
||||
self.mu2:resizeAs(valid):copy(self.conv:forward(target))
|
||||
self.mu1_sq:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu1)
|
||||
self.mu2_sq:resizeAs(self.mu2):copy(self.mu2):cmul(self.mu2)
|
||||
self.mu1_mu2:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu2)
|
||||
self.sigma1_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, input)):add(-1, self.mu1_sq))
|
||||
self.sigma2_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(target, target)):add(-1, self.mu2_sq))
|
||||
self.sigma12:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, target)):add(-1, self.mu1_mu2))
|
||||
|
||||
local ssim = self.mu1_mu2:mul(2):add(self.c1):cmul(self.sigma12:mul(2):add(self.c2)):
|
||||
cdiv(self.mu1_sq:add(self.mu2_sq):add(self.c1):cmul(self.sigma1_sq:add(self.sigma2_sq):add(self.c2))):mean()
|
||||
return ssim
|
||||
end
|
||||
function SSIMCriterion:updateGradInput(input, target)
|
||||
error("not implemented")
|
||||
end
|
|
@ -40,8 +40,7 @@ function alpha_util.make_border(rgb, alpha, offset)
|
|||
collectgarbage()
|
||||
end
|
||||
end
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
rgb:clamp(0.0, 1.0)
|
||||
|
||||
return rgb
|
||||
end
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
require 'image'
|
||||
require 'pl'
|
||||
require 'cunn'
|
||||
local iproc = require 'iproc'
|
||||
local gm = require 'graphicsmagick'
|
||||
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
local data_augmentation = {}
|
||||
|
||||
local function pcacov(x)
|
||||
|
@ -25,8 +26,7 @@ function data_augmentation.color_noise(src, p, factor)
|
|||
pca_space[i]:mul(color_scale[i])
|
||||
end
|
||||
local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
|
||||
dest[torch.lt(dest, 0.0)] = 0.0
|
||||
dest[torch.gt(dest, 1.0)] = 1.0
|
||||
dest:clamp(0.0, 1.0)
|
||||
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
|
@ -70,6 +70,75 @@ function data_augmentation.unsharp_mask(src, p)
|
|||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
|
||||
size = size or "3"
|
||||
filters = utils.split(size, ",")
|
||||
for i = 1, #filters do
|
||||
local s = tonumber(filters[i])
|
||||
filters[i] = s
|
||||
end
|
||||
if torch.uniform() < p then
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
local kernel_size = filters[torch.random(1, #filters)]
|
||||
local sigma
|
||||
if sigma_min == sigma_max then
|
||||
sigma = sigma_min
|
||||
else
|
||||
sigma = torch.uniform(sigma_min, sigma_max)
|
||||
end
|
||||
local kernel = iproc.gaussian2d(kernel_size, sigma)
|
||||
local dest = image.convolve(src, kernel, 'same')
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
|
||||
if torch.uniform() < p then
|
||||
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
local scale = torch.uniform(scale_min, scale_max)
|
||||
local h = math.floor(x:size(2) * scale)
|
||||
local w = math.floor(x:size(3) * scale)
|
||||
x = iproc.scale(x, w, h, "Triangle")
|
||||
y = iproc.scale(y, w, h, "Triangle")
|
||||
return x, y
|
||||
else
|
||||
return x, y
|
||||
end
|
||||
end
|
||||
function data_augmentation.pairwise_rotate(x, y, p, r_min, r_max)
|
||||
if torch.uniform() < p then
|
||||
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
local r = torch.uniform(r_min, r_max) / 360.0 * math.pi
|
||||
x = iproc.rotate(x, r)
|
||||
y = iproc.rotate(y, r)
|
||||
return x, y
|
||||
else
|
||||
return x, y
|
||||
end
|
||||
end
|
||||
function data_augmentation.pairwise_negate(x, y, p)
|
||||
if torch.uniform() < p then
|
||||
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
x = iproc.negate(x)
|
||||
y = iproc.negate(y)
|
||||
return x, y
|
||||
else
|
||||
return x, y
|
||||
end
|
||||
end
|
||||
function data_augmentation.pairwise_negate_x(x, y, p)
|
||||
if torch.uniform() < p then
|
||||
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
x = iproc.negate(x)
|
||||
return x, y
|
||||
else
|
||||
return x, y
|
||||
end
|
||||
end
|
||||
function data_augmentation.shift_1px(src)
|
||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||
local direction = torch.random(1, 4)
|
||||
|
@ -107,11 +176,11 @@ function data_augmentation.flip(src)
|
|||
src = src:transpose(2, 3):contiguous()
|
||||
end
|
||||
if flip == 1 then
|
||||
dest = image.hflip(src)
|
||||
dest = iproc.hflip(src)
|
||||
elseif flip == 2 then
|
||||
dest = image.vflip(src)
|
||||
dest = iproc.vflip(src)
|
||||
elseif flip == 3 then
|
||||
dest = image.hflip(image.vflip(src))
|
||||
dest = iproc.hflip(iproc.vflip(src))
|
||||
elseif flip == 4 then
|
||||
dest = src
|
||||
end
|
||||
|
@ -120,4 +189,20 @@ function data_augmentation.flip(src)
|
|||
end
|
||||
return dest
|
||||
end
|
||||
|
||||
local function test_blur()
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local image =require 'image'
|
||||
local src = image.lena()
|
||||
|
||||
image.display({image = src, min=0, max=1})
|
||||
local dest = data_augmentation.blur(src, 1.0, "3,5", 0.5, 0.6)
|
||||
image.display({image = dest, min=0, max=1})
|
||||
dest = data_augmentation.blur(src, 1.0, "3", 1.0, 1.0)
|
||||
image.display({image = dest, min=0, max=1})
|
||||
dest = data_augmentation.blur(src, 1.0, "5", 0.75, 0.75)
|
||||
image.display({image = dest, min=0, max=1})
|
||||
end
|
||||
--test_blur()
|
||||
|
||||
return data_augmentation
|
||||
|
|
|
@ -22,8 +22,7 @@ function image_loader.encode_png(rgb, options)
|
|||
else
|
||||
rgb = rgb:clone():add(clip_eps8)
|
||||
end
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
rgb:clamp(0.0, 1.0)
|
||||
rgb = rgb:mul(255):floor():div(255)
|
||||
else
|
||||
if options.inplace then
|
||||
|
@ -31,8 +30,7 @@ function image_loader.encode_png(rgb, options)
|
|||
else
|
||||
rgb = rgb:clone():add(clip_eps16)
|
||||
end
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
rgb:clamp(0.0, 1.0)
|
||||
rgb = rgb:mul(65535):floor():div(65535)
|
||||
end
|
||||
local im
|
||||
|
|
148
lib/iproc.lua
148
lib/iproc.lua
|
@ -1,6 +1,7 @@
|
|||
local gm = require 'graphicsmagick'
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
require 'dok'
|
||||
local image = require 'image'
|
||||
|
||||
local iproc = {}
|
||||
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
||||
|
||||
|
@ -42,8 +43,7 @@ function iproc.float2byte(src)
|
|||
if src:type() == "torch.FloatTensor" then
|
||||
conversion = true
|
||||
dest = (src + clip_eps8):mul(255.0)
|
||||
dest[torch.lt(dest, 0.0)] = 0
|
||||
dest[torch.gt(dest, 255.0)] = 255.0
|
||||
dest:clamp(0, 255.0)
|
||||
dest = dest:byte()
|
||||
end
|
||||
return dest, conversion
|
||||
|
@ -80,6 +80,7 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
|
|||
return dest
|
||||
end
|
||||
function iproc.padding(img, w1, w2, h1, h2)
|
||||
image = image or require 'image'
|
||||
local dst_height = img:size(2) + h1 + h2
|
||||
local dst_width = img:size(3) + w1 + w2
|
||||
local flow = torch.Tensor(2, dst_height, dst_width)
|
||||
|
@ -90,6 +91,7 @@ function iproc.padding(img, w1, w2, h1, h2)
|
|||
return image.warp(img, flow, "simple", false, "clamp")
|
||||
end
|
||||
function iproc.zero_padding(img, w1, w2, h1, h2)
|
||||
image = image or require 'image'
|
||||
local dst_height = img:size(2) + h1 + h2
|
||||
local dst_width = img:size(3) + w1 + w2
|
||||
local flow = torch.Tensor(2, dst_height, dst_width)
|
||||
|
@ -115,8 +117,7 @@ function iproc.white_noise(src, std, rgb_weights, gamma)
|
|||
local dest
|
||||
if gamma ~= 0 then
|
||||
dest = src:clone():pow(gamma):add(noise)
|
||||
dest[torch.lt(dest, 0.0)] = 0.0
|
||||
dest[torch.gt(dest, 1.0)] = 1.0
|
||||
dest:clamp(0.0, 1.0)
|
||||
dest:pow(1.0 / gamma)
|
||||
else
|
||||
dest = src + noise
|
||||
|
@ -126,6 +127,101 @@ function iproc.white_noise(src, std, rgb_weights, gamma)
|
|||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.hflip(src)
|
||||
local t
|
||||
if src:type() == "torch.ByteTensor" then
|
||||
t = "byte"
|
||||
else
|
||||
t = "float"
|
||||
end
|
||||
if src:size(1) == 3 then
|
||||
color = "RGB"
|
||||
else
|
||||
color = "I"
|
||||
end
|
||||
local im = gm.Image(src, color, "DHW")
|
||||
return im:flop():toTensor(t, color, "DHW")
|
||||
end
|
||||
function iproc.vflip(src)
|
||||
local t
|
||||
if src:type() == "torch.ByteTensor" then
|
||||
t = "byte"
|
||||
else
|
||||
t = "float"
|
||||
end
|
||||
if src:size(1) == 3 then
|
||||
color = "RGB"
|
||||
else
|
||||
color = "I"
|
||||
end
|
||||
local im = gm.Image(src, color, "DHW")
|
||||
return im:flip():toTensor(t, color, "DHW")
|
||||
end
|
||||
local function rotate_with_warp(src, dst, theta, mode)
|
||||
local height
|
||||
local width
|
||||
if src:dim() == 2 then
|
||||
height = src:size(1)
|
||||
width = src:size(2)
|
||||
elseif src:dim() == 3 then
|
||||
height = src:size(2)
|
||||
width = src:size(3)
|
||||
else
|
||||
dok.error('src image must be 2D or 3D', 'image.rotate')
|
||||
end
|
||||
local flow = torch.Tensor(2, height, width)
|
||||
local kernel = torch.Tensor({{math.cos(-theta), -math.sin(-theta)},
|
||||
{math.sin(-theta), math.cos(-theta)}})
|
||||
flow[1] = torch.ger(torch.linspace(0, 1, height), torch.ones(width))
|
||||
flow[1]:mul(-(height -1)):add(math.floor(height / 2 + 0.5))
|
||||
flow[2] = torch.ger(torch.ones(height), torch.linspace(0, 1, width))
|
||||
flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
|
||||
flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
|
||||
dst:resizeAs(src)
|
||||
return image.warp(dst, src, flow, mode, true, 'clamp')
|
||||
end
|
||||
function iproc.rotate(src, theta)
|
||||
local conversion
|
||||
src, conversion = iproc.byte2float(src)
|
||||
local dest = torch.Tensor():typeAs(src):resizeAs(src)
|
||||
rotate_with_warp(src, dest, theta, 'bilinear')
|
||||
dest:clamp(0, 1)
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.negate(src)
|
||||
if src:type() == "torch.ByteTensor" then
|
||||
return -src + 255
|
||||
else
|
||||
return -src + 1
|
||||
end
|
||||
end
|
||||
|
||||
function iproc.gaussian2d(kernel_size, sigma)
|
||||
sigma = sigma or 1
|
||||
local kernel = torch.Tensor(kernel_size, kernel_size)
|
||||
local u = math.floor(kernel_size / 2) + 1
|
||||
local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
|
||||
for x = 1, kernel_size do
|
||||
for y = 1, kernel_size do
|
||||
kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
|
||||
end
|
||||
end
|
||||
kernel:div(kernel:sum())
|
||||
return kernel
|
||||
end
|
||||
function iproc.rgb2y(src)
|
||||
local conversion
|
||||
src, conversion = iproc.byte2float(src)
|
||||
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
|
||||
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
|
||||
local function test_conversion()
|
||||
local a = torch.linspace(0, 255, 256):float():div(255.0)
|
||||
|
@ -144,6 +240,46 @@ local function test_conversion()
|
|||
print(b)
|
||||
assert(b:float():sum() == 254.0 * 3)
|
||||
end
|
||||
local function test_flip()
|
||||
require 'sys'
|
||||
require 'torch'
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
image = require 'image'
|
||||
local src = image.lena()
|
||||
local src_byte = src:clone():mul(255):byte()
|
||||
|
||||
print(src:size())
|
||||
print((image.hflip(src) - iproc.hflip(src)):sum())
|
||||
print((image.hflip(src_byte) - iproc.hflip(src_byte)):sum())
|
||||
print((image.vflip(src) - iproc.vflip(src)):sum())
|
||||
print((image.vflip(src_byte) - iproc.vflip(src_byte)):sum())
|
||||
end
|
||||
local function test_gaussian2d()
|
||||
local t = {3, 5, 7}
|
||||
for i = 1, #t do
|
||||
local kp = iproc.gaussian2d(t[i], 0.5)
|
||||
print(kp)
|
||||
end
|
||||
end
|
||||
local function test_conv()
|
||||
local image = require 'image'
|
||||
local src = image.lena()
|
||||
local kernel = torch.Tensor(3, 3):fill(1)
|
||||
kernel:div(kernel:sum())
|
||||
--local blur = image.convolve(iproc.padding(src, 1, 1, 1, 1), kernel, 'valid')
|
||||
local blur = image.convolve(src, kernel, 'same')
|
||||
print(src:size(), blur:size())
|
||||
local diff = (blur - src):abs()
|
||||
image.save("diff.png", diff)
|
||||
image.display({image = blur, min=0, max=1})
|
||||
image.display({image = diff, min=0, max=1})
|
||||
end
|
||||
|
||||
--test_conversion()
|
||||
--test_flip()
|
||||
--test_gaussian2d()
|
||||
--test_conv()
|
||||
|
||||
return iproc
|
||||
|
||||
|
||||
|
|
|
@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric,
|
|||
local output = model:forward(inputs)
|
||||
local f = criterion:forward(output, targets)
|
||||
local se = 0
|
||||
for i = 1, batch_size do
|
||||
local el = eval_metric:forward(output[i], targets[i])
|
||||
se = se + el
|
||||
instance_loss[shuffle[t + i - 1]] = el
|
||||
if config.xInstanceLoss then
|
||||
for i = 1, batch_size do
|
||||
local el = eval_metric:forward(output[i], targets[i])
|
||||
se = se + el
|
||||
instance_loss[shuffle[t + i - 1]] = el
|
||||
end
|
||||
se = (se / batch_size)
|
||||
else
|
||||
se = eval_metric:forward(output, targets)
|
||||
end
|
||||
sum_eval = sum_eval + (se / batch_size)
|
||||
sum_eval = sum_eval + se
|
||||
sum_loss = sum_loss + f
|
||||
count_loss = count_loss + 1
|
||||
model:backward(inputs, criterion:backward(output, targets))
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
local pairwise_utils = require 'pairwise_transform_utils'
|
||||
local gm = require 'graphicsmagick'
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
local iproc = require 'iproc'
|
||||
local pairwise_transform = {}
|
||||
|
||||
|
@ -42,8 +43,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
|||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
if torch.uniform() < options.nr_rate then
|
||||
-- reducing noise
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
local pairwise_utils = require 'pairwise_transform_utils'
|
||||
local iproc = require 'iproc'
|
||||
local gm = require 'graphicsmagick'
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
local pairwise_transform = {}
|
||||
|
||||
local function add_jpeg_noise_(x, quality, options)
|
||||
|
@ -117,8 +118,8 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
|
|||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
local pairwise_utils = require 'pairwise_transform_utils'
|
||||
local iproc = require 'iproc'
|
||||
local gm = require 'graphicsmagick'
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
local pairwise_transform = {}
|
||||
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
|
@ -50,8 +51,8 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
|
|
|
@ -1,43 +1,30 @@
|
|||
local pairwise_utils = require 'pairwise_transform_utils'
|
||||
local iproc = require 'iproc'
|
||||
local gm = require 'graphicsmagick'
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
local pairwise_transform = {}
|
||||
|
||||
local function crop_if_large(x, y, scale_y, max_size, mod)
|
||||
local tries = 4
|
||||
if y:size(2) > max_size and y:size(3) > max_size then
|
||||
assert(max_size % 4 == 0)
|
||||
local rect_x, rect_y
|
||||
for i = 1, tries do
|
||||
local yi = torch.random(0, y:size(2) - max_size)
|
||||
local xi = torch.random(0, y:size(3) - max_size)
|
||||
if mod then
|
||||
yi = yi - (yi % mod)
|
||||
xi = xi - (xi % mod)
|
||||
end
|
||||
rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
|
||||
rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
|
||||
-- ignore simple background
|
||||
if rect_y:float():std() >= 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
return rect_x, rect_y
|
||||
else
|
||||
return x, y
|
||||
end
|
||||
end
|
||||
function pairwise_transform.user(x, y, size, offset, n, options)
|
||||
assert(x:size(1) == y:size(1))
|
||||
|
||||
local scale_y = y:size(2) / x:size(2)
|
||||
assert(x:size(3) == y:size(3) / scale_y)
|
||||
|
||||
x, y = crop_if_large(x, y, scale_y, options.max_size, scale_y)
|
||||
x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
|
||||
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
|
||||
local batch = {}
|
||||
local lowres_y = pairwise_utils.low_resolution(y)
|
||||
local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
||||
local lowres_y = nil
|
||||
local xs ={x}
|
||||
local ys = {y}
|
||||
local ls = {}
|
||||
|
||||
if options.active_cropping_rate > 0 then
|
||||
lowres_y = pairwise_utils.low_resolution(y)
|
||||
end
|
||||
if options.pairwise_flip then
|
||||
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
||||
end
|
||||
assert(#xs == #ys)
|
||||
for i = 1, n do
|
||||
local t = (i % #xs) + 1
|
||||
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
|
||||
|
@ -47,8 +34,17 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
|||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
if options.gcn then
|
||||
local mean = xc:mean()
|
||||
local stdv = xc:std()
|
||||
if stdv > 0 then
|
||||
xc:add(-mean):div(stdv)
|
||||
else
|
||||
xc:add(-mean)
|
||||
end
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
require 'image'
|
||||
require 'cunn'
|
||||
local iproc = require 'iproc'
|
||||
local gm = require 'graphicsmagick'
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
local data_augmentation = require 'data_augmentation'
|
||||
local pairwise_transform_utils = {}
|
||||
|
||||
|
@ -36,6 +36,30 @@ function pairwise_transform_utils.crop_if_large(src, max_size, mod)
|
|||
return src
|
||||
end
|
||||
end
|
||||
function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
|
||||
local tries = 4
|
||||
if y:size(2) > max_size and y:size(3) > max_size then
|
||||
assert(max_size % 4 == 0)
|
||||
local rect_x, rect_y
|
||||
for i = 1, tries do
|
||||
local yi = torch.random(0, y:size(2) - max_size)
|
||||
local xi = torch.random(0, y:size(3) - max_size)
|
||||
if mod then
|
||||
yi = yi - (yi % mod)
|
||||
xi = xi - (xi % mod)
|
||||
end
|
||||
rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
|
||||
rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
|
||||
-- ignore simple background
|
||||
if rect_y:float():std() >= 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
return rect_x, rect_y
|
||||
else
|
||||
return x, y
|
||||
end
|
||||
end
|
||||
function pairwise_transform_utils.preprocess(src, crop_size, options)
|
||||
local dest = src
|
||||
local box_only = false
|
||||
|
@ -47,7 +71,6 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
|
|||
if box_only then
|
||||
local mod = 2 -- assert pos % 2 == 0
|
||||
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
|
@ -55,7 +78,10 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
|
|||
else
|
||||
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
|
||||
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.blur(dest, options.random_blur_rate,
|
||||
options.random_blur_size,
|
||||
options.random_blur_sigma_min,
|
||||
options.random_blur_sigma_max)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
|
@ -63,6 +89,33 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
|
|||
end
|
||||
return dest
|
||||
end
|
||||
function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
|
||||
|
||||
x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
|
||||
x, y = data_augmentation.pairwise_rotate(x, y,
|
||||
options.random_pairwise_rotate_rate,
|
||||
options.random_pairwise_rotate_min,
|
||||
options.random_pairwise_rotate_max)
|
||||
|
||||
local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
|
||||
local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
|
||||
x, y = data_augmentation.pairwise_scale(x, y,
|
||||
options.random_pairwise_scale_rate,
|
||||
scale_min,
|
||||
scale_max)
|
||||
x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
|
||||
x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
|
||||
|
||||
x = iproc.crop_mod4(x)
|
||||
y = iproc.crop_mod4(y)
|
||||
|
||||
if options.pairwise_y_binary then
|
||||
y[torch.lt(y, 128)] = 0
|
||||
y[torch.gt(y, 0)] = 255
|
||||
end
|
||||
|
||||
return x, y
|
||||
end
|
||||
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
||||
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
|
||||
assert("crop_size % scale == 0", size % scale == 0)
|
||||
|
@ -111,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
|||
|
||||
for j = 1, 2 do
|
||||
-- TTA
|
||||
local xi, yi, ri
|
||||
local xi, yi, ri, ni
|
||||
if j == 1 then
|
||||
xi = x
|
||||
ni = x_noise
|
||||
|
@ -123,42 +176,55 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
|||
ni = x_noise:transpose(2, 3):contiguous()
|
||||
end
|
||||
yi = y:transpose(2, 3):contiguous()
|
||||
ri = lowres_y:transpose(2, 3):contiguous()
|
||||
if lowres_y then
|
||||
ri = lowres_y:transpose(2, 3):contiguous()
|
||||
end
|
||||
end
|
||||
local xv = image.vflip(xi)
|
||||
local xv = iproc.vflip(xi)
|
||||
local nv
|
||||
if x_noise then
|
||||
nv = image.vflip(ni)
|
||||
nv = iproc.vflip(ni)
|
||||
end
|
||||
local yv = iproc.vflip(yi)
|
||||
local rv
|
||||
if ri then
|
||||
rv = iproc.vflip(ri)
|
||||
end
|
||||
local yv = image.vflip(yi)
|
||||
local rv = image.vflip(ri)
|
||||
table.insert(xs, xi)
|
||||
if ni then
|
||||
table.insert(ns, ni)
|
||||
end
|
||||
table.insert(ys, yi)
|
||||
table.insert(ls, ri)
|
||||
if ri then
|
||||
table.insert(ls, ri)
|
||||
end
|
||||
|
||||
table.insert(xs, xv)
|
||||
if nv then
|
||||
table.insert(ns, nv)
|
||||
end
|
||||
table.insert(ys, yv)
|
||||
table.insert(ls, rv)
|
||||
if rv then
|
||||
table.insert(ls, rv)
|
||||
end
|
||||
|
||||
table.insert(xs, image.hflip(xi))
|
||||
table.insert(xs, iproc.hflip(xi))
|
||||
if ni then
|
||||
table.insert(ns, image.hflip(ni))
|
||||
table.insert(ns, iproc.hflip(ni))
|
||||
end
|
||||
table.insert(ys, iproc.hflip(yi))
|
||||
if ri then
|
||||
table.insert(ls, iproc.hflip(ri))
|
||||
end
|
||||
table.insert(ys, image.hflip(yi))
|
||||
table.insert(ls, image.hflip(ri))
|
||||
|
||||
table.insert(xs, image.hflip(xv))
|
||||
table.insert(xs, iproc.hflip(xv))
|
||||
if nv then
|
||||
table.insert(ns, image.hflip(nv))
|
||||
table.insert(ns, iproc.hflip(nv))
|
||||
end
|
||||
table.insert(ys, iproc.hflip(yv))
|
||||
if rv then
|
||||
table.insert(ls, iproc.hflip(rv))
|
||||
end
|
||||
table.insert(ys, image.hflip(yv))
|
||||
table.insert(ls, image.hflip(rv))
|
||||
end
|
||||
return xs, ys, ls, ns
|
||||
end
|
||||
|
@ -171,6 +237,9 @@ end
|
|||
local g_lowres_model = nil
|
||||
local g_lowres_gpu = nil
|
||||
function pairwise_transform_utils.low_resolution(src)
|
||||
--[[
|
||||
-- I am not sure that the following process is thraed-safe
|
||||
|
||||
g_lowres_model = g_lowres_model or lowres_model()
|
||||
if g_lowres_gpu == nil then
|
||||
--benchmark
|
||||
|
@ -203,6 +272,11 @@ function pairwise_transform_utils.low_resolution(src)
|
|||
size(src:size(3), src:size(2), "Box"):
|
||||
toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
--]]
|
||||
return gm.Image(src, "RGB", "DHW"):
|
||||
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
||||
size(src:size(3), src:size(2), "Box"):
|
||||
toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
|
||||
return pairwise_transform_utils
|
||||
|
|
|
@ -40,6 +40,15 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
|
|||
break
|
||||
end
|
||||
input[j+1]:copy(x[input_indexes[i + j]])
|
||||
if model.w2nn_gcn then
|
||||
local mean = input[j + 1]:mean()
|
||||
local stdv = input[j + 1]:std()
|
||||
if stdv > 0 then
|
||||
input[j + 1]:add(-mean):div(stdv)
|
||||
else
|
||||
input[j + 1]:add(-mean)
|
||||
end
|
||||
end
|
||||
c = c + 1
|
||||
end
|
||||
input_cuda:copy(input)
|
||||
|
@ -80,7 +89,12 @@ local function padding_params(x, model, block_size)
|
|||
p.x_w = x:size(3)
|
||||
p.x_h = x:size(2)
|
||||
p.inner_scale = reconstruct.inner_scale(model)
|
||||
local input_offset = math.ceil(offset / p.inner_scale)
|
||||
local input_offset
|
||||
if model.w2nn_input_offset then
|
||||
input_offset = model.w2nn_input_offset
|
||||
else
|
||||
input_offset = math.ceil(offset / p.inner_scale)
|
||||
end
|
||||
local input_block_size = block_size
|
||||
local process_size = input_block_size - input_offset * 2
|
||||
local h_blocks = math.floor(p.x_h / process_size) +
|
||||
|
@ -172,6 +186,9 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
|||
return output
|
||||
end
|
||||
function reconstruct.image(model, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
local i2rgb = false
|
||||
if x:size(1) == 1 then
|
||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||
|
@ -194,6 +211,9 @@ function reconstruct.image(model, x, block_size)
|
|||
return x
|
||||
end
|
||||
function reconstruct.scale(model, scale, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
local i2rgb = false
|
||||
if x:size(1) == 1 then
|
||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||
|
@ -287,6 +307,9 @@ local function tta(f, n, model, x, block_size)
|
|||
return average:div(#augments)
|
||||
end
|
||||
function reconstruct.image_tta(model, n, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
if reconstruct.is_rgb(model) then
|
||||
return tta(reconstruct.image_rgb, n, model, x, block_size)
|
||||
else
|
||||
|
@ -294,6 +317,9 @@ function reconstruct.image_tta(model, n, x, block_size)
|
|||
end
|
||||
end
|
||||
function reconstruct.scale_tta(model, n, scale, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
if reconstruct.is_rgb(model) then
|
||||
local f = function (model, x, offset, block_size)
|
||||
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
require 'xlua'
|
||||
require 'pl'
|
||||
require 'trepl'
|
||||
require 'cutorch'
|
||||
|
||||
-- global settings
|
||||
|
||||
|
@ -18,7 +19,7 @@ cmd:text()
|
|||
cmd:text("waifu2x-training")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-gpu", -1, 'GPU Device ID')
|
||||
cmd:option("-seed", 11, 'RNG seed')
|
||||
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
|
||||
cmd:option("-data_dir", "./data", 'path to data directory')
|
||||
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
||||
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
||||
|
@ -32,6 +33,20 @@ cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise
|
|||
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
|
||||
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
|
||||
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
|
||||
cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0-1.0)')
|
||||
cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
|
||||
cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
|
||||
cmd:option("-random_blur_sigma_max", 1.0, 'max sigma for random gaussian blur')
|
||||
cmd:option("-random_pairwise_scale_rate", 0.0, 'data augmentation using pairwise resize for user method')
|
||||
cmd:option("-random_pairwise_scale_min", 0.85, 'min scale factor for random pairwise scale')
|
||||
cmd:option("-random_pairwise_scale_max", 1.176, 'max scale factor for random pairwise scale')
|
||||
cmd:option("-random_pairwise_rotate_rate", 0.0, 'data augmentation using pairwise resize for user method')
|
||||
cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairwise rotate')
|
||||
cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
|
||||
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
|
||||
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
|
||||
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
|
||||
cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
|
||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-crop_size", 48, 'crop size')
|
||||
|
@ -59,6 +74,8 @@ cmd:option("-oracle_drop_rate", 0.5, '')
|
|||
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
||||
cmd:option("-resume", "", 'resume model file')
|
||||
cmd:option("-name", "user", 'model name for user method')
|
||||
cmd:option("-gpu", 1, 'Device ID')
|
||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
@ -75,6 +92,8 @@ end
|
|||
to_bool(settings, "plot")
|
||||
to_bool(settings, "save_history")
|
||||
to_bool(settings, "use_transparent_png")
|
||||
to_bool(settings, "pairwise_y_binary")
|
||||
to_bool(settings, "pairwise_flip")
|
||||
|
||||
if settings.plot then
|
||||
require 'gnuplot'
|
||||
|
@ -148,4 +167,6 @@ end
|
|||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||
|
||||
cutorch.setDevice(opt.gpu)
|
||||
|
||||
return settings
|
||||
|
|
274
lib/srcnn.lua
274
lib/srcnn.lua
|
@ -136,6 +136,24 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
local function ReLU(backend)
|
||||
if backend == "cunn" then
|
||||
return nn.ReLU(true)
|
||||
elseif backend == "cudnn" then
|
||||
return cudnn.ReLU(true)
|
||||
else
|
||||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
||||
elseif backend == "cudnn" then
|
||||
return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
||||
else
|
||||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
|
||||
-- VGG style net(7 layers)
|
||||
function srcnn.vgg_7(backend, ch)
|
||||
|
@ -153,6 +171,7 @@ function srcnn.vgg_7(backend, ch)
|
|||
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "vgg_7"
|
||||
|
@ -190,6 +209,7 @@ function srcnn.vgg_12(backend, ch)
|
|||
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "vgg_12"
|
||||
|
@ -219,6 +239,7 @@ function srcnn.dilated_7(backend, ch)
|
|||
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "dilated_7"
|
||||
|
@ -249,6 +270,7 @@ function srcnn.upconv_7(backend, ch)
|
|||
model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "upconv_7"
|
||||
|
@ -257,11 +279,255 @@ function srcnn.upconv_7(backend, ch)
|
|||
model.w2nn_resize = true
|
||||
model.w2nn_channels = ch
|
||||
|
||||
return model
|
||||
end
|
||||
|
||||
-- large version of upconv_7
|
||||
-- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
|
||||
function srcnn.upconv_7l(backend, ch)
|
||||
local model = nn.Sequential()
|
||||
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "upconv_7l"
|
||||
model.w2nn_offset = 14
|
||||
model.w2nn_scale_factor = 2
|
||||
model.w2nn_resize = true
|
||||
model.w2nn_channels = ch
|
||||
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
|
||||
-- layerwise linear blending with skip connections
|
||||
-- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l
|
||||
function srcnn.skiplb_7(backend, ch)
|
||||
local function skip(backend, i, o)
|
||||
local con = nn.Concat(2)
|
||||
local conv = nn.Sequential()
|
||||
conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1))
|
||||
conv:add(nn.LeakyReLU(0.1, true))
|
||||
|
||||
-- depth concat
|
||||
con:add(conv)
|
||||
con:add(nn.Identity()) -- skip
|
||||
return con
|
||||
end
|
||||
local model = nn.Sequential()
|
||||
model:add(skip(backend, ch, 16))
|
||||
model:add(skip(backend, 16+ch, 32))
|
||||
model:add(skip(backend, 32+16+ch, 64))
|
||||
model:add(skip(backend, 64+32+16+ch, 128))
|
||||
model:add(skip(backend, 128+64+32+16+ch, 128))
|
||||
model:add(skip(backend, 128+128+64+32+16+ch, 256))
|
||||
-- input of last layer = [all layerwise output(contains input layer)].flatten
|
||||
model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
model.w2nn_arch_name = "skiplb_7"
|
||||
model.w2nn_offset = 14
|
||||
model.w2nn_scale_factor = 2
|
||||
model.w2nn_resize = true
|
||||
model.w2nn_channels = ch
|
||||
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
|
||||
-- dilated convolution + deconvolution
|
||||
-- Note: This model is not better than upconv_7. Maybe becuase of under-fitting.
|
||||
function srcnn.dilated_upconv_7(backend, ch)
|
||||
local model = nn.Sequential()
|
||||
model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "dilated_upconv_7"
|
||||
model.w2nn_offset = 20
|
||||
model.w2nn_scale_factor = 2
|
||||
model.w2nn_resize = true
|
||||
model.w2nn_channels = ch
|
||||
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
|
||||
-- ref: https://arxiv.org/abs/1609.04802
|
||||
-- note: no batch-norm, no zero-paading
|
||||
function srcnn.srresnet_2x(backend, ch)
|
||||
local function resblock(backend)
|
||||
local seq = nn.Sequential()
|
||||
local con = nn.ConcatTable()
|
||||
local conv = nn.Sequential()
|
||||
conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||
conv:add(ReLU(backend))
|
||||
conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||
conv:add(ReLU(backend))
|
||||
con:add(conv)
|
||||
con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
|
||||
seq:add(con)
|
||||
seq:add(nn.CAddTable())
|
||||
return seq
|
||||
end
|
||||
local model = nn.Sequential()
|
||||
--model:add(skip(backend, ch, 64 - ch))
|
||||
model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(resblock(backend))
|
||||
model:add(resblock(backend))
|
||||
model:add(resblock(backend))
|
||||
model:add(resblock(backend))
|
||||
model:add(resblock(backend))
|
||||
model:add(resblock(backend))
|
||||
model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2))
|
||||
model:add(ReLU(backend))
|
||||
model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
|
||||
|
||||
model:add(w2nn.InplaceClip01())
|
||||
--model:add(nn.View(-1):setNumInputDims(3))
|
||||
model.w2nn_arch_name = "srresnet_2x"
|
||||
model.w2nn_offset = 28
|
||||
model.w2nn_scale_factor = 2
|
||||
model.w2nn_resize = true
|
||||
model.w2nn_channels = ch
|
||||
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
|
||||
-- large version of srresnet_2x. It's current best model but slow.
|
||||
function srcnn.resnet_14l(backend, ch)
|
||||
local function resblock(backend, i, o)
|
||||
local seq = nn.Sequential()
|
||||
local con = nn.ConcatTable()
|
||||
local conv = nn.Sequential()
|
||||
conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
|
||||
conv:add(nn.LeakyReLU(0.1, true))
|
||||
conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
|
||||
conv:add(nn.LeakyReLU(0.1, true))
|
||||
con:add(conv)
|
||||
if i == o then
|
||||
con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
|
||||
else
|
||||
local seq = nn.Sequential()
|
||||
seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
|
||||
seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
|
||||
con:add(seq)
|
||||
end
|
||||
seq:add(con)
|
||||
seq:add(nn.CAddTable())
|
||||
return seq
|
||||
end
|
||||
local model = nn.Sequential()
|
||||
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(resblock(backend, 32, 64))
|
||||
model:add(resblock(backend, 64, 64))
|
||||
model:add(resblock(backend, 64, 128))
|
||||
model:add(resblock(backend, 128, 128))
|
||||
model:add(resblock(backend, 128, 256))
|
||||
model:add(resblock(backend, 256, 256))
|
||||
model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
model.w2nn_arch_name = "resnet_14l"
|
||||
model.w2nn_offset = 28
|
||||
model.w2nn_scale_factor = 2
|
||||
model.w2nn_resize = true
|
||||
model.w2nn_channels = ch
|
||||
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
|
||||
-- for segmentation
|
||||
function srcnn.fcn_v1(backend, ch)
|
||||
-- input_size = 120
|
||||
local model = nn.Sequential()
|
||||
--i = 120
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
|
||||
|
||||
model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
|
||||
|
||||
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
|
||||
|
||||
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
|
||||
|
||||
model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(nn.Dropout(0.5, false, true))
|
||||
|
||||
model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
|
||||
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
model.w2nn_arch_name = "fcn_v1"
|
||||
model.w2nn_offset = 36
|
||||
model.w2nn_scale_factor = 1
|
||||
model.w2nn_channels = ch
|
||||
model.w2nn_input_size = 120
|
||||
--model.w2nn_gcn = true
|
||||
|
||||
return model
|
||||
end
|
||||
function srcnn.create(model_name, backend, color)
|
||||
model_name = model_name or "vgg_7"
|
||||
backend = backend or "cunn"
|
||||
|
@ -282,8 +548,10 @@ function srcnn.create(model_name, backend, color)
|
|||
error("unsupported model_name: " .. model_name)
|
||||
end
|
||||
end
|
||||
|
||||
--local model = srcnn.upconv_6("cunn", 3):cuda()
|
||||
--print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
|
||||
--[[
|
||||
local model = srcnn.fcn_v1("cunn", 3):cuda()
|
||||
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
|
||||
print(model)
|
||||
--]]
|
||||
|
||||
return srcnn
|
||||
|
|
|
@ -30,5 +30,8 @@ else
|
|||
require 'LeakyReLU'
|
||||
require 'ClippedWeightedHuberCriterion'
|
||||
require 'ClippedMSECriterion'
|
||||
require 'SSIMCriterion'
|
||||
require 'InplaceClip01'
|
||||
require 'L1Criterion'
|
||||
return w2nn
|
||||
end
|
||||
|
|
1
models/resnet_14l/README.md
Normal file
1
models/resnet_14l/README.md
Normal file
|
@ -0,0 +1 @@
|
|||
Currently, this models are for the benchmark.
|
BIN
models/resnet_14l/photo/scale2.0x_model.t7
Normal file
BIN
models/resnet_14l/photo/scale2.0x_model.t7
Normal file
Binary file not shown.
1
models/upconv_7l/README.md
Normal file
1
models/upconv_7l/README.md
Normal file
|
@ -0,0 +1 @@
|
|||
Currently, this models are for the benchmark.
|
BIN
models/upconv_7l/art/scale2.0x_model.json
Normal file
BIN
models/upconv_7l/art/scale2.0x_model.json
Normal file
Binary file not shown.
BIN
models/upconv_7l/art/scale2.0x_model.t7
Normal file
BIN
models/upconv_7l/art/scale2.0x_model.t7
Normal file
Binary file not shown.
BIN
models/upconv_7l/photo/scale2.0x_model.json
Normal file
BIN
models/upconv_7l/photo/scale2.0x_model.json
Normal file
Binary file not shown.
BIN
models/upconv_7l/photo/scale2.0x_model.t7
Normal file
BIN
models/upconv_7l/photo/scale2.0x_model.t7
Normal file
Binary file not shown.
|
@ -18,7 +18,7 @@ cmd:option("-dir", "./data/test", 'test image directory')
|
|||
cmd:option("-file", "", 'test image file list')
|
||||
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
|
||||
cmd:option("-model2_dir", "", 'model2 directory (optional)')
|
||||
cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff)')
|
||||
cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff|scale4)')
|
||||
cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
|
||||
cmd:option("-resize_blur", 1.0, 'blur parameter for resize')
|
||||
cmd:option("-color", "y", '(rgb|y|r|g|b)')
|
||||
|
@ -46,6 +46,7 @@ cmd:option("-x_dir", "", 'input image for user method')
|
|||
cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be the same as x_dir')
|
||||
cmd:option("-x_file", "", 'input image for user method')
|
||||
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
||||
cmd:option("-border", 0, 'border px that will removed')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
@ -153,12 +154,24 @@ local function baseline_scale(x, filter)
|
|||
x:size(2) * 2.0,
|
||||
filter)
|
||||
end
|
||||
local function baseline_scale4(x, filter)
|
||||
return iproc.scale(x,
|
||||
x:size(3) * 4.0,
|
||||
x:size(2) * 4.0,
|
||||
filter)
|
||||
end
|
||||
local function transform_scale(x, opt)
|
||||
return iproc.scale(x,
|
||||
x:size(3) * 0.5,
|
||||
x:size(2) * 0.5,
|
||||
opt.filter, opt.resize_blur)
|
||||
end
|
||||
local function transform_scale4(x, opt)
|
||||
return iproc.scale(x,
|
||||
x:size(3) * 0.25,
|
||||
x:size(2) * 0.25,
|
||||
opt.filter, opt.resize_blur)
|
||||
end
|
||||
|
||||
local function transform_scale_jpeg(x, opt)
|
||||
x = iproc.scale(x,
|
||||
|
@ -179,9 +192,15 @@ local function transform_scale_jpeg(x, opt)
|
|||
end
|
||||
return iproc.byte2float(x)
|
||||
end
|
||||
|
||||
local function remove_border(x, border)
|
||||
return iproc.crop(x,
|
||||
border, border,
|
||||
x:size(3) - border,
|
||||
x:size(2) - border)
|
||||
end
|
||||
local function benchmark(opt, x, model1, model2)
|
||||
local mse
|
||||
local mse1, mse2
|
||||
local won = {0, 0}
|
||||
local model1_mse = 0
|
||||
local model2_mse = 0
|
||||
local baseline_mse = 0
|
||||
|
@ -192,6 +211,10 @@ local function benchmark(opt, x, model1, model2)
|
|||
local model2_time = 0
|
||||
local scale_f = reconstruct.scale
|
||||
local image_f = reconstruct.image
|
||||
local detail_fp = nil
|
||||
if opt.save_info then
|
||||
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
||||
end
|
||||
if opt.tta then
|
||||
scale_f = function(model, scale, x, block_size, batch_size)
|
||||
return reconstruct.scale_tta(model, opt.tta_level,
|
||||
|
@ -204,12 +227,15 @@ local function benchmark(opt, x, model1, model2)
|
|||
end
|
||||
|
||||
for i = 1, #x do
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
local basename = x[i].basename
|
||||
local input, model1_output, model2_output, baseline_output, ground_truth
|
||||
|
||||
if opt.method == "scale" then
|
||||
input = transform_scale(x[i].y, opt)
|
||||
ground_truth = x[i].y
|
||||
input = transform_scale(iproc.byte2float(x[i].y), opt)
|
||||
ground_truth = iproc.byte2float(x[i].y)
|
||||
|
||||
if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
|
||||
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
|
||||
|
@ -226,9 +252,29 @@ local function benchmark(opt, x, model1, model2)
|
|||
model2_time = model2_time + (sys.clock() - t)
|
||||
end
|
||||
baseline_output = baseline_scale(input, opt.baseline_filter)
|
||||
elseif opt.method == "scale4" then
|
||||
input = transform_scale4(iproc.byte2float(x[i].y), opt)
|
||||
ground_truth = iproc.byte2float(x[i].y)
|
||||
if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
|
||||
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
|
||||
if model2 then
|
||||
model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
|
||||
end
|
||||
end
|
||||
t = sys.clock()
|
||||
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
|
||||
model1_output = scale_f(model1, 2.0, model1_output, opt.crop_size, opt.batch_size)
|
||||
model1_time = model1_time + (sys.clock() - t)
|
||||
if model2 then
|
||||
t = sys.clock()
|
||||
model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
|
||||
model2_output = scale_f(model2, 2.0, model2_output, opt.crop_size, opt.batch_size)
|
||||
model2_time = model2_time + (sys.clock() - t)
|
||||
end
|
||||
baseline_output = baseline_scale4(input, opt.baseline_filter)
|
||||
elseif opt.method == "noise" then
|
||||
input = transform_jpeg(x[i].y, opt)
|
||||
ground_truth = x[i].y
|
||||
input = transform_jpeg(iproc.byte2float(x[i].y), opt)
|
||||
ground_truth = iproc.byte2float(x[i].y)
|
||||
|
||||
if opt.force_cudnn and i == 1 then
|
||||
model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
|
||||
|
@ -246,8 +292,8 @@ local function benchmark(opt, x, model1, model2)
|
|||
end
|
||||
baseline_output = input
|
||||
elseif opt.method == "noise_scale" then
|
||||
input = transform_scale_jpeg(x[i].y, opt)
|
||||
ground_truth = x[i].y
|
||||
input = transform_scale_jpeg(iproc.byte2float(x[i].y), opt)
|
||||
ground_truth = iproc.byte2float(x[i].y)
|
||||
|
||||
if opt.force_cudnn and i == 1 then
|
||||
if model1.noise_scale_model then
|
||||
|
@ -312,8 +358,8 @@ local function benchmark(opt, x, model1, model2)
|
|||
end
|
||||
baseline_output = baseline_scale(input, opt.baseline_filter)
|
||||
elseif opt.method == "user" then
|
||||
input = x[i].x
|
||||
ground_truth = x[i].y
|
||||
input = iproc.byte2float(x[i].x)
|
||||
ground_truth = iproc.byte2float(x[i].y)
|
||||
local y_scale = ground_truth:size(2) / input:size(2)
|
||||
if y_scale > 1 then
|
||||
if opt.force_cudnn and i == 1 then
|
||||
|
@ -347,19 +393,44 @@ local function benchmark(opt, x, model1, model2)
|
|||
end
|
||||
end
|
||||
elseif opt.method == "diff" then
|
||||
input = x[i].x
|
||||
ground_truth = x[i].y
|
||||
input = iproc.byte2float(x[i].x)
|
||||
ground_truth = iproc.byte2float(x[i].y)
|
||||
model1_output = input
|
||||
end
|
||||
mse = MSE(ground_truth, model1_output, opt.color)
|
||||
model1_mse = model1_mse + mse
|
||||
model1_psnr = model1_psnr + MSE2PSNR(mse)
|
||||
if opt.border > 0 then
|
||||
ground_truth = remove_border(ground_truth, opt.border)
|
||||
model1_output = remove_border(model1_output, opt.border)
|
||||
end
|
||||
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||
model1_mse = model1_mse + mse1
|
||||
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||
|
||||
local won_model = 1
|
||||
if model2 then
|
||||
mse = MSE(ground_truth, model2_output, opt.color)
|
||||
model2_mse = model2_mse + mse
|
||||
model2_psnr = model2_psnr + MSE2PSNR(mse)
|
||||
if opt.border > 0 then
|
||||
model2_output = remove_border(model2_output, opt.border)
|
||||
end
|
||||
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||
model2_mse = model2_mse + mse2
|
||||
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
||||
|
||||
if mse1 < mse2 then
|
||||
won[1] = won[1] + 1
|
||||
elseif mse1 > mse2 then
|
||||
won[2] = won[2] + 1
|
||||
won_model = 2
|
||||
end
|
||||
if detail_fp then
|
||||
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
||||
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
||||
end
|
||||
else
|
||||
if detail_fp then
|
||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
||||
end
|
||||
end
|
||||
if baseline_output then
|
||||
baseline_output = remove_border(baseline_output, opt.border)
|
||||
mse = MSE(ground_truth, baseline_output, opt.color)
|
||||
baseline_mse = baseline_mse + mse
|
||||
baseline_psnr = baseline_psnr + MSE2PSNR(mse)
|
||||
|
@ -382,29 +453,31 @@ local function benchmark(opt, x, model1, model2)
|
|||
if model2 then
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(baseline_mse / i),
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
baseline_psnr / i,
|
||||
model1_psnr / i, model2_psnr / i
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
model1_psnr / i, model2_psnr / i
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
end
|
||||
else
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r",
|
||||
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
||||
|
@ -412,7 +485,7 @@ local function benchmark(opt, x, model1, model2)
|
|||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model1_rmse=%f, model1_psnr=%f \r",
|
||||
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(model1_mse / i), model1_psnr / i
|
||||
|
@ -438,6 +511,9 @@ local function benchmark(opt, x, model1, model2)
|
|||
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
||||
end
|
||||
fp:close()
|
||||
if detail_fp then
|
||||
detail_fp:close()
|
||||
end
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
|
@ -448,7 +524,7 @@ local function load_data_from_dir(test_dir)
|
|||
local name = path.basename(files[i])
|
||||
local e = path.extension(name)
|
||||
local base = name:sub(0, name:len() - e:len())
|
||||
local img = image_loader.load_float(files[i])
|
||||
local img = image_loader.load_byte(files[i])
|
||||
if img then
|
||||
table.insert(test_x, {y = iproc.crop_mod4(img),
|
||||
basename = base})
|
||||
|
@ -456,6 +532,9 @@ local function load_data_from_dir(test_dir)
|
|||
if opt.show_progress then
|
||||
xlua.progress(i, #files)
|
||||
end
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
return test_x
|
||||
end
|
||||
|
@ -466,7 +545,7 @@ local function load_data_from_file(test_file)
|
|||
local name = path.basename(files[i])
|
||||
local e = path.extension(name)
|
||||
local base = name:sub(0, name:len() - e:len())
|
||||
local img = image_loader.load_float(files[i])
|
||||
local img = image_loader.load_byte(files[i])
|
||||
if img then
|
||||
table.insert(test_x, {y = iproc.crop_mod4(img),
|
||||
basename = base})
|
||||
|
@ -474,6 +553,9 @@ local function load_data_from_file(test_file)
|
|||
if opt.show_progress then
|
||||
xlua.progress(i, #files)
|
||||
end
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
return test_x
|
||||
end
|
||||
|
@ -519,16 +601,19 @@ local function load_user_data(y_dir, y_file, x_dir, x_file)
|
|||
end
|
||||
for i = 1, #y_files do
|
||||
local key = get_basename(y_files[i])
|
||||
local x = image_loader.load_float(basename_db[key].x)
|
||||
local y = image_loader.load_float(basename_db[key].y)
|
||||
local x = image_loader.load_byte(basename_db[key].x)
|
||||
local y = image_loader.load_byte(basename_db[key].y)
|
||||
if x and y then
|
||||
table.insert(test, {y = y,
|
||||
x = x,
|
||||
basename = base})
|
||||
basename = key})
|
||||
end
|
||||
if opt.show_progress then
|
||||
xlua.progress(i, #y_files)
|
||||
end
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
return test
|
||||
end
|
||||
|
@ -563,7 +648,7 @@ if opt.show_progress then
|
|||
print(opt)
|
||||
end
|
||||
|
||||
if opt.method == "scale" then
|
||||
if opt.method == "scale" or opt.method == "scale4" then
|
||||
local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
|
||||
local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
|
||||
local s1, model1 = pcall(w2nn.load_model, f1, opt.force_cudnn)
|
||||
|
|
|
@ -33,5 +33,7 @@ export_model() {
|
|||
}
|
||||
export_model vgg_7/art
|
||||
export_model upconv_7/art
|
||||
export_model upconv_7l/art
|
||||
export_model vgg_7/photo
|
||||
export_model upconv_7/photo
|
||||
export_model upconv_7l/photo
|
||||
|
|
|
@ -22,7 +22,6 @@ local function includes(s, a)
|
|||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function get_bias(mod)
|
||||
if mod.bias then
|
||||
return mod.bias:float()
|
||||
|
@ -31,20 +30,18 @@ local function get_bias(mod)
|
|||
return torch.FloatTensor(mod.nOutputPlane):zero()
|
||||
end
|
||||
end
|
||||
local function export(model, output)
|
||||
local function export_weight(jmodules, seq)
|
||||
local targets = {"nn.SpatialConvolutionMM",
|
||||
"cudnn.SpatialConvolution",
|
||||
"nn.SpatialFullConvolution",
|
||||
"cudnn.SpatialFullConvolution"
|
||||
}
|
||||
local jmodules = {}
|
||||
local model_config = meta_data(model)
|
||||
local first_layer = true
|
||||
|
||||
for k = 1, #model.modules do
|
||||
local mod = model.modules[k]
|
||||
for k = 1, #seq.modules do
|
||||
local mod = seq.modules[k]
|
||||
local name = torch.typename(mod)
|
||||
if includes(name, targets) then
|
||||
if name == "nn.Sequential" or name == "nn.ConcatTable" then
|
||||
export_weight(jmodules, mod)
|
||||
elseif includes(name, targets) then
|
||||
local weight = mod.weight:float()
|
||||
if name:match("FullConvolution") then
|
||||
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
|
||||
|
@ -71,6 +68,14 @@ local function export(model, output)
|
|||
table.insert(jmodules, jmod)
|
||||
end
|
||||
end
|
||||
end
|
||||
local function export(model, output)
|
||||
local jmodules = {}
|
||||
local model_config = meta_data(model)
|
||||
local first_layer = true
|
||||
|
||||
export_weight(jmodules, model)
|
||||
|
||||
local fp = io.open(output, "w")
|
||||
if not fp then
|
||||
error("IO Error: " .. output)
|
||||
|
|
425
train.lua
425
train.lua
|
@ -3,15 +3,14 @@ local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^
|
|||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
require 'optim'
|
||||
require 'xlua'
|
||||
|
||||
require 'image'
|
||||
require 'w2nn'
|
||||
local threads = require 'threads'
|
||||
local settings = require 'settings'
|
||||
local srcnn = require 'srcnn'
|
||||
local minibatch_adam = require 'minibatch_adam'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local compression = require 'compression'
|
||||
local pairwise_transform = require 'pairwise_transform'
|
||||
local image_loader = require 'image_loader'
|
||||
|
||||
local function save_test_scale(model, rgb, file)
|
||||
|
@ -42,20 +41,218 @@ local function split_data(x, test_size)
|
|||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
local function make_validation_set(x, transformer, n, patches)
|
||||
|
||||
local g_transform_pool = nil
|
||||
local g_mutex = nil
|
||||
local g_mutex_id = nil
|
||||
local function transform_pool_init(has_resize, offset)
|
||||
local nthread = torch.getnumthreads()
|
||||
if (settings.thread > 0) then
|
||||
nthread = settings.thread
|
||||
end
|
||||
g_mutex = threads.Mutex()
|
||||
g_mutex_id = g_mutex:id()
|
||||
g_transform_pool = threads.Threads(
|
||||
nthread,
|
||||
threads.safe(
|
||||
function(threadid)
|
||||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
require 'torch'
|
||||
require 'nn'
|
||||
require 'cunn'
|
||||
|
||||
torch.setnumthreads(1)
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
|
||||
local threads = require 'threads'
|
||||
local compression = require 'compression'
|
||||
local pairwise_transform = require 'pairwise_transform'
|
||||
|
||||
function transformer(x, is_validation, n)
|
||||
local mutex = threads.Mutex(g_mutex_id)
|
||||
local meta = {data = {}}
|
||||
local y = nil
|
||||
if type(x) == "table" and type(x[2]) == "table" then
|
||||
meta = x[2]
|
||||
if x[1].x and x[1].y then
|
||||
y = compression.decompress(x[1].y)
|
||||
x = compression.decompress(x[1].x)
|
||||
else
|
||||
x = compression.decompress(x[1])
|
||||
end
|
||||
else
|
||||
x = compression.decompress(x)
|
||||
end
|
||||
n = n or settings.patches
|
||||
if is_validation == nil then is_validation = false end
|
||||
local random_color_noise_rate = nil
|
||||
local random_overlay_rate = nil
|
||||
local active_cropping_rate = nil
|
||||
local active_cropping_tries = nil
|
||||
if is_validation then
|
||||
active_cropping_rate = settings.active_cropping_rate
|
||||
active_cropping_tries = settings.active_cropping_tries
|
||||
random_color_noise_rate = 0.0
|
||||
random_overlay_rate = 0.0
|
||||
else
|
||||
active_cropping_rate = settings.active_cropping_rate
|
||||
active_cropping_tries = settings.active_cropping_tries
|
||||
random_color_noise_rate = settings.random_color_noise_rate
|
||||
random_overlay_rate = settings.random_overlay_rate
|
||||
end
|
||||
if settings.method == "scale" then
|
||||
local conf = tablex.update({
|
||||
mutex = mutex,
|
||||
downsampling_filters = settings.downsampling_filters,
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
random_blur_rate = settings.random_blur_rate,
|
||||
random_blur_size = settings.random_blur_size,
|
||||
random_blur_sigma_min = settings.random_blur_sigma_min,
|
||||
random_blur_sigma_max = settings.random_blur_sigma_max,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb"),
|
||||
x_upsampling = not has_resize,
|
||||
resize_blur_min = settings.resize_blur_min,
|
||||
resize_blur_max = settings.resize_blur_max}, meta)
|
||||
return pairwise_transform.scale(x,
|
||||
settings.scale,
|
||||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
elseif settings.method == "noise" then
|
||||
local conf = tablex.update({
|
||||
mutex = mutex,
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
random_blur_rate = settings.random_blur_rate,
|
||||
random_blur_size = settings.random_blur_size,
|
||||
random_blur_sigma_min = settings.random_blur_sigma_min,
|
||||
random_blur_sigma_max = settings.random_blur_sigma_max,
|
||||
max_size = settings.max_size,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
nr_rate = settings.nr_rate,
|
||||
rgb = (settings.color == "rgb")}, meta)
|
||||
return pairwise_transform.jpeg(x,
|
||||
settings.style,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
elseif settings.method == "noise_scale" then
|
||||
local conf = tablex.update({
|
||||
mutex = mutex,
|
||||
downsampling_filters = settings.downsampling_filters,
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
random_blur_rate = settings.random_blur_rate,
|
||||
random_blur_size = settings.random_blur_size,
|
||||
random_blur_sigma_min = settings.random_blur_sigma_min,
|
||||
random_blur_sigma_max = settings.random_blur_sigma_max,
|
||||
max_size = settings.max_size,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
nr_rate = settings.nr_rate,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb"),
|
||||
x_upsampling = not has_resize,
|
||||
resize_blur_min = settings.resize_blur_min,
|
||||
resize_blur_max = settings.resize_blur_max}, meta)
|
||||
return pairwise_transform.jpeg_scale(x,
|
||||
settings.scale,
|
||||
settings.style,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
elseif settings.method == "user" then
|
||||
if is_validation == nil then is_validation = false end
|
||||
local rotate_rate = nil
|
||||
local scale_rate = nil
|
||||
local negate_rate = nil
|
||||
local negate_x_rate = nil
|
||||
if is_validation then
|
||||
rotate_rate = 0
|
||||
scale_rate = 0
|
||||
negate_rate = 0
|
||||
negate_x_rate = 0
|
||||
else
|
||||
rotate_rate = settings.random_pairwise_rotate_rate
|
||||
scale_rate = settings.random_pairwise_scale_rate
|
||||
negate_rate = settings.random_pairwise_negate_rate
|
||||
negate_x_rate = settings.random_pairwise_negate_x_rate
|
||||
end
|
||||
local conf = tablex.update({
|
||||
gcn = settings.gcn,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
random_pairwise_rotate_rate = rotate_rate,
|
||||
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
|
||||
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
|
||||
random_pairwise_scale_rate = scale_rate,
|
||||
random_pairwise_scale_min = settings.random_pairwise_scale_min,
|
||||
random_pairwise_scale_max = settings.random_pairwise_scale_max,
|
||||
random_pairwise_negate_rate = negate_rate,
|
||||
random_pairwise_negate_x_rate = negate_x_rate,
|
||||
pairwise_y_binary = settings.pairwise_y_binary,
|
||||
pairwise_flip = settings.pairwise_flip,
|
||||
rgb = (settings.color == "rgb")}, meta)
|
||||
return pairwise_transform.user(x, y,
|
||||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
end
|
||||
end
|
||||
end)
|
||||
)
|
||||
g_transform_pool:synchronize()
|
||||
end
|
||||
|
||||
local function make_validation_set(x, n, patches)
|
||||
local nthread = torch.getnumthreads()
|
||||
if (settings.thread > 0) then
|
||||
nthread = settings.thread
|
||||
end
|
||||
n = n or 4
|
||||
local validation_patches = math.min(16, patches or 16)
|
||||
local data = {}
|
||||
|
||||
g_transform_pool:synchronize()
|
||||
torch.setnumthreads(1) -- 1
|
||||
|
||||
for i = 1, #x do
|
||||
for k = 1, math.max(n / validation_patches, 1) do
|
||||
local xy = transformer(x[i], true, validation_patches)
|
||||
for j = 1, #xy do
|
||||
table.insert(data, {x = xy[j][1], y = xy[j][2]})
|
||||
end
|
||||
local input = x[i]
|
||||
g_transform_pool:addjob(
|
||||
function()
|
||||
local xy = transformer(input, true, validation_patches)
|
||||
return xy
|
||||
end,
|
||||
function(xy)
|
||||
for j = 1, #xy do
|
||||
table.insert(data, {x = xy[j][1], y = xy[j][2]})
|
||||
end
|
||||
end
|
||||
)
|
||||
end
|
||||
if i % 20 == 0 then
|
||||
collectgarbage()
|
||||
g_transform_pool:synchronize()
|
||||
xlua.progress(i, #x)
|
||||
end
|
||||
xlua.progress(i, #x)
|
||||
collectgarbage()
|
||||
end
|
||||
g_transform_pool:synchronize()
|
||||
torch.setnumthreads(nthread) -- revert
|
||||
|
||||
local new_data = {}
|
||||
local perm = torch.randperm(#data)
|
||||
for i = 1, perm:size(1) do
|
||||
|
@ -102,144 +299,71 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|||
end
|
||||
|
||||
local function create_criterion(model)
|
||||
if reconstruct.is_rgb(model) then
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(3, output_w * output_w)
|
||||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
else
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(1, output_w * output_w)
|
||||
weight[1]:fill(1.0)
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
end
|
||||
end
|
||||
local function transformer(model, x, is_validation, n, offset)
|
||||
local meta = {data = {}}
|
||||
local y = nil
|
||||
if type(x) == "table" and type(x[2]) == "table" then
|
||||
meta = x[2]
|
||||
if x[1].x and x[1].y then
|
||||
y = compression.decompress(x[1].y)
|
||||
x = compression.decompress(x[1].x)
|
||||
if settings.loss == "huber" then
|
||||
if reconstruct.is_rgb(model) then
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(3, output_w * output_w)
|
||||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
else
|
||||
x = compression.decompress(x[1])
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(1, output_w * output_w)
|
||||
weight[1]:fill(1.0)
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
end
|
||||
elseif settings.loss == "l1" then
|
||||
return w2nn.L1Criterion():cuda()
|
||||
elseif settings.loss == "mse" then
|
||||
return w2nn.ClippedMSECriterion(0, 1.0):cuda()
|
||||
else
|
||||
x = compression.decompress(x)
|
||||
end
|
||||
n = n or settings.patches
|
||||
if is_validation == nil then is_validation = false end
|
||||
local random_color_noise_rate = nil
|
||||
local random_overlay_rate = nil
|
||||
local active_cropping_rate = nil
|
||||
local active_cropping_tries = nil
|
||||
if is_validation then
|
||||
active_cropping_rate = settings.active_cropping_rate
|
||||
active_cropping_tries = settings.active_cropping_tries
|
||||
random_color_noise_rate = 0.0
|
||||
random_overlay_rate = 0.0
|
||||
else
|
||||
active_cropping_rate = settings.active_cropping_rate
|
||||
active_cropping_tries = settings.active_cropping_tries
|
||||
random_color_noise_rate = settings.random_color_noise_rate
|
||||
random_overlay_rate = settings.random_overlay_rate
|
||||
end
|
||||
if settings.method == "scale" then
|
||||
local conf = tablex.update({
|
||||
downsampling_filters = settings.downsampling_filters,
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb"),
|
||||
x_upsampling = not reconstruct.has_resize(model),
|
||||
resize_blur_min = settings.resize_blur_min,
|
||||
resize_blur_max = settings.resize_blur_max}, meta)
|
||||
return pairwise_transform.scale(x,
|
||||
settings.scale,
|
||||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
elseif settings.method == "noise" then
|
||||
local conf = tablex.update({
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
nr_rate = settings.nr_rate,
|
||||
rgb = (settings.color == "rgb")}, meta)
|
||||
return pairwise_transform.jpeg(x,
|
||||
settings.style,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
elseif settings.method == "noise_scale" then
|
||||
local conf = tablex.update({
|
||||
downsampling_filters = settings.downsampling_filters,
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
nr_rate = settings.nr_rate,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb"),
|
||||
x_upsampling = not reconstruct.has_resize(model),
|
||||
resize_blur_min = settings.resize_blur_min,
|
||||
resize_blur_max = settings.resize_blur_max}, meta)
|
||||
return pairwise_transform.jpeg_scale(x,
|
||||
settings.scale,
|
||||
settings.style,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
elseif settings.method == "user" then
|
||||
local conf = tablex.update({
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb")}, meta)
|
||||
return pairwise_transform.user(x, y,
|
||||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
error("unsupported loss .." .. settings.loss)
|
||||
end
|
||||
end
|
||||
|
||||
local function resampling(x, y, train_x, transformer, input_size, target_size)
|
||||
local function resampling(x, y, train_x)
|
||||
local c = 1
|
||||
local shuffle = torch.randperm(#train_x)
|
||||
local nthread = torch.getnumthreads()
|
||||
if (settings.thread > 0) then
|
||||
nthread = settings.thread
|
||||
end
|
||||
torch.setnumthreads(1) -- 1
|
||||
|
||||
for t = 1, #train_x do
|
||||
xlua.progress(t, #train_x)
|
||||
local xy = transformer(train_x[shuffle[t]], false, settings.patches)
|
||||
for i = 1, #xy do
|
||||
x[c]:copy(xy[i][1])
|
||||
y[c]:copy(xy[i][2])
|
||||
c = c + 1
|
||||
if c > x:size(1) then
|
||||
break
|
||||
local input = train_x[shuffle[t]]
|
||||
g_transform_pool:addjob(
|
||||
function()
|
||||
local xy = transformer(input, false, settings.patches)
|
||||
return xy
|
||||
end,
|
||||
function(xy)
|
||||
for i = 1, #xy do
|
||||
if c <= x:size(1) then
|
||||
x[c]:copy(xy[i][1])
|
||||
y[c]:copy(xy[i][2])
|
||||
c = c + 1
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
)
|
||||
if t % 50 == 0 then
|
||||
collectgarbage()
|
||||
g_transform_pool:synchronize()
|
||||
xlua.progress(t, #train_x)
|
||||
end
|
||||
if c > x:size(1) then
|
||||
break
|
||||
end
|
||||
if t % 50 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
g_transform_pool:synchronize()
|
||||
xlua.progress(#train_x, #train_x)
|
||||
torch.setnumthreads(nthread) -- revert
|
||||
end
|
||||
local function get_oracle_data(x, y, instance_loss, k, samples)
|
||||
local index = torch.LongTensor(instance_loss:size(1))
|
||||
|
@ -262,6 +386,7 @@ local function get_oracle_data(x, y, instance_loss, k, samples)
|
|||
end
|
||||
|
||||
local function remove_small_image(x)
|
||||
local compression = require 'compression'
|
||||
local new_x = {}
|
||||
for i = 1, #x do
|
||||
local xe, meta, x_s
|
||||
|
@ -293,6 +418,8 @@ local function plot(train, valid)
|
|||
{'validation', torch.Tensor(valid), '-'}})
|
||||
end
|
||||
local function train()
|
||||
local x = remove_small_image(torch.load(settings.images))
|
||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||
local hist_train = {}
|
||||
local hist_valid = {}
|
||||
local model
|
||||
|
@ -301,20 +428,30 @@ local function train()
|
|||
else
|
||||
model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||
end
|
||||
if model.w2nn_input_size then
|
||||
if settings.crop_size ~= model.w2nn_input_size then
|
||||
io.stderr:write(string.format("warning: crop_size is replaced with %d\n",
|
||||
model.w2nn_input_size))
|
||||
settings.crop_size = model.w2nn_input_size
|
||||
end
|
||||
end
|
||||
if model.w2nn_gcn then
|
||||
settings.gcn = true
|
||||
else
|
||||
settings.gcn = false
|
||||
end
|
||||
dir.makepath(settings.model_dir)
|
||||
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local pairwise_func = function(x, is_validation, n)
|
||||
return transformer(model, x, is_validation, n, offset)
|
||||
end
|
||||
transform_pool_init(reconstruct.has_resize(model), offset)
|
||||
|
||||
local criterion = create_criterion(model)
|
||||
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
|
||||
local x = remove_small_image(torch.load(settings.images))
|
||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||
local adam_config = {
|
||||
xLearningRate = settings.learning_rate,
|
||||
xBatchSize = settings.batch_size,
|
||||
xLearningRateDecay = settings.learning_rate_decay
|
||||
xLearningRateDecay = settings.learning_rate_decay,
|
||||
xInstanceLoss = (settings.oracle_rate > 0)
|
||||
}
|
||||
local ch = nil
|
||||
if settings.color == "y" then
|
||||
|
@ -324,7 +461,7 @@ local function train()
|
|||
end
|
||||
local best_score = 1000.0
|
||||
print("# make validation-set")
|
||||
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
||||
local valid_xy = make_validation_set(valid_x,
|
||||
settings.validation_crops,
|
||||
settings.patches)
|
||||
valid_x = nil
|
||||
|
@ -358,7 +495,7 @@ local function train()
|
|||
if oracle_n > 0 then
|
||||
local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n)
|
||||
resampling(x:narrow(1, oracle_x:size(1) + 1, x:size(1)-oracle_x:size(1)),
|
||||
y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x, pairwise_func)
|
||||
y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x)
|
||||
x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x)
|
||||
y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y)
|
||||
|
||||
|
@ -374,7 +511,7 @@ local function train()
|
|||
min = 0,
|
||||
max = 1}))
|
||||
else
|
||||
resampling(x, y, train_x, pairwise_func)
|
||||
resampling(x, y, train_x)
|
||||
end
|
||||
else
|
||||
resampling(x, y, train_x, pairwise_func)
|
||||
|
@ -395,9 +532,9 @@ local function train()
|
|||
if settings.plot then
|
||||
plot(hist_train, hist_valid)
|
||||
end
|
||||
if score.MSE < best_score then
|
||||
if score.loss < best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
best_score = score.MSE
|
||||
best_score = score.loss
|
||||
print("* model has updated")
|
||||
if settings.save_history then
|
||||
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
||||
|
@ -446,7 +583,7 @@ local function train()
|
|||
end
|
||||
end
|
||||
end
|
||||
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
|
||||
print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
|
|
|
@ -267,6 +267,7 @@ local function waifu2x()
|
|||
cmd:option("-tta_level", 8, 'TTA level (2|4|8). A higher value makes better quality output but slow')
|
||||
cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
|
||||
cmd:option("-q", 0, 'quiet (0|1)')
|
||||
cmd:option("-gpu", 1, 'Device ID')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if opt.method:len() > 0 then
|
||||
|
@ -292,5 +293,6 @@ local function waifu2x()
|
|||
else
|
||||
convert_frames(opt)
|
||||
end
|
||||
cutorch.setDevice(opt.gpu)
|
||||
end
|
||||
waifu2x()
|
||||
|
|
Loading…
Reference in a new issue