1
0
Fork 0
mirror of synced 2024-06-25 17:40:17 +12:00

Update RedditDownloader tests

This commit is contained in:
Serene-Arc 2021-03-15 13:22:41 +10:00 committed by Ali Parlakci
parent 28b7deb6d3
commit 6aab009204
2 changed files with 30 additions and 53 deletions

View file

@ -167,16 +167,24 @@ class RedditDownloader:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
match = re.match(pattern, subreddit)
if not match:
raise errors.RedditAuthenticationError('')
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
return match.group(1)
@staticmethod
def _split_args_input(subreddit_entries: list[str]) -> set[str]:
all_subreddits = []
split_pattern = re.compile(r'[,;]\s?')
for entry in subreddit_entries:
results = re.split(split_pattern, entry)
all_subreddits.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results])
return set(all_subreddits)
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit:
out = []
sort_function = self._determine_sort_function()
for reddit in self.args.subreddit:
for reddit in self._split_args_input(self.args.subreddit):
try:
reddit = self._sanitise_subreddit_name(reddit)
reddit = self.reddit_instance.subreddit(reddit)
if self.args.search:
out.append(
@ -228,9 +236,8 @@ class RedditDownloader:
if self.args.multireddit:
out = []
sort_function = self._determine_sort_function()
for multi in self.args.multireddit:
for multi in self._split_args_input(self.args.multireddit):
try:
multi = self._sanitise_subreddit_name(multi)
multi = self.reddit_instance.multireddit(self.args.user, multi)
if not multi.subreddits:
raise errors.BulkDownloaderException

View file

@ -30,6 +30,8 @@ def args() -> Configuration:
def downloader_mock(args: argparse.Namespace):
mock_downloader = MagicMock()
mock_downloader.args = args
mock_downloader._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
mock_downloader._split_args_input = RedditDownloader._split_args_input
return mock_downloader
@ -146,6 +148,7 @@ def test_get_submissions_from_link(
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
(('Futurology',), 10),
(('Futurology', 'Mindustry, Python'), 10),
(('Futurology',), 20),
(('Futurology', 'Python'), 10),
(('Futurology',), 100),
@ -157,13 +160,14 @@ def test_get_subreddit_normal(
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.args.limit = limit
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock)
results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results)
test_subreddits = downloader_mock._split_args_input(test_subreddits)
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name in test_subreddits for res in results])
@ -181,7 +185,6 @@ def test_get_subreddit_search(
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits
@ -207,7 +210,6 @@ def test_get_multireddits_public(
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
@ -237,30 +239,6 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
assert all([res.author.name == test_user for res in results])
@pytest.mark.online
@pytest.mark.reddit
def test_get_user_no_user(downloader_mock: MagicMock):
downloader_mock.args.upvoted = True
with pytest.raises(BulkDownloaderException):
RedditDownloader._get_user_data(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_user', (
'rockcanopicjartheme',
'exceptionalcatfishracecarbatter',
))
def test_get_user_nonexistent_user(test_user: str, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.reddit_instance = reddit_instance
downloader_mock.args.user = test_user
downloader_mock.args.upvoted = True
downloader_mock._check_user_existence.return_value = RedditDownloader._check_user_existence(
downloader_mock, test_user)
with pytest.raises(RedditUserError):
RedditDownloader._get_user_data(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
@ -276,16 +254,6 @@ def test_get_user_upvoted(downloader_mock: MagicMock, authenticated_reddit_insta
assert_all_results_are_submissions(10, results)
@pytest.mark.online
@pytest.mark.reddit
def test_get_user_upvoted_unauthenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.user = 'random'
downloader_mock.args.upvoted = True
downloader_mock.authenticated = False
with pytest.raises(RedditAuthenticationError):
RedditDownloader._get_user_data(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
@ -301,16 +269,6 @@ def test_get_user_saved(downloader_mock: MagicMock, authenticated_reddit_instanc
assert_all_results_are_submissions(10, results)
@pytest.mark.online
@pytest.mark.reddit
def test_get_user_saved_unauthenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.user = 'random'
downloader_mock.args.saved = True
downloader_mock.authenticated = False
with pytest.raises(RedditAuthenticationError):
RedditDownloader._get_user_data(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
def test_download_submission(downloader_mock: MagicMock, reddit_instance: praw.Reddit, tmp_path: Path):
@ -392,3 +350,15 @@ def test_search_existing_files():
results = RedditDownloader.scan_existing_files(Path('.'))
assert all([isinstance(result, str) for result in results])
assert len(results) >= 40
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'})
))
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
results = RedditDownloader._split_args_input(test_subreddit_entries)
assert results == expected