1
0
Fork 0
mirror of synced 2024-09-28 07:12:07 +12:00

Allow --user to be specified multiple times

This commit is contained in:
Serene-Arc 2021-05-27 15:22:58 +10:00
parent 346df4726d
commit 6b78a23484
6 changed files with 36 additions and 28 deletions

View file

@ -26,7 +26,7 @@ _common_options = [
click.option('--saved', is_flag=True, default=None), click.option('--saved', is_flag=True, default=None),
click.option('--search', default=None, type=str), click.option('--search', default=None, type=str),
click.option('--time-format', type=str, default=None), click.option('--time-format', type=str, default=None),
click.option('-u', '--user', type=str, default=None), click.option('-u', '--user', type=str, multiple=True, default=None),
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None), click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new',
'controversial', 'rising', 'relevance')), default=None), 'controversial', 'rising', 'relevance')), default=None),

View file

@ -14,7 +14,6 @@ from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
from bdfr.configuration import Configuration from bdfr.configuration import Configuration
from bdfr.downloader import RedditDownloader
from bdfr.connector import RedditConnector from bdfr.connector import RedditConnector
from bdfr.exceptions import ArchiverError from bdfr.exceptions import ArchiverError
from bdfr.resource import Resource from bdfr.resource import Resource
@ -47,8 +46,9 @@ class Archiver(RedditConnector):
results = super(Archiver, self).get_user_data() results = super(Archiver, self).get_user_data()
if self.args.user and self.args.all_comments: if self.args.user and self.args.all_comments:
sort = self.determine_sort_function() sort = self.determine_sort_function()
logger.debug(f'Retrieving comments of user {self.args.user}') for user in self.args.user:
results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit)) logger.debug(f'Retrieving comments of user {user}')
results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit))
return results return results
@staticmethod @staticmethod

View file

@ -35,7 +35,7 @@ class Configuration(Namespace):
self.time: str = 'all' self.time: str = 'all'
self.time_format = None self.time_format = None
self.upvoted: bool = False self.upvoted: bool = False
self.user: Optional[str] = None self.user: list[str] = []
self.verbose: int = 0 self.verbose: int = 0
self.make_hard_links = False self.make_hard_links = False

View file

