1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Add GCN option for user method

This commit is contained in:
nagadomi 2016-12-25 20:17:47 +09:00
parent a95f41f464
commit 43a9b58fcb
5 changed files with 32 additions and 2 deletions

1
.gitignore vendored
View file

@ -12,6 +12,7 @@ models/*
!models/photo
!models/upconv_7
!models/upconv_7l
!models/srresnet_12l
!models/vgg_7
models/*/*.png
models/*/*/*.png

View file

@ -37,6 +37,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
if options.gcn then
local mean = xc:mean()
local stdv = xc:std()
if stdv > 0 then
xc:add(-mean):div(stdv)
else
xc:add(-mean)
end
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end

View file

@ -40,6 +40,15 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
break
end
input[j+1]:copy(x[input_indexes[i + j]])
if model.w2nn_gcn then
local mean = input[j + 1]:mean()
local stdv = input[j + 1]:std()
if stdv > 0 then
input[j + 1]:add(-mean):div(stdv)
else
input[j + 1]:add(-mean)
end
end
c = c + 1
end
input_cuda:copy(input)
@ -80,7 +89,12 @@ local function padding_params(x, model, block_size)
p.x_w = x:size(3)
p.x_h = x:size(2)
p.inner_scale = reconstruct.inner_scale(model)
local input_offset = math.ceil(offset / p.inner_scale)
local input_offset
if model.w2nn_input_offset then
input_offset = model.w2nn_input_offset
else
input_offset = math.ceil(offset / p.inner_scale)
end
local input_block_size = block_size
local process_size = input_block_size - input_offset * 2
local h_blocks = math.floor(p.x_h / process_size) +

View file

@ -519,12 +519,12 @@ function srcnn.fcn_v1(backend, ch)
model:add(w2nn.InplaceClip01())
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "fcn_v1"
model.w2nn_offset = 36
model.w2nn_scale_factor = 1
model.w2nn_channels = ch
model.w2nn_input_size = 120
model.w2nn_gcn = true
return model
end

View file

@ -192,6 +192,7 @@ local function transform_pool_init(has_resize, offset)
negate_x_rate = settings.random_pairwise_negate_x_rate
end
local conf = tablex.update({
gcn = settings.gcn,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
@ -432,6 +433,11 @@ local function train()
settings.crop_size = model.w2nn_input_size
end
end
if model.w2nn_gcn then
settings.gcn = true
else
settings.gcn = false
end
dir.makepath(settings.model_dir)
local offset = reconstruct.offset_size(model)