Add download filter class
This commit is contained in:
parent
f2415b6bd0
commit
b1f0632a80
39
bulkredditdownloader/download_filter.py
Normal file
39
bulkredditdownloader/download_filter.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import re
|
||||
|
||||
|
||||
class DownloadFilter:
|
||||
def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None):
|
||||
self.excluded_extensions = excluded_extensions
|
||||
self.excluded_domains = excluded_domains
|
||||
|
||||
def check_url(self, url: str) -> bool:
|
||||
"""Return whether a URL is allowed or not"""
|
||||
if not self._check_extension(url):
|
||||
return False
|
||||
elif not self._check_domain(url):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _check_extension(self, url: str) -> bool:
|
||||
if not self.excluded_extensions:
|
||||
return True
|
||||
combined_extensions = '|'.join(self.excluded_extensions)
|
||||
pattern = re.compile(r'.*({})$'.format(combined_extensions))
|
||||
if re.match(pattern, url):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _check_domain(self, url: str) -> bool:
|
||||
if not self.excluded_domains:
|
||||
return True
|
||||
combined_domains = '|'.join(self.excluded_domains)
|
||||
pattern = re.compile(r'https?://.*({}).*'.format(combined_domains))
|
||||
if re.match(pattern, url):
|
||||
return False
|
||||
else:
|
||||
return True
|
54
bulkredditdownloader/tests/test_download_filter.py
Normal file
54
bulkredditdownloader/tests/test_download_filter.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import pytest
|
||||
|
||||
from bulkredditdownloader.download_filter import DownloadFilter
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def download_filter() -> DownloadFilter:
|
||||
return DownloadFilter(['mp4', 'mp3'], ['test.com', 'reddit.com'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', False),
|
||||
('test.avi', True),
|
||||
('test.random.mp3', False)
|
||||
))
|
||||
def test_filter_extension(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||
result = download_filter._check_extension(test_url)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', True),
|
||||
('http://reddit.com/test.mp4', False),
|
||||
('http://reddit.com/test.gif', False),
|
||||
('https://www.example.com/test.mp4', True),
|
||||
('https://www.example.com/test.png', True),
|
||||
))
|
||||
def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||
result = download_filter._check_domain(test_url)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', False),
|
||||
('test.gif', True),
|
||||
('https://www.example.com/test.mp4', False),
|
||||
('https://www.example.com/test.png', True),
|
||||
('http://reddit.com/test.mp4', False),
|
||||
('http://reddit.com/test.gif', False),
|
||||
))
|
||||
def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||
result = download_filter.check_url(test_url)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_url', ('test.mp3',
|
||||
'test.mp4',
|
||||
'http://reddit.com/test.mp4',
|
||||
't',
|
||||
))
|
||||
def test_filter_empty_filter(test_url: str):
|
||||
download_filter = DownloadFilter()
|
||||
result = download_filter.check_url(test_url)
|
||||
assert result is True
|
Loading…
Reference in a new issue