Fix time filters (#279)
This commit is contained in:
parent
aefe8b79b6
commit
b37ff0714f
7 changed files with 121 additions and 62 deletions
|
@ -41,19 +41,20 @@ def _calc_hash(existing_file: Path):
|
||||||
|
|
||||||
class RedditTypes:
|
class RedditTypes:
|
||||||
class SortType(Enum):
|
class SortType(Enum):
|
||||||
HOT = auto()
|
|
||||||
RISING = auto()
|
|
||||||
CONTROVERSIAL = auto()
|
CONTROVERSIAL = auto()
|
||||||
|
HOT = auto()
|
||||||
NEW = auto()
|
NEW = auto()
|
||||||
RELEVENCE = auto()
|
RELEVENCE = auto()
|
||||||
|
RISING = auto()
|
||||||
|
TOP = auto()
|
||||||
|
|
||||||
class TimeType(Enum):
|
class TimeType(Enum):
|
||||||
HOUR = auto()
|
ALL = 'all'
|
||||||
DAY = auto()
|
DAY = 'day'
|
||||||
WEEK = auto()
|
HOUR = 'hour'
|
||||||
MONTH = auto()
|
MONTH = 'month'
|
||||||
YEAR = auto()
|
WEEK = 'week'
|
||||||
ALL = auto()
|
YEAR = 'year'
|
||||||
|
|
||||||
|
|
||||||
class RedditDownloader:
|
class RedditDownloader:
|
||||||
|
@ -229,16 +230,16 @@ class RedditDownloader:
|
||||||
try:
|
try:
|
||||||
reddit = self.reddit_instance.subreddit(reddit)
|
reddit = self.reddit_instance.subreddit(reddit)
|
||||||
if self.args.search:
|
if self.args.search:
|
||||||
out.append(
|
out.append(reddit.search(
|
||||||
reddit.search(
|
|
||||||
self.args.search,
|
self.args.search,
|
||||||
sort=self.sort_filter.name.lower(),
|
sort=self.sort_filter.name.lower(),
|
||||||
limit=self.args.limit,
|
limit=self.args.limit,
|
||||||
|
time_filter=self.time_filter.value,
|
||||||
))
|
))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
|
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
|
||||||
else:
|
else:
|
||||||
out.append(sort_function(reddit, limit=self.args.limit))
|
out.append(self._create_filtered_listing_generator(reddit))
|
||||||
logger.debug(f'Added submissions from subreddit {reddit}')
|
logger.debug(f'Added submissions from subreddit {reddit}')
|
||||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
||||||
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
||||||
|
@ -271,6 +272,8 @@ class RedditDownloader:
|
||||||
sort_function = praw.models.Subreddit.rising
|
sort_function = praw.models.Subreddit.rising
|
||||||
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
|
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
|
||||||
sort_function = praw.models.Subreddit.controversial
|
sort_function = praw.models.Subreddit.controversial
|
||||||
|
elif self.sort_filter is RedditTypes.SortType.TOP:
|
||||||
|
sort_function = praw.models.Subreddit.top
|
||||||
else:
|
else:
|
||||||
sort_function = praw.models.Subreddit.hot
|
sort_function = praw.models.Subreddit.hot
|
||||||
return sort_function
|
return sort_function
|
||||||
|
@ -278,13 +281,12 @@ class RedditDownloader:
|
||||||
def _get_multireddits(self) -> list[Iterator]:
|
def _get_multireddits(self) -> list[Iterator]:
|
||||||
if self.args.multireddit:
|
if self.args.multireddit:
|
||||||
out = []
|
out = []
|
||||||
sort_function = self._determine_sort_function()
|
|
||||||
for multi in self._split_args_input(self.args.multireddit):
|
for multi in self._split_args_input(self.args.multireddit):
|
||||||
try:
|
try:
|
||||||
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
||||||
if not multi.subreddits:
|
if not multi.subreddits:
|
||||||
raise errors.BulkDownloaderException
|
raise errors.BulkDownloaderException
|
||||||
out.append(sort_function(multi, limit=self.args.limit))
|
out.append(self._create_filtered_listing_generator(multi))
|
||||||
logger.debug(f'Added submissions from multireddit {multi}')
|
logger.debug(f'Added submissions from multireddit {multi}')
|
||||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
|
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
|
||||||
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
|
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
|
||||||
|
@ -292,6 +294,13 @@ class RedditDownloader:
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _create_filtered_listing_generator(self, reddit_source) -> Iterator:
|
||||||
|
sort_function = self._determine_sort_function()
|
||||||
|
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
|
||||||
|
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
|
||||||
|
else:
|
||||||
|
return sort_function(reddit_source, limit=self.args.limit)
|
||||||
|
|
||||||
def _get_user_data(self) -> list[Iterator]:
|
def _get_user_data(self) -> list[Iterator]:
|
||||||
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
|
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
|
||||||
if self.args.user:
|
if self.args.user:
|
||||||
|
@ -299,13 +308,10 @@ class RedditDownloader:
|
||||||
logger.error(f'User {self.args.user} does not exist')
|
logger.error(f'User {self.args.user} does not exist')
|
||||||
return []
|
return []
|
||||||
generators = []
|
generators = []
|
||||||
sort_function = self._determine_sort_function()
|
|
||||||
if self.args.submitted:
|
if self.args.submitted:
|
||||||
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
|
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
|
||||||
generators.append(
|
generators.append(self._create_filtered_listing_generator(
|
||||||
sort_function(
|
|
||||||
self.reddit_instance.redditor(self.args.user).submissions,
|
self.reddit_instance.redditor(self.args.user).submissions,
|
||||||
limit=self.args.limit,
|
|
||||||
))
|
))
|
||||||
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
||||||
logger.warning('Accessing user lists requires authentication')
|
logger.warning('Accessing user lists requires authentication')
|
||||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
import _hashlib
|
import _hashlib
|
||||||
import requests
|
import requests
|
||||||
|
@ -64,7 +65,8 @@ class Resource:
|
||||||
self.hash = hashlib.md5(self.content)
|
self.hash = hashlib.md5(self.content)
|
||||||
|
|
||||||
def _determine_extension(self) -> Optional[str]:
|
def _determine_extension(self) -> Optional[str]:
|
||||||
extension_pattern = re.compile(r'.*(\..{3,5})(?:\?.*)?(?:#.*)?$')
|
extension_pattern = re.compile(r'.*(\..{3,5})$')
|
||||||
match = re.search(extension_pattern, self.url)
|
stripped_url = urllib.parse.urlsplit(self.url).path
|
||||||
|
match = re.search(extension_pattern, stripped_url)
|
||||||
if match:
|
if match:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import urllib.parse
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from bdfr.exceptions import NotADownloadableLinkError
|
from bdfr.exceptions import NotADownloadableLinkError
|
||||||
|
@ -21,30 +22,38 @@ from bdfr.site_downloaders.youtube import Youtube
|
||||||
class DownloadFactory:
|
class DownloadFactory:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pull_lever(url: str) -> Type[BaseDownloader]:
|
def pull_lever(url: str) -> Type[BaseDownloader]:
|
||||||
url_beginning = r'\s*(https?://(www\.)?)'
|
sanitised_url = DownloadFactory._sanitise_url(url)
|
||||||
if re.match(url_beginning + r'(i\.)?imgur.*\.gifv$', url):
|
if re.match(r'(i\.)?imgur.*\.gifv$', sanitised_url):
|
||||||
return Imgur
|
return Imgur
|
||||||
elif re.match(url_beginning + r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', url):
|
elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url):
|
||||||
return Direct
|
return Direct
|
||||||
elif re.match(url_beginning + r'erome\.com.*', url):
|
elif re.match(r'erome\.com.*', sanitised_url):
|
||||||
return Erome
|
return Erome
|
||||||
elif re.match(url_beginning + r'reddit\.com/gallery/.*', url):
|
elif re.match(r'reddit\.com/gallery/.*', sanitised_url):
|
||||||
return Gallery
|
return Gallery
|
||||||
elif re.match(url_beginning + r'gfycat\.', url):
|
elif re.match(r'gfycat\.', sanitised_url):
|
||||||
return Gfycat
|
return Gfycat
|
||||||
elif re.match(url_beginning + r'gifdeliverynetwork', url):
|
elif re.match(r'gifdeliverynetwork', sanitised_url):
|
||||||
return GifDeliveryNetwork
|
return GifDeliveryNetwork
|
||||||
elif re.match(url_beginning + r'(m\.)?imgur.*', url):
|
elif re.match(r'(m\.)?imgur.*', sanitised_url):
|
||||||
return Imgur
|
return Imgur
|
||||||
elif re.match(url_beginning + r'redgifs.com', url):
|
elif re.match(r'redgifs.com', sanitised_url):
|
||||||
return Redgifs
|
return Redgifs
|
||||||
elif re.match(url_beginning + r'reddit\.com/r/', url):
|
elif re.match(r'reddit\.com/r/', sanitised_url):
|
||||||
return SelfPost
|
return SelfPost
|
||||||
elif re.match(url_beginning + r'v\.redd\.it', url):
|
elif re.match(r'v\.redd\.it', sanitised_url):
|
||||||
return VReddit
|
return VReddit
|
||||||
elif re.match(url_beginning + r'(m\.)?youtu\.?be', url):
|
elif re.match(r'(m\.)?youtu\.?be', sanitised_url):
|
||||||
return Youtube
|
return Youtube
|
||||||
elif re.match(url_beginning + r'i\.redd\.it.*', url):
|
elif re.match(r'i\.redd\.it.*', sanitised_url):
|
||||||
return Direct
|
return Direct
|
||||||
else:
|
else:
|
||||||
raise NotADownloadableLinkError(f'No downloader module exists for url {url}')
|
raise NotADownloadableLinkError(f'No downloader module exists for url {url}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitise_url(url: str) -> str:
|
||||||
|
beginning_regex = re.compile(r'\s*(www\.?)?')
|
||||||
|
split_url = urllib.parse.urlsplit(url)
|
||||||
|
split_url = split_url.netloc + split_url.path
|
||||||
|
split_url = re.sub(beginning_regex, '', split_url)
|
||||||
|
return split_url
|
||||||
|
|
|
@ -58,3 +58,14 @@ def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownlo
|
||||||
def test_factory_lever_bad(test_url: str):
|
def test_factory_lever_bad(test_url: str):
|
||||||
with pytest.raises(NotADownloadableLinkError):
|
with pytest.raises(NotADownloadableLinkError):
|
||||||
DownloadFactory.pull_lever(test_url)
|
DownloadFactory.pull_lever(test_url)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||||
|
('www.test.com/test.png', 'test.com/test.png'),
|
||||||
|
('www.test.com/test.png?test_value=random', 'test.com/test.png'),
|
||||||
|
('https://youtube.com/watch?v=Gv8Wz74FjVA', 'youtube.com/watch'),
|
||||||
|
('https://i.imgur.com/BuzvZwb.gifv', 'i.imgur.com/BuzvZwb.gifv'),
|
||||||
|
))
|
||||||
|
def test_sanitise_urll(test_url: str, expected: str):
|
||||||
|
result = DownloadFactory._sanitise_url(test_url)
|
||||||
|
assert result == expected
|
||||||
|
|
|
@ -148,54 +148,71 @@ def test_get_submissions_from_link(
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
|
@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
|
||||||
(('Futurology',), 10),
|
(('Futurology',), 10, 'hot', 'all', 10),
|
||||||
(('Futurology', 'Mindustry, Python'), 10),
|
(('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
|
||||||
(('Futurology',), 20),
|
(('Futurology',), 20, 'hot', 'all', 20),
|
||||||
(('Futurology', 'Python'), 10),
|
(('Futurology', 'Python'), 10, 'hot', 'all', 20),
|
||||||
(('Futurology',), 100),
|
(('Futurology',), 100, 'hot', 'all', 100),
|
||||||
(('Futurology',), 0),
|
(('Futurology',), 0, 'hot', 'all', 0),
|
||||||
|
(('Futurology',), 10, 'top', 'all', 10),
|
||||||
|
(('Futurology',), 10, 'top', 'week', 10),
|
||||||
|
(('Futurology',), 10, 'hot', 'week', 10),
|
||||||
))
|
))
|
||||||
def test_get_subreddit_normal(
|
def test_get_subreddit_normal(
|
||||||
test_subreddits: list[str],
|
test_subreddits: list[str],
|
||||||
limit: int,
|
limit: int,
|
||||||
|
sort_type: str,
|
||||||
|
time_filter: str,
|
||||||
|
max_expected_len: int,
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
reddit_instance: praw.Reddit):
|
reddit_instance: praw.Reddit,
|
||||||
|
):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock.args.sort = sort_type
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditDownloader._create_sort_filter(downloader_mock)
|
||||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||||
test_subreddits = downloader_mock._split_args_input(test_subreddits)
|
test_subreddits = downloader_mock._split_args_input(test_subreddits)
|
||||||
results = assert_all_results_are_submissions(
|
results = [sub for res1 in results for sub in res1]
|
||||||
(limit * len(test_subreddits)) if limit else None, results)
|
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
|
||||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||||
|
assert len(results) <= max_expected_len
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
|
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
|
||||||
(('Python',), 'scraper', 10),
|
(('Python',), 'scraper', 10, 'all', 10),
|
||||||
(('Python',), '', 10),
|
(('Python',), '', 10, 'all', 10),
|
||||||
(('Python',), 'djsdsgewef', 0),
|
(('Python',), 'djsdsgewef', 10, 'all', 0),
|
||||||
|
(('Python',), 'scraper', 10, 'year', 10),
|
||||||
|
(('Python',), 'scraper', 10, 'hour', 1),
|
||||||
))
|
))
|
||||||
def test_get_subreddit_search(
|
def test_get_subreddit_search(
|
||||||
test_subreddits: list[str],
|
test_subreddits: list[str],
|
||||||
search_term: str,
|
search_term: str,
|
||||||
|
time_filter: str,
|
||||||
limit: int,
|
limit: int,
|
||||||
|
max_expected_len: int,
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
reddit_instance: praw.Reddit):
|
reddit_instance: praw.Reddit,
|
||||||
|
):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock.args.search = search_term
|
downloader_mock.args.search = search_term
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
downloader_mock.args.time = time_filter
|
||||||
|
downloader_mock.time_filter = RedditDownloader._create_time_filter(downloader_mock)
|
||||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||||
results = assert_all_results_are_submissions(
|
results = [sub for res in results for sub in res]
|
||||||
(limit * len(test_subreddits)) if limit else None, results)
|
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||||
|
assert len(results) <= max_expected_len
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
|
@ -210,15 +227,23 @@ def test_get_multireddits_public(
|
||||||
test_multireddits: list[str],
|
test_multireddits: list[str],
|
||||||
limit: int,
|
limit: int,
|
||||||
reddit_instance: praw.Reddit,
|
reddit_instance: praw.Reddit,
|
||||||
downloader_mock: MagicMock):
|
downloader_mock: MagicMock,
|
||||||
|
):
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock.args.multireddit = test_multireddits
|
downloader_mock.args.multireddit = test_multireddits
|
||||||
downloader_mock.args.user = test_user
|
downloader_mock.args.user = test_user
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock._create_filtered_listing_generator.return_value = \
|
||||||
|
RedditDownloader._create_filtered_listing_generator(
|
||||||
|
downloader_mock,
|
||||||
|
reddit_instance.multireddit(test_user, test_multireddits[0]),
|
||||||
|
)
|
||||||
results = RedditDownloader._get_multireddits(downloader_mock)
|
results = RedditDownloader._get_multireddits(downloader_mock)
|
||||||
assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results)
|
results = [sub for res in results for sub in res]
|
||||||
|
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||||
|
assert len(results) == limit
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
|
@ -236,6 +261,11 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
|
||||||
downloader_mock.args.user = test_user
|
downloader_mock.args.user = test_user
|
||||||
downloader_mock.authenticated = False
|
downloader_mock.authenticated = False
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock._create_filtered_listing_generator.return_value = \
|
||||||
|
RedditDownloader._create_filtered_listing_generator(
|
||||||
|
downloader_mock,
|
||||||
|
reddit_instance.redditor(test_user).submissions,
|
||||||
|
)
|
||||||
results = RedditDownloader._get_user_data(downloader_mock)
|
results = RedditDownloader._get_user_data(downloader_mock)
|
||||||
results = assert_all_results_are_submissions(limit, results)
|
results = assert_all_results_are_submissions(limit, results)
|
||||||
assert all([res.author.name == test_user for res in results])
|
assert all([res.author.name == test_user for res in results])
|
||||||
|
|
|
@ -101,7 +101,6 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa
|
||||||
['--user', 'djnish', '--submitted', '-L', 10],
|
['--user', 'djnish', '--submitted', '-L', 10],
|
||||||
['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'],
|
['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'],
|
||||||
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'],
|
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'],
|
||||||
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial', '--time', 'month'],
|
|
||||||
))
|
))
|
||||||
def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path):
|
def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from bdfr.resource import Resource
|
from bdfr.resource import Resource
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ from bdfr.resource import Resource
|
||||||
('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'),
|
('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'),
|
||||||
('test.jpg#test', '.jpg'),
|
('test.jpg#test', '.jpg'),
|
||||||
('test.jpg?width=247#test', '.jpg'),
|
('test.jpg?width=247#test', '.jpg'),
|
||||||
|
('https://www.test.com/test/test2/example.png?random=test#thing', '.png'),
|
||||||
))
|
))
|
||||||
def test_resource_get_extension(test_url: str, expected: str):
|
def test_resource_get_extension(test_url: str, expected: str):
|
||||||
test_resource = Resource(MagicMock(), test_url)
|
test_resource = Resource(MagicMock(), test_url)
|
||||||
|
|
Loading…
Reference in a new issue