1
0
Fork 0
mirror of synced 2024-05-13 16:52:46 +12:00

Add ability to read IDs from files

This commit is contained in:
Serene-Arc 2021-07-05 16:58:33 +10:00
parent b58eebb51f
commit 1a4ff07f78
5 changed files with 34 additions and 13 deletions

View file

@ -6,9 +6,9 @@ import sys
import click import click
from bdfr.archiver import Archiver from bdfr.archiver import Archiver
from bdfr.cloner import RedditCloner
from bdfr.configuration import Configuration from bdfr.configuration import Configuration
from bdfr.downloader import RedditDownloader from bdfr.downloader import RedditDownloader
from bdfr.cloner import RedditCloner
logger = logging.getLogger() logger = logging.getLogger()
@ -17,6 +17,7 @@ _common_options = [
click.option('--authenticate', is_flag=True, default=None), click.option('--authenticate', is_flag=True, default=None),
click.option('--config', type=str, default=None), click.option('--config', type=str, default=None),
click.option('--disable-module', multiple=True, default=None, type=str), click.option('--disable-module', multiple=True, default=None, type=str),
click.option('--include-id-file', multiple=True, default=None),
click.option('--log', type=str, default=None), click.option('--log', type=str, default=None),
click.option('--saved', is_flag=True, default=None), click.option('--saved', is_flag=True, default=None),
click.option('--search', default=None, type=str), click.option('--search', default=None, type=str),
@ -26,12 +27,12 @@ _common_options = [
click.option('-L', '--limit', default=None, type=int), click.option('-L', '--limit', default=None, type=int),
click.option('-l', '--link', multiple=True, default=None, type=str), click.option('-l', '--link', multiple=True, default=None, type=str),
click.option('-m', '--multireddit', multiple=True, default=None, type=str), click.option('-m', '--multireddit', multiple=True, default=None, type=str),
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', 'controversial', 'rising', 'relevance')),
default=None),
click.option('-s', '--subreddit', multiple=True, default=None, type=str), click.option('-s', '--subreddit', multiple=True, default=None, type=str),
click.option('-v', '--verbose', default=None, count=True),
click.option('-u', '--user', type=str, multiple=True, default=None),
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None), click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', click.option('-u', '--user', type=str, multiple=True, default=None),
'controversial', 'rising', 'relevance')), default=None), click.option('-v', '--verbose', default=None, count=True),
] ]
_downloader_options = [ _downloader_options = [

View file

@ -18,6 +18,7 @@ class Configuration(Namespace):
self.exclude_id_file = [] self.exclude_id_file = []
self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}' self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}'
self.folder_scheme: str = '{SUBREDDIT}' self.folder_scheme: str = '{SUBREDDIT}'
self.include_id_file = []
self.limit: Optional[int] = None self.limit: Optional[int] = None
self.link: list[str] = [] self.link: list[str] = []
self.log: Optional[str] = None self.log: Optional[str] = None

View file

@ -3,6 +3,7 @@
import configparser import configparser
import importlib.resources import importlib.resources
import itertools
import logging import logging
import logging.handlers import logging.handlers
import re import re
@ -78,7 +79,12 @@ class RedditConnector(metaclass=ABCMeta):
self.create_reddit_instance() self.create_reddit_instance()
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user])) self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
self.excluded_submission_ids = self.read_excluded_ids() self.excluded_submission_ids = set.union(
self.read_id_files(self.args.exclude_id_file),
set(self.args.exclude_id),
)
self.args.link = list(itertools.chain(self.args.link, self.read_id_files(self.args.include_id_file)))
self.master_hash_list = {} self.master_hash_list = {}
self.authenticator = self.create_authenticator() self.authenticator = self.create_authenticator()
@ -403,13 +409,13 @@ class RedditConnector(metaclass=ABCMeta):
except prawcore.Forbidden: except prawcore.Forbidden:
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped') raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped')
def read_excluded_ids(self) -> set[str]: @staticmethod
def read_id_files(file_locations: list[str]) -> set[str]:
out = [] out = []
out.extend(self.args.exclude_id) for id_file in file_locations:
for id_file in self.args.exclude_id_file:
id_file = Path(id_file).resolve().expanduser() id_file = Path(id_file).resolve().expanduser()
if not id_file.exists(): if not id_file.exists():
logger.warning(f'ID exclusion file at {id_file} does not exist') logger.warning(f'ID file at {id_file} does not exist')
continue continue
with open(id_file, 'r') as file: with open(id_file, 'r') as file:
for line in file: for line in file:

View file

@ -306,3 +306,17 @@ def test_cli_download_disable_modules(test_args: list[str], tmp_path: Path):
assert result.exit_code == 0 assert result.exit_code == 0
assert 'skipped due to disabled module' in result.output assert 'skipped due to disabled module' in result.output
assert 'Downloaded submission' not in result.output assert 'Downloaded submission' not in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
def test_cli_download_include_id_file(tmp_path: Path):
test_file = Path(tmp_path, 'include.txt')
test_args = ['--include-id-file', str(test_file)]
test_file.write_text('odr9wg\nody576')
runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Downloaded submission' in result.output

View file

@ -339,11 +339,10 @@ def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: se
assert results == expected assert results == expected
def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
test_file = tmp_path / 'test.txt' test_file = tmp_path / 'test.txt'
test_file.write_text('aaaaaa\nbbbbbb') test_file.write_text('aaaaaa\nbbbbbb')
downloader_mock.args.exclude_id_file = [test_file] results = RedditConnector.read_id_files([str(test_file)])
results = RedditConnector.read_excluded_ids(downloader_mock)
assert results == {'aaaaaa', 'bbbbbb'} assert results == {'aaaaaa', 'bbbbbb'}