Update RedditDownloader tests
This commit is contained in:
parent
28b7deb6d3
commit
6aab009204
2 changed files with 30 additions and 53 deletions
|
@ -167,16 +167,24 @@ class RedditDownloader:
|
||||||
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
|
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
|
||||||
match = re.match(pattern, subreddit)
|
match = re.match(pattern, subreddit)
|
||||||
if not match:
|
if not match:
|
||||||
raise errors.RedditAuthenticationError('')
|
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_args_input(subreddit_entries: list[str]) -> set[str]:
|
||||||
|
all_subreddits = []
|
||||||
|
split_pattern = re.compile(r'[,;]\s?')
|
||||||
|
for entry in subreddit_entries:
|
||||||
|
results = re.split(split_pattern, entry)
|
||||||
|
all_subreddits.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results])
|
||||||
|
return set(all_subreddits)
|
||||||
|
|
||||||
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
||||||
if self.args.subreddit:
|
if self.args.subreddit:
|
||||||
out = []
|
out = []
|
||||||
sort_function = self._determine_sort_function()
|
sort_function = self._determine_sort_function()
|
||||||
for reddit in self.args.subreddit:
|
for reddit in self._split_args_input(self.args.subreddit):
|
||||||
try:
|
try:
|
||||||
reddit = self._sanitise_subreddit_name(reddit)
|
|
||||||
reddit = self.reddit_instance.subreddit(reddit)
|
reddit = self.reddit_instance.subreddit(reddit)
|
||||||
if self.args.search:
|
if self.args.search:
|
||||||
out.append(
|
out.append(
|
||||||
|
@ -228,9 +236,8 @@ class RedditDownloader:
|
||||||
if self.args.multireddit:
|
if self.args.multireddit:
|
||||||
out = []
|
out = []
|
||||||
sort_function = self._determine_sort_function()
|
sort_function = self._determine_sort_function()
|
||||||
for multi in self.args.multireddit:
|
for multi in self._split_args_input(self.args.multireddit):
|
||||||
try:
|
try:
|
||||||
multi = self._sanitise_subreddit_name(multi)
|
|
||||||
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
||||||
if not multi.subreddits:
|
if not multi.subreddits:
|
||||||
raise errors.BulkDownloaderException
|
raise errors.BulkDownloaderException
|
||||||
|
|
|
@ -30,6 +30,8 @@ def args() -> Configuration:
|
||||||
def downloader_mock(args: argparse.Namespace):
|
def downloader_mock(args: argparse.Namespace):
|
||||||
mock_downloader = MagicMock()
|
mock_downloader = MagicMock()
|
||||||
mock_downloader.args = args
|
mock_downloader.args = args
|
||||||
|
mock_downloader._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||||
|
mock_downloader._split_args_input = RedditDownloader._split_args_input
|
||||||
return mock_downloader
|
return mock_downloader
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,6 +148,7 @@ def test_get_submissions_from_link(
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
|
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
|
||||||
(('Futurology',), 10),
|
(('Futurology',), 10),
|
||||||
|
(('Futurology', 'Mindustry, Python'), 10),
|
||||||
(('Futurology',), 20),
|
(('Futurology',), 20),
|
||||||
(('Futurology', 'Python'), 10),
|
(('Futurology', 'Python'), 10),
|
||||||
(('Futurology',), 100),
|
(('Futurology',), 100),
|
||||||
|
@ -157,13 +160,14 @@ def test_get_subreddit_normal(
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
reddit_instance: praw.Reddit):
|
reddit_instance: praw.Reddit):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||||
results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results)
|
test_subreddits = downloader_mock._split_args_input(test_subreddits)
|
||||||
|
results = assert_all_results_are_submissions(
|
||||||
|
(limit * len(test_subreddits)) if limit else None, results)
|
||||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,7 +185,6 @@ def test_get_subreddit_search(
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
reddit_instance: praw.Reddit):
|
reddit_instance: praw.Reddit):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock.args.search = search_term
|
downloader_mock.args.search = search_term
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
|
@ -207,7 +210,6 @@ def test_get_multireddits_public(
|
||||||
reddit_instance: praw.Reddit,
|
reddit_instance: praw.Reddit,
|
||||||
downloader_mock: MagicMock):
|
downloader_mock: MagicMock):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock.args.multireddit = test_multireddits
|
downloader_mock.args.multireddit = test_multireddits
|
||||||
|
@ -237,30 +239,6 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
|
||||||
assert all([res.author.name == test_user for res in results])
|
assert all([res.author.name == test_user for res in results])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
def test_get_user_no_user(downloader_mock: MagicMock):
|
|
||||||
downloader_mock.args.upvoted = True
|
|
||||||
with pytest.raises(BulkDownloaderException):
|
|
||||||
RedditDownloader._get_user_data(downloader_mock)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize('test_user', (
|
|
||||||
'rockcanopicjartheme',
|
|
||||||
'exceptionalcatfishracecarbatter',
|
|
||||||
))
|
|
||||||
def test_get_user_nonexistent_user(test_user: str, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock.args.user = test_user
|
|
||||||
downloader_mock.args.upvoted = True
|
|
||||||
downloader_mock._check_user_existence.return_value = RedditDownloader._check_user_existence(
|
|
||||||
downloader_mock, test_user)
|
|
||||||
with pytest.raises(RedditUserError):
|
|
||||||
RedditDownloader._get_user_data(downloader_mock)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.authenticated
|
@pytest.mark.authenticated
|
||||||
|
@ -276,16 +254,6 @@ def test_get_user_upvoted(downloader_mock: MagicMock, authenticated_reddit_insta
|
||||||
assert_all_results_are_submissions(10, results)
|
assert_all_results_are_submissions(10, results)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
def test_get_user_upvoted_unauthenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
|
||||||
downloader_mock.args.user = 'random'
|
|
||||||
downloader_mock.args.upvoted = True
|
|
||||||
downloader_mock.authenticated = False
|
|
||||||
with pytest.raises(RedditAuthenticationError):
|
|
||||||
RedditDownloader._get_user_data(downloader_mock)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.authenticated
|
@pytest.mark.authenticated
|
||||||
|
@ -301,16 +269,6 @@ def test_get_user_saved(downloader_mock: MagicMock, authenticated_reddit_instanc
|
||||||
assert_all_results_are_submissions(10, results)
|
assert_all_results_are_submissions(10, results)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
def test_get_user_saved_unauthenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
|
||||||
downloader_mock.args.user = 'random'
|
|
||||||
downloader_mock.args.saved = True
|
|
||||||
downloader_mock.authenticated = False
|
|
||||||
with pytest.raises(RedditAuthenticationError):
|
|
||||||
RedditDownloader._get_user_data(downloader_mock)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
def test_download_submission(downloader_mock: MagicMock, reddit_instance: praw.Reddit, tmp_path: Path):
|
def test_download_submission(downloader_mock: MagicMock, reddit_instance: praw.Reddit, tmp_path: Path):
|
||||||
|
@ -392,3 +350,15 @@ def test_search_existing_files():
|
||||||
results = RedditDownloader.scan_existing_files(Path('.'))
|
results = RedditDownloader.scan_existing_files(Path('.'))
|
||||||
assert all([isinstance(result, str) for result in results])
|
assert all([isinstance(result, str) for result in results])
|
||||||
assert len(results) >= 40
|
assert len(results) >= 40
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
|
||||||
|
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||||
|
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||||
|
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||||
|
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||||
|
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'})
|
||||||
|
))
|
||||||
|
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
|
||||||
|
results = RedditDownloader._split_args_input(test_subreddit_entries)
|
||||||
|
assert results == expected
|
||||||
|
|
Loading…
Reference in a new issue