Fix crop bug in rare case
This commit is contained in:
parent
1db2eeb788
commit
763f5ddcab
|
@ -125,8 +125,14 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
|
|||
t = "byte"
|
||||
end
|
||||
if p < r then
|
||||
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
|
||||
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
|
||||
local xi = 0
|
||||
local yi = 0
|
||||
if x:size(2) > size + 1 then
|
||||
xi = torch.random(0, x:size(2) - (size + 1)) * scale
|
||||
end
|
||||
if x:size(3) > size + 1 then
|
||||
yi = torch.random(0, x:size(3) - (size + 1)) * scale
|
||||
end
|
||||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
|
||||
return xc, yc
|
||||
|
|
|
@ -127,6 +127,8 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialConvolution = SpatialConvolution
|
||||
|
||||
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
|
||||
|
@ -136,6 +138,8 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialFullConvolution = SpatialFullConvolution
|
||||
|
||||
local function ReLU(backend)
|
||||
if backend == "cunn" then
|
||||
return nn.ReLU(true)
|
||||
|
@ -145,6 +149,8 @@ local function ReLU(backend)
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.ReLU = ReLU
|
||||
|
||||
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
||||
|
@ -154,6 +160,7 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialMaxPooling = SpatialMaxPooling
|
||||
|
||||
-- VGG style net(7 layers)
|
||||
function srcnn.vgg_7(backend, ch)
|
||||
|
|
12
train.lua
12
train.lua
|
@ -418,7 +418,10 @@ local function plot(train, valid)
|
|||
{'validation', torch.Tensor(valid), '-'}})
|
||||
end
|
||||
local function train()
|
||||
local x = remove_small_image(torch.load(settings.images))
|
||||
local x = torch.load(settings.images)
|
||||
if settings.method ~= "user" then
|
||||
x = remove_small_image(x)
|
||||
end
|
||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||
local hist_train = {}
|
||||
local hist_valid = {}
|
||||
|
@ -426,7 +429,12 @@ local function train()
|
|||
if settings.resume:len() > 0 then
|
||||
model = torch.load(settings.resume, "ascii")
|
||||
else
|
||||
model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||
if stringx.endswith(settings.model, ".lua") then
|
||||
local create_model = dofile(settings.model)
|
||||
model = create_model(srcnn, settings)
|
||||
else
|
||||
model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||
end
|
||||
end
|
||||
if model.w2nn_input_size then
|
||||
if settings.crop_size ~= model.w2nn_input_size then
|
||||
|
|
Loading…
Reference in a new issue