Fix time filters (#279)
This commit is contained in:
parent
aefe8b79b6
commit
b37ff0714f
|
@ -41,19 +41,20 @@ def _calc_hash(existing_file: Path):
|
|||
|
||||
class RedditTypes:
|
||||
class SortType(Enum):
|
||||
HOT = auto()
|
||||
RISING = auto()
|
||||
CONTROVERSIAL = auto()
|
||||
HOT = auto()
|
||||
NEW = auto()
|
||||
RELEVENCE = auto()
|
||||
RISING = auto()
|
||||
TOP = auto()
|
||||
|
||||
class TimeType(Enum):
|
||||
HOUR = auto()
|
||||
DAY = auto()
|
||||
WEEK = auto()
|
||||
MONTH = auto()
|
||||
YEAR = auto()
|
||||
ALL = auto()
|
||||
ALL = 'all'
|
||||
DAY = 'day'
|
||||
HOUR = 'hour'
|
||||
MONTH = 'month'
|
||||
WEEK = 'week'
|
||||
YEAR = 'year'
|
||||
|
||||
|
||||
class RedditDownloader:
|
||||
|
@ -229,16 +230,16 @@ class RedditDownloader:
|
|||
try:
|
||||
reddit = self.reddit_instance.subreddit(reddit)
|
||||
if self.args.search:
|
||||
out.append(
|
||||
reddit.search(
|
||||
self.args.search,
|
||||
sort=self.sort_filter.name.lower(),
|
||||
limit=self.args.limit,
|
||||
))
|
||||
out.append(reddit.search(
|
||||
self.args.search,
|
||||
sort=self.sort_filter.name.lower(),
|
||||
limit=self.args.limit,
|
||||
time_filter=self.time_filter.value,
|
||||
))
|
||||
logger.debug(
|
||||
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
|
||||
else:
|
||||
out.append(sort_function(reddit, limit=self.args.limit))
|
||||
out.append(self._create_filtered_listing_generator(reddit))
|
||||
logger.debug(f'Added submissions from subreddit {reddit}')
|
||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
||||
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
||||
|
@ -271,6 +272,8 @@ class RedditDownloader:
|
|||
sort_function = praw.models.Subreddit.rising
|
||||
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
|
||||
sort_function = praw.models.Subreddit.controversial
|
||||
elif self.sort_filter is RedditTypes.SortType.TOP:
|
||||
sort_function = praw.models.Subreddit.top
|
||||
else:
|
||||
sort_function = praw.models.Subreddit.hot
|
||||
return sort_function
|
||||
|
@ -278,13 +281,12 @@ class RedditDownloader:
|
|||
def _get_multireddits(self) -> list[Iterator]:
|
||||
if self.args.multireddit:
|
||||
out = []
|
||||
sort_function = self._determine_sort_function()
|
||||
for multi in self._split_args_input(self.args.multireddit):
|
||||
try:
|
||||
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
||||
if not multi.subreddits:
|
||||
raise errors.BulkDownloaderException
|
||||
out.append(sort_function(multi, limit=self.args.limit))
|
||||
out.append(self._create_filtered_listing_generator(multi))
|
||||
logger.debug(f'Added submissions from multireddit {multi}')
|
||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
|
||||
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
|
||||
|
@ -292,6 +294,13 @@ class RedditDownloader:
|
|||
else:
|
||||
return []
|
||||
|
||||
def _create_filtered_listing_generator(self, reddit_source) -> Iterator:
|
||||
sort_function = self._determine_sort_function()
|
||||
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
|
||||
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
|
||||
else:
|
||||
return sort_function(reddit_source, limit=self.args.limit)
|
||||
|
||||
def _get_user_data(self) -> list[Iterator]:
|
||||
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
|
||||
if self.args.user:
|
||||
|
@ -299,14 +308,11 @@ class RedditDownloader:
|
|||
logger.error(f'User {self.args.user} does not exist')
|
||||
return []
|
||||
generators = []
|
||||
sort_function = self._determine_sort_function()
|
||||
if self.args.submitted:
|
||||
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
|
||||
generators.append(
|
||||
sort_function(
|
||||
self.reddit_instance.redditor(self.args.user).submissions,
|
||||
limit=self.args.limit,
|
||||
))
|
||||
generators.append(self._create_filtered_listing_generator(
|
||||
self.reddit_instance.redditor(self.args.user).submissions,
|
||||
))
|
||||
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
||||
logger.warning('Accessing user lists requires authentication')
|
||||
else:
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
import urllib.parse
|
||||
|
||||
import _hashlib
|
||||
import requests
|
||||
|
@ -64,7 +65,8 @@ class Resource:
|
|||
self.hash = hashlib.md5(self.content)
|
||||
|
||||
def _determine_extension(self) -> Optional[str]:
|
||||
extension_pattern = re.compile(r'.*(\..{3,5})(?:\?.*)?(?:#.*)?$')
|
||||
match = re.search(extension_pattern, self.url)
|
||||
extension_pattern = re.compile(r'.*(\..{3,5})$')
|
||||
stripped_url = urllib.parse.urlsplit(self.url).path
|
||||
match = re.search(extension_pattern, stripped_url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# coding=utf-8
|
||||
|
||||
import re
|
||||
import urllib.parse
|
||||
from typing import Type
|
||||
|
||||
from bdfr.exceptions import NotADownloadableLinkError
|
||||
|
@ -21,30 +22,38 @@ from bdfr.site_downloaders.youtube import Youtube
|
|||
class DownloadFactory:
|
||||
@staticmethod
|
||||
def pull_lever(url: str) -> Type[BaseDownloader]:
|
||||
url_beginning = r'\s*(https?://(www\.)?)'
|
||||
if re.match(url_beginning + r'(i\.)?imgur.*\.gifv$', url):
|
||||
sanitised_url = DownloadFactory._sanitise_url(url)
|
||||
if re.match(r'(i\.)?imgur.*\.gifv$', sanitised_url):
|
||||
return Imgur
|
||||
elif re.match(url_beginning + r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', url):
|
||||
elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url):
|
||||
return Direct
|
||||
elif re.match(url_beginning + r'erome\.com.*', url):
|
||||
elif re.match(r'erome\.com.*', sanitised_url):
|
||||
return Erome
|
||||
elif re.match(url_beginning + r'reddit\.com/gallery/.*', url):
|
||||
elif re.match(r'reddit\.com/gallery/.*', sanitised_url):
|
||||
return Gallery
|
||||
elif re.match(url_beginning + r'gfycat\.', url):
|
||||
elif re.match(r'gfycat\.', sanitised_url):
|
||||
return Gfycat
|
||||
elif re.match(url_beginning + r'gifdeliverynetwork', url):
|
||||
elif re.match(r'gifdeliverynetwork', sanitised_url):
|
||||
return GifDeliveryNetwork
|
||||
elif re.match(url_beginning + r'(m\.)?imgur.*', url):
|
||||
elif re.match(r'(m\.)?imgur.*', sanitised_url):
|
||||
return Imgur
|
||||
elif re.match(url_beginning + r'redgifs.com', url):
|
||||
elif re.match(r'redgifs.com', sanitised_url):
|
||||
return Redgifs
|
||||
elif re.match(url_beginning + r'reddit\.com/r/', url):
|
||||
elif re.match(r'reddit\.com/r/', sanitised_url):
|
||||
return SelfPost
|
||||
elif re.match(url_beginning + r'v\.redd\.it', url):
|
||||
elif re.match(r'v\.redd\.it', sanitised_url):
|
||||
return VReddit
|
||||
elif re.match(url_beginning + r'(m\.)?youtu\.?be', url):
|
||||
elif re.match(r'(m\.)?youtu\.?be', sanitised_url):
|
||||
return Youtube
|
||||
elif re.match(url_beginning + r'i\.redd\.it.*', url):
|
||||
elif re.match(r'i\.redd\.it.*', sanitised_url):
|
||||
return Direct
|
||||
else:
|
||||
raise NotADownloadableLinkError(f'No downloader module exists for url {url}')
|
||||
|
||||
@staticmethod
|
||||
def _sanitise_url(url: str) -> str:
|
||||
beginning_regex = re.compile(r'\s*(www\.?)?')
|
||||
split_url = urllib.parse.urlsplit(url)
|
||||
split_url = split_url.netloc + split_url.path
|
||||
split_url = re.sub(beginning_regex, '', split_url)
|
||||
return split_url
|
||||
|
|
|
@ -58,3 +58,14 @@ def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownlo
|
|||
def test_factory_lever_bad(test_url: str):
|
||||
with pytest.raises(NotADownloadableLinkError):
|
||||
DownloadFactory.pull_lever(test_url)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||
('www.test.com/test.png', 'test.com/test.png'),
|
||||
('www.test.com/test.png?test_value=random', 'test.com/test.png'),
|
||||
('https://youtube.com/watch?v=Gv8Wz74FjVA', 'youtube.com/watch'),
|
||||
('https://i.imgur.com/BuzvZwb.gifv', 'i.imgur.com/BuzvZwb.gifv'),
|
||||
))
|
||||
def test_sanitise_urll(test_url: str, expected: str):
|
||||
result = DownloadFactory._sanitise_url(test_url)
|
||||
assert result == expected
|
||||
|
|
|
@ -148,54 +148,71 @@ def test_get_submissions_from_link(
|
|||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
|
||||
(('Futurology',), 10),
|
||||
(('Futurology', 'Mindustry, Python'), 10),
|
||||
(('Futurology',), 20),
|
||||
(('Futurology', 'Python'), 10),
|
||||
(('Futurology',), 100),
|
||||
(('Futurology',), 0),
|
||||
@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
|
||||
(('Futurology',), 10, 'hot', 'all', 10),
|
||||
(('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
|
||||
(('Futurology',), 20, 'hot', 'all', 20),
|
||||
(('Futurology', 'Python'), 10, 'hot', 'all', 20),
|
||||
(('Futurology',), 100, 'hot', 'all', 100),
|
||||
(('Futurology',), 0, 'hot', 'all', 0),
|
||||
(('Futurology',), 10, 'top', 'all', 10),
|
||||
(('Futurology',), 10, 'top', 'week', 10),
|
||||
(('Futurology',), 10, 'hot', 'week', 10),
|
||||
))
|
||||
def test_get_subreddit_normal(
|
||||
test_subreddits: list[str],
|
||||
limit: int,
|
||||
sort_type: str,
|
||||
time_filter: str,
|
||||
max_expected_len: int,
|
||||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit):
|
||||
reddit_instance: praw.Reddit,
|
||||
):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.sort = sort_type
|
||||
downloader_mock.args.subreddit = test_subreddits
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.sort_filter = RedditDownloader._create_sort_filter(downloader_mock)
|
||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||
test_subreddits = downloader_mock._split_args_input(test_subreddits)
|
||||
results = assert_all_results_are_submissions(
|
||||
(limit * len(test_subreddits)) if limit else None, results)
|
||||
results = [sub for res1 in results for sub in res1]
|
||||
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
|
||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||
assert len(results) <= max_expected_len
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
|
||||
(('Python',), 'scraper', 10),
|
||||
(('Python',), '', 10),
|
||||
(('Python',), 'djsdsgewef', 0),
|
||||
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
|
||||
(('Python',), 'scraper', 10, 'all', 10),
|
||||
(('Python',), '', 10, 'all', 10),
|
||||
(('Python',), 'djsdsgewef', 10, 'all', 0),
|
||||
(('Python',), 'scraper', 10, 'year', 10),
|
||||
(('Python',), 'scraper', 10, 'hour', 1),
|
||||
))
|
||||
def test_get_subreddit_search(
|
||||
test_subreddits: list[str],
|
||||
search_term: str,
|
||||
time_filter: str,
|
||||
limit: int,
|
||||
max_expected_len: int,
|
||||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit):
|
||||
reddit_instance: praw.Reddit,
|
||||
):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.search = search_term
|
||||
downloader_mock.args.subreddit = test_subreddits
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.args.time = time_filter
|
||||
downloader_mock.time_filter = RedditDownloader._create_time_filter(downloader_mock)
|
||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||
results = assert_all_results_are_submissions(
|
||||
(limit * len(test_subreddits)) if limit else None, results)
|
||||
results = [sub for res in results for sub in res]
|
||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||
assert len(results) <= max_expected_len
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
|
@ -210,15 +227,23 @@ def test_get_multireddits_public(
|
|||
test_multireddits: list[str],
|
||||
limit: int,
|
||||
reddit_instance: praw.Reddit,
|
||||
downloader_mock: MagicMock):
|
||||
downloader_mock: MagicMock,
|
||||
):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.multireddit = test_multireddits
|
||||
downloader_mock.args.user = test_user
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock._create_filtered_listing_generator.return_value = \
|
||||
RedditDownloader._create_filtered_listing_generator(
|
||||
downloader_mock,
|
||||
reddit_instance.multireddit(test_user, test_multireddits[0]),
|
||||
)
|
||||
results = RedditDownloader._get_multireddits(downloader_mock)
|
||||
assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results)
|
||||
results = [sub for res in results for sub in res]
|
||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||
assert len(results) == limit
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
|
@ -236,6 +261,11 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
|
|||
downloader_mock.args.user = test_user
|
||||
downloader_mock.authenticated = False
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock._create_filtered_listing_generator.return_value = \
|
||||
RedditDownloader._create_filtered_listing_generator(
|
||||
downloader_mock,
|
||||
reddit_instance.redditor(test_user).submissions,
|
||||
)
|
||||
results = RedditDownloader._get_user_data(downloader_mock)
|
||||
results = assert_all_results_are_submissions(limit, results)
|
||||
assert all([res.author.name == test_user for res in results])
|
||||
|
|
|
@ -101,7 +101,6 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa
|
|||
['--user', 'djnish', '--submitted', '-L', 10],
|
||||
['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'],
|
||||
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'],
|
||||
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial', '--time', 'month'],
|
||||
))
|
||||
def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.resource import Resource
|
||||
|
||||
|
||||
|
@ -15,8 +16,9 @@ from bdfr.resource import Resource
|
|||
('https://www.resource.com/test/example.jpg', '.jpg'),
|
||||
('hard.png.mp4', '.mp4'),
|
||||
('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'),
|
||||
('test.jpg#test','.jpg'),
|
||||
('test.jpg?width=247#test','.jpg'),
|
||||
('test.jpg#test', '.jpg'),
|
||||
('test.jpg?width=247#test', '.jpg'),
|
||||
('https://www.test.com/test/test2/example.png?random=test#thing', '.png'),
|
||||
))
|
||||
def test_resource_get_extension(test_url: str, expected: str):
|
||||
test_resource = Resource(MagicMock(), test_url)
|
||||
|
|
Loading…
Reference in a new issue