1
0
Fork 0
mirror of synced 2024-05-17 18:52:46 +12:00

Fix time filters (#279)

This commit is contained in:
Serene 2021-04-18 21:24:11 +10:00 committed by Ali Parlakci
parent aefe8b79b6
commit b37ff0714f
7 changed files with 121 additions and 62 deletions

View file

@ -41,19 +41,20 @@ def _calc_hash(existing_file: Path):
class RedditTypes:
class SortType(Enum):
HOT = auto()
RISING = auto()
CONTROVERSIAL = auto()
HOT = auto()
NEW = auto()
RELEVENCE = auto()
RISING = auto()
TOP = auto()
class TimeType(Enum):
HOUR = auto()
DAY = auto()
WEEK = auto()
MONTH = auto()
YEAR = auto()
ALL = auto()
ALL = 'all'
DAY = 'day'
HOUR = 'hour'
MONTH = 'month'
WEEK = 'week'
YEAR = 'year'
class RedditDownloader:
@ -229,16 +230,16 @@ class RedditDownloader:
try:
reddit = self.reddit_instance.subreddit(reddit)
if self.args.search:
out.append(
reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
))
out.append(reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
time_filter=self.time_filter.value,
))
logger.debug(
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
else:
out.append(sort_function(reddit, limit=self.args.limit))
out.append(self._create_filtered_listing_generator(reddit))
logger.debug(f'Added submissions from subreddit {reddit}')
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
@ -271,6 +272,8 @@ class RedditDownloader:
sort_function = praw.models.Subreddit.rising
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
sort_function = praw.models.Subreddit.controversial
elif self.sort_filter is RedditTypes.SortType.TOP:
sort_function = praw.models.Subreddit.top
else:
sort_function = praw.models.Subreddit.hot
return sort_function
@ -278,13 +281,12 @@ class RedditDownloader:
def _get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
out = []
sort_function = self._determine_sort_function()
for multi in self._split_args_input(self.args.multireddit):
try:
multi = self.reddit_instance.multireddit(self.args.user, multi)
if not multi.subreddits:
raise errors.BulkDownloaderException
out.append(sort_function(multi, limit=self.args.limit))
out.append(self._create_filtered_listing_generator(multi))
logger.debug(f'Added submissions from multireddit {multi}')
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
@ -292,6 +294,13 @@ class RedditDownloader:
else:
return []
def _create_filtered_listing_generator(self, reddit_source) -> Iterator:
sort_function = self._determine_sort_function()
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
else:
return sort_function(reddit_source, limit=self.args.limit)
def _get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if self.args.user:
@ -299,14 +308,11 @@ class RedditDownloader:
logger.error(f'User {self.args.user} does not exist')
return []
generators = []
sort_function = self._determine_sort_function()
if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit,
))
generators.append(self._create_filtered_listing_generator(
self.reddit_instance.redditor(self.args.user).submissions,
))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication')
else:

View file

@ -6,6 +6,7 @@ import logging
import re
import time
from typing import Optional
import urllib.parse
import _hashlib
import requests
@ -64,7 +65,8 @@ class Resource:
self.hash = hashlib.md5(self.content)
def _determine_extension(self) -> Optional[str]:
extension_pattern = re.compile(r'.*(\..{3,5})(?:\?.*)?(?:#.*)?$')
match = re.search(extension_pattern, self.url)
extension_pattern = re.compile(r'.*(\..{3,5})$')
stripped_url = urllib.parse.urlsplit(self.url).path
match = re.search(extension_pattern, stripped_url)
if match:
return match.group(1)

View file

@ -2,6 +2,7 @@
# coding=utf-8
import re
import urllib.parse
from typing import Type
from bdfr.exceptions import NotADownloadableLinkError
@ -21,30 +22,38 @@ from bdfr.site_downloaders.youtube import Youtube
class DownloadFactory:
@staticmethod
def pull_lever(url: str) -> Type[BaseDownloader]:
url_beginning = r'\s*(https?://(www\.)?)'
if re.match(url_beginning + r'(i\.)?imgur.*\.gifv$', url):
sanitised_url = DownloadFactory._sanitise_url(url)
if re.match(r'(i\.)?imgur.*\.gifv$', sanitised_url):
return Imgur
elif re.match(url_beginning + r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', url):
elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url):
return Direct
elif re.match(url_beginning + r'erome\.com.*', url):
elif re.match(r'erome\.com.*', sanitised_url):
return Erome
elif re.match(url_beginning + r'reddit\.com/gallery/.*', url):
elif re.match(r'reddit\.com/gallery/.*', sanitised_url):
return Gallery
elif re.match(url_beginning + r'gfycat\.', url):
elif re.match(r'gfycat\.', sanitised_url):
return Gfycat
elif re.match(url_beginning + r'gifdeliverynetwork', url):
elif re.match(r'gifdeliverynetwork', sanitised_url):
return GifDeliveryNetwork
elif re.match(url_beginning + r'(m\.)?imgur.*', url):
elif re.match(r'(m\.)?imgur.*', sanitised_url):
return Imgur
elif re.match(url_beginning + r'redgifs.com', url):
elif re.match(r'redgifs.com', sanitised_url):
return Redgifs
elif re.match(url_beginning + r'reddit\.com/r/', url):
elif re.match(r'reddit\.com/r/', sanitised_url):
return SelfPost
elif re.match(url_beginning + r'v\.redd\.it', url):
elif re.match(r'v\.redd\.it', sanitised_url):
return VReddit
elif re.match(url_beginning + r'(m\.)?youtu\.?be', url):
elif re.match(r'(m\.)?youtu\.?be', sanitised_url):
return Youtube
elif re.match(url_beginning + r'i\.redd\.it.*', url):
elif re.match(r'i\.redd\.it.*', sanitised_url):
return Direct
else:
raise NotADownloadableLinkError(f'No downloader module exists for url {url}')
@staticmethod
def _sanitise_url(url: str) -> str:
beginning_regex = re.compile(r'\s*(www\.?)?')
split_url = urllib.parse.urlsplit(url)
split_url = split_url.netloc + split_url.path
split_url = re.sub(beginning_regex, '', split_url)
return split_url

