diff --git a/.gitignore b/.gitignore index b20ea3d..60dcc5d 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ models/* !models/photo !models/upconv_7 !models/upconv_7l +!models/srresnet_12l !models/vgg_7 models/*/*.png models/*/*/*.png diff --git a/lib/pairwise_transform_user.lua b/lib/pairwise_transform_user.lua index d6a4f82..279b343 100644 --- a/lib/pairwise_transform_user.lua +++ b/lib/pairwise_transform_user.lua @@ -37,6 +37,15 @@ function pairwise_transform.user(x, y, size, offset, n, options) yc = iproc.rgb2y(yc) xc = iproc.rgb2y(xc) end + if options.gcn then + local mean = xc:mean() + local stdv = xc:std() + if stdv > 0 then + xc:add(-mean):div(stdv) + else + xc:add(-mean) + end + end table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) end diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 68d2736..6c37f15 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -40,6 +40,15 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s break end input[j+1]:copy(x[input_indexes[i + j]]) + if model.w2nn_gcn then + local mean = input[j + 1]:mean() + local stdv = input[j + 1]:std() + if stdv > 0 then + input[j + 1]:add(-mean):div(stdv) + else + input[j + 1]:add(-mean) + end + end c = c + 1 end input_cuda:copy(input) @@ -80,7 +89,12 @@ local function padding_params(x, model, block_size) p.x_w = x:size(3) p.x_h = x:size(2) p.inner_scale = reconstruct.inner_scale(model) - local input_offset = math.ceil(offset / p.inner_scale) + local input_offset + if model.w2nn_input_offset then + input_offset = model.w2nn_input_offset + else + input_offset = math.ceil(offset / p.inner_scale) + end local input_block_size = block_size local process_size = input_block_size - input_offset * 2 local h_blocks = math.floor(p.x_h / process_size) + diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 7004477..4073b31 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -519,12 +519,12 @@ function srcnn.fcn_v1(backend, ch) model:add(w2nn.InplaceClip01()) model:add(nn.View(-1):setNumInputDims(3)) - model.w2nn_arch_name = "fcn_v1" model.w2nn_offset = 36 model.w2nn_scale_factor = 1 model.w2nn_channels = ch model.w2nn_input_size = 120 + model.w2nn_gcn = true return model end diff --git a/train.lua b/train.lua index cf6850e..fe20f89 100644 --- a/train.lua +++ b/train.lua @@ -192,6 +192,7 @@ local function transform_pool_init(has_resize, offset) negate_x_rate = settings.random_pairwise_negate_x_rate end local conf = tablex.update({ + gcn = settings.gcn, max_size = settings.max_size, active_cropping_rate = active_cropping_rate, active_cropping_tries = active_cropping_tries, @@ -432,6 +433,11 @@ local function train() settings.crop_size = model.w2nn_input_size end end + if model.w2nn_gcn then + settings.gcn = true + else + settings.gcn = false + end dir.makepath(settings.model_dir) local offset = reconstruct.offset_size(model)