diff --git a/web.lua b/web.lua index a6d220b..c6d316d 100644 --- a/web.lua +++ b/web.lua @@ -12,6 +12,7 @@ local iproc = require 'iproc' local reconstruct = require 'reconstruct' local image_loader = require 'image_loader' local alpha_util = require 'alpha_util' +local compression = require 'compression' local gm = require 'graphicsmagick' -- Note: turbo and xlua has different implementation of string:split(). @@ -34,6 +35,7 @@ cmd:option("-curl_request_timeout", 60, "request_timeout for curl") cmd:option("-curl_connect_timeout", 60, "connect_timeout for curl") cmd:option("-curl_max_redirects", 2, "max_redirects for curl") cmd:option("-max_body_size", 5 * 1024 * 1024, "maximum allowed size for uploaded files") +cmd:option("-cache_max", 200, "number of cached images on RAM") local opt = cmd:parse(arg) cutorch.setDevice(opt.gpu) @@ -75,6 +77,7 @@ local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could us local CACHE_DIR = path.join(ROOT, "cache") local MAX_NOISE_IMAGE = opt.max_pixels local MAX_SCALE_IMAGE = (math.sqrt(opt.max_pixels) / 2)^2 +local PNG_DEPTH = 8 local CURL_OPTIONS = { request_timeout = opt.curl_request_timeout, connect_timeout = opt.curl_connect_timeout, @@ -167,26 +170,73 @@ local function cleanup_model(model) model:clearState() -- release GPU memory end end -local function convert(x, meta, options) - local cache_file = path.join(CACHE_DIR, options.prefix .. ".png") - local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png") - local alpha = meta.alpha - local alpha_orig = alpha - if path.exists(alpha_cache_file) then - alpha = image_loader.load_float(alpha_cache_file) - if alpha:dim() == 2 then - alpha = alpha:reshape(1, alpha:size(1), alpha:size(2)) - end - if alpha:size(1) == 3 then - alpha = image.rgb2y(alpha) +-- cache +local g_cache = {} +local function cache_count() + local count = 0 + for _ in pairs(g_cache) do + count = count + 1 + end + return count +end +local function cache_remove_old() + local old_time = nil + local old_key = nil + for k, v in pairs(g_cache) do + if old_time == nil or old_time > v.updated_at then + old_key = k + old_time = v.updated_at end end - if path.exists(cache_file) then - x = image_loader.load_float(cache_file) + if old_key then + g_cache[old_key] = nil + end +end +local function cache_compress(raw_image) + if raw_image then + compressed_image = compression.compress(iproc.float2byte(raw_image)) + return compressed_image + else + return nil + end +end +local function cache_decompress(compressed_image) + if compressed_image then + local raw_image = compression.decompress(compressed_image) + return iproc.byte2float(raw_image) + else + return nil + end +end +local function cache_get(filename) + local cache = g_cache[filename] + if cache then + return {image = cache_decompress(cache.image), + alpha = cache_decompress(cache.alpha)} + else + return nil + end +end +local function cache_put(filename, image, alpha) + g_cache[filename] = {image = cache_compress(image), + alpha = cache_compress(alpha), + updated_at = os.time()}; + local count = cache_count(g_cache) + if count > opt.cache_max then + cache_remove_old() + end +end +local function convert(x, meta, options) + local cache_file = path.join(CACHE_DIR, options.prefix .. ".png") + local alpha = meta.alpha + local alpha_orig = alpha + local cache = cache_get(cache_file) + + if cache then meta = tablex.copy(meta) - meta.alpha = alpha - return x, meta + meta.alpha = cache.alpha + return cache.image, meta else local model = nil if options.style == "art" then @@ -209,7 +259,6 @@ local function convert(x, meta, options) if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then alpha = reconstruct.scale(model.scale, 2.0, alpha, opt.crop_size, opt.batch_size) - image_loader.save_png(alpha_cache_file, alpha) cleanup_model(model.scale) end end @@ -223,7 +272,7 @@ local function convert(x, meta, options) x, opt.crop_size, opt.batch_size) cleanup_model(model[options.method]) end - image_loader.save_png(cache_file, x) + cache_put(cache_file, x, alpha) meta = tablex.copy(meta) meta.alpha = alpha return x, meta @@ -321,7 +370,7 @@ function APIHandler:post() name = uuid() .. ".png" end local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha), - tablex.update({depth = 8, inplace = true}, meta)) + tablex.update({depth = PNG_DEPTH, inplace = true}, meta)) self:set_header("Content-Length", string.format("%d", #blob)) if download > 0 then