View file

@ -58,3 +58,14 @@ def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownlo
def test_factory_lever_bad(test_url: str):
with pytest.raises(NotADownloadableLinkError):
DownloadFactory.pull_lever(test_url)
@pytest.mark.parametrize(('test_url', 'expected'), (
('www.test.com/test.png', 'test.com/test.png'),
('www.test.com/test.png?test_value=random', 'test.com/test.png'),
('https://youtube.com/watch?v=Gv8Wz74FjVA', 'youtube.com/watch'),
('https://i.imgur.com/BuzvZwb.gifv', 'i.imgur.com/BuzvZwb.gifv'),
))
def test_sanitise_urll(test_url: str, expected: str):
result = DownloadFactory._sanitise_url(test_url)
assert result == expected

View file

@ -148,54 +148,71 @@ def test_get_submissions_from_link(
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
(('Futurology',), 10),
(('Futurology', 'Mindustry, Python'), 10),
(('Futurology',), 20),
(('Futurology', 'Python'), 10),
(('Futurology',), 100),
(('Futurology',), 0),
@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
(('Futurology',), 10, 'hot', 'all', 10),
(('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
(('Futurology',), 20, 'hot', 'all', 20),
(('Futurology', 'Python'), 10, 'hot', 'all', 20),
(('Futurology',), 100, 'hot', 'all', 100),
(('Futurology',), 0, 'hot', 'all', 0),
(('Futurology',), 10, 'top', 'all', 10),
(('Futurology',), 10, 'top', 'week', 10),
(('Futurology',), 10, 'hot', 'week', 10),
))
def test_get_subreddit_normal(
test_subreddits: list[str],
limit: int,
sort_type: str,
time_filter: str,
max_expected_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
reddit_instance: praw.Reddit,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.sort = sort_type
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.sort_filter = RedditDownloader._create_sort_filter(downloader_mock)
results = RedditDownloader._get_subreddits(downloader_mock)
test_subreddits = downloader_mock._split_args_input(test_subreddits)
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
results = [sub for res1 in results for sub in res1]
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
assert all([res.subreddit.display_name in test_subreddits for res in results])
assert len(results) <= max_expected_len
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
(('Python',), 'scraper', 10),
(('Python',), '', 10),
(('Python',), 'djsdsgewef', 0),
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
(('Python',), 'scraper', 10, 'all', 10),
(('Python',), '', 10, 'all', 10),
(('Python',), 'djsdsgewef', 10, 'all', 0),
(('Python',), 'scraper', 10, 'year', 10),
(('Python',), 'scraper', 10, 'hour', 1),
))
def test_get_subreddit_search(
test_subreddits: list[str],
search_term: str,
time_filter: str,
limit: int,
max_expected_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
reddit_instance: praw.Reddit,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.time = time_filter
downloader_mock.time_filter = RedditDownloader._create_time_filter(downloader_mock)
results = RedditDownloader._get_subreddits(downloader_mock)
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert all([res.subreddit.display_name in test_subreddits for res in results])
assert len(results) <= max_expected_len
@pytest.mark.online
@ -210,15 +227,23 @@ def test_get_multireddits_public(
test_multireddits: list[str],
limit: int,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock: MagicMock,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = test_user
downloader_mock.reddit_instance = reddit_instance
downloader_mock._create_filtered_listing_generator.return_value = \
RedditDownloader._create_filtered_listing_generator(
downloader_mock,
reddit_instance.multireddit(test_user, test_multireddits[0]),
)
results = RedditDownloader._get_multireddits(downloader_mock)
assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert len(results) == limit
@pytest.mark.online
@ -236,6 +261,11 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
downloader_mock.args.user = test_user
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
downloader_mock._create_filtered_listing_generator.return_value = \
RedditDownloader._create_filtered_listing_generator(
downloader_mock,
reddit_instance.redditor(test_user).submissions,
)
results = RedditDownloader._get_user_data(downloader_mock)
results = assert_all_results_are_submissions(limit, results)
assert all([res.author.name == test_user for res in results])

View file

@ -101,7 +101,6 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa
['--user', 'djnish', '--submitted', '-L', 10],
['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'],
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'],
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial', '--time', 'month'],
))
def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path):
runner = CliRunner()

View file

@ -1,9 +1,10 @@
#!/usr/bin/env python3
# coding=utf-8
import pytest
from unittest.mock import MagicMock
import pytest
from bdfr.resource import Resource
@ -15,8 +16,9 @@ from bdfr.resource import Resource
('https://www.resource.com/test/example.jpg', '.jpg'),
('hard.png.mp4', '.mp4'),
('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'),
('test.jpg#test','.jpg'),
('test.jpg?width=247#test','.jpg'),
('test.jpg#test', '.jpg'),
('test.jpg?width=247#test', '.jpg'),
('https://www.test.com/test/test2/example.png?random=test#thing', '.png'),
))
def test_resource_get_extension(test_url: str, expected: str):
test_resource = Resource(MagicMock(), test_url)