diff --git a/convert_data.lua b/convert_data.lua index 4cebfbd..181b292 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -1,12 +1,13 @@ +require 'pl' local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path -require 'pl' require 'image' local compression = require 'compression' local settings = require 'settings' local image_loader = require 'image_loader' local iproc = require 'iproc' +local alpha_util = require 'alpha_util' local function crop_if_large(src, max_size) local tries = 4 @@ -35,25 +36,24 @@ local function load_images(list) local line = lines[i] local im, meta = image_loader.load_byte(line) if meta and meta.alpha then - io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line)) - else - if settings.max_training_image_size > 0 then - im = crop_if_large(im, settings.max_training_image_size) - end - im = iproc.crop_mod4(im) - local scale = 1.0 - if settings.random_half_rate > 0.0 then - scale = 2.0 - end - if im then - if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then - table.insert(x, compression.compress(im)) - else - io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN)) - end + im = alpha_util.fill(im, meta.alpha, torch.random(0, 1)) + end + if settings.max_training_image_size > 0 then + im = crop_if_large(im, settings.max_training_image_size) + end + im = iproc.crop_mod4(im) + local scale = 1.0 + if settings.random_half_rate > 0.0 then + scale = 2.0 + end + if im then + if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then + table.insert(x, compression.compress(im)) else - io.stderr:write(string.format("\n%s: skip: load error.\n", line)) + io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN)) end + else + io.stderr:write(string.format("\n%s: skip: load error.\n", line)) end xlua.progress(i, #lines) if i % 10 == 0 then diff --git a/lib/alpha_util.lua b/lib/alpha_util.lua index 82305eb..ad8266f 100644 --- a/lib/alpha_util.lua +++ b/lib/alpha_util.lua @@ -63,6 +63,26 @@ function alpha_util.composite(rgb, alpha, model2x) out[4]:copy(alpha) return out end +function alpha_util.fill(fg, alpha, val) + assert(fg:size(2) == alpha:size(2) and fg:size(3) == alpha:size(3)) + local conversion = false + fg, conversion = iproc.byte2float(fg) + val = val or 0 + fg = fg:clone() + bg = fg:clone():fill(val) + bg[1]:cmul(1-alpha) + bg[2]:cmul(1-alpha) + bg[3]:cmul(1-alpha) + fg[1]:cmul(alpha) + fg[2]:cmul(alpha) + fg[3]:cmul(alpha) + + local ret = bg:add(fg) + if conversion then + ret = iproc.float2byte(ret) + end + return ret +end local function test() require 'sys'