1
0
Fork 0
mirror of synced 2024-05-19 04:12:19 +12:00
waifu2x/tools/find_unet.py

171 lines
4.3 KiB
Python
Raw Normal View History

2018-10-29 08:07:05 +13:00
def find_upcunet_v2():
2018-10-28 20:03:52 +13:00
avg_pool=4
print_mod = False
check_mod = True
print("cascade")
for i in range(76, 512):
print("-- {}".format(i))
print_buf = []
s = i
# unet 1
s = s - 4 # conv3x3x2
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
# deconv
s = s
s = s * 2 - 4
# unet 2
s = s - 4 # conv3x3x2
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 2 # conv3x3 last
#if s % avg_pool != 0:
# continue
print("ok", i, s)
2018-10-29 08:07:05 +13:00
def find_upcunet():
2018-10-28 20:03:52 +13:00
check_mod = True
print_size = False
print("cascade")
2018-10-29 08:07:05 +13:00
for i in range(72, 512):
2018-10-28 20:03:52 +13:00
print_buf = []
s = i
# unet 1
s = s - 4 # conv3x3x2
if print_size: print("1/2", s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
# deconv
s = s - 2
s = s * 2 - 4
# unet 2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 2 # conv3x3
s = s - 2 # conv3x3 last
#if s % avg_pool != 0:
# continue
2018-10-29 08:07:05 +13:00
print("ok", i, s, s/ i)
2018-10-28 20:03:52 +13:00
2018-10-29 08:07:05 +13:00
def find_cunet():
check_mod = True
print_size = False
print("cascade")
for i in range(72, 512):
print_buf = []
s = i
# unet 1
s = s - 4 # conv3x3x2
if print_size: print("1/2", s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
2018-10-28 20:03:52 +13:00
2018-10-29 08:07:05 +13:00
s = s - 4
#s = s * 2 - 4
# unet 2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 2 # conv3x3
s = s - 2 # conv3x3 last
#if s % avg_pool != 0:
# continue
print("ok", i, s, s / i)
#find_upcunet()
find_cunet()