1
0
Fork 0
mirror of synced 2024-06-24 01:00:41 +12:00
waifu2x/appendix/caffe_prototxt/make_cunet.py
2018-11-14 06:26:24 +09:00

111 lines
4.1 KiB
Python

# Generate prototxt of waifu2x's cunet/upcunet arch. Training is not possible.
from __future__ import print_function
import sys
sys.path.insert(0, "../python") # pycaffe path
from caffe import layers as L, params as P, to_proto
from caffe.proto import caffe_pb2
import caffe
def seblock(bottom, o, r):
m = int(o / r)
gap = L.Pooling(bottom, pool=P.Pooling.AVE, global_pooling=True)
linear1 = L.Convolution(gap, kernel_size=1, pad=0, stride=1, num_output=m)
relu1 = L.ReLU(linear1, in_place=True)
linear2 = L.Convolution(relu1, kernel_size=1, pad=0, stride=1, num_output=o)
sigmoid1 = L.Sigmoid(linear2, in_place=True)
flatten1 = L.Flatten(sigmoid1)
return flatten1
def conv_relu(bottom, o, stride=1, pad=0):
conv = L.Convolution(bottom, kernel_size=3, stride=stride, num_output=o, pad=pad)
relu = L.ReLU(conv, in_place=True, negative_slope=0.1)
return relu
def unet_conv(bottom, m, o, se=True):
conv1 = L.Convolution(bottom, kernel_size=3, stride=1,
num_output=m, pad=0)
relu1 = L.ReLU(conv1, in_place=True, negative_slope=0.1)
conv2 = L.Convolution(relu1, kernel_size=3, stride=1,
num_output=o, pad=0)
relu2 = L.ReLU(conv2, in_place=True, negative_slope=0.1)
if se:
se1 = seblock(relu2, o, 8)
return L.Scale(relu2, se1, axis=0, bias_term=False)
else:
return relu2
def unet_branch(bottom, insert_f, i, o, depad):
pool = L.Convolution(bottom, kernel_size=2, stride=2, num_output=i, pad=0)
relu1 = L.ReLU(pool, in_place=True, negative_slope=0.1)
feat = insert_f(relu1)
unpool = L.Deconvolution(feat, convolution_param=dict(num_output=o, kernel_size=2, pad=0, stride=2))
relu2 = L.ReLU(unpool, in_place=True, negative_slope=0.1)
crop = L.Crop(bottom, relu2, crop_param=dict(axis=2, offset=depad))
cadd = L.Eltwise(crop, relu2, operation=P.Eltwise.SUM)
return cadd
def unet1(bottom, ch, deconv):
block1 = lambda bottom: unet_conv(bottom, 128, 64, True)
conv1 = unet_conv(bottom, 32, 64, se=False)
ub1 = unet_branch(conv1, block1, 64, 64, 4)
conv2 = conv_relu(ub1, 64)
if deconv:
return L.Deconvolution(conv2, convolution_param=dict(num_output=ch, kernel_size=4, pad=3, stride=2))
else:
return L.Convolution(conv2, kernel_size=3, stride=1, num_output=ch, pad=0)
def unet2(bottom, ch, deconv):
def block1(bottom):
return unet_conv(bottom, 256, 128, se=True)
def block2(bottom):
conv1 = unet_conv(bottom, 64, 128, se=True)
ub1 = unet_branch(conv1, block1, 128, 128, 4)
conv2 = unet_conv(ub1, 64, 64, se=True)
return conv2
conv1 = unet_conv(bottom, 32, 64, se=False)
ub1 = unet_branch(conv1, block2, 64, 64, 16)
conv2 = conv_relu(ub1, 64)
if deconv:
return L.Deconvolution(conv2, convolution_param=dict(num_output=ch, kernel_size=4, pad=3, stride=2))
else:
return L.Convolution(conv2, kernel_size=3, stride=1, num_output=ch, pad=0)
def make_upcunet():
netoffset = 36
ch = 3
input_size = (256 / 2) + netoffset * 2
assert(input_size % 4 == 0)
data = L.Input(name="input", shape=dict(dim=[1, ch, input_size, input_size]))
u1 = unet1(data, ch=ch, deconv=True)
u2 = unet2(u1, ch=ch, deconv=False)
crop = L.Crop(u1, u2, crop_param=dict(axis=2, offset=20))
cadd = L.Eltwise(crop, u2, operation=P.Eltwise.SUM)
return to_proto(cadd)
def make_cunet():
netoffset = 28
ch = 3
input_size = 256 + netoffset * 2
assert(input_size % 4 == 0)
data = L.Input(name="input", shape=dict(dim=[1, ch, input_size, input_size]))
u1 = unet1(data, ch=ch, deconv=False)
u2 = unet2(u1, ch=ch, deconv=False)
crop = L.Crop(u1, u2, crop_param=dict(axis=2, offset=20))
cadd = L.Eltwise(crop, u2, operation=P.Eltwise.SUM)
return to_proto(cadd)
def make_net():
with open('upcunet.prototxt', 'w') as f:
print(make_upcunet(), file=f)
with open('cunet.prototxt', 'w') as f:
print(make_cunet(), file=f)
if __name__ == '__main__':
make_net()
# test loading the net
caffe.Net('upcunet.prototxt', caffe.TEST)
caffe.Net('cunet.prototxt', caffe.TEST)