diff --git a/bulkredditdownloader/downloader.py b/bulkredditdownloader/downloader.py index 154865b..b60ed2b 100644 --- a/bulkredditdownloader/downloader.py +++ b/bulkredditdownloader/downloader.py @@ -3,6 +3,7 @@ import configparser import logging +import re import socket from datetime import datetime from enum import Enum, auto @@ -153,9 +154,18 @@ class RedditDownloader: main_logger.addHandler(file_handler) + @staticmethod + def _sanitise_subreddit_name(subreddit: str) -> str: + pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$') + match = re.match(pattern, subreddit) + if not match: + raise errors.RedditAuthenticationError('') + return match.group(1) + def _get_subreddits(self) -> list[praw.models.ListingGenerator]: if self.args.subreddit: - subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit] + subreddits = [self._sanitise_subreddit_name(subreddit) for subreddit in self.args.subreddit] + subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in subreddits] if self.args.search: return [ reddit.search( @@ -197,10 +207,11 @@ class RedditDownloader: if self.authenticated: if self.args.user: sort_function = self._determine_sort_function() + multireddits = [self._sanitise_subreddit_name(multi) for multi in self.args.multireddit] return [ sort_function(self.reddit_instance.multireddit( self.args.user, - m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit] + m_reddit_choice), limit=self.args.limit) for m_reddit_choice in multireddits] else: raise errors.BulkDownloaderException('A user must be provided to download a multireddit') else: diff --git a/bulkredditdownloader/tests/test_downloader.py b/bulkredditdownloader/tests/test_downloader.py index caec1e7..2c4208f 100644 --- a/bulkredditdownloader/tests/test_downloader.py +++ b/bulkredditdownloader/tests/test_downloader.py @@ -160,13 +160,13 @@ def test_get_subreddit_normal( downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name downloader_mock.args.limit = limit downloader_mock.args.subreddit = test_subreddits downloader_mock.reddit_instance = reddit_instance downloader_mock.sort_filter = RedditTypes.SortType.HOT results = RedditDownloader._get_subreddits(downloader_mock) - results = assert_all_results_are_submissions( - (limit * len(test_subreddits)) if limit else None, results) + results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results) assert all([res.subreddit.display_name in test_subreddits for res in results]) @@ -184,6 +184,7 @@ def test_get_subreddit_search( downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name downloader_mock.args.limit = limit downloader_mock.args.search = search_term downloader_mock.args.subreddit = test_subreddits @@ -209,6 +210,7 @@ def test_get_multireddits_public( reddit_instance: praw.Reddit, downloader_mock: MagicMock): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.limit = limit downloader_mock.args.multireddit = test_multireddits @@ -389,3 +391,19 @@ def test_download_submission_hash_exists( output = capsys.readouterr() assert len(folder_contents) == 0 assert re.search(r'Resource from .*? downloaded elsewhere', output.out) + + +@pytest.mark.parametrize(('test_name', 'expected'), ( + ('Mindustry', 'Mindustry'), + ('Futurology', 'Futurology'), + ('r/Mindustry', 'Mindustry'), + ('TrollXChromosomes', 'TrollXChromosomes'), + ('r/TrollXChromosomes', 'TrollXChromosomes'), + ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'), + ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'), + ('https://www.reddit.com/r/Futurology/', 'Futurology'), + ('https://www.reddit.com/r/Futurology', 'Futurology'), +)) +def test_sanitise_subreddit_name(test_name: str, expected: str): + result = RedditDownloader._sanitise_subreddit_name(test_name) + assert result == expected