Add GCN option for user method
This commit is contained in:
parent
a95f41f464
commit
43a9b58fcb
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -12,6 +12,7 @@ models/*
|
|||
!models/photo
|
||||
!models/upconv_7
|
||||
!models/upconv_7l
|
||||
!models/srresnet_12l
|
||||
!models/vgg_7
|
||||
models/*/*.png
|
||||
models/*/*/*.png
|
||||
|
|
|
@ -37,6 +37,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
|||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
if options.gcn then
|
||||
local mean = xc:mean()
|
||||
local stdv = xc:std()
|
||||
if stdv > 0 then
|
||||
xc:add(-mean):div(stdv)
|
||||
else
|
||||
xc:add(-mean)
|
||||
end
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
|
||||
|
|
|
@ -40,6 +40,15 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
|
|||
break
|
||||
end
|
||||
input[j+1]:copy(x[input_indexes[i + j]])
|
||||
if model.w2nn_gcn then
|
||||
local mean = input[j + 1]:mean()
|
||||
local stdv = input[j + 1]:std()
|
||||
if stdv > 0 then
|
||||
input[j + 1]:add(-mean):div(stdv)
|
||||
else
|
||||
input[j + 1]:add(-mean)
|
||||
end
|
||||
end
|
||||
c = c + 1
|
||||
end
|
||||
input_cuda:copy(input)
|
||||
|
@ -80,7 +89,12 @@ local function padding_params(x, model, block_size)
|
|||
p.x_w = x:size(3)
|
||||
p.x_h = x:size(2)
|
||||
p.inner_scale = reconstruct.inner_scale(model)
|
||||
local input_offset = math.ceil(offset / p.inner_scale)
|
||||
local input_offset
|
||||
if model.w2nn_input_offset then
|
||||
input_offset = model.w2nn_input_offset
|
||||
else
|
||||
input_offset = math.ceil(offset / p.inner_scale)
|
||||
end
|
||||
local input_block_size = block_size
|
||||
local process_size = input_block_size - input_offset * 2
|
||||
local h_blocks = math.floor(p.x_h / process_size) +
|
||||
|
|
|
@ -519,12 +519,12 @@ function srcnn.fcn_v1(backend, ch)
|
|||
|
||||
model:add(w2nn.InplaceClip01())
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "fcn_v1"
|
||||
model.w2nn_offset = 36
|
||||
model.w2nn_scale_factor = 1
|
||||
model.w2nn_channels = ch
|
||||
model.w2nn_input_size = 120
|
||||
model.w2nn_gcn = true
|
||||
|
||||
return model
|
||||
end
|
||||
|
|
|
@ -192,6 +192,7 @@ local function transform_pool_init(has_resize, offset)
|
|||
negate_x_rate = settings.random_pairwise_negate_x_rate
|
||||
end
|
||||
local conf = tablex.update({
|
||||
gcn = settings.gcn,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
|
@ -432,6 +433,11 @@ local function train()
|
|||
settings.crop_size = model.w2nn_input_size
|
||||
end
|
||||
end
|
||||
if model.w2nn_gcn then
|
||||
settings.gcn = true
|
||||
else
|
||||
settings.gcn = false
|
||||
end
|
||||
dir.makepath(settings.model_dir)
|
||||
|
||||
local offset = reconstruct.offset_size(model)
|
||||
|
|
Loading…
Reference in a new issue