Update RedditDownloader tests
This commit is contained in:
parent
28b7deb6d3
commit
6aab009204
|
@ -167,16 +167,24 @@ class RedditDownloader:
|
|||
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
|
||||
match = re.match(pattern, subreddit)
|
||||
if not match:
|
||||
raise errors.RedditAuthenticationError('')
|
||||
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
|
||||
return match.group(1)
|
||||
|
||||
@staticmethod
|
||||
def _split_args_input(subreddit_entries: list[str]) -> set[str]:
|
||||
all_subreddits = []
|
||||
split_pattern = re.compile(r'[,;]\s?')
|
||||
for entry in subreddit_entries:
|
||||
results = re.split(split_pattern, entry)
|
||||
all_subreddits.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results])
|
||||
return set(all_subreddits)
|
||||
|
||||
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
||||
if self.args.subreddit:
|
||||
out = []
|
||||
sort_function = self._determine_sort_function()
|
||||
for reddit in self.args.subreddit:
|
||||
for reddit in self._split_args_input(self.args.subreddit):
|
||||
try:
|
||||
reddit = self._sanitise_subreddit_name(reddit)
|
||||
reddit = self.reddit_instance.subreddit(reddit)
|
||||
if self.args.search:
|
||||
out.append(
|
||||
|
@ -228,9 +236,8 @@ class RedditDownloader:
|
|||
if self.args.multireddit:
|
||||
out = []
|
||||
sort_function = self._determine_sort_function()
|
||||
for multi in self.args.multireddit:
|
||||
for multi in self._split_args_input(self.args.multireddit):
|
||||
try:
|
||||
multi = self._sanitise_subreddit_name(multi)
|
||||
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
||||
if not multi.subreddits:
|
||||
raise errors.BulkDownloaderException
|
||||
|
|
|
@ -30,6 +30,8 @@ def args() -> Configuration:
|
|||
def downloader_mock(args: argparse.Namespace):
|
||||
mock_downloader = MagicMock()
|
||||
mock_downloader.args = args
|
||||
mock_downloader._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||
mock_downloader._split_args_input = RedditDownloader._split_args_input
|
||||
return mock_downloader
|
||||
|
||||
|
||||
|
@ -146,6 +148,7 @@ def test_get_submissions_from_link(
|
|||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
|
||||
(('Futurology',), 10),
|
||||
(('Futurology', 'Mindustry, Python'), 10),
|
||||
(('Futurology',), 20),
|
||||
(('Futurology', 'Python'), 10),
|
||||
(('Futurology',), 100),
|
||||
|
@ -157,13 +160,14 @@ def test_get_subreddit_normal(
|
|||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.subreddit = test_subreddits
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||
results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results)
|
||||
test_subreddits = downloader_mock._split_args_input(test_subreddits)
|
||||
results = assert_all_results_are_submissions(
|
||||
(limit * len(test_subreddits)) if limit else None, results)
|
||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||
|
||||
|
||||
|
@ -181,7 +185,6 @@ def test_get_subreddit_search(
|
|||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.search = search_term
|
||||
downloader_mock.args.subreddit = test_subreddits
|
||||
|
@ -207,7 +210,6 @@ def test_get_multireddits_public(
|
|||
reddit_instance: praw.Reddit,
|
||||
downloader_mock: MagicMock):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.multireddit = test_multireddits
|
||||
|
@ -237,30 +239,6 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
|
|||
assert all([res.author.name == test_user for res in results])
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
def test_get_user_no_user(downloader_mock: MagicMock):
|
||||
downloader_mock.args.upvoted = True
|
||||
with pytest.raises(BulkDownloaderException):
|
||||
RedditDownloader._get_user_data(downloader_mock)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize('test_user', (
|
||||
'rockcanopicjartheme',
|
||||
'exceptionalcatfishracecarbatter',
|
||||
))
|
||||
def test_get_user_nonexistent_user(test_user: str, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.args.user = test_user
|
||||
downloader_mock.args.upvoted = True
|
||||
downloader_mock._check_user_existence.return_value = RedditDownloader._check_user_existence(
|
||||
downloader_mock, test_user)
|
||||
with pytest.raises(RedditUserError):
|
||||
RedditDownloader._get_user_data(downloader_mock)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.authenticated
|
||||
|
@ -276,16 +254,6 @@ def test_get_user_upvoted(downloader_mock: MagicMock, authenticated_reddit_insta
|
|||
assert_all_results_are_submissions(10, results)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
def test_get_user_upvoted_unauthenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||
downloader_mock.args.user = 'random'
|
||||
downloader_mock.args.upvoted = True
|
||||
downloader_mock.authenticated = False
|
||||
with pytest.raises(RedditAuthenticationError):
|
||||
RedditDownloader._get_user_data(downloader_mock)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.authenticated
|
||||
|
@ -301,16 +269,6 @@ def test_get_user_saved(downloader_mock: MagicMock, authenticated_reddit_instanc
|
|||
assert_all_results_are_submissions(10, results)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
def test_get_user_saved_unauthenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||
downloader_mock.args.user = 'random'
|
||||
downloader_mock.args.saved = True
|
||||
downloader_mock.authenticated = False
|
||||
with pytest.raises(RedditAuthenticationError):
|
||||
RedditDownloader._get_user_data(downloader_mock)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
def test_download_submission(downloader_mock: MagicMock, reddit_instance: praw.Reddit, tmp_path: Path):
|
||||
|
@ -392,3 +350,15 @@ def test_search_existing_files():
|
|||
results = RedditDownloader.scan_existing_files(Path('.'))
|
||||
assert all([isinstance(result, str) for result in results])
|
||||
assert len(results) >= 40
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
|
||||
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'})
|
||||
))
|
||||
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
|
||||
results = RedditDownloader._split_args_input(test_subreddit_entries)
|
||||
assert results == expected
|
||||
|
|
Loading…
Reference in a new issue