From 50fd999c38e2a78764c6c327889826e2e42c11b2 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Tue, 18 Oct 2016 19:54:00 +0900 Subject: [PATCH] Add -gpu option --- lib/settings.lua | 4 ++++ waifu2x.lua | 2 ++ 2 files changed, 6 insertions(+) diff --git a/lib/settings.lua b/lib/settings.lua index ff05bb9..c150410 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -1,6 +1,7 @@ require 'xlua' require 'pl' require 'trepl' +require 'cutorch' -- global settings @@ -63,6 +64,7 @@ cmd:option("-oracle_drop_rate", 0.5, '') cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))') cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') +cmd:option("-gpu", 1, 'Device ID') local function to_bool(settings, name) if settings[name] == 1 then @@ -152,4 +154,6 @@ end settings.images = string.format("%s/images.t7", settings.data_dir) settings.image_list = string.format("%s/image_list.txt", settings.data_dir) +cutorch.setDevice(opt.gpu) + return settings diff --git a/waifu2x.lua b/waifu2x.lua index 6ffa98d..8e48937 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -267,6 +267,7 @@ local function waifu2x() cmd:option("-tta_level", 8, 'TTA level (2|4|8). A higher value makes better quality output but slow') cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)') cmd:option("-q", 0, 'quiet (0|1)') + cmd:option("-gpu", 1, 'Device ID') local opt = cmd:parse(arg) if opt.method:len() > 0 then @@ -292,5 +293,6 @@ local function waifu2x() else convert_frames(opt) end + cutorch.setDevice(opt.gpu) end waifu2x()