@ -74,7 +74,7 @@ class RedditConnector(metaclass=ABCMeta):
logger.log(9, 'Create file name formatter') logger.log(9, 'Create file name formatter')
self.create_reddit_instance() self.create_reddit_instance()
self.resolve_user_name() self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
self.excluded_submission_ids = self.read_excluded_ids() self.excluded_submission_ids = self.read_excluded_ids()
@ -256,14 +256,16 @@ class RedditConnector(metaclass=ABCMeta):
else: else:
return [] return []
def resolve_user_name(self): def resolve_user_name(self, in_name: str) -> str:
if self.args.user == 'me': if in_name == 'me':
if self.authenticated: if self.authenticated:
self.args.user = self.reddit_instance.user.me().name resolved_name = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {self.args.user}') logger.log(9, f'Resolved user to {resolved_name}')
return resolved_name
else: else:
self.args.user = None
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used') logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
else:
return in_name
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]: def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = [] supplied_submissions = []
@ -289,10 +291,13 @@ class RedditConnector(metaclass=ABCMeta):
def get_multireddits(self) -> list[Iterator]: def get_multireddits(self) -> list[Iterator]:
if self.args.multireddit: if self.args.multireddit:
if len(self.args.user) != 1:
logger.error(f'Only 1 user can be supplied when retrieving from multireddits')
return []
out = [] out = []
for multi in self.split_args_input(self.args.multireddit): for multi in self.split_args_input(self.args.multireddit):
try: try:
multi = self.reddit_instance.multireddit(self.args.user, multi) multi = self.reddit_instance.multireddit(self.args.user[0], multi)
if not multi.subreddits: if not multi.subreddits:
raise errors.BulkDownloaderException raise errors.BulkDownloaderException
out.append(self.create_filtered_listing_generator(multi)) out.append(self.create_filtered_listing_generator(multi))
@ -312,31 +317,31 @@ class RedditConnector(metaclass=ABCMeta):
def get_user_data(self) -> list[Iterator]: def get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]): if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if self.args.user: if not self.args.user:
logger.warning('At least one user must be supplied to download user data')
return []
generators = []
for user in self.args.user:
try: try:
self.check_user_existence(self.args.user) self.check_user_existence(user)
except errors.BulkDownloaderException as e: except errors.BulkDownloaderException as e:
logger.error(e) logger.error(e)
return [] continue
generators = []
if self.args.submitted: if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}') logger.debug(f'Retrieving submitted posts of user {self.args.user}')
generators.append(self.create_filtered_listing_generator( generators.append(self.create_filtered_listing_generator(
self.reddit_instance.redditor(self.args.user).submissions, self.reddit_instance.redditor(user).submissions,
)) ))
if not self.authenticated and any((self.args.upvoted, self.args.saved)): if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication') logger.warning('Accessing user lists requires authentication')
else: else:
if self.args.upvoted: if self.args.upvoted:
logger.debug(f'Retrieving upvoted posts of user {self.args.user}') logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit)) generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit))
if self.args.saved: if self.args.saved:
logger.debug(f'Retrieving saved posts of user {self.args.user}') logger.debug(f'Retrieving saved posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit)) generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit))
return generators return generators
else:
logger.warning('A user must be supplied to download user data')
return []
else: else:
return [] return []

View file

@ -34,7 +34,7 @@ def downloader_mock(args: Configuration):
return downloader_mock return downloader_mock
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]): def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list:
results = [sub for res in results for sub in res] results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results]) assert all([isinstance(res, praw.models.Submission) for res in results])
if result_limit is not None: if result_limit is not None:
@ -232,7 +232,7 @@ def test_get_multireddits_public(
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = test_user downloader_mock.args.user = [test_user]
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.create_filtered_listing_generator.return_value = \ downloader_mock.create_filtered_listing_generator.return_value = \
RedditConnector.create_filtered_listing_generator( RedditConnector.create_filtered_listing_generator(
@ -257,7 +257,7 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.submitted = True downloader_mock.args.submitted = True
downloader_mock.args.user = test_user downloader_mock.args.user = [test_user]
downloader_mock.authenticated = False downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.create_filtered_listing_generator.return_value = \ downloader_mock.create_filtered_listing_generator.return_value = \
@ -284,11 +284,10 @@ def test_get_user_authenticated_lists(
): ):
downloader_mock.args.__dict__[test_flag] = True downloader_mock.args.__dict__[test_flag] = True
downloader_mock.reddit_instance = authenticated_reddit_instance downloader_mock.reddit_instance = authenticated_reddit_instance
downloader_mock.args.user = 'me'
downloader_mock.args.limit = 10 downloader_mock.args.limit = 10
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
RedditConnector.resolve_user_name(downloader_mock) downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')]
results = RedditConnector.get_user_data(downloader_mock) results = RedditConnector.get_user_data(downloader_mock)
assert_all_results_are_submissions(10, results) assert_all_results_are_submissions(10, results)

View file

@ -117,6 +117,7 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa
@pytest.mark.authenticated @pytest.mark.authenticated
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize('test_args', (
['--user', 'djnish', '--submitted', '--user', 'FriesWithThat', '-L', 10],
['--user', 'me', '--upvoted', '--authenticate', '-L', 10], ['--user', 'me', '--upvoted', '--authenticate', '-L', 10],
['--user', 'me', '--saved', '--authenticate', '-L', 10], ['--user', 'me', '--saved', '--authenticate', '-L', 10],
['--user', 'me', '--submitted', '--authenticate', '-L', 10], ['--user', 'me', '--submitted', '--authenticate', '-L', 10],
@ -231,6 +232,7 @@ def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path):
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize('test_args', (
['--user', 'me', '--authenticate', '--all-comments', '-L', '10'], ['--user', 'me', '--authenticate', '--all-comments', '-L', '10'],
['--user', 'me', '--user', 'djnish', '--authenticate', '--all-comments', '-L', '10'],
)) ))
def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path): def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
@ -265,12 +267,14 @@ def test_cli_archive_long(test_args: list[str], tmp_path: Path):
['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10], ['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10],
['--subreddit', 'submitters', '-L', 10], # Private subreddit ['--subreddit', 'submitters', '-L', 10], # Private subreddit
['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit ['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit
['--user', 'djnish', '--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10],
)) ))
def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path): def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Downloaded' not in result.output
@pytest.mark.online @pytest.mark.online