From 6aab009204993df351bf57d120dc2bbc53468f9e Mon Sep 17 00:00:00 2001 From: Serene-Arc Date: Mon, 15 Mar 2021 13:22:41 +1000 Subject: [PATCH] Update RedditDownloader tests --- bulkredditdownloader/downloader.py | 17 +++-- bulkredditdownloader/tests/test_downloader.py | 66 +++++-------------- 2 files changed, 30 insertions(+), 53 deletions(-) diff --git a/bulkredditdownloader/downloader.py b/bulkredditdownloader/downloader.py index 3cf1f78..67c481a 100644 --- a/bulkredditdownloader/downloader.py +++ b/bulkredditdownloader/downloader.py @@ -167,16 +167,24 @@ class RedditDownloader: pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$') match = re.match(pattern, subreddit) if not match: - raise errors.RedditAuthenticationError('') + raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}') return match.group(1) + @staticmethod + def _split_args_input(subreddit_entries: list[str]) -> set[str]: + all_subreddits = [] + split_pattern = re.compile(r'[,;]\s?') + for entry in subreddit_entries: + results = re.split(split_pattern, entry) + all_subreddits.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results]) + return set(all_subreddits) + def _get_subreddits(self) -> list[praw.models.ListingGenerator]: if self.args.subreddit: out = [] sort_function = self._determine_sort_function() - for reddit in self.args.subreddit: + for reddit in self._split_args_input(self.args.subreddit): try: - reddit = self._sanitise_subreddit_name(reddit) reddit = self.reddit_instance.subreddit(reddit) if self.args.search: out.append( @@ -228,9 +236,8 @@ class RedditDownloader: if self.args.multireddit: out = [] sort_function = self._determine_sort_function() - for multi in self.args.multireddit: + for multi in self._split_args_input(self.args.multireddit): try: - multi = self._sanitise_subreddit_name(multi) multi = self.reddit_instance.multireddit(self.args.user, multi) if not multi.subreddits: raise errors.BulkDownloaderException diff --git a/bulkredditdownloader/tests/test_downloader.py b/bulkredditdownloader/tests/test_downloader.py index aae921e..d4a2e03 100644 --- a/bulkredditdownloader/tests/test_downloader.py +++ b/bulkredditdownloader/tests/test_downloader.py @@ -30,6 +30,8 @@ def args() -> Configuration: def downloader_mock(args: argparse.Namespace): mock_downloader = MagicMock() mock_downloader.args = args + mock_downloader._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name + mock_downloader._split_args_input = RedditDownloader._split_args_input return mock_downloader @@ -146,6 +148,7 @@ def test_get_submissions_from_link( @pytest.mark.reddit @pytest.mark.parametrize(('test_subreddits', 'limit'), ( (('Futurology',), 10), + (('Futurology', 'Mindustry, Python'), 10), (('Futurology',), 20), (('Futurology', 'Python'), 10), (('Futurology',), 100), @@ -157,13 +160,14 @@ def test_get_subreddit_normal( downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name downloader_mock.args.limit = limit downloader_mock.args.subreddit = test_subreddits downloader_mock.reddit_instance = reddit_instance downloader_mock.sort_filter = RedditTypes.SortType.HOT results = RedditDownloader._get_subreddits(downloader_mock) - results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results) + test_subreddits = downloader_mock._split_args_input(test_subreddits) + results = assert_all_results_are_submissions( + (limit * len(test_subreddits)) if limit else None, results) assert all([res.subreddit.display_name in test_subreddits for res in results]) @@ -181,7 +185,6 @@ def test_get_subreddit_search( downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name downloader_mock.args.limit = limit downloader_mock.args.search = search_term downloader_mock.args.subreddit = test_subreddits @@ -207,7 +210,6 @@ def test_get_multireddits_public( reddit_instance: praw.Reddit, downloader_mock: MagicMock): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.limit = limit downloader_mock.args.multireddit = test_multireddits @@ -237,30 +239,6 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic assert all([res.author.name == test_user for res in results]) -@pytest.mark.online -@pytest.mark.reddit -def test_get_user_no_user(downloader_mock: MagicMock): - downloader_mock.args.upvoted = True - with pytest.raises(BulkDownloaderException): - RedditDownloader._get_user_data(downloader_mock) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_user', ( - 'rockcanopicjartheme', - 'exceptionalcatfishracecarbatter', -)) -def test_get_user_nonexistent_user(test_user: str, downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.reddit_instance = reddit_instance - downloader_mock.args.user = test_user - downloader_mock.args.upvoted = True - downloader_mock._check_user_existence.return_value = RedditDownloader._check_user_existence( - downloader_mock, test_user) - with pytest.raises(RedditUserError): - RedditDownloader._get_user_data(downloader_mock) - - @pytest.mark.online @pytest.mark.reddit @pytest.mark.authenticated @@ -276,16 +254,6 @@ def test_get_user_upvoted(downloader_mock: MagicMock, authenticated_reddit_insta assert_all_results_are_submissions(10, results) -@pytest.mark.online -@pytest.mark.reddit -def test_get_user_upvoted_unauthenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.args.user = 'random' - downloader_mock.args.upvoted = True - downloader_mock.authenticated = False - with pytest.raises(RedditAuthenticationError): - RedditDownloader._get_user_data(downloader_mock) - - @pytest.mark.online @pytest.mark.reddit @pytest.mark.authenticated @@ -301,16 +269,6 @@ def test_get_user_saved(downloader_mock: MagicMock, authenticated_reddit_instanc assert_all_results_are_submissions(10, results) -@pytest.mark.online -@pytest.mark.reddit -def test_get_user_saved_unauthenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.args.user = 'random' - downloader_mock.args.saved = True - downloader_mock.authenticated = False - with pytest.raises(RedditAuthenticationError): - RedditDownloader._get_user_data(downloader_mock) - - @pytest.mark.online @pytest.mark.reddit def test_download_submission(downloader_mock: MagicMock, reddit_instance: praw.Reddit, tmp_path: Path): @@ -392,3 +350,15 @@ def test_search_existing_files(): results = RedditDownloader.scan_existing_files(Path('.')) assert all([isinstance(result, str) for result in results]) assert len(results) >= 40 + + +@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), ( + (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}) +)) +def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]): + results = RedditDownloader._split_args_input(test_subreddit_entries) + assert results == expected