diff --git a/bdfr/__main__.py b/bdfr/__main__.py index 67e4f99..367f8c6 100644 --- a/bdfr/__main__.py +++ b/bdfr/__main__.py @@ -6,9 +6,9 @@ import sys import click from bdfr.archiver import Archiver +from bdfr.cloner import RedditCloner from bdfr.configuration import Configuration from bdfr.downloader import RedditDownloader -from bdfr.cloner import RedditCloner logger = logging.getLogger() @@ -17,6 +17,7 @@ _common_options = [ click.option('--authenticate', is_flag=True, default=None), click.option('--config', type=str, default=None), click.option('--disable-module', multiple=True, default=None, type=str), + click.option('--include-id-file', multiple=True, default=None), click.option('--log', type=str, default=None), click.option('--saved', is_flag=True, default=None), click.option('--search', default=None, type=str), @@ -26,12 +27,12 @@ _common_options = [ click.option('-L', '--limit', default=None, type=int), click.option('-l', '--link', multiple=True, default=None, type=str), click.option('-m', '--multireddit', multiple=True, default=None, type=str), + click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', 'controversial', 'rising', 'relevance')), + default=None), click.option('-s', '--subreddit', multiple=True, default=None, type=str), - click.option('-v', '--verbose', default=None, count=True), - click.option('-u', '--user', type=str, multiple=True, default=None), click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None), - click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', - 'controversial', 'rising', 'relevance')), default=None), + click.option('-u', '--user', type=str, multiple=True, default=None), + click.option('-v', '--verbose', default=None, count=True), ] _downloader_options = [ diff --git a/bdfr/configuration.py b/bdfr/configuration.py index 36a1860..bc4c541 100644 --- a/bdfr/configuration.py +++ b/bdfr/configuration.py @@ -18,6 +18,7 @@ class Configuration(Namespace): self.exclude_id_file = [] self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}' self.folder_scheme: str = '{SUBREDDIT}' + self.include_id_file = [] self.limit: Optional[int] = None self.link: list[str] = [] self.log: Optional[str] = None diff --git a/bdfr/connector.py b/bdfr/connector.py index 0e78c8c..a379847 100644 --- a/bdfr/connector.py +++ b/bdfr/connector.py @@ -3,6 +3,7 @@ import configparser import importlib.resources +import itertools import logging import logging.handlers import re @@ -78,7 +79,12 @@ class RedditConnector(metaclass=ABCMeta): self.create_reddit_instance() self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user])) - self.excluded_submission_ids = self.read_excluded_ids() + self.excluded_submission_ids = set.union( + self.read_id_files(self.args.exclude_id_file), + set(self.args.exclude_id), + ) + + self.args.link = list(itertools.chain(self.args.link, self.read_id_files(self.args.include_id_file))) self.master_hash_list = {} self.authenticator = self.create_authenticator() @@ -403,13 +409,13 @@ class RedditConnector(metaclass=ABCMeta): except prawcore.Forbidden: raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped') - def read_excluded_ids(self) -> set[str]: + @staticmethod + def read_id_files(file_locations: list[str]) -> set[str]: out = [] - out.extend(self.args.exclude_id) - for id_file in self.args.exclude_id_file: + for id_file in file_locations: id_file = Path(id_file).resolve().expanduser() if not id_file.exists(): - logger.warning(f'ID exclusion file at {id_file} does not exist') + logger.warning(f'ID file at {id_file} does not exist') continue with open(id_file, 'r') as file: for line in file: diff --git a/tests/integration_tests/test_download_integration.py b/tests/integration_tests/test_download_integration.py index 305fe99..cb4a273 100644 --- a/tests/integration_tests/test_download_integration.py +++ b/tests/integration_tests/test_download_integration.py @@ -306,3 +306,17 @@ def test_cli_download_disable_modules(test_args: list[str], tmp_path: Path): assert result.exit_code == 0 assert 'skipped due to disabled module' in result.output assert 'Downloaded submission' not in result.output + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') +def test_cli_download_include_id_file(tmp_path: Path): + test_file = Path(tmp_path, 'include.txt') + test_args = ['--include-id-file', str(test_file)] + test_file.write_text('odr9wg\nody576') + runner = CliRunner() + test_args = create_basic_args_for_download_runner(test_args, tmp_path) + result = runner.invoke(cli, test_args) + assert result.exit_code == 0 + assert 'Downloaded submission' in result.output diff --git a/tests/test_connector.py b/tests/test_connector.py index 15eede1..2dd76f9 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -339,11 +339,10 @@ def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: se assert results == expected -def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): +def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): test_file = tmp_path / 'test.txt' test_file.write_text('aaaaaa\nbbbbbb') - downloader_mock.args.exclude_id_file = [test_file] - results = RedditConnector.read_excluded_ids(downloader_mock) + results = RedditConnector.read_id_files([str(test_file)]) assert results == {'aaaaaa', 'bbbbbb'}