Add method to sanitise subreddit inputs
This commit is contained in:
parent
d3c8897f6a
commit
f7989ca518
2 changed files with 33 additions and 4 deletions
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import socket
|
import socket
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
@ -153,9 +154,18 @@ class RedditDownloader:
|
||||||
|
|
||||||
main_logger.addHandler(file_handler)
|
main_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitise_subreddit_name(subreddit: str) -> str:
|
||||||
|
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
|
||||||
|
match = re.match(pattern, subreddit)
|
||||||
|
if not match:
|
||||||
|
raise errors.RedditAuthenticationError('')
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
||||||
if self.args.subreddit:
|
if self.args.subreddit:
|
||||||
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
|
subreddits = [self._sanitise_subreddit_name(subreddit) for subreddit in self.args.subreddit]
|
||||||
|
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in subreddits]
|
||||||
if self.args.search:
|
if self.args.search:
|
||||||
return [
|
return [
|
||||||
reddit.search(
|
reddit.search(
|
||||||
|
@ -197,10 +207,11 @@ class RedditDownloader:
|
||||||
if self.authenticated:
|
if self.authenticated:
|
||||||
if self.args.user:
|
if self.args.user:
|
||||||
sort_function = self._determine_sort_function()
|
sort_function = self._determine_sort_function()
|
||||||
|
multireddits = [self._sanitise_subreddit_name(multi) for multi in self.args.multireddit]
|
||||||
return [
|
return [
|
||||||
sort_function(self.reddit_instance.multireddit(
|
sort_function(self.reddit_instance.multireddit(
|
||||||
self.args.user,
|
self.args.user,
|
||||||
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
|
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in multireddits]
|
||||||
else:
|
else:
|
||||||
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
|
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -160,13 +160,13 @@ def test_get_subreddit_normal(
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
reddit_instance: praw.Reddit):
|
reddit_instance: praw.Reddit):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||||
results = assert_all_results_are_submissions(
|
results = assert_all_results_are_submissions((limit * len(test_subreddits)) if limit else None, results)
|
||||||
(limit * len(test_subreddits)) if limit else None, results)
|
|
||||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,6 +184,7 @@ def test_get_subreddit_search(
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
reddit_instance: praw.Reddit):
|
reddit_instance: praw.Reddit):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock.args.search = search_term
|
downloader_mock.args.search = search_term
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
|
@ -209,6 +210,7 @@ def test_get_multireddits_public(
|
||||||
reddit_instance: praw.Reddit,
|
reddit_instance: praw.Reddit,
|
||||||
downloader_mock: MagicMock):
|
downloader_mock: MagicMock):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock.args.multireddit = test_multireddits
|
downloader_mock.args.multireddit = test_multireddits
|
||||||
|
@ -389,3 +391,19 @@ def test_download_submission_hash_exists(
|
||||||
output = capsys.readouterr()
|
output = capsys.readouterr()
|
||||||
assert len(folder_contents) == 0
|
assert len(folder_contents) == 0
|
||||||
assert re.search(r'Resource from .*? downloaded elsewhere', output.out)
|
assert re.search(r'Resource from .*? downloaded elsewhere', output.out)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_name', 'expected'), (
|
||||||
|
('Mindustry', 'Mindustry'),
|
||||||
|
('Futurology', 'Futurology'),
|
||||||
|
('r/Mindustry', 'Mindustry'),
|
||||||
|
('TrollXChromosomes', 'TrollXChromosomes'),
|
||||||
|
('r/TrollXChromosomes', 'TrollXChromosomes'),
|
||||||
|
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
|
||||||
|
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
|
||||||
|
('https://www.reddit.com/r/Futurology/', 'Futurology'),
|
||||||
|
('https://www.reddit.com/r/Futurology', 'Futurology'),
|
||||||
|
))
|
||||||
|
def test_sanitise_subreddit_name(test_name: str, expected: str):
|
||||||
|
result = RedditDownloader._sanitise_subreddit_name(test_name)
|
||||||
|
assert result == expected
|
||||||
|
|
Loading…
Reference in a new issue