diff --git a/bdfr/download_filter.py b/bdfr/download_filter.py index 37a6ce9..3bbbdec 100644 --- a/bdfr/download_filter.py +++ b/bdfr/download_filter.py @@ -4,6 +4,8 @@ import logging import re +from bdfr.resource import Resource + logger = logging.getLogger(__name__) @@ -21,13 +23,20 @@ class DownloadFilter: else: return True - def _check_extension(self, url: str) -> bool: + def check_resource(self, res: Resource) -> bool: + if not self._check_extension(res.extension): + return False + elif not self._check_domain(res.url): + return False + return True + + def _check_extension(self, resource_extension: str) -> bool: if not self.excluded_extensions: return True combined_extensions = '|'.join(self.excluded_extensions) pattern = re.compile(r'.*({})$'.format(combined_extensions)) - if re.match(pattern, url): - logger.log(9, f'Url "{url}" matched with "{str(pattern)}"') + if re.match(pattern, resource_extension): + logger.log(9, f'Url "{resource_extension}" matched with "{str(pattern)}"') return False else: return True diff --git a/bdfr/downloader.py b/bdfr/downloader.py index 3348628..f0b1977 100644 --- a/bdfr/downloader.py +++ b/bdfr/downloader.py @@ -394,9 +394,6 @@ class RedditDownloader: if not isinstance(submission, praw.models.Submission): logger.warning(f'{submission.id} is not a submission') return - if not self.download_filter.check_url(submission.url): - logger.debug(f'Download filter removed submission {submission.id} with URL {submission.url}') - return try: downloader_class = DownloadFactory.pull_lever(submission.url) downloader = downloader_class(submission) @@ -413,6 +410,8 @@ class RedditDownloader: for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory): if destination.exists(): logger.debug(f'File {destination} already exists, continuing') + elif not self.download_filter.check_resource(res): + logger.debug(f'Download filter removed {submission.id} with URL {submission.url}') else: try: res.download(self.args.max_wait_time) diff --git a/tests/test_download_filter.py b/tests/test_download_filter.py index 3c2adba..ead2b2f 100644 --- a/tests/test_download_filter.py +++ b/tests/test_download_filter.py @@ -1,9 +1,12 @@ #!/usr/bin/env python3 # coding=utf-8 +from unittest.mock import MagicMock + import pytest from bdfr.download_filter import DownloadFilter +from bdfr.resource import Resource @pytest.fixture() @@ -11,13 +14,14 @@ def download_filter() -> DownloadFilter: return DownloadFilter(['mp4', 'mp3'], ['test.com', 'reddit.com']) -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('test.mp4', False), - ('test.avi', True), - ('test.random.mp3', False), +@pytest.mark.parametrize(('test_extension', 'expected'), ( + ('.mp4', False), + ('.avi', True), + ('.random.mp3', False), + ('mp4', False), )) -def test_filter_extension(test_url: str, expected: bool, download_filter: DownloadFilter): - result = download_filter._check_extension(test_url) +def test_filter_extension(test_extension: str, expected: bool, download_filter: DownloadFilter): + result = download_filter._check_extension(test_extension) assert result == expected @@ -42,7 +46,8 @@ def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadF ('http://reddit.com/test.gif', False), )) def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter): - result = download_filter.check_url(test_url) + test_resource = Resource(MagicMock(), test_url) + result = download_filter.check_resource(test_resource) assert result == expected @@ -54,5 +59,6 @@ def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilt )) def test_filter_empty_filter(test_url: str): download_filter = DownloadFilter() - result = download_filter.check_url(test_url) + test_resource = Resource(MagicMock(), test_url) + result = download_filter.check_resource(test_resource) assert result is True diff --git a/tests/test_integration.py b/tests/test_integration.py index 003a465..6345a7c 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -161,13 +161,14 @@ def test_cli_download_search_existing(test_args: list[str], tmp_path: Path): @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.parametrize('test_args', ( ['--subreddit', 'tumblr', '-L', '25', '--skip', 'png', '--skip', 'jpg'], + ['--subreddit', 'MaliciousCompliance', '-L', '25', '--skip', 'txt'], )) def test_cli_download_download_filters(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Download filter removed submission' in result.output + assert 'Download filter removed ' in result.output @pytest.mark.online