1
0
Fork 0
mirror of synced 2024-06-28 19:10:41 +12:00

Add method to sanitise subreddit inputs

This commit is contained in:
Serene-Arc 2021-03-11 12:25:21 +10:00 committed by Ali Parlakci
parent d3c8897f6a
commit f7989ca518
2 changed files with 33 additions and 4 deletions

View file

@ -3,6 +3,7 @@
import configparser import configparser
import logging import logging
import re
import socket import socket
from datetime import datetime from datetime import datetime
from enum import Enum, auto from enum import Enum, auto
@ -153,9 +154,18 @@ class RedditDownloader:
main_logger.addHandler(file_handler) main_logger.addHandler(file_handler)
@staticmethod
def _sanitise_subreddit_name(subreddit: str) -> str:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
match = re.match(pattern, subreddit)
if not match:
raise errors.RedditAuthenticationError('')
return match.group(1)
def _get_subreddits(self) -> list[praw.models.ListingGenerator]: def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit: if self.args.subreddit:
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit] subreddits = [self._sanitise_subreddit_name(subreddit) for subreddit in self.args.subreddit]
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in subreddits]
if self.args.search: if self.args.search:
return [ return [
reddit.search( reddit.search(
@ -197,10 +207,11 @@ class RedditDownloader:
if self.authenticated: if self.authenticated:
if self.args.user: if self.args.user:
sort_function = self._determine_sort_function() sort_function = self._determine_sort_function()
multireddits = [self._sanitise_subreddit_name(multi) for multi in self.args.multireddit]
return [ return [
sort_function(self.reddit_instance.multireddit( sort_function(self.reddit_instance.multireddit(
self.args.user, self.args.user,
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit] m_reddit_choice), limit=self.args.limit) for m_reddit_choice in multireddits]
else: else:
raise errors.BulkDownloaderException('A user must be provided to download a multireddit') raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
else: else:

View file

@ -160,13 +160,13 @@ def test_get_subreddit_normal(
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit): reddit_instance: praw.Reddit):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.args.limit = limit downloader_mock.args.limit = limit
downloader_mock.args.subreddit = test_subreddits downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock) results = RedditDownloader._get_subreddits(downloader_mock)
results = assert_all_results_are_submissions( results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results)
(limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name in test_subreddits for res in results]) assert all([res.subreddit.display_name in test_subreddits for res in results])
@ -184,6 +184,7 @@ def test_get_subreddit_search(
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit): reddit_instance: praw.Reddit):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.args.limit = limit downloader_mock.args.limit = limit
downloader_mock.args.search = search_term downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits downloader_mock.args.subreddit = test_subreddits
@ -209,6 +210,7 @@ def test_get_multireddits_public(
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
downloader_mock: MagicMock): downloader_mock: MagicMock):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits downloader_mock.args.multireddit = test_multireddits
@ -389,3 +391,19 @@ def test_download_submission_hash_exists(
output = capsys.readouterr() output = capsys.readouterr()
assert len(folder_contents) == 0 assert len(folder_contents) == 0
assert re.search(r'Resource from .*? downloaded elsewhere', output.out) assert re.search(r'Resource from .*? downloaded elsewhere', output.out)
@pytest.mark.parametrize(('test_name', 'expected'), (
('Mindustry', 'Mindustry'),
('Futurology', 'Futurology'),
('r/Mindustry', 'Mindustry'),
('TrollXChromosomes', 'TrollXChromosomes'),
('r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/Futurology/', 'Futurology'),
('https://www.reddit.com/r/Futurology', 'Futurology'),
))
def test_sanitise_subreddit_name(test_name: str, expected: str):
result = RedditDownloader._sanitise_subreddit_name(test_name)
assert result == expected