From 6b78a23484d78bb780a85370007bc98b09ff173a Mon Sep 17 00:00:00 2001 From: Serene-Arc Date: Thu, 27 May 2021 15:22:58 +1000 Subject: [PATCH] Allow --user to be specified multiple times --- bdfr/__main__.py | 2 +- bdfr/archiver.py | 6 +++--- bdfr/configuration.py | 2 +- bdfr/connector.py | 41 ++++++++++++++++++++++----------------- tests/test_connector.py | 9 ++++----- tests/test_integration.py | 4 ++++ 6 files changed, 36 insertions(+), 28 deletions(-) diff --git a/bdfr/__main__.py b/bdfr/__main__.py index 28ef207..cf039a5 100644 --- a/bdfr/__main__.py +++ b/bdfr/__main__.py @@ -26,7 +26,7 @@ _common_options = [ click.option('--saved', is_flag=True, default=None), click.option('--search', default=None, type=str), click.option('--time-format', type=str, default=None), - click.option('-u', '--user', type=str, default=None), + click.option('-u', '--user', type=str, multiple=True, default=None), click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None), click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', 'controversial', 'rising', 'relevance')), default=None), diff --git a/bdfr/archiver.py b/bdfr/archiver.py index 3e0b907..b19a042 100644 --- a/bdfr/archiver.py +++ b/bdfr/archiver.py @@ -14,7 +14,6 @@ from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry from bdfr.configuration import Configuration -from bdfr.downloader import RedditDownloader from bdfr.connector import RedditConnector from bdfr.exceptions import ArchiverError from bdfr.resource import Resource @@ -47,8 +46,9 @@ class Archiver(RedditConnector): results = super(Archiver, self).get_user_data() if self.args.user and self.args.all_comments: sort = self.determine_sort_function() - logger.debug(f'Retrieving comments of user {self.args.user}') - results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit)) + for user in self.args.user: + logger.debug(f'Retrieving comments of user {user}') + results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit)) return results @staticmethod diff --git a/bdfr/configuration.py b/bdfr/configuration.py index 9ab9d45..446bc82 100644 --- a/bdfr/configuration.py +++ b/bdfr/configuration.py @@ -35,7 +35,7 @@ class Configuration(Namespace): self.time: str = 'all' self.time_format = None self.upvoted: bool = False - self.user: Optional[str] = None + self.user: list[str] = [] self.verbose: int = 0 self.make_hard_links = False diff --git a/bdfr/connector.py b/bdfr/connector.py index c20b749..6aec2f5 100644 --- a/bdfr/connector.py +++ b/bdfr/connector.py @@ -74,7 +74,7 @@ class RedditConnector(metaclass=ABCMeta): logger.log(9, 'Create file name formatter') self.create_reddit_instance() - self.resolve_user_name() + self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user])) self.excluded_submission_ids = self.read_excluded_ids() @@ -256,14 +256,16 @@ class RedditConnector(metaclass=ABCMeta): else: return [] - def resolve_user_name(self): - if self.args.user == 'me': + def resolve_user_name(self, in_name: str) -> str: + if in_name == 'me': if self.authenticated: - self.args.user = self.reddit_instance.user.me().name - logger.log(9, f'Resolved user to {self.args.user}') + resolved_name = self.reddit_instance.user.me().name + logger.log(9, f'Resolved user to {resolved_name}') + return resolved_name else: - self.args.user = None logger.warning('To use "me" as a user, an authenticated Reddit instance must be used') + else: + return in_name def get_submissions_from_link(self) -> list[list[praw.models.Submission]]: supplied_submissions = [] @@ -289,10 +291,13 @@ class RedditConnector(metaclass=ABCMeta): def get_multireddits(self) -> list[Iterator]: if self.args.multireddit: + if len(self.args.user) != 1: + logger.error(f'Only 1 user can be supplied when retrieving from multireddits') + return [] out = [] for multi in self.split_args_input(self.args.multireddit): try: - multi = self.reddit_instance.multireddit(self.args.user, multi) + multi = self.reddit_instance.multireddit(self.args.user[0], multi) if not multi.subreddits: raise errors.BulkDownloaderException out.append(self.create_filtered_listing_generator(multi)) @@ -312,31 +317,31 @@ class RedditConnector(metaclass=ABCMeta): def get_user_data(self) -> list[Iterator]: if any([self.args.submitted, self.args.upvoted, self.args.saved]): - if self.args.user: + if not self.args.user: + logger.warning('At least one user must be supplied to download user data') + return [] + generators = [] + for user in self.args.user: try: - self.check_user_existence(self.args.user) + self.check_user_existence(user) except errors.BulkDownloaderException as e: logger.error(e) - return [] - generators = [] + continue if self.args.submitted: logger.debug(f'Retrieving submitted posts of user {self.args.user}') generators.append(self.create_filtered_listing_generator( - self.reddit_instance.redditor(self.args.user).submissions, + self.reddit_instance.redditor(user).submissions, )) if not self.authenticated and any((self.args.upvoted, self.args.saved)): logger.warning('Accessing user lists requires authentication') else: if self.args.upvoted: logger.debug(f'Retrieving upvoted posts of user {self.args.user}') - generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit)) + generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit)) if self.args.saved: logger.debug(f'Retrieving saved posts of user {self.args.user}') - generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit)) - return generators - else: - logger.warning('A user must be supplied to download user data') - return [] + generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit)) + return generators else: return [] diff --git a/tests/test_connector.py b/tests/test_connector.py index 1078707..03d2668 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -34,7 +34,7 @@ def downloader_mock(args: Configuration): return downloader_mock -def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]): +def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list: results = [sub for res in results for sub in res] assert all([isinstance(res, praw.models.Submission) for res in results]) if result_limit is not None: @@ -232,7 +232,7 @@ def test_get_multireddits_public( downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.limit = limit downloader_mock.args.multireddit = test_multireddits - downloader_mock.args.user = test_user + downloader_mock.args.user = [test_user] downloader_mock.reddit_instance = reddit_instance downloader_mock.create_filtered_listing_generator.return_value = \ RedditConnector.create_filtered_listing_generator( @@ -257,7 +257,7 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.submitted = True - downloader_mock.args.user = test_user + downloader_mock.args.user = [test_user] downloader_mock.authenticated = False downloader_mock.reddit_instance = reddit_instance downloader_mock.create_filtered_listing_generator.return_value = \ @@ -284,11 +284,10 @@ def test_get_user_authenticated_lists( ): downloader_mock.args.__dict__[test_flag] = True downloader_mock.reddit_instance = authenticated_reddit_instance - downloader_mock.args.user = 'me' downloader_mock.args.limit = 10 downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT - RedditConnector.resolve_user_name(downloader_mock) + downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')] results = RedditConnector.get_user_data(downloader_mock) assert_all_results_are_submissions(10, results) diff --git a/tests/test_integration.py b/tests/test_integration.py index 7aec0eb..2ff1909 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -117,6 +117,7 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa @pytest.mark.authenticated @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.parametrize('test_args', ( + ['--user', 'djnish', '--submitted', '--user', 'FriesWithThat', '-L', 10], ['--user', 'me', '--upvoted', '--authenticate', '-L', 10], ['--user', 'me', '--saved', '--authenticate', '-L', 10], ['--user', 'me', '--submitted', '--authenticate', '-L', 10], @@ -231,6 +232,7 @@ def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path): @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.parametrize('test_args', ( ['--user', 'me', '--authenticate', '--all-comments', '-L', '10'], + ['--user', 'me', '--user', 'djnish', '--authenticate', '--all-comments', '-L', '10'], )) def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path): runner = CliRunner() @@ -265,12 +267,14 @@ def test_cli_archive_long(test_args: list[str], tmp_path: Path): ['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10], ['--subreddit', 'submitters', '-L', 10], # Private subreddit ['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit + ['--user', 'djnish', '--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10], )) def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 + assert 'Downloaded' not in result.output @pytest.mark.online