From a620ae91a18013d1e172d628d66b65980e16c0e0 Mon Sep 17 00:00:00 2001 From: Serene-Arc Date: Fri, 18 Feb 2022 10:21:52 +1000 Subject: [PATCH] Add --subscribed option --- bdfr/__main__.py | 1 + bdfr/configuration.py | 1 + bdfr/connector.py | 20 +++++++---- bdfr/default_config.cfg | 2 +- .../test_download_integration.py | 35 ++++++++++--------- tests/test_connector.py | 18 +++++++++- 6 files changed, 52 insertions(+), 25 deletions(-) diff --git a/bdfr/__main__.py b/bdfr/__main__.py index de658de..56ffb0f 100644 --- a/bdfr/__main__.py +++ b/bdfr/__main__.py @@ -23,6 +23,7 @@ _common_options = [ click.option('--saved', is_flag=True, default=None), click.option('--search', default=None, type=str), click.option('--submitted', is_flag=True, default=None), + click.option('--subscribed', is_flag=True, default=None), click.option('--time-format', type=str, default=None), click.option('--upvoted', is_flag=True, default=None), click.option('-L', '--limit', default=None, type=int), diff --git a/bdfr/configuration.py b/bdfr/configuration.py index 81fa3e4..ef24e36 100644 --- a/bdfr/configuration.py +++ b/bdfr/configuration.py @@ -35,6 +35,7 @@ class Configuration(Namespace): self.skip_subreddit: list[str] = [] self.sort: str = 'hot' self.submitted: bool = False + self.subscribed: bool = True self.subreddit: list[str] = [] self.time: str = 'all' self.time_format = None diff --git a/bdfr/connector.py b/bdfr/connector.py index 506e23f..e04d9ef 100644 --- a/bdfr/connector.py +++ b/bdfr/connector.py @@ -243,9 +243,19 @@ class RedditConnector(metaclass=ABCMeta): return set(all_entries) def get_subreddits(self) -> list[praw.models.ListingGenerator]: - if self.args.subreddit: - out = [] - for reddit in self.split_args_input(self.args.subreddit): + out = [] + subscribed_subreddits = set() + if self.args.subscribed: + if self.args.authenticate: + try: + subscribed_subreddits = list(self.reddit_instance.user.subreddits(limit=None)) + subscribed_subreddits = set([s.display_name for s in subscribed_subreddits]) + except prawcore.InsufficientScope: + logger.error('BDFR has insufficient scope to access subreddit lists') + else: + logger.error('Cannot find subscribed subreddits without an authenticated instance') + if self.args.subreddit or subscribed_subreddits: + for reddit in self.split_args_input(self.args.subreddit) | subscribed_subreddits: if reddit == 'friends' and self.authenticated is False: logger.error('Cannot read friends subreddit without an authenticated instance') continue @@ -270,9 +280,7 @@ class RedditConnector(metaclass=ABCMeta): logger.debug(f'Added submissions from subreddit {reddit}') except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') - return out - else: - return [] + return out def resolve_user_name(self, in_name: str) -> str: if in_name == 'me': diff --git a/bdfr/default_config.cfg b/bdfr/default_config.cfg index b8039a9..c601152 100644 --- a/bdfr/default_config.cfg +++ b/bdfr/default_config.cfg @@ -1,7 +1,7 @@ [DEFAULT] client_id = U-6gk4ZCh3IeNQ client_secret = 7CZHY6AmKweZME5s50SfDGylaPg -scopes = identity, history, read, save +scopes = identity, history, read, save, mysubreddits backup_log_count = 3 max_wait_time = 120 time_format = ISO \ No newline at end of file diff --git a/tests/integration_tests/test_download_integration.py b/tests/integration_tests/test_download_integration.py index bd53382..75216dd 100644 --- a/tests/integration_tests/test_download_integration.py +++ b/tests/integration_tests/test_download_integration.py @@ -31,23 +31,23 @@ def create_basic_args_for_download_runner(test_args: list[str], run_path: Path): @pytest.mark.reddit @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.parametrize('test_args', ( - ['-s', 'Mindustry', '-L', 1], - ['-s', 'r/Mindustry', '-L', 1], - ['-s', 'r/mindustry', '-L', 1], - ['-s', 'mindustry', '-L', 1], - ['-s', 'https://www.reddit.com/r/TrollXChromosomes/', '-L', 1], - ['-s', 'r/TrollXChromosomes/', '-L', 1], - ['-s', 'TrollXChromosomes/', '-L', 1], - ['-s', 'trollxchromosomes', '-L', 1], - ['-s', 'trollxchromosomes,mindustry,python', '-L', 1], - ['-s', 'trollxchromosomes, mindustry, python', '-L', 1], - ['-s', 'trollxchromosomes', '-L', 1, '--time', 'day'], - ['-s', 'trollxchromosomes', '-L', 1, '--sort', 'new'], - ['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--sort', 'new'], - ['-s', 'trollxchromosomes', '-L', 1, '--search', 'women'], - ['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--search', 'women'], - ['-s', 'trollxchromosomes', '-L', 1, '--sort', 'new', '--search', 'women'], - ['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--sort', 'new', '--search', 'women'], + ['-s', 'Mindustry', '-L', 3], + ['-s', 'r/Mindustry', '-L', 3], + ['-s', 'r/mindustry', '-L', 3], + ['-s', 'mindustry', '-L', 3], + ['-s', 'https://www.reddit.com/r/TrollXChromosomes/', '-L', 3], + ['-s', 'r/TrollXChromosomes/', '-L', 3], + ['-s', 'TrollXChromosomes/', '-L', 3], + ['-s', 'trollxchromosomes', '-L', 3], + ['-s', 'trollxchromosomes,mindustry,python', '-L', 3], + ['-s', 'trollxchromosomes, mindustry, python', '-L', 3], + ['-s', 'trollxchromosomes', '-L', 3, '--time', 'day'], + ['-s', 'trollxchromosomes', '-L', 3, '--sort', 'new'], + ['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--sort', 'new'], + ['-s', 'trollxchromosomes', '-L', 3, '--search', 'women'], + ['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--search', 'women'], + ['-s', 'trollxchromosomes', '-L', 3, '--sort', 'new', '--search', 'women'], + ['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--sort', 'new', '--search', 'women'], )) def test_cli_download_subreddits(test_args: list[str], tmp_path: Path): runner = CliRunner() @@ -64,6 +64,7 @@ def test_cli_download_subreddits(test_args: list[str], tmp_path: Path): @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.parametrize('test_args', ( ['-s', 'hentai', '-L', 10, '--search', 'red', '--authenticate'], + ['--authenticate', '--subscribed', '-L', 10], )) def test_cli_download_search_subreddits_authenticated(test_args: list[str], tmp_path: Path): runner = CliRunner() diff --git a/tests/test_connector.py b/tests/test_connector.py index 9fe58f2..3a10757 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -336,13 +336,29 @@ def test_get_user_authenticated_lists( downloader_mock.args.__dict__[test_flag] = True downloader_mock.reddit_instance = authenticated_reddit_instance downloader_mock.args.limit = 10 - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')] results = RedditConnector.get_user_data(downloader_mock) assert_all_results_are_submissions_or_comments(10, results) +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.authenticated +def test_get_subscribed_subreddits(downloader_mock: MagicMock, authenticated_reddit_instance: praw.Reddit): + downloader_mock.reddit_instance = authenticated_reddit_instance + downloader_mock.args.limit = 10 + downloader_mock.args.authenticate = True + downloader_mock.args.subscribed = True + downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + results = RedditConnector.get_subreddits(downloader_mock) + assert all([isinstance(s, praw.models.ListingGenerator) for s in results]) + assert len(results) > 0 + + @pytest.mark.parametrize(('test_name', 'expected'), ( ('Mindustry', 'Mindustry'), ('Futurology', 'Futurology'),