1
0
Fork 0
mirror of synced 2024-05-16 19:02:21 +12:00

Fix crop bug in rare case

This commit is contained in:
nagadomi 2017-02-12 17:46:07 +09:00
parent 1db2eeb788
commit 763f5ddcab
3 changed files with 25 additions and 4 deletions

View file

@ -125,8 +125,14 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
t = "byte"
end
if p < r then
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
local xi = 0
local yi = 0
if x:size(2) > size + 1 then
xi = torch.random(0, x:size(2) - (size + 1)) * scale
end
if x:size(3) > size + 1 then
yi = torch.random(0, x:size(3) - (size + 1)) * scale
end
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
return xc, yc

View file

@ -127,6 +127,8 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialConvolution = SpatialConvolution
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
if backend == "cunn" then
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
@ -136,6 +138,8 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialFullConvolution = SpatialFullConvolution
local function ReLU(backend)
if backend == "cunn" then
return nn.ReLU(true)
@ -145,6 +149,8 @@ local function ReLU(backend)
error("unsupported backend:" .. backend)
end
end
srcnn.ReLU = ReLU
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
@ -154,6 +160,7 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialMaxPooling = SpatialMaxPooling
-- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch)

View file

@ -418,7 +418,10 @@ local function plot(train, valid)
{'validation', torch.Tensor(valid), '-'}})
end
local function train()
local x = remove_small_image(torch.load(settings.images))
local x = torch.load(settings.images)
if settings.method ~= "user" then
x = remove_small_image(x)
end
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local hist_train = {}
local hist_valid = {}
@ -426,7 +429,12 @@ local function train()
if settings.resume:len() > 0 then
model = torch.load(settings.resume, "ascii")
else
model = srcnn.create(settings.model, settings.backend, settings.color)
if stringx.endswith(settings.model, ".lua") then
local create_model = dofile(settings.model)
model = create_model(srcnn, settings)
else
model = srcnn.create(settings.model, settings.backend, settings.color)
end
end
if model.w2nn_input_size then
if settings.crop_size ~= model.w2nn_input_size then