1
0
Fork 0
mirror of synced 2024-06-25 09:30:36 +12:00

Add method to sanitise subreddit inputs

This commit is contained in:
Serene-Arc 2021-03-11 12:25:21 +10:00 committed by Ali Parlakci
parent d3c8897f6a
commit f7989ca518
2 changed files with 33 additions and 4 deletions

View file

@ -3,6 +3,7 @@
import configparser
import logging
import re
import socket
from datetime import datetime
from enum import Enum, auto
@ -153,9 +154,18 @@ class RedditDownloader:
main_logger.addHandler(file_handler)
@staticmethod
def _sanitise_subreddit_name(subreddit: str) -> str:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
match = re.match(pattern, subreddit)
if not match:
raise errors.RedditAuthenticationError('')
return match.group(1)
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit:
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
subreddits = [self._sanitise_subreddit_name(subreddit) for subreddit in self.args.subreddit]
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in subreddits]
if self.args.search:
return [
reddit.search(
@ -197,10 +207,11 @@ class RedditDownloader:
if self.authenticated:
if self.args.user:
sort_function = self._determine_sort_function()
multireddits = [self._sanitise_subreddit_name(multi) for multi in self.args.multireddit]
return [
sort_function(self.reddit_instance.multireddit(
self.args.user,
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in multireddits]
else:
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
else:

View file

@ -160,13 +160,13 @@ def test_get_subreddit_normal(
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.args.limit = limit
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock)
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name in test_subreddits for res in results])
@ -184,6 +184,7 @@ def test_get_subreddit_search(
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits
@ -209,6 +210,7 @@ def test_get_multireddits_public(
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
@ -389,3 +391,19 @@ def test_download_submission_hash_exists(
output = capsys.readouterr()
assert len(folder_contents) == 0
assert re.search(r'Resource from .*? downloaded elsewhere', output.out)
@pytest.mark.parametrize(('test_name', 'expected'), (
('Mindustry', 'Mindustry'),
('Futurology', 'Futurology'),
('r/Mindustry', 'Mindustry'),
('TrollXChromosomes', 'TrollXChromosomes'),
('r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/Futurology/', 'Futurology'),
('https://www.reddit.com/r/Futurology', 'Futurology'),
))
def test_sanitise_subreddit_name(test_name: str, expected: str):
result = RedditDownloader._sanitise_subreddit_name(test_name)
assert result == expected