Allow --user to be specified multiple times
This commit is contained in:
parent
346df4726d
commit
6b78a23484
6 changed files with 36 additions and 28 deletions
|
@ -26,7 +26,7 @@ _common_options = [
|
|||
click.option('--saved', is_flag=True, default=None),
|
||||
click.option('--search', default=None, type=str),
|
||||
click.option('--time-format', type=str, default=None),
|
||||
click.option('-u', '--user', type=str, default=None),
|
||||
click.option('-u', '--user', type=str, multiple=True, default=None),
|
||||
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
|
||||
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new',
|
||||
'controversial', 'rising', 'relevance')), default=None),
|
||||
|
|
|
@ -14,7 +14,6 @@ from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
|
|||
from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
|
||||
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
|
||||
from bdfr.configuration import Configuration
|
||||
from bdfr.downloader import RedditDownloader
|
||||
from bdfr.connector import RedditConnector
|
||||
from bdfr.exceptions import ArchiverError
|
||||
from bdfr.resource import Resource
|
||||
|
@ -47,8 +46,9 @@ class Archiver(RedditConnector):
|
|||
results = super(Archiver, self).get_user_data()
|
||||
if self.args.user and self.args.all_comments:
|
||||
sort = self.determine_sort_function()
|
||||
logger.debug(f'Retrieving comments of user {self.args.user}')
|
||||
results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit))
|
||||
for user in self.args.user:
|
||||
logger.debug(f'Retrieving comments of user {user}')
|
||||
results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -35,7 +35,7 @@ class Configuration(Namespace):
|
|||
self.time: str = 'all'
|
||||
self.time_format = None
|
||||
self.upvoted: bool = False
|
||||
self.user: Optional[str] = None
|
||||
self.user: list[str] = []
|
||||
self.verbose: int = 0
|
||||
self.make_hard_links = False
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ class RedditConnector(metaclass=ABCMeta):
|
|||
logger.log(9, 'Create file name formatter')
|
||||
|
||||
self.create_reddit_instance()
|
||||
self.resolve_user_name()
|
||||
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
|
||||
|
||||
self.excluded_submission_ids = self.read_excluded_ids()
|
||||
|
||||
|
@ -256,14 +256,16 @@ class RedditConnector(metaclass=ABCMeta):
|
|||
else:
|
||||
return []
|
||||
|
||||
def resolve_user_name(self):
|
||||
if self.args.user == 'me':
|
||||
def resolve_user_name(self, in_name: str) -> str:
|
||||
if in_name == 'me':
|
||||
if self.authenticated:
|
||||
self.args.user = self.reddit_instance.user.me().name
|
||||
logger.log(9, f'Resolved user to {self.args.user}')
|
||||
resolved_name = self.reddit_instance.user.me().name
|
||||
logger.log(9, f'Resolved user to {resolved_name}')
|
||||
return resolved_name
|
||||
else:
|
||||
self.args.user = None
|
||||
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
|
||||
else:
|
||||
return in_name
|
||||
|
||||
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||
supplied_submissions = []
|
||||
|
@ -289,10 +291,13 @@ class RedditConnector(metaclass=ABCMeta):
|
|||
|
||||
def get_multireddits(self) -> list[Iterator]:
|
||||
if self.args.multireddit:
|
||||
if len(self.args.user) != 1:
|
||||
logger.error(f'Only 1 user can be supplied when retrieving from multireddits')
|
||||
return []
|
||||
out = []
|
||||
for multi in self.split_args_input(self.args.multireddit):
|
||||
try:
|
||||
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
||||
multi = self.reddit_instance.multireddit(self.args.user[0], multi)
|
||||
if not multi.subreddits:
|
||||
raise errors.BulkDownloaderException
|
||||
out.append(self.create_filtered_listing_generator(multi))
|
||||
|
@ -312,31 +317,31 @@ class RedditConnector(metaclass=ABCMeta):
|
|||
|
||||
def get_user_data(self) -> list[Iterator]:
|
||||
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
|
||||
if self.args.user:
|
||||
try:
|
||||
self.check_user_existence(self.args.user)
|
||||
except errors.BulkDownloaderException as e:
|
||||
logger.error(e)
|
||||
if not self.args.user:
|
||||
logger.warning('At least one user must be supplied to download user data')
|
||||
return []
|
||||
generators = []
|
||||
for user in self.args.user:
|
||||
try:
|
||||
self.check_user_existence(user)
|
||||
except errors.BulkDownloaderException as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
if self.args.submitted:
|
||||
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
|
||||
generators.append(self.create_filtered_listing_generator(
|
||||
self.reddit_instance.redditor(self.args.user).submissions,
|
||||
self.reddit_instance.redditor(user).submissions,
|
||||
))
|
||||
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
||||
logger.warning('Accessing user lists requires authentication')
|
||||
else:
|
||||
if self.args.upvoted:
|
||||
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
|
||||
generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit))
|
||||
generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit))
|
||||
if self.args.saved:
|
||||
logger.debug(f'Retrieving saved posts of user {self.args.user}')
|
||||
generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit))
|
||||
generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit))
|
||||
return generators
|
||||
else:
|
||||
logger.warning('A user must be supplied to download user data')
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ def downloader_mock(args: Configuration):
|
|||
return downloader_mock
|
||||
|
||||
|
||||
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
|
||||
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list:
|
||||
results = [sub for res in results for sub in res]
|
||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||
if result_limit is not None:
|
||||
|
@ -232,7 +232,7 @@ def test_get_multireddits_public(
|
|||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.multireddit = test_multireddits
|
||||
downloader_mock.args.user = test_user
|
||||
downloader_mock.args.user = [test_user]
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.create_filtered_listing_generator.return_value = \
|
||||
RedditConnector.create_filtered_listing_generator(
|
||||
|
@ -257,7 +257,7 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
|
|||
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.args.submitted = True
|
||||
downloader_mock.args.user = test_user
|
||||
downloader_mock.args.user = [test_user]
|
||||
downloader_mock.authenticated = False
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.create_filtered_listing_generator.return_value = \
|
||||
|
@ -284,11 +284,10 @@ def test_get_user_authenticated_lists(
|
|||
):
|
||||
downloader_mock.args.__dict__[test_flag] = True
|
||||
downloader_mock.reddit_instance = authenticated_reddit_instance
|
||||
downloader_mock.args.user = 'me'
|
||||
downloader_mock.args.limit = 10
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
RedditConnector.resolve_user_name(downloader_mock)
|
||||
downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')]
|
||||
results = RedditConnector.get_user_data(downloader_mock)
|
||||
assert_all_results_are_submissions(10, results)
|
||||
|
||||
|
|
|
@ -117,6 +117,7 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa
|
|||
@pytest.mark.authenticated
|
||||
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--user', 'djnish', '--submitted', '--user', 'FriesWithThat', '-L', 10],
|
||||
['--user', 'me', '--upvoted', '--authenticate', '-L', 10],
|
||||
['--user', 'me', '--saved', '--authenticate', '-L', 10],
|
||||
['--user', 'me', '--submitted', '--authenticate', '-L', 10],
|
||||
|
@ -231,6 +232,7 @@ def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path):
|
|||
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--user', 'me', '--authenticate', '--all-comments', '-L', '10'],
|
||||
['--user', 'me', '--user', 'djnish', '--authenticate', '--all-comments', '-L', '10'],
|
||||
))
|
||||
def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
|
@ -265,12 +267,14 @@ def test_cli_archive_long(test_args: list[str], tmp_path: Path):
|
|||
['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10],
|
||||
['--subreddit', 'submitters', '-L', 10], # Private subreddit
|
||||
['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit
|
||||
['--user', 'djnish', '--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10],
|
||||
))
|
||||
def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = create_basic_args_for_download_runner(test_args, tmp_path)
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Downloaded' not in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
|
|
Loading…
Reference in a new issue