1
0
Fork 0
mirror of synced 2024-06-02 18:34:37 +12:00

Merge pull request #432 from Serene-Arc/enhancement_429

Allow --user to be specified multiple times
This commit is contained in:
Ali Parlakçı 2021-06-06 13:25:00 +03:00 committed by GitHub
commit a2f010c40d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 38 additions and 28 deletions

View file

@ -123,6 +123,8 @@ The following options are common between both the `archive` and `download` comma
- `-u, --user`
- This specifies the user to scrape in concert with other options
- When using `--authenticate`, `--user me` can be used to refer to the authenticated user
- Can be specified multiple times for multiple users
- If downloading a multireddit, only one user can be specified
- `-v, --verbose`
- Increases the verbosity of the program
- Can be specified multiple times

View file

@ -26,7 +26,7 @@ _common_options = [
click.option('--saved', is_flag=True, default=None),
click.option('--search', default=None, type=str),
click.option('--time-format', type=str, default=None),
click.option('-u', '--user', type=str, default=None),
click.option('-u', '--user', type=str, multiple=True, default=None),
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new',
'controversial', 'rising', 'relevance')), default=None),

View file

@ -14,7 +14,6 @@ from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
from bdfr.configuration import Configuration
from bdfr.downloader import RedditDownloader
from bdfr.connector import RedditConnector
from bdfr.exceptions import ArchiverError
from bdfr.resource import Resource
@ -47,8 +46,9 @@ class Archiver(RedditConnector):
results = super(Archiver, self).get_user_data()
if self.args.user and self.args.all_comments:
sort = self.determine_sort_function()
logger.debug(f'Retrieving comments of user {self.args.user}')
results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit))
for user in self.args.user:
logger.debug(f'Retrieving comments of user {user}')
results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit))
return results
@staticmethod

View file

@ -35,7 +35,7 @@ class Configuration(Namespace):
self.time: str = 'all'
self.time_format = None
self.upvoted: bool = False
self.user: Optional[str] = None
self.user: list[str] = []
self.verbose: int = 0
self.make_hard_links = False

View file

@ -74,7 +74,7 @@ class RedditConnector(metaclass=ABCMeta):
logger.log(9, 'Create file name formatter')
self.create_reddit_instance()
self.resolve_user_name()
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
self.excluded_submission_ids = self.read_excluded_ids()
@ -256,14 +256,16 @@ class RedditConnector(metaclass=ABCMeta):
else:
return []
def resolve_user_name(self):
if self.args.user == 'me':
def resolve_user_name(self, in_name: str) -> str:
if in_name == 'me':
if self.authenticated:
self.args.user = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {self.args.user}')
resolved_name = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {resolved_name}')
return resolved_name
else:
self.args.user = None
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
else:
return in_name
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
@ -289,10 +291,13 @@ class RedditConnector(metaclass=ABCMeta):
def get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
if len(self.args.user) != 1:
logger.error(f'Only 1 user can be supplied when retrieving from multireddits')
return []
out = []
for multi in self.split_args_input(self.args.multireddit):
try:
multi = self.reddit_instance.multireddit(self.args.user, multi)
multi = self.reddit_instance.multireddit(self.args.user[0], multi)
if not multi.subreddits:
raise errors.BulkDownloaderException
out.append(self.create_filtered_listing_generator(multi))
@ -312,31 +317,31 @@ class RedditConnector(metaclass=ABCMeta):
def get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if self.args.user:
if not self.args.user:
logger.warning('At least one user must be supplied to download user data')
return []
generators = []
for user in self.args.user:
try:
self.check_user_existence(self.args.user)
self.check_user_existence(user)
except errors.BulkDownloaderException as e:
logger.error(e)
return []
generators = []
continue
if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
generators.append(self.create_filtered_listing_generator(
self.reddit_instance.redditor(self.args.user).submissions,
self.reddit_instance.redditor(user).submissions,
))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication')
else:
if self.args.upvoted:
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit))
generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit))
if self.args.saved:
logger.debug(f'Retrieving saved posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit))
return generators
else:
logger.warning('A user must be supplied to download user data')
return []
generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit))
return generators
else:
return []

View file

@ -34,7 +34,7 @@ def downloader_mock(args: Configuration):
return downloader_mock
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list:
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
if result_limit is not None:
@ -232,7 +232,7 @@ def test_get_multireddits_public(
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = test_user
downloader_mock.args.user = [test_user]
downloader_mock.reddit_instance = reddit_instance
downloader_mock.create_filtered_listing_generator.return_value = \
RedditConnector.create_filtered_listing_generator(
@ -257,7 +257,7 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.submitted = True
downloader_mock.args.user = test_user
downloader_mock.args.user = [test_user]
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
downloader_mock.create_filtered_listing_generator.return_value = \
@ -284,11 +284,10 @@ def test_get_user_authenticated_lists(
):
downloader_mock.args.__dict__[test_flag] = True
downloader_mock.reddit_instance = authenticated_reddit_instance
downloader_mock.args.user = 'me'
downloader_mock.args.limit = 10
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
RedditConnector.resolve_user_name(downloader_mock)
downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')]
results = RedditConnector.get_user_data(downloader_mock)
assert_all_results_are_submissions(10, results)

View file

@ -117,6 +117,7 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa
@pytest.mark.authenticated
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'djnish', '--submitted', '--user', 'FriesWithThat', '-L', 10],
['--user', 'me', '--upvoted', '--authenticate', '-L', 10],
['--user', 'me', '--saved', '--authenticate', '-L', 10],
['--user', 'me', '--submitted', '--authenticate', '-L', 10],
@ -231,6 +232,7 @@ def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path):
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'me', '--authenticate', '--all-comments', '-L', '10'],
['--user', 'me', '--user', 'djnish', '--authenticate', '--all-comments', '-L', '10'],
))
def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path):
runner = CliRunner()
@ -265,12 +267,14 @@ def test_cli_archive_long(test_args: list[str], tmp_path: Path):
['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10],
['--subreddit', 'submitters', '-L', 10], # Private subreddit
['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit
['--user', 'djnish', '--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10],
))
def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Downloaded' not in result.output
@pytest.mark.online