Add method to sanitise subreddit inputs
This commit is contained in:
parent
d3c8897f6a
commit
f7989ca518
|
@ -3,6 +3,7 @@
|
|||
|
||||
import configparser
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from enum import Enum, auto
|
||||
|
@ -153,9 +154,18 @@ class RedditDownloader:
|
|||
|
||||
main_logger.addHandler(file_handler)
|
||||
|
||||
@staticmethod
|
||||
def _sanitise_subreddit_name(subreddit: str) -> str:
|
||||
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
|
||||
match = re.match(pattern, subreddit)
|
||||
if not match:
|
||||
raise errors.RedditAuthenticationError('')
|
||||
return match.group(1)
|
||||
|
||||
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
||||
if self.args.subreddit:
|
||||
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
|
||||
subreddits = [self._sanitise_subreddit_name(subreddit) for subreddit in self.args.subreddit]
|
||||
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in subreddits]
|
||||
if self.args.search:
|
||||
return [
|
||||
reddit.search(
|
||||
|
@ -197,10 +207,11 @@ class RedditDownloader:
|
|||
if self.authenticated:
|
||||
if self.args.user:
|
||||
sort_function = self._determine_sort_function()
|
||||
multireddits = [self._sanitise_subreddit_name(multi) for multi in self.args.multireddit]
|
||||
return [
|
||||
sort_function(self.reddit_instance.multireddit(
|
||||
self.args.user,
|
||||
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
|
||||
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in multireddits]
|
||||
else:
|
||||
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
|
||||
else:
|
||||
|
|
|
@ -160,13 +160,13 @@ def test_get_subreddit_normal(
|
|||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.subreddit = test_subreddits
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||
results = assert_all_results_are_submissions(
|
||||
(limit * len(test_subreddits)) if limit else None, results)
|
||||
results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results)
|
||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||
|
||||
|
||||
|
@ -184,6 +184,7 @@ def test_get_subreddit_search(
|
|||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.search = search_term
|
||||
downloader_mock.args.subreddit = test_subreddits
|
||||
|
@ -209,6 +210,7 @@ def test_get_multireddits_public(
|
|||
reddit_instance: praw.Reddit,
|
||||
downloader_mock: MagicMock):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.multireddit = test_multireddits
|
||||
|
@ -389,3 +391,19 @@ def test_download_submission_hash_exists(
|
|||
output = capsys.readouterr()
|
||||
assert len(folder_contents) == 0
|
||||
assert re.search(r'Resource from .*? downloaded elsewhere', output.out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_name', 'expected'), (
|
||||
('Mindustry', 'Mindustry'),
|
||||
('Futurology', 'Futurology'),
|
||||
('r/Mindustry', 'Mindustry'),
|
||||
('TrollXChromosomes', 'TrollXChromosomes'),
|
||||
('r/TrollXChromosomes', 'TrollXChromosomes'),
|
||||
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
|
||||
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
|
||||
('https://www.reddit.com/r/Futurology/', 'Futurology'),
|
||||
('https://www.reddit.com/r/Futurology', 'Futurology'),
|
||||
))
|
||||
def test_sanitise_subreddit_name(test_name: str, expected: str):
|
||||
result = RedditDownloader._sanitise_subreddit_name(test_name)
|
||||
assert result == expected
|
||||
|
|
Loading…
Reference in a new issue