1
0
Fork 0
mirror of synced 2024-06-02 11:04:31 +12:00
waifu2x/tools/find_unet.py

117 lines
3 KiB
Python
Raw Normal View History

2018-10-28 20:03:52 +13:00
def find_unet_v2():
avg_pool=4
print_mod = False
check_mod = True
print("cascade")
for i in range(76, 512):
print("-- {}".format(i))
print_buf = []
s = i
# unet 1
s = s - 4 # conv3x3x2
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
# deconv
s = s
s = s * 2 - 4
# unet 2
s = s - 4 # conv3x3x2
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 2 # conv3x3 last
#if s % avg_pool != 0:
# continue
print("ok", i, s)
def find_unet():
check_mod = True
print_size = False
print("cascade")
for i in range(76, 512):
print_buf = []
s = i
# unet 1
s = s - 4 # conv3x3x2
if print_size: print("1/2", s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
# deconv
s = s - 2
s = s * 2 - 4
# unet 2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 2 # conv3x3
s = s - 2 # conv3x3 last
#if s % avg_pool != 0:
# continue
print("ok", i, s)
find_unet()