diff --git a/.gitignore b/.gitignore index 7207598..3918aa5 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,6 @@ cython_debug/ # Test configuration file test_config.cfg + +.vscode/ +.idea/ \ No newline at end of file diff --git a/README.md b/README.md index 15d50c2..a06e4af 100644 --- a/README.md +++ b/README.md @@ -27,16 +27,24 @@ If you want to use the source code or make contributions, refer to [CONTRIBUTING The BDFR works by taking submissions from a variety of "sources" from Reddit and then parsing them to download. These sources might be a subreddit, multireddit, a user list, or individual links. These sources are combined and downloaded to disk, according to a naming and organisational scheme defined by the user. -There are two modes to the BDFR: download, and archive. Each one has a command that performs similar but distinct functions. The `download` command will download the resource linked in the Reddit submission, such as the images, video, etc. The `archive` command will download the submission data itself and store it, such as the submission details, upvotes, text, statistics, as and all the comments on that submission. These can then be saved in a data markup language form, such as JSON, XML, or YAML. +There are three modes to the BDFR: download, archive, and clone. Each one has a command that performs similar but distinct functions. The `download` command will download the resource linked in the Reddit submission, such as the images, video, etc. The `archive` command will download the submission data itself and store it, such as the submission details, upvotes, text, statistics, as and all the comments on that submission. These can then be saved in a data markup language form, such as JSON, XML, or YAML. Lastly, the `clone` command will perform both functions of the previous commands at once and is more efficient than running those commands sequentially. + +Note that the `clone` command is not a true, failthful clone of Reddit. It simply retrieves much of the raw data that Reddit provides. To get a true clone of Reddit, another tool such as HTTrack should be used. After installation, run the program from any directory as shown below: + ```bash python3 -m bdfr download ``` + ```bash python3 -m bdfr archive ``` +```bash +python3 -m bdfr clone +``` + However, these commands are not enough. You should chain parameters in [Options](#options) according to your use case. Don't forget that some parameters can be provided multiple times. Some quick reference commands are: ```bash @@ -64,6 +72,10 @@ The following options are common between both the `archive` and `download` comma - `--config` - If the path to a configuration file is supplied with this option, the BDFR will use the specified config - See [Configuration Files](#configuration) for more details +- `--disable-module` + - Can be specified multiple times + - Disables certain modules from being used + - See [Disabling Modules](#disabling-modules) for more information and a list of module names - `--log` - This allows one to specify the location of the logfile - This must be done when running multiple instances of the BDFR, see [Multiple Instances](#multiple-instances) below @@ -124,6 +136,8 @@ The following options are common between both the `archive` and `download` comma - `-u, --user` - This specifies the user to scrape in concert with other options - When using `--authenticate`, `--user me` can be used to refer to the authenticated user + - Can be specified multiple times for multiple users + - If downloading a multireddit, only one user can be specified - `-v, --verbose` - Increases the verbosity of the program - Can be specified multiple times @@ -132,13 +146,6 @@ The following options are common between both the `archive` and `download` comma The following options apply only to the `download` command. This command downloads the files and resources linked to in the submission, or a text submission itself, to the disk in the specified directory. -- `--exclude-id` - - This will skip the download of any submission with the ID provided - - Can be specified multiple times -- `--exclude-id-file` - - This will skip the download of any submission with any of the IDs in the files provided - - Can be specified multiple times - - Format is one ID per line - `--make-hard-links` - This flag will create hard links to an existing file when a duplicate is downloaded - This will make the file appear in multiple directories while only taking the space of a single instance @@ -159,6 +166,13 @@ The following options apply only to the `download` command. This command downloa - Sets the scheme for folders - Default is `{SUBREDDIT}` - See [Folder and File Name Schemes](#folder-and-file-name-schemes) for more details +- `--exclude-id` + - This will skip the download of any submission with the ID provided + - Can be specified multiple times +- `--exclude-id-file` + - This will skip the download of any submission with any of the IDs in the files provided + - Can be specified multiple times + - Format is one ID per line - `--skip-domain` - This adds domains to the download filter i.e. submissions coming from these domains will not be downloaded - Can be specified multiple times @@ -183,6 +197,10 @@ The following options are for the `archive` command specifically. - `xml` - `yaml` +### Cloner Options + +The `clone` command can take all the options listed above for both the `archive` and `download` commands since it performs the functions of both. + ## Authentication and Security The BDFR uses OAuth2 authentication to connect to Reddit if authentication is required. This means that it is a secure, token-based system for making requests. This also means that the BDFR only has access to specific parts of the account authenticated, by default only saved posts, upvoted posts, and the identity of the authenticated account. Note that authentication is not required unless accessing private things like upvoted posts, saved posts, and private multireddits. @@ -253,6 +271,7 @@ The following keys are optional, and defaults will be used if they cannot be fou - `backup_log_count` - `max_wait_time` - `time_format` + - `disabled_modules` All of these should not be modified unless you know what you're doing, as the default values will enable the BDFR to function just fine. A configuration is included in the BDFR when it is installed, and this will be placed in the configuration directory as the default. @@ -264,6 +283,22 @@ The option `time_format` will specify the format of the timestamp that replaces The format can be specified through the [format codes](https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior) that are standard in the Python `datetime` library. +#### Disabling Modules + +The individual modules of the BDFR, used to download submissions from websites, can be disabled. This is helpful especially in the case of the fallback downloaders, since the `--skip-domain` option cannot be effectively used in these cases. For example, the Youtube-DL downloader can retrieve data from hundreds of websites and domains; thus the only way to fully disable it is via the `--disable-module` option. + +Modules can be disabled through the command line interface for the BDFR or more permanently in the configuration file via the `disabled_modules` option. The list of downloaders that can be disabled are the following. Note that they are case-insensitive. + +- `Direct` +- `Erome` +- `Gallery` (Reddit Image Galleries) +- `Gfycat` +- `Imgur` +- `Redgifs` +- `SelfPost` (Reddit Text Post) +- `Youtube` +- `YoutubeDlFallback` + ### Rate Limiting The option `max_wait_time` has to do with retrying downloads. There are certain HTTP errors that mean that no amount of requests will return the wanted data, but some errors are from rate-limiting. This is when a single client is making so many requests that the remote website cuts the client off to preserve the function of the site. This is a common situation when downloading many resources from the same site. It is polite and best practice to obey the website's wishes in these cases. diff --git a/bdfr/__main__.py b/bdfr/__main__.py index 372c7c3..1103581 100644 --- a/bdfr/__main__.py +++ b/bdfr/__main__.py @@ -8,35 +8,58 @@ import click from bdfr.archiver import Archiver from bdfr.configuration import Configuration from bdfr.downloader import RedditDownloader +from bdfr.cloner import RedditCloner logger = logging.getLogger() _common_options = [ click.argument('directory', type=str), - click.option('--config', type=str, default=None), - click.option('-v', '--verbose', default=None, count=True), - click.option('-l', '--link', multiple=True, default=None, type=str), - click.option('-s', '--subreddit', multiple=True, default=None, type=str), - click.option('-m', '--multireddit', multiple=True, default=None, type=str), - click.option('-L', '--limit', default=None, type=int), click.option('--authenticate', is_flag=True, default=None), + click.option('--config', type=str, default=None), + click.option('--disable-module', multiple=True, default=None, type=str), click.option('--log', type=str, default=None), - click.option('--submitted', is_flag=True, default=None), - click.option('--upvoted', is_flag=True, default=None), click.option('--saved', is_flag=True, default=None), click.option('--search', default=None, type=str), + click.option('--submitted', is_flag=True, default=None), click.option('--time-format', type=str, default=None), - click.option('-u', '--user', type=str, default=None), + click.option('--upvoted', is_flag=True, default=None), + click.option('-L', '--limit', default=None, type=int), + click.option('-l', '--link', multiple=True, default=None, type=str), + click.option('-m', '--multireddit', multiple=True, default=None, type=str), + click.option('-s', '--subreddit', multiple=True, default=None, type=str), + click.option('-v', '--verbose', default=None, count=True), + click.option('-u', '--user', type=str, multiple=True, default=None), click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None), click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', 'controversial', 'rising', 'relevance')), default=None), ] +_downloader_options = [ + click.option('--file-scheme', default=None, type=str), + click.option('--folder-scheme', default=None, type=str), + click.option('--make-hard-links', is_flag=True, default=None), + click.option('--max-wait-time', type=int, default=None), + click.option('--no-dupes', is_flag=True, default=None), + click.option('--search-existing', is_flag=True, default=None), + click.option('--exclude-id', default=None, multiple=True), + click.option('--exclude-id-file', default=None, multiple=True), + click.option('--skip', default=None, multiple=True), + click.option('--skip-domain', default=None, multiple=True), + click.option('--skip-subreddit', default=None, multiple=True), +] -def _add_common_options(func): - for opt in _common_options: - func = opt(func) - return func +_archiver_options = [ + click.option('--all-comments', is_flag=True, default=None), + click.option('-f', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None), +] + + +def _add_options(opts: list): + def wrap(func): + for opt in opts: + func = opt(func) + return func + return wrap @click.group() @@ -45,18 +68,8 @@ def cli(): @cli.command('download') -@click.option('--exclude-id', default=None, multiple=True) -@click.option('--exclude-id-file', default=None, multiple=True) -@click.option('--file-scheme', default=None, type=str) -@click.option('--folder-scheme', default=None, type=str) -@click.option('--make-hard-links', is_flag=True, default=None) -@click.option('--max-wait-time', type=int, default=None) -@click.option('--no-dupes', is_flag=True, default=None) -@click.option('--search-existing', is_flag=True, default=None) -@click.option('--skip', default=None, multiple=True) -@click.option('--skip-domain', default=None, multiple=True) -@click.option('--skip-subreddit', default=None, multiple=True) -@_add_common_options +@_add_options(_common_options) +@_add_options(_downloader_options) @click.pass_context def cli_download(context: click.Context, **_): config = Configuration() @@ -73,9 +86,8 @@ def cli_download(context: click.Context, **_): @cli.command('archive') -@_add_common_options -@click.option('--all-comments', is_flag=True, default=None) -@click.option('-f', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None) +@_add_options(_common_options) +@_add_options(_archiver_options) @click.pass_context def cli_archive(context: click.Context, **_): config = Configuration() @@ -85,7 +97,26 @@ def cli_archive(context: click.Context, **_): reddit_archiver = Archiver(config) reddit_archiver.download() except Exception: - logger.exception('Downloader exited unexpectedly') + logger.exception('Archiver exited unexpectedly') + raise + else: + logger.info('Program complete') + + +@cli.command('clone') +@_add_options(_common_options) +@_add_options(_archiver_options) +@_add_options(_downloader_options) +@click.pass_context +def cli_clone(context: click.Context, **_): + config = Configuration() + config.process_click_arguments(context) + setup_logging(config.verbose) + try: + reddit_scraper = RedditCloner(config) + reddit_scraper.download() + except Exception: + logger.exception('Scraper exited unexpectedly') raise else: logger.info('Program complete') diff --git a/bdfr/archive_entry/base_archive_entry.py b/bdfr/archive_entry/base_archive_entry.py index 775ed68..7b84fbe 100644 --- a/bdfr/archive_entry/base_archive_entry.py +++ b/bdfr/archive_entry/base_archive_entry.py @@ -26,6 +26,7 @@ class BaseArchiveEntry(ABC): 'stickied': in_comment.stickied, 'body': in_comment.body, 'is_submitter': in_comment.is_submitter, + 'distinguished': in_comment.distinguished, 'created_utc': in_comment.created_utc, 'parent_id': in_comment.parent_id, 'replies': [], diff --git a/bdfr/archive_entry/submission_archive_entry.py b/bdfr/archive_entry/submission_archive_entry.py index aaa423b..538aea8 100644 --- a/bdfr/archive_entry/submission_archive_entry.py +++ b/bdfr/archive_entry/submission_archive_entry.py @@ -35,6 +35,10 @@ class SubmissionArchiveEntry(BaseArchiveEntry): 'link_flair_text': self.source.link_flair_text, 'num_comments': self.source.num_comments, 'over_18': self.source.over_18, + 'spoiler': self.source.spoiler, + 'pinned': self.source.pinned, + 'locked': self.source.locked, + 'distinguished': self.source.distinguished, 'created_utc': self.source.created_utc, } diff --git a/bdfr/archiver.py b/bdfr/archiver.py index 1945dfe..b19a042 100644 --- a/bdfr/archiver.py +++ b/bdfr/archiver.py @@ -14,14 +14,14 @@ from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry from bdfr.configuration import Configuration -from bdfr.downloader import RedditDownloader +from bdfr.connector import RedditConnector from bdfr.exceptions import ArchiverError from bdfr.resource import Resource logger = logging.getLogger(__name__) -class Archiver(RedditDownloader): +class Archiver(RedditConnector): def __init__(self, args: Configuration): super(Archiver, self).__init__(args) @@ -29,9 +29,9 @@ class Archiver(RedditDownloader): for generator in self.reddit_lists: for submission in generator: logger.debug(f'Attempting to archive submission {submission.id}') - self._write_entry(submission) + self.write_entry(submission) - def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]: + def get_submissions_from_link(self) -> list[list[praw.models.Submission]]: supplied_submissions = [] for sub_id in self.args.link: if len(sub_id) == 6: @@ -42,12 +42,13 @@ class Archiver(RedditDownloader): supplied_submissions.append(self.reddit_instance.submission(url=sub_id)) return [supplied_submissions] - def _get_user_data(self) -> list[Iterator]: - results = super(Archiver, self)._get_user_data() + def get_user_data(self) -> list[Iterator]: + results = super(Archiver, self).get_user_data() if self.args.user and self.args.all_comments: - sort = self._determine_sort_function() - logger.debug(f'Retrieving comments of user {self.args.user}') - results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit)) + sort = self.determine_sort_function() + for user in self.args.user: + logger.debug(f'Retrieving comments of user {user}') + results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit)) return results @staticmethod @@ -59,7 +60,7 @@ class Archiver(RedditDownloader): else: raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}') - def _write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)): + def write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)): archive_entry = self._pull_lever_entry_factory(praw_item) if self.args.format == 'json': self._write_entry_json(archive_entry) diff --git a/bdfr/cloner.py b/bdfr/cloner.py new file mode 100644 index 0000000..979f50f --- /dev/null +++ b/bdfr/cloner.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +import logging + +from bdfr.archiver import Archiver +from bdfr.configuration import Configuration +from bdfr.downloader import RedditDownloader + +logger = logging.getLogger(__name__) + + +class RedditCloner(RedditDownloader, Archiver): + def __init__(self, args: Configuration): + super(RedditCloner, self).__init__(args) + + def download(self): + for generator in self.reddit_lists: + for submission in generator: + self._download_submission(submission) + self.write_entry(submission) diff --git a/bdfr/configuration.py b/bdfr/configuration.py index 9ab9d45..327a453 100644 --- a/bdfr/configuration.py +++ b/bdfr/configuration.py @@ -13,19 +13,21 @@ class Configuration(Namespace): self.authenticate = False self.config = None self.directory: str = '.' + self.disable_module: list[str] = [] self.exclude_id = [] self.exclude_id_file = [] + self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}' + self.folder_scheme: str = '{SUBREDDIT}' self.limit: Optional[int] = None self.link: list[str] = [] self.log: Optional[str] = None + self.make_hard_links = False self.max_wait_time = None self.multireddit: list[str] = [] self.no_dupes: bool = False self.saved: bool = False self.search: Optional[str] = None self.search_existing: bool = False - self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}' - self.folder_scheme: str = '{SUBREDDIT}' self.skip: list[str] = [] self.skip_domain: list[str] = [] self.skip_subreddit: list[str] = [] @@ -35,9 +37,8 @@ class Configuration(Namespace): self.time: str = 'all' self.time_format = None self.upvoted: bool = False - self.user: Optional[str] = None + self.user: list[str] = [] self.verbose: int = 0 - self.make_hard_links = False # Archiver-specific options self.format = 'json' diff --git a/bdfr/connector.py b/bdfr/connector.py new file mode 100644 index 0000000..68efc0c --- /dev/null +++ b/bdfr/connector.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +import configparser +import importlib.resources +import logging +import logging.handlers +import re +import shutil +import socket +from abc import ABCMeta, abstractmethod +from datetime import datetime +from enum import Enum, auto +from pathlib import Path +from typing import Callable, Iterator + +import appdirs +import praw +import praw.exceptions +import praw.models +import prawcore + +from bdfr import exceptions as errors +from bdfr.configuration import Configuration +from bdfr.download_filter import DownloadFilter +from bdfr.file_name_formatter import FileNameFormatter +from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager +from bdfr.site_authenticator import SiteAuthenticator + +logger = logging.getLogger(__name__) + + +class RedditTypes: + class SortType(Enum): + CONTROVERSIAL = auto() + HOT = auto() + NEW = auto() + RELEVENCE = auto() + RISING = auto() + TOP = auto() + + class TimeType(Enum): + ALL = 'all' + DAY = 'day' + HOUR = 'hour' + MONTH = 'month' + WEEK = 'week' + YEAR = 'year' + + +class RedditConnector(metaclass=ABCMeta): + def __init__(self, args: Configuration): + self.args = args + self.config_directories = appdirs.AppDirs('bdfr', 'BDFR') + self.run_time = datetime.now().isoformat() + self._setup_internal_objects() + + self.reddit_lists = self.retrieve_reddit_lists() + + def _setup_internal_objects(self): + self.determine_directories() + self.load_config() + self.create_file_logger() + + self.read_config() + + self.parse_disabled_modules() + + self.download_filter = self.create_download_filter() + logger.log(9, 'Created download filter') + self.time_filter = self.create_time_filter() + logger.log(9, 'Created time filter') + self.sort_filter = self.create_sort_filter() + logger.log(9, 'Created sort filter') + self.file_name_formatter = self.create_file_name_formatter() + logger.log(9, 'Create file name formatter') + + self.create_reddit_instance() + self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user])) + + self.excluded_submission_ids = self.read_excluded_ids() + + self.master_hash_list = {} + self.authenticator = self.create_authenticator() + logger.log(9, 'Created site authenticator') + + self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit) + self.args.skip_subreddit = set([sub.lower() for sub in self.args.skip_subreddit]) + + def read_config(self): + """Read any cfg values that need to be processed""" + if self.args.max_wait_time is None: + if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'): + self.cfg_parser.set('DEFAULT', 'max_wait_time', '120') + logger.log(9, 'Wrote default download wait time download to config file') + self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time') + logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds') + if self.args.time_format is None: + option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO') + if re.match(r'^[ \'\"]*$', option): + option = 'ISO' + logger.debug(f'Setting datetime format string to {option}') + self.args.time_format = option + if not self.args.disable_module: + self.args.disable_module = [self.cfg_parser.get('DEFAULT', 'disabled_modules', fallback='')] + # Update config on disk + with open(self.config_location, 'w') as file: + self.cfg_parser.write(file) + + def parse_disabled_modules(self): + disabled_modules = self.args.disable_module + disabled_modules = self.split_args_input(disabled_modules) + disabled_modules = set([name.strip().lower() for name in disabled_modules]) + self.args.disable_module = disabled_modules + logger.debug(f'Disabling the following modules: {", ".join(self.args.disable_module)}') + + def create_reddit_instance(self): + if self.args.authenticate: + logger.debug('Using authenticated Reddit instance') + if not self.cfg_parser.has_option('DEFAULT', 'user_token'): + logger.log(9, 'Commencing OAuth2 authentication') + scopes = self.cfg_parser.get('DEFAULT', 'scopes') + scopes = OAuth2Authenticator.split_scopes(scopes) + oauth2_authenticator = OAuth2Authenticator( + scopes, + self.cfg_parser.get('DEFAULT', 'client_id'), + self.cfg_parser.get('DEFAULT', 'client_secret'), + ) + token = oauth2_authenticator.retrieve_new_token() + self.cfg_parser['DEFAULT']['user_token'] = token + with open(self.config_location, 'w') as file: + self.cfg_parser.write(file, True) + token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location) + + self.authenticated = True + self.reddit_instance = praw.Reddit( + client_id=self.cfg_parser.get('DEFAULT', 'client_id'), + client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), + user_agent=socket.gethostname(), + token_manager=token_manager, + ) + else: + logger.debug('Using unauthenticated Reddit instance') + self.authenticated = False + self.reddit_instance = praw.Reddit( + client_id=self.cfg_parser.get('DEFAULT', 'client_id'), + client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), + user_agent=socket.gethostname(), + ) + + def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]: + master_list = [] + master_list.extend(self.get_subreddits()) + logger.log(9, 'Retrieved subreddits') + master_list.extend(self.get_multireddits()) + logger.log(9, 'Retrieved multireddits') + master_list.extend(self.get_user_data()) + logger.log(9, 'Retrieved user data') + master_list.extend(self.get_submissions_from_link()) + logger.log(9, 'Retrieved submissions for given links') + return master_list + + def determine_directories(self): + self.download_directory = Path(self.args.directory).resolve().expanduser() + self.config_directory = Path(self.config_directories.user_config_dir) + + self.download_directory.mkdir(exist_ok=True, parents=True) + self.config_directory.mkdir(exist_ok=True, parents=True) + + def load_config(self): + self.cfg_parser = configparser.ConfigParser() + if self.args.config: + if (cfg_path := Path(self.args.config)).exists(): + self.cfg_parser.read(cfg_path) + self.config_location = cfg_path + return + possible_paths = [ + Path('./config.cfg'), + Path('./default_config.cfg'), + Path(self.config_directory, 'config.cfg'), + Path(self.config_directory, 'default_config.cfg'), + ] + self.config_location = None + for path in possible_paths: + if path.resolve().expanduser().exists(): + self.config_location = path + logger.debug(f'Loading configuration from {path}') + break + if not self.config_location: + self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0] + shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg')) + if not self.config_location: + raise errors.BulkDownloaderException('Could not find a configuration file to load') + self.cfg_parser.read(self.config_location) + + def create_file_logger(self): + main_logger = logging.getLogger() + if self.args.log is None: + log_path = Path(self.config_directory, 'log_output.txt') + else: + log_path = Path(self.args.log).resolve().expanduser() + if not log_path.parent.exists(): + raise errors.BulkDownloaderException(f'Designated location for logfile does not exist') + backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3) + file_handler = logging.handlers.RotatingFileHandler( + log_path, + mode='a', + backupCount=backup_count, + ) + if log_path.exists(): + try: + file_handler.doRollover() + except PermissionError as e: + logger.critical( + 'Cannot rollover logfile, make sure this is the only ' + 'BDFR process or specify alternate logfile location') + raise + formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s') + file_handler.setFormatter(formatter) + file_handler.setLevel(0) + + main_logger.addHandler(file_handler) + + @staticmethod + def sanitise_subreddit_name(subreddit: str) -> str: + pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$') + match = re.match(pattern, subreddit) + if not match: + raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}') + return match.group(1) + + @staticmethod + def split_args_input(entries: list[str]) -> set[str]: + all_entries = [] + split_pattern = re.compile(r'[,;]\s?') + for entry in entries: + results = re.split(split_pattern, entry) + all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results]) + return set(all_entries) + + def get_subreddits(self) -> list[praw.models.ListingGenerator]: + if self.args.subreddit: + out = [] + for reddit in self.split_args_input(self.args.subreddit): + try: + reddit = self.reddit_instance.subreddit(reddit) + try: + self.check_subreddit_status(reddit) + except errors.BulkDownloaderException as e: + logger.error(e) + continue + if self.args.search: + out.append(reddit.search( + self.args.search, + sort=self.sort_filter.name.lower(), + limit=self.args.limit, + time_filter=self.time_filter.value, + )) + logger.debug( + f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"') + else: + out.append(self.create_filtered_listing_generator(reddit)) + logger.debug(f'Added submissions from subreddit {reddit}') + except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: + logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') + return out + else: + return [] + + def resolve_user_name(self, in_name: str) -> str: + if in_name == 'me': + if self.authenticated: + resolved_name = self.reddit_instance.user.me().name + logger.log(9, f'Resolved user to {resolved_name}') + return resolved_name + else: + logger.warning('To use "me" as a user, an authenticated Reddit instance must be used') + else: + return in_name + + def get_submissions_from_link(self) -> list[list[praw.models.Submission]]: + supplied_submissions = [] + for sub_id in self.args.link: + if len(sub_id) == 6: + supplied_submissions.append(self.reddit_instance.submission(id=sub_id)) + else: + supplied_submissions.append(self.reddit_instance.submission(url=sub_id)) + return [supplied_submissions] + + def determine_sort_function(self) -> Callable: + if self.sort_filter is RedditTypes.SortType.NEW: + sort_function = praw.models.Subreddit.new + elif self.sort_filter is RedditTypes.SortType.RISING: + sort_function = praw.models.Subreddit.rising + elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL: + sort_function = praw.models.Subreddit.controversial + elif self.sort_filter is RedditTypes.SortType.TOP: + sort_function = praw.models.Subreddit.top + else: + sort_function = praw.models.Subreddit.hot + return sort_function + + def get_multireddits(self) -> list[Iterator]: + if self.args.multireddit: + if len(self.args.user) != 1: + logger.error(f'Only 1 user can be supplied when retrieving from multireddits') + return [] + out = [] + for multi in self.split_args_input(self.args.multireddit): + try: + multi = self.reddit_instance.multireddit(self.args.user[0], multi) + if not multi.subreddits: + raise errors.BulkDownloaderException + out.append(self.create_filtered_listing_generator(multi)) + logger.debug(f'Added submissions from multireddit {multi}') + except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e: + logger.error(f'Failed to get submissions for multireddit {multi}: {e}') + return out + else: + return [] + + def create_filtered_listing_generator(self, reddit_source) -> Iterator: + sort_function = self.determine_sort_function() + if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL): + return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value) + else: + return sort_function(reddit_source, limit=self.args.limit) + + def get_user_data(self) -> list[Iterator]: + if any([self.args.submitted, self.args.upvoted, self.args.saved]): + if not self.args.user: + logger.warning('At least one user must be supplied to download user data') + return [] + generators = [] + for user in self.args.user: + try: + self.check_user_existence(user) + except errors.BulkDownloaderException as e: + logger.error(e) + continue + if self.args.submitted: + logger.debug(f'Retrieving submitted posts of user {self.args.user}') + generators.append(self.create_filtered_listing_generator( + self.reddit_instance.redditor(user).submissions, + )) + if not self.authenticated and any((self.args.upvoted, self.args.saved)): + logger.warning('Accessing user lists requires authentication') + else: + if self.args.upvoted: + logger.debug(f'Retrieving upvoted posts of user {self.args.user}') + generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit)) + if self.args.saved: + logger.debug(f'Retrieving saved posts of user {self.args.user}') + generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit)) + return generators + else: + return [] + + def check_user_existence(self, name: str): + user = self.reddit_instance.redditor(name=name) + try: + if user.id: + return + except prawcore.exceptions.NotFound: + raise errors.BulkDownloaderException(f'Could not find user {name}') + except AttributeError: + if hasattr(user, 'is_suspended'): + raise errors.BulkDownloaderException(f'User {name} is banned') + + def create_file_name_formatter(self) -> FileNameFormatter: + return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format) + + def create_time_filter(self) -> RedditTypes.TimeType: + try: + return RedditTypes.TimeType[self.args.time.upper()] + except (KeyError, AttributeError): + return RedditTypes.TimeType.ALL + + def create_sort_filter(self) -> RedditTypes.SortType: + try: + return RedditTypes.SortType[self.args.sort.upper()] + except (KeyError, AttributeError): + return RedditTypes.SortType.HOT + + def create_download_filter(self) -> DownloadFilter: + return DownloadFilter(self.args.skip, self.args.skip_domain) + + def create_authenticator(self) -> SiteAuthenticator: + return SiteAuthenticator(self.cfg_parser) + + @abstractmethod + def download(self): + pass + + @staticmethod + def check_subreddit_status(subreddit: praw.models.Subreddit): + if subreddit.display_name == 'all': + return + try: + assert subreddit.id + except prawcore.NotFound: + raise errors.BulkDownloaderException(f'Source {subreddit.display_name} does not exist or cannot be found') + except prawcore.Forbidden: + raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped') + + def read_excluded_ids(self) -> set[str]: + out = [] + out.extend(self.args.exclude_id) + for id_file in self.args.exclude_id_file: + id_file = Path(id_file).resolve().expanduser() + if not id_file.exists(): + logger.warning(f'ID exclusion file at {id_file} does not exist') + continue + with open(id_file, 'r') as file: + for line in file: + out.append(line.strip()) + return set(out) diff --git a/bdfr/downloader.py b/bdfr/downloader.py index 1625c8f..61158a3 100644 --- a/bdfr/downloader.py +++ b/bdfr/downloader.py @@ -1,405 +1,61 @@ #!/usr/bin/env python3 # coding=utf-8 -import configparser import hashlib -import importlib.resources -import logging import logging.handlers import os -import re -import shutil -import socket +import time from datetime import datetime -from enum import Enum, auto from multiprocessing import Pool from pathlib import Path -from typing import Callable, Iterator -import appdirs import praw import praw.exceptions import praw.models -import prawcore -import bdfr.exceptions as errors +from bdfr import exceptions as errors from bdfr.configuration import Configuration -from bdfr.download_filter import DownloadFilter -from bdfr.file_name_formatter import FileNameFormatter -from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager -from bdfr.site_authenticator import SiteAuthenticator +from bdfr.connector import RedditConnector from bdfr.site_downloaders.download_factory import DownloadFactory logger = logging.getLogger(__name__) def _calc_hash(existing_file: Path): + chunk_size = 1024 * 1024 + md5_hash = hashlib.md5() with open(existing_file, 'rb') as file: - file_hash = hashlib.md5(file.read()).hexdigest() - return existing_file, file_hash + chunk = file.read(chunk_size) + while chunk: + md5_hash.update(chunk) + chunk = file.read(chunk_size) + file_hash = md5_hash.hexdigest() + return existing_file, file_hash -class RedditTypes: - class SortType(Enum): - CONTROVERSIAL = auto() - HOT = auto() - NEW = auto() - RELEVENCE = auto() - RISING = auto() - TOP = auto() - - class TimeType(Enum): - ALL = 'all' - DAY = 'day' - HOUR = 'hour' - MONTH = 'month' - WEEK = 'week' - YEAR = 'year' - - -class RedditDownloader: +class RedditDownloader(RedditConnector): def __init__(self, args: Configuration): - self.args = args - self.config_directories = appdirs.AppDirs('bdfr', 'BDFR') - self.run_time = datetime.now().isoformat() - self._setup_internal_objects() - - self.reddit_lists = self._retrieve_reddit_lists() - - def _setup_internal_objects(self): - self._determine_directories() - self._load_config() - self._create_file_logger() - - self._read_config() - - self.download_filter = self._create_download_filter() - logger.log(9, 'Created download filter') - self.time_filter = self._create_time_filter() - logger.log(9, 'Created time filter') - self.sort_filter = self._create_sort_filter() - logger.log(9, 'Created sort filter') - self.file_name_formatter = self._create_file_name_formatter() - logger.log(9, 'Create file name formatter') - - self._create_reddit_instance() - self._resolve_user_name() - - self.excluded_submission_ids = self._read_excluded_ids() - + super(RedditDownloader, self).__init__(args) if self.args.search_existing: self.master_hash_list = self.scan_existing_files(self.download_directory) - else: - self.master_hash_list = {} - self.authenticator = self._create_authenticator() - logger.log(9, 'Created site authenticator') - - self.args.skip_subreddit = self._split_args_input(self.args.skip_subreddit) - self.args.skip_subreddit = set([sub.lower() for sub in self.args.skip_subreddit]) - - def _read_config(self): - """Read any cfg values that need to be processed""" - if self.args.max_wait_time is None: - if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'): - self.cfg_parser.set('DEFAULT', 'max_wait_time', '120') - logger.log(9, 'Wrote default download wait time download to config file') - self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time') - logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds') - if self.args.time_format is None: - option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO') - if re.match(r'^[ \'\"]*$', option): - option = 'ISO' - logger.debug(f'Setting datetime format string to {option}') - self.args.time_format = option - # Update config on disk - with open(self.config_location, 'w') as file: - self.cfg_parser.write(file) - - def _create_reddit_instance(self): - if self.args.authenticate: - logger.debug('Using authenticated Reddit instance') - if not self.cfg_parser.has_option('DEFAULT', 'user_token'): - logger.log(9, 'Commencing OAuth2 authentication') - scopes = self.cfg_parser.get('DEFAULT', 'scopes') - scopes = OAuth2Authenticator.split_scopes(scopes) - oauth2_authenticator = OAuth2Authenticator( - scopes, - self.cfg_parser.get('DEFAULT', 'client_id'), - self.cfg_parser.get('DEFAULT', 'client_secret'), - ) - token = oauth2_authenticator.retrieve_new_token() - self.cfg_parser['DEFAULT']['user_token'] = token - with open(self.config_location, 'w') as file: - self.cfg_parser.write(file, True) - token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location) - - self.authenticated = True - self.reddit_instance = praw.Reddit( - client_id=self.cfg_parser.get('DEFAULT', 'client_id'), - client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), - user_agent=socket.gethostname(), - token_manager=token_manager, - ) - else: - logger.debug('Using unauthenticated Reddit instance') - self.authenticated = False - self.reddit_instance = praw.Reddit( - client_id=self.cfg_parser.get('DEFAULT', 'client_id'), - client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), - user_agent=socket.gethostname(), - ) - - def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]: - master_list = [] - master_list.extend(self._get_subreddits()) - logger.log(9, 'Retrieved subreddits') - master_list.extend(self._get_multireddits()) - logger.log(9, 'Retrieved multireddits') - master_list.extend(self._get_user_data()) - logger.log(9, 'Retrieved user data') - master_list.extend(self._get_submissions_from_link()) - logger.log(9, 'Retrieved submissions for given links') - return master_list - - def _determine_directories(self): - self.download_directory = Path(self.args.directory).resolve().expanduser() - self.config_directory = Path(self.config_directories.user_config_dir) - - self.download_directory.mkdir(exist_ok=True, parents=True) - self.config_directory.mkdir(exist_ok=True, parents=True) - - def _load_config(self): - self.cfg_parser = configparser.ConfigParser() - if self.args.config: - if (cfg_path := Path(self.args.config)).exists(): - self.cfg_parser.read(cfg_path) - self.config_location = cfg_path - return - possible_paths = [ - Path('./config.cfg'), - Path('./default_config.cfg'), - Path(self.config_directory, 'config.cfg'), - Path(self.config_directory, 'default_config.cfg'), - ] - self.config_location = None - for path in possible_paths: - if path.resolve().expanduser().exists(): - self.config_location = path - logger.debug(f'Loading configuration from {path}') - break - if not self.config_location: - self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0] - shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg')) - if not self.config_location: - raise errors.BulkDownloaderException('Could not find a configuration file to load') - self.cfg_parser.read(self.config_location) - - def _create_file_logger(self): - main_logger = logging.getLogger() - if self.args.log is None: - log_path = Path(self.config_directory, 'log_output.txt') - else: - log_path = Path(self.args.log).resolve().expanduser() - if not log_path.parent.exists(): - raise errors.BulkDownloaderException(f'Designated location for logfile does not exist') - backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3) - file_handler = logging.handlers.RotatingFileHandler( - log_path, - mode='a', - backupCount=backup_count, - ) - if log_path.exists(): - try: - file_handler.doRollover() - except PermissionError as e: - logger.critical( - 'Cannot rollover logfile, make sure this is the only ' - 'BDFR process or specify alternate logfile location') - raise - formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s') - file_handler.setFormatter(formatter) - file_handler.setLevel(0) - - main_logger.addHandler(file_handler) - - @staticmethod - def _sanitise_subreddit_name(subreddit: str) -> str: - pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$') - match = re.match(pattern, subreddit) - if not match: - raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}') - return match.group(1) - - @staticmethod - def _split_args_input(entries: list[str]) -> set[str]: - all_entries = [] - split_pattern = re.compile(r'[,;]\s?') - for entry in entries: - results = re.split(split_pattern, entry) - all_entries.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results]) - return set(all_entries) - - def _get_subreddits(self) -> list[praw.models.ListingGenerator]: - if self.args.subreddit: - out = [] - for reddit in self._split_args_input(self.args.subreddit): - try: - reddit = self.reddit_instance.subreddit(reddit) - try: - self._check_subreddit_status(reddit) - except errors.BulkDownloaderException as e: - logger.error(e) - continue - if self.args.search: - out.append(reddit.search( - self.args.search, - sort=self.sort_filter.name.lower(), - limit=self.args.limit, - time_filter=self.time_filter.value, - )) - logger.debug( - f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"') - else: - out.append(self._create_filtered_listing_generator(reddit)) - logger.debug(f'Added submissions from subreddit {reddit}') - except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: - logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') - return out - else: - return [] - - def _resolve_user_name(self): - if self.args.user == 'me': - if self.authenticated: - self.args.user = self.reddit_instance.user.me().name - logger.log(9, f'Resolved user to {self.args.user}') - else: - self.args.user = None - logger.warning('To use "me" as a user, an authenticated Reddit instance must be used') - - def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]: - supplied_submissions = [] - for sub_id in self.args.link: - if len(sub_id) == 6: - supplied_submissions.append(self.reddit_instance.submission(id=sub_id)) - else: - supplied_submissions.append(self.reddit_instance.submission(url=sub_id)) - return [supplied_submissions] - - def _determine_sort_function(self) -> Callable: - if self.sort_filter is RedditTypes.SortType.NEW: - sort_function = praw.models.Subreddit.new - elif self.sort_filter is RedditTypes.SortType.RISING: - sort_function = praw.models.Subreddit.rising - elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL: - sort_function = praw.models.Subreddit.controversial - elif self.sort_filter is RedditTypes.SortType.TOP: - sort_function = praw.models.Subreddit.top - else: - sort_function = praw.models.Subreddit.hot - return sort_function - - def _get_multireddits(self) -> list[Iterator]: - if self.args.multireddit: - out = [] - for multi in self._split_args_input(self.args.multireddit): - try: - multi = self.reddit_instance.multireddit(self.args.user, multi) - if not multi.subreddits: - raise errors.BulkDownloaderException - out.append(self._create_filtered_listing_generator(multi)) - logger.debug(f'Added submissions from multireddit {multi}') - except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e: - logger.error(f'Failed to get submissions for multireddit {multi}: {e}') - return out - else: - return [] - - def _create_filtered_listing_generator(self, reddit_source) -> Iterator: - sort_function = self._determine_sort_function() - if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL): - return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value) - else: - return sort_function(reddit_source, limit=self.args.limit) - - def _get_user_data(self) -> list[Iterator]: - if any([self.args.submitted, self.args.upvoted, self.args.saved]): - if self.args.user: - try: - self._check_user_existence(self.args.user) - except errors.BulkDownloaderException as e: - logger.error(e) - return [] - generators = [] - if self.args.submitted: - logger.debug(f'Retrieving submitted posts of user {self.args.user}') - generators.append(self._create_filtered_listing_generator( - self.reddit_instance.redditor(self.args.user).submissions, - )) - if not self.authenticated and any((self.args.upvoted, self.args.saved)): - logger.warning('Accessing user lists requires authentication') - else: - if self.args.upvoted: - logger.debug(f'Retrieving upvoted posts of user {self.args.user}') - generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit)) - if self.args.saved: - logger.debug(f'Retrieving saved posts of user {self.args.user}') - generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit)) - return generators - else: - logger.warning('A user must be supplied to download user data') - return [] - else: - return [] - - def _check_user_existence(self, name: str): - user = self.reddit_instance.redditor(name=name) - try: - if user.id: - return - except prawcore.exceptions.NotFound: - raise errors.BulkDownloaderException(f'Could not find user {name}') - except AttributeError: - if hasattr(user, 'is_suspended'): - raise errors.BulkDownloaderException(f'User {name} is banned') - - def _create_file_name_formatter(self) -> FileNameFormatter: - return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format) - - def _create_time_filter(self) -> RedditTypes.TimeType: - try: - return RedditTypes.TimeType[self.args.time.upper()] - except (KeyError, AttributeError): - return RedditTypes.TimeType.ALL - - def _create_sort_filter(self) -> RedditTypes.SortType: - try: - return RedditTypes.SortType[self.args.sort.upper()] - except (KeyError, AttributeError): - return RedditTypes.SortType.HOT - - def _create_download_filter(self) -> DownloadFilter: - return DownloadFilter(self.args.skip, self.args.skip_domain) - - def _create_authenticator(self) -> SiteAuthenticator: - return SiteAuthenticator(self.cfg_parser) def download(self): for generator in self.reddit_lists: for submission in generator: - if submission.id in self.excluded_submission_ids: - logger.debug(f'Object {submission.id} in exclusion list, skipping') - continue - elif submission.subreddit.display_name.lower() in self.args.skip_subreddit: - logger.debug(f'Submission {submission.id} in {submission.subreddit.display_name} in skip list') - else: - logger.debug(f'Attempting to download submission {submission.id}') - self._download_submission(submission) + self._download_submission(submission) def _download_submission(self, submission: praw.models.Submission): - if not isinstance(submission, praw.models.Submission): + if submission.id in self.excluded_submission_ids: + logger.debug(f'Object {submission.id} in exclusion list, skipping') + return + elif submission.subreddit.display_name.lower() in self.args.skip_subreddit: + logger.debug(f'Submission {submission.id} in {submission.subreddit.display_name} in skip list') + return + elif not isinstance(submission, praw.models.Submission): logger.warning(f'{submission.id} is not a submission') return + + logger.debug(f'Attempting to download submission {submission.id}') try: downloader_class = DownloadFactory.pull_lever(submission.url) downloader = downloader_class(submission) @@ -407,7 +63,9 @@ class RedditDownloader: except errors.NotADownloadableLinkError as e: logger.error(f'Could not download submission {submission.id}: {e}') return - + if downloader_class.__name__.lower() in self.args.disable_module: + logger.debug(f'Submission {submission.id} skipped due to disabled module {downloader_class.__name__}') + return try: content = downloader.find_resources(self.authenticator) except errors.SiteDownloaderError as e: @@ -415,34 +73,42 @@ class RedditDownloader: return for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory): if destination.exists(): - logger.debug(f'File {destination} already exists, continuing') + logger.debug(f'File {destination} from submission {submission.id} already exists, continuing') + continue elif not self.download_filter.check_resource(res): logger.debug(f'Download filter removed {submission.id} with URL {submission.url}') - else: - try: - res.download(self.args.max_wait_time) - except errors.BulkDownloaderException as e: - logger.error(f'Failed to download resource {res.url} in submission {submission.id} ' - f'with downloader {downloader_class.__name__}: {e}') + continue + try: + res.download(self.args.max_wait_time) + except errors.BulkDownloaderException as e: + logger.error(f'Failed to download resource {res.url} in submission {submission.id} ' + f'with downloader {downloader_class.__name__}: {e}') + return + resource_hash = res.hash.hexdigest() + destination.parent.mkdir(parents=True, exist_ok=True) + if resource_hash in self.master_hash_list: + if self.args.no_dupes: + logger.info( + f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere') return - resource_hash = res.hash.hexdigest() - destination.parent.mkdir(parents=True, exist_ok=True) - if resource_hash in self.master_hash_list: - if self.args.no_dupes: - logger.info( - f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere') - return - elif self.args.make_hard_links: - self.master_hash_list[resource_hash].link_to(destination) - logger.info( - f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}') - return + elif self.args.make_hard_links: + self.master_hash_list[resource_hash].link_to(destination) + logger.info( + f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}' + f' in submission {submission.id}') + return + try: with open(destination, 'wb') as file: file.write(res.content) logger.debug(f'Written file to {destination}') - self.master_hash_list[resource_hash] = destination - logger.debug(f'Hash added to master list: {resource_hash}') - logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}') + except OSError as e: + logger.exception(e) + logger.error(f'Failed to write file to {destination} in submission {submission.id}: {e}') + creation_time = time.mktime(datetime.fromtimestamp(submission.created_utc).timetuple()) + os.utime(destination, (creation_time, creation_time)) + self.master_hash_list[resource_hash] = destination + logger.debug(f'Hash added to master list: {resource_hash}') + logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}') @staticmethod def scan_existing_files(directory: Path) -> dict[str, Path]: @@ -457,27 +123,3 @@ class RedditDownloader: hash_list = {res[1]: res[0] for res in results} return hash_list - - def _read_excluded_ids(self) -> set[str]: - out = [] - out.extend(self.args.exclude_id) - for id_file in self.args.exclude_id_file: - id_file = Path(id_file).resolve().expanduser() - if not id_file.exists(): - logger.warning(f'ID exclusion file at {id_file} does not exist') - continue - with open(id_file, 'r') as file: - for line in file: - out.append(line.strip()) - return set(out) - - @staticmethod - def _check_subreddit_status(subreddit: praw.models.Subreddit): - if subreddit.display_name == 'all': - return - try: - assert subreddit.id - except prawcore.NotFound: - raise errors.BulkDownloaderException(f'Source {subreddit.display_name} does not exist or cannot be found') - except prawcore.Forbidden: - raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped') diff --git a/bdfr/file_name_formatter.py b/bdfr/file_name_formatter.py index c6c13c2..2fbf95f 100644 --- a/bdfr/file_name_formatter.py +++ b/bdfr/file_name_formatter.py @@ -4,6 +4,7 @@ import datetime import logging import platform import re +import subprocess from pathlib import Path from typing import Optional @@ -104,32 +105,46 @@ class FileNameFormatter: ) -> Path: subfolder = Path( destination_directory, - *[self._format_name(resource.source_submission, part) for part in self.directory_format_string] + *[self._format_name(resource.source_submission, part) for part in self.directory_format_string], ) index = f'_{str(index)}' if index else '' if not resource.extension: raise BulkDownloaderException(f'Resource from {resource.url} has no extension') ending = index + resource.extension file_name = str(self._format_name(resource.source_submission, self.file_format_string)) - file_name = self._limit_file_name_length(file_name, ending) try: - file_path = Path(subfolder, file_name) + file_path = self._limit_file_name_length(file_name, ending, subfolder) except TypeError: raise BulkDownloaderException(f'Could not determine path name: {subfolder}, {index}, {resource.extension}') return file_path @staticmethod - def _limit_file_name_length(filename: str, ending: str) -> str: + def _limit_file_name_length(filename: str, ending: str, root: Path) -> Path: + root = root.resolve().expanduser() possible_id = re.search(r'((?:_\w{6})?$)', filename) if possible_id: ending = possible_id.group(1) + ending filename = filename[:possible_id.start()] + max_path = FileNameFormatter.find_max_path_length() max_length_chars = 255 - len(ending) max_length_bytes = 255 - len(ending.encode('utf-8')) - while len(filename) > max_length_chars or len(filename.encode('utf-8')) > max_length_bytes: + max_path_length = max_path - len(ending) - len(str(root)) - 1 + while len(filename) > max_length_chars or \ + len(filename.encode('utf-8')) > max_length_bytes or \ + len(filename) > max_path_length: filename = filename[:-1] - return filename + ending + return Path(root, filename + ending) + + @staticmethod + def find_max_path_length() -> int: + try: + return int(subprocess.check_output(['getconf', 'PATH_MAX', '/'])) + except (ValueError, subprocess.CalledProcessError, OSError): + if platform.system() == 'Windows': + return 260 + else: + return 4096 def format_resource_paths( self, diff --git a/bdfr/resource.py b/bdfr/resource.py index 966f5ba..e8f9fd1 100644 --- a/bdfr/resource.py +++ b/bdfr/resource.py @@ -5,8 +5,8 @@ import hashlib import logging import re import time -from typing import Optional import urllib.parse +from typing import Optional import _hashlib import requests @@ -28,8 +28,7 @@ class Resource: self.extension = self._determine_extension() @staticmethod - def retry_download(url: str, max_wait_time: int) -> Optional[bytes]: - wait_time = 60 + def retry_download(url: str, max_wait_time: int, current_wait_time: int = 60) -> Optional[bytes]: try: response = requests.get(url) if re.match(r'^2\d{2}', str(response.status_code)) and response.content: @@ -39,11 +38,12 @@ class Resource: else: raise BulkDownloaderException( f'Unrecoverable error requesting resource: HTTP Code {response.status_code}') - except requests.exceptions.ConnectionError as e: - logger.warning(f'Error occured downloading from {url}, waiting {wait_time} seconds: {e}') - time.sleep(wait_time) - if wait_time < max_wait_time: - return Resource.retry_download(url, max_wait_time) + except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e: + logger.warning(f'Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}') + time.sleep(current_wait_time) + if current_wait_time < max_wait_time: + current_wait_time += 60 + return Resource.retry_download(url, max_wait_time, current_wait_time) else: logger.error(f'Max wait time exceeded for resource at url {url}') raise diff --git a/bdfr/site_downloaders/download_factory.py b/bdfr/site_downloaders/download_factory.py index 7035dc2..41813f9 100644 --- a/bdfr/site_downloaders/download_factory.py +++ b/bdfr/site_downloaders/download_factory.py @@ -21,10 +21,11 @@ from bdfr.site_downloaders.youtube import Youtube class DownloadFactory: @staticmethod def pull_lever(url: str) -> Type[BaseDownloader]: - sanitised_url = DownloadFactory._sanitise_url(url) + sanitised_url = DownloadFactory.sanitise_url(url) if re.match(r'(i\.)?imgur.*\.gifv$', sanitised_url): return Imgur - elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url): + elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url) and \ + not DownloadFactory.is_web_resource(sanitised_url): return Direct elif re.match(r'erome\.com.*', sanitised_url): return Erome @@ -49,9 +50,29 @@ class DownloadFactory: f'No downloader module exists for url {url}') @staticmethod - def _sanitise_url(url: str) -> str: + def sanitise_url(url: str) -> str: beginning_regex = re.compile(r'\s*(www\.?)?') split_url = urllib.parse.urlsplit(url) split_url = split_url.netloc + split_url.path split_url = re.sub(beginning_regex, '', split_url) return split_url + + @staticmethod + def is_web_resource(url: str) -> bool: + web_extensions = ( + 'asp', + 'aspx', + 'cfm', + 'cfml', + 'css', + 'htm', + 'html', + 'js', + 'php', + 'php3', + 'xhtml', + ) + if re.match(rf'(?i).*/.*\.({"|".join(web_extensions)})$', url): + return True + else: + return False diff --git a/bdfr/site_downloaders/imgur.py b/bdfr/site_downloaders/imgur.py index 6ae8a5e..3d071d4 100644 --- a/bdfr/site_downloaders/imgur.py +++ b/bdfr/site_downloaders/imgur.py @@ -71,6 +71,7 @@ class Imgur(BaseDownloader): @staticmethod def _validate_extension(extension_suffix: str) -> str: + extension_suffix = extension_suffix.strip('?1') possible_extensions = ('.jpg', '.png', '.mp4', '.gif') selection = [ext for ext in possible_extensions if ext == extension_suffix] if len(selection) == 1: diff --git a/scripts/extract_failed_ids.sh b/scripts/extract_failed_ids.sh index cdf1f21..89f1896 100755 --- a/scripts/extract_failed_ids.sh +++ b/scripts/extract_failed_ids.sh @@ -14,5 +14,9 @@ else output="failed.txt" fi -grep 'Could not download submission' "$file" | awk '{ print $12 }' | rev | cut -c 2- | rev >>"$output" -grep 'Failed to download resource' "$file" | awk '{ print $15 }' >>"$output" +{ + grep 'Could not download submission' "$file" | awk '{ print $12 }' | rev | cut -c 2- | rev ; + grep 'Failed to download resource' "$file" | awk '{ print $15 }' ; + grep 'failed to download submission' "$file" | awk '{ print $14 }' | rev | cut -c 2- | rev ; + grep 'Failed to write file' "$file" | awk '{ print $16 }' | rev | cut -c 2- | rev ; +} >>"$output" diff --git a/scripts/extract_successful_ids.sh b/scripts/extract_successful_ids.sh index 3b6f7bc..19e8bd7 100755 --- a/scripts/extract_successful_ids.sh +++ b/scripts/extract_successful_ids.sh @@ -14,4 +14,10 @@ else output="successful.txt" fi -grep 'Downloaded submission' "$file" | awk '{ print $(NF-2) }' >> "$output" +{ + grep 'Downloaded submission' "$file" | awk '{ print $(NF-2) }' ; + grep 'Resource hash' "$file" | awk '{ print $(NF-2) }' ; + grep 'Download filter' "$file" | awk '{ print $(NF-3) }' ; + grep 'already exists, continuing' "$file" | awk '{ print $(NF-3) }' ; + grep 'Hard link made' "$file" | awk '{ print $(NF) }' ; +} >> "$output" diff --git a/setup.cfg b/setup.cfg index b1345d9..2969fe0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,7 +4,7 @@ description_file = README.md description_content_type = text/markdown home_page = https://github.com/aliparlakci/bulk-downloader-for-reddit keywords = reddit, download, archive -version = 2.1.1 +version = 2.2.0 author = Ali Parlakci author_email = parlakciali@gmail.com maintainer = Serene Arc diff --git a/tests/archive_entry/test_comment_archive_entry.py b/tests/archive_entry/test_comment_archive_entry.py index 27dfcb3..e453d27 100644 --- a/tests/archive_entry/test_comment_archive_entry.py +++ b/tests/archive_entry/test_comment_archive_entry.py @@ -15,6 +15,7 @@ from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry 'subreddit': 'Python', 'submission': 'mgi4op', 'submission_title': '76% Faster CPython', + 'distinguished': None, }), )) def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_instance: praw.Reddit): diff --git a/tests/archive_entry/test_submission_archive_entry.py b/tests/archive_entry/test_submission_archive_entry.py index 2b1bb72..60f47b5 100644 --- a/tests/archive_entry/test_submission_archive_entry.py +++ b/tests/archive_entry/test_submission_archive_entry.py @@ -26,6 +26,13 @@ def test_get_comments(test_submission_id: str, min_comments: int, reddit_instanc 'author': 'sinjen-tos', 'id': 'm3reby', 'link_flair_text': 'image', + 'pinned': False, + 'spoiler': False, + 'over_18': False, + 'locked': False, + 'distinguished': None, + 'created_utc': 1615583837, + 'permalink': '/r/australia/comments/m3reby/this_little_guy_fell_out_of_a_tree_and_in_front/' }), ('m3kua3', {'author': 'DELETED'}), )) diff --git a/tests/site_downloaders/fallback_downloaders/youtubedl_fallback.py b/tests/site_downloaders/fallback_downloaders/test_youtubedl_fallback.py similarity index 100% rename from tests/site_downloaders/fallback_downloaders/youtubedl_fallback.py rename to tests/site_downloaders/fallback_downloaders/test_youtubedl_fallback.py diff --git a/tests/site_downloaders/test_download_factory.py b/tests/site_downloaders/test_download_factory.py index f02e9f7..4b5356c 100644 --- a/tests/site_downloaders/test_download_factory.py +++ b/tests/site_downloaders/test_download_factory.py @@ -69,6 +69,19 @@ def test_factory_lever_bad(test_url: str): ('https://youtube.com/watch?v=Gv8Wz74FjVA', 'youtube.com/watch'), ('https://i.imgur.com/BuzvZwb.gifv', 'i.imgur.com/BuzvZwb.gifv'), )) -def test_sanitise_urll(test_url: str, expected: str): - result = DownloadFactory._sanitise_url(test_url) +def test_sanitise_url(test_url: str, expected: str): + result = DownloadFactory.sanitise_url(test_url) + assert result == expected + + +@pytest.mark.parametrize(('test_url', 'expected'), ( + ('www.example.com/test.asp', True), + ('www.example.com/test.html', True), + ('www.example.com/test.js', True), + ('www.example.com/test.xhtml', True), + ('www.example.com/test.mp4', False), + ('www.example.com/test.png', False), +)) +def test_is_web_resource(test_url: str, expected: bool): + result = DownloadFactory.is_web_resource(test_url) assert result == expected diff --git a/tests/site_downloaders/test_erome.py b/tests/site_downloaders/test_erome.py index 1de9afd..84546c4 100644 --- a/tests/site_downloaders/test_erome.py +++ b/tests/site_downloaders/test_erome.py @@ -34,9 +34,6 @@ def test_get_link(test_url: str, expected_urls: tuple[str]): ('https://www.erome.com/a/vqtPuLXh', { '5da2a8d60d87bed279431fdec8e7d72f' }), - ('https://www.erome.com/i/ItASD33e', { - 'b0d73fedc9ce6995c2f2c4fdb6f11eff' - }), ('https://www.erome.com/a/lGrcFxmb', { '0e98f9f527a911dcedde4f846bb5b69f', '25696ae364750a5303fc7d7dc78b35c1', diff --git a/tests/site_downloaders/test_imgur.py b/tests/site_downloaders/test_imgur.py index ee98c42..792926a 100644 --- a/tests/site_downloaders/test_imgur.py +++ b/tests/site_downloaders/test_imgur.py @@ -122,6 +122,14 @@ def test_imgur_extension_validation_bad(test_extension: str): '029c475ce01b58fdf1269d8771d33913', ), ), + ( + 'https://imgur.com/a/eemHCCK', + ( + '9cb757fd8f055e7ef7aa88addc9d9fa5', + 'b6cb6c918e2544e96fb7c07d828774b5', + 'fb6c913d721c0bbb96aa65d7f560d385', + ), + ), )) def test_find_resources(test_url: str, expected_hashes: list[str]): mock_download = Mock() @@ -131,5 +139,4 @@ def test_find_resources(test_url: str, expected_hashes: list[str]): assert all([isinstance(res, Resource) for res in results]) [res.download(120) for res in results] hashes = set([res.hash.hexdigest() for res in results]) - assert len(results) == len(expected_hashes) assert hashes == set(expected_hashes) diff --git a/tests/test_archiver.py b/tests/test_archiver.py index 622c555..627caee 100644 --- a/tests/test_archiver.py +++ b/tests/test_archiver.py @@ -7,51 +7,20 @@ from unittest.mock import MagicMock import praw import pytest -from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry from bdfr.archiver import Archiver @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_submission_id', ( - 'm3reby', +@pytest.mark.parametrize(('test_submission_id', 'test_format'), ( + ('m3reby', 'xml'), + ('m3reby', 'json'), + ('m3reby', 'yaml'), )) -def test_write_submission_json(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit): +def test_write_submission_json(test_submission_id: str, tmp_path: Path, test_format: str, reddit_instance: praw.Reddit): archiver_mock = MagicMock() - test_path = Path(tmp_path, 'test.json') + archiver_mock.args.format = test_format + test_path = Path(tmp_path, 'test') test_submission = reddit_instance.submission(id=test_submission_id) archiver_mock.file_name_formatter.format_path.return_value = test_path - test_entry = SubmissionArchiveEntry(test_submission) - Archiver._write_entry_json(archiver_mock, test_entry) - archiver_mock._write_content_to_disk.assert_called_once() - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_submission_id', ( - 'm3reby', -)) -def test_write_submission_xml(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit): - archiver_mock = MagicMock() - test_path = Path(tmp_path, 'test.xml') - test_submission = reddit_instance.submission(id=test_submission_id) - archiver_mock.file_name_formatter.format_path.return_value = test_path - test_entry = SubmissionArchiveEntry(test_submission) - Archiver._write_entry_xml(archiver_mock, test_entry) - archiver_mock._write_content_to_disk.assert_called_once() - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_submission_id', ( - 'm3reby', -)) -def test_write_submission_yaml(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit): - archiver_mock = MagicMock() - archiver_mock.download_directory = tmp_path - test_path = Path(tmp_path, 'test.yaml') - test_submission = reddit_instance.submission(id=test_submission_id) - archiver_mock.file_name_formatter.format_path.return_value = test_path - test_entry = SubmissionArchiveEntry(test_submission) - Archiver._write_entry_yaml(archiver_mock, test_entry) - archiver_mock._write_content_to_disk.assert_called_once() + Archiver.write_entry(archiver_mock, test_submission) diff --git a/tests/test_connector.py b/tests/test_connector.py new file mode 100644 index 0000000..2249b96 --- /dev/null +++ b/tests/test_connector.py @@ -0,0 +1,402 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +from pathlib import Path +from typing import Iterator +from unittest.mock import MagicMock + +import praw +import praw.models +import pytest + +from bdfr.configuration import Configuration +from bdfr.connector import RedditConnector, RedditTypes +from bdfr.download_filter import DownloadFilter +from bdfr.exceptions import BulkDownloaderException +from bdfr.file_name_formatter import FileNameFormatter +from bdfr.site_authenticator import SiteAuthenticator + + +@pytest.fixture() +def args() -> Configuration: + args = Configuration() + args.time_format = 'ISO' + return args + + +@pytest.fixture() +def downloader_mock(args: Configuration): + downloader_mock = MagicMock() + downloader_mock.args = args + downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name + downloader_mock.split_args_input = RedditConnector.split_args_input + downloader_mock.master_hash_list = {} + return downloader_mock + + +def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list: + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + if result_limit is not None: + assert len(results) == result_limit + return results + + +def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): + downloader_mock.args.directory = tmp_path / 'test' + downloader_mock.config_directories.user_config_dir = tmp_path + RedditConnector.determine_directories(downloader_mock) + assert Path(tmp_path / 'test').exists() + + +@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( + ([], []), + (['.test'], ['test.com'],), +)) +def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): + downloader_mock.args.skip = skip_extensions + downloader_mock.args.skip_domain = skip_domains + result = RedditConnector.create_download_filter(downloader_mock) + + assert isinstance(result, DownloadFilter) + assert result.excluded_domains == skip_domains + assert result.excluded_extensions == skip_extensions + + +@pytest.mark.parametrize(('test_time', 'expected'), ( + ('all', 'all'), + ('hour', 'hour'), + ('day', 'day'), + ('week', 'week'), + ('random', 'all'), + ('', 'all'), +)) +def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): + downloader_mock.args.time = test_time + result = RedditConnector.create_time_filter(downloader_mock) + + assert isinstance(result, RedditTypes.TimeType) + assert result.name.lower() == expected + + +@pytest.mark.parametrize(('test_sort', 'expected'), ( + ('', 'hot'), + ('hot', 'hot'), + ('controversial', 'controversial'), + ('new', 'new'), +)) +def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): + downloader_mock.args.sort = test_sort + result = RedditConnector.create_sort_filter(downloader_mock) + + assert isinstance(result, RedditTypes.SortType) + assert result.name.lower() == expected + + +@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( + ('{POSTID}', '{SUBREDDIT}'), + ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), + ('{POSTID}', 'test'), + ('{POSTID}', ''), + ('{POSTID}', '{SUBREDDIT}/{REDDITOR}'), +)) +def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): + downloader_mock.args.file_scheme = test_file_scheme + downloader_mock.args.folder_scheme = test_folder_scheme + result = RedditConnector.create_file_name_formatter(downloader_mock) + + assert isinstance(result, FileNameFormatter) + assert result.file_format_string == test_file_scheme + assert result.directory_format_string == test_folder_scheme.split('/') + + +@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( + ('', ''), + ('', '{SUBREDDIT}'), + ('test', '{SUBREDDIT}'), +)) +def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): + downloader_mock.args.file_scheme = test_file_scheme + downloader_mock.args.folder_scheme = test_folder_scheme + with pytest.raises(BulkDownloaderException): + RedditConnector.create_file_name_formatter(downloader_mock) + + +def test_create_authenticator(downloader_mock: MagicMock): + result = RedditConnector.create_authenticator(downloader_mock) + assert isinstance(result, SiteAuthenticator) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_submission_ids', ( + ('lvpf4l',), + ('lvpf4l', 'lvqnsn'), + ('lvpf4l', 'lvqnsn', 'lvl9kd'), +)) +def test_get_submissions_from_link( + test_submission_ids: list[str], + reddit_instance: praw.Reddit, + downloader_mock: MagicMock): + downloader_mock.args.link = test_submission_ids + downloader_mock.reddit_instance = reddit_instance + results = RedditConnector.get_submissions_from_link(downloader_mock) + assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res]) + assert len(results[0]) == len(test_submission_ids) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), ( + (('Futurology',), 10, 'hot', 'all', 10), + (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30), + (('Futurology',), 20, 'hot', 'all', 20), + (('Futurology', 'Python'), 10, 'hot', 'all', 20), + (('Futurology',), 100, 'hot', 'all', 100), + (('Futurology',), 0, 'hot', 'all', 0), + (('Futurology',), 10, 'top', 'all', 10), + (('Futurology',), 10, 'top', 'week', 10), + (('Futurology',), 10, 'hot', 'week', 10), +)) +def test_get_subreddit_normal( + test_subreddits: list[str], + limit: int, + sort_type: str, + time_filter: str, + max_expected_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, +): + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.args.limit = limit + downloader_mock.args.sort = sort_type + downloader_mock.args.subreddit = test_subreddits + downloader_mock.reddit_instance = reddit_instance + downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock) + results = RedditConnector.get_subreddits(downloader_mock) + test_subreddits = downloader_mock._split_args_input(test_subreddits) + results = [sub for res1 in results for sub in res1] + assert all([isinstance(res1, praw.models.Submission) for res1 in results]) + assert all([res.subreddit.display_name in test_subreddits for res in results]) + assert len(results) <= max_expected_len + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), ( + (('Python',), 'scraper', 10, 'all', 10), + (('Python',), '', 10, 'all', 10), + (('Python',), 'djsdsgewef', 10, 'all', 0), + (('Python',), 'scraper', 10, 'year', 10), + (('Python',), 'scraper', 10, 'hour', 1), +)) +def test_get_subreddit_search( + test_subreddits: list[str], + search_term: str, + time_filter: str, + limit: int, + max_expected_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, +): + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.args.limit = limit + downloader_mock.args.search = search_term + downloader_mock.args.subreddit = test_subreddits + downloader_mock.reddit_instance = reddit_instance + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.time = time_filter + downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock) + results = RedditConnector.get_subreddits(downloader_mock) + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + assert all([res.subreddit.display_name in test_subreddits for res in results]) + assert len(results) <= max_expected_len + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), ( + ('helen_darten', ('cuteanimalpics',), 10), + ('korfor', ('chess',), 100), +)) +# Good sources at https://www.reddit.com/r/multihub/ +def test_get_multireddits_public( + test_user: str, + test_multireddits: list[str], + limit: int, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, +): + downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.limit = limit + downloader_mock.args.multireddit = test_multireddits + downloader_mock.args.user = [test_user] + downloader_mock.reddit_instance = reddit_instance + downloader_mock.create_filtered_listing_generator.return_value = \ + RedditConnector.create_filtered_listing_generator( + downloader_mock, + reddit_instance.multireddit(test_user, test_multireddits[0]), + ) + results = RedditConnector.get_multireddits(downloader_mock) + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + assert len(results) == limit + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_user', 'limit'), ( + ('danigirl3694', 10), + ('danigirl3694', 50), + ('CapitanHam', None), +)) +def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): + downloader_mock.args.limit = limit + downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.submitted = True + downloader_mock.args.user = [test_user] + downloader_mock.authenticated = False + downloader_mock.reddit_instance = reddit_instance + downloader_mock.create_filtered_listing_generator.return_value = \ + RedditConnector.create_filtered_listing_generator( + downloader_mock, + reddit_instance.redditor(test_user).submissions, + ) + results = RedditConnector.get_user_data(downloader_mock) + results = assert_all_results_are_submissions(limit, results) + assert all([res.author.name == test_user for res in results]) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.authenticated +@pytest.mark.parametrize('test_flag', ( + 'upvoted', + 'saved', +)) +def test_get_user_authenticated_lists( + test_flag: str, + downloader_mock: MagicMock, + authenticated_reddit_instance: praw.Reddit, +): + downloader_mock.args.__dict__[test_flag] = True + downloader_mock.reddit_instance = authenticated_reddit_instance + downloader_mock.args.limit = 10 + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')] + results = RedditConnector.get_user_data(downloader_mock) + assert_all_results_are_submissions(10, results) + + +@pytest.mark.parametrize(('test_name', 'expected'), ( + ('Mindustry', 'Mindustry'), + ('Futurology', 'Futurology'), + ('r/Mindustry', 'Mindustry'), + ('TrollXChromosomes', 'TrollXChromosomes'), + ('r/TrollXChromosomes', 'TrollXChromosomes'), + ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'), + ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'), + ('https://www.reddit.com/r/Futurology/', 'Futurology'), + ('https://www.reddit.com/r/Futurology', 'Futurology'), +)) +def test_sanitise_subreddit_name(test_name: str, expected: str): + result = RedditConnector.sanitise_subreddit_name(test_name) + assert result == expected + + +@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), ( + (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}), + ([''], {''}), + (['test'], {'test'}), +)) +def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]): + results = RedditConnector.split_args_input(test_subreddit_entries) + assert results == expected + + +def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): + test_file = tmp_path / 'test.txt' + test_file.write_text('aaaaaa\nbbbbbb') + downloader_mock.args.exclude_id_file = [test_file] + results = RedditConnector.read_excluded_ids(downloader_mock) + assert results == {'aaaaaa', 'bbbbbb'} + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_redditor_name', ( + 'Paracortex', + 'crowdstrike', + 'HannibalGoddamnit', +)) +def test_check_user_existence_good( + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, +): + downloader_mock.reddit_instance = reddit_instance + RedditConnector.check_user_existence(downloader_mock, test_redditor_name) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_redditor_name', ( + 'lhnhfkuhwreolo', + 'adlkfmnhglojh', +)) +def test_check_user_existence_nonexistent( + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, +): + downloader_mock.reddit_instance = reddit_instance + with pytest.raises(BulkDownloaderException, match='Could not find'): + RedditConnector.check_user_existence(downloader_mock, test_redditor_name) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_redditor_name', ( + 'Bree-Boo', +)) +def test_check_user_existence_banned( + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, +): + downloader_mock.reddit_instance = reddit_instance + with pytest.raises(BulkDownloaderException, match='is banned'): + RedditConnector.check_user_existence(downloader_mock, test_redditor_name) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), ( + ('donaldtrump', 'cannot be found'), + ('submitters', 'private and cannot be scraped') +)) +def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit): + test_subreddit = reddit_instance.subreddit(test_subreddit_name) + with pytest.raises(BulkDownloaderException, match=expected_message): + RedditConnector.check_subreddit_status(test_subreddit) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_subreddit_name', ( + 'Python', + 'Mindustry', + 'TrollXChromosomes', + 'all', +)) +def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit): + test_subreddit = reddit_instance.subreddit(test_subreddit_name) + RedditConnector.check_subreddit_status(test_subreddit) diff --git a/tests/test_downloader.py b/tests/test_downloader.py index f1a20fc..d67aee6 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -1,22 +1,19 @@ #!/usr/bin/env python3 # coding=utf-8 +import os import re from pathlib import Path -from typing import Iterator -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch -import praw import praw.models import pytest +import bdfr.site_downloaders.download_factory from bdfr.__main__ import setup_logging from bdfr.configuration import Configuration -from bdfr.download_filter import DownloadFilter -from bdfr.downloader import RedditDownloader, RedditTypes -from bdfr.exceptions import BulkDownloaderException -from bdfr.file_name_formatter import FileNameFormatter -from bdfr.site_authenticator import SiteAuthenticator +from bdfr.connector import RedditConnector +from bdfr.downloader import RedditDownloader @pytest.fixture() @@ -30,314 +27,105 @@ def args() -> Configuration: def downloader_mock(args: Configuration): downloader_mock = MagicMock() downloader_mock.args = args - downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name - downloader_mock._split_args_input = RedditDownloader._split_args_input + downloader_mock._sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name + downloader_mock._split_args_input = RedditConnector.split_args_input downloader_mock.master_hash_list = {} return downloader_mock -def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]): - results = [sub for res in results for sub in res] - assert all([isinstance(res, praw.models.Submission) for res in results]) - if result_limit is not None: - assert len(results) == result_limit - return results - - -def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): - downloader_mock.args.directory = tmp_path / 'test' - downloader_mock.config_directories.user_config_dir = tmp_path - RedditDownloader._determine_directories(downloader_mock) - assert Path(tmp_path / 'test').exists() - - -@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( - ([], []), - (['.test'], ['test.com'],), +@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), ( + (('aaaaaa',), (), 1), + (('aaaaaa',), ('aaaaaa',), 0), + ((), ('aaaaaa',), 0), + (('aaaaaa', 'bbbbbb'), ('aaaaaa',), 1), + (('aaaaaa', 'bbbbbb', 'cccccc'), ('aaaaaa',), 2), )) -def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): - downloader_mock.args.skip = skip_extensions - downloader_mock.args.skip_domain = skip_domains - result = RedditDownloader._create_download_filter(downloader_mock) - - assert isinstance(result, DownloadFilter) - assert result.excluded_domains == skip_domains - assert result.excluded_extensions == skip_extensions - - -@pytest.mark.parametrize(('test_time', 'expected'), ( - ('all', 'all'), - ('hour', 'hour'), - ('day', 'day'), - ('week', 'week'), - ('random', 'all'), - ('', 'all'), -)) -def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): - downloader_mock.args.time = test_time - result = RedditDownloader._create_time_filter(downloader_mock) - - assert isinstance(result, RedditTypes.TimeType) - assert result.name.lower() == expected - - -@pytest.mark.parametrize(('test_sort', 'expected'), ( - ('', 'hot'), - ('hot', 'hot'), - ('controversial', 'controversial'), - ('new', 'new'), -)) -def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): - downloader_mock.args.sort = test_sort - result = RedditDownloader._create_sort_filter(downloader_mock) - - assert isinstance(result, RedditTypes.SortType) - assert result.name.lower() == expected - - -@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( - ('{POSTID}', '{SUBREDDIT}'), - ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), - ('{POSTID}', 'test'), - ('{POSTID}', ''), - ('{POSTID}', '{SUBREDDIT}/{REDDITOR}'), -)) -def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): - downloader_mock.args.file_scheme = test_file_scheme - downloader_mock.args.folder_scheme = test_folder_scheme - result = RedditDownloader._create_file_name_formatter(downloader_mock) - - assert isinstance(result, FileNameFormatter) - assert result.file_format_string == test_file_scheme - assert result.directory_format_string == test_folder_scheme.split('/') - - -@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( - ('', ''), - ('', '{SUBREDDIT}'), - ('test', '{SUBREDDIT}'), -)) -def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): - downloader_mock.args.file_scheme = test_file_scheme - downloader_mock.args.folder_scheme = test_folder_scheme - with pytest.raises(BulkDownloaderException): - RedditDownloader._create_file_name_formatter(downloader_mock) - - -def test_create_authenticator(downloader_mock: MagicMock): - result = RedditDownloader._create_authenticator(downloader_mock) - assert isinstance(result, SiteAuthenticator) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_submission_ids', ( - ('lvpf4l',), - ('lvpf4l', 'lvqnsn'), - ('lvpf4l', 'lvqnsn', 'lvl9kd'), -)) -def test_get_submissions_from_link( - test_submission_ids: list[str], - reddit_instance: praw.Reddit, - downloader_mock: MagicMock): - downloader_mock.args.link = test_submission_ids - downloader_mock.reddit_instance = reddit_instance - results = RedditDownloader._get_submissions_from_link(downloader_mock) - assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res]) - assert len(results[0]) == len(test_submission_ids) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), ( - (('Futurology',), 10, 'hot', 'all', 10), - (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30), - (('Futurology',), 20, 'hot', 'all', 20), - (('Futurology', 'Python'), 10, 'hot', 'all', 20), - (('Futurology',), 100, 'hot', 'all', 100), - (('Futurology',), 0, 'hot', 'all', 0), - (('Futurology',), 10, 'top', 'all', 10), - (('Futurology',), 10, 'top', 'week', 10), - (('Futurology',), 10, 'hot', 'week', 10), -)) -def test_get_subreddit_normal( - test_subreddits: list[str], - limit: int, - sort_type: str, - time_filter: str, - max_expected_len: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, -): - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.args.limit = limit - downloader_mock.args.sort = sort_type - downloader_mock.args.subreddit = test_subreddits - downloader_mock.reddit_instance = reddit_instance - downloader_mock.sort_filter = RedditDownloader._create_sort_filter(downloader_mock) - results = RedditDownloader._get_subreddits(downloader_mock) - test_subreddits = downloader_mock._split_args_input(test_subreddits) - results = [sub for res1 in results for sub in res1] - assert all([isinstance(res1, praw.models.Submission) for res1 in results]) - assert all([res.subreddit.display_name in test_subreddits for res in results]) - assert len(results) <= max_expected_len - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), ( - (('Python',), 'scraper', 10, 'all', 10), - (('Python',), '', 10, 'all', 10), - (('Python',), 'djsdsgewef', 10, 'all', 0), - (('Python',), 'scraper', 10, 'year', 10), - (('Python',), 'scraper', 10, 'hour', 1), -)) -def test_get_subreddit_search( - test_subreddits: list[str], - search_term: str, - time_filter: str, - limit: int, - max_expected_len: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, -): - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.args.limit = limit - downloader_mock.args.search = search_term - downloader_mock.args.subreddit = test_subreddits - downloader_mock.reddit_instance = reddit_instance - downloader_mock.sort_filter = RedditTypes.SortType.HOT - downloader_mock.args.time = time_filter - downloader_mock.time_filter = RedditDownloader._create_time_filter(downloader_mock) - results = RedditDownloader._get_subreddits(downloader_mock) - results = [sub for res in results for sub in res] - assert all([isinstance(res, praw.models.Submission) for res in results]) - assert all([res.subreddit.display_name in test_subreddits for res in results]) - assert len(results) <= max_expected_len - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), ( - ('helen_darten', ('cuteanimalpics',), 10), - ('korfor', ('chess',), 100), -)) -# Good sources at https://www.reddit.com/r/multihub/ -def test_get_multireddits_public( - test_user: str, - test_multireddits: list[str], - limit: int, - reddit_instance: praw.Reddit, +@patch('bdfr.site_downloaders.download_factory.DownloadFactory.pull_lever') +def test_excluded_ids( + mock_function: MagicMock, + test_ids: tuple[str], + test_excluded: tuple[str], + expected_len: int, downloader_mock: MagicMock, ): - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.sort_filter = RedditTypes.SortType.HOT - downloader_mock.args.limit = limit - downloader_mock.args.multireddit = test_multireddits - downloader_mock.args.user = test_user - downloader_mock.reddit_instance = reddit_instance - downloader_mock._create_filtered_listing_generator.return_value = \ - RedditDownloader._create_filtered_listing_generator( - downloader_mock, - reddit_instance.multireddit(test_user, test_multireddits[0]), - ) - results = RedditDownloader._get_multireddits(downloader_mock) - results = [sub for res in results for sub in res] - assert all([isinstance(res, praw.models.Submission) for res in results]) - assert len(results) == limit + downloader_mock.excluded_submission_ids = test_excluded + mock_function.return_value = MagicMock() + mock_function.return_value.__name__ = 'test' + test_submissions = [] + for test_id in test_ids: + m = MagicMock() + m.id = test_id + m.subreddit.display_name.return_value = 'https://www.example.com/' + m.__class__ = praw.models.Submission + test_submissions.append(m) + downloader_mock.reddit_lists = [test_submissions] + for submission in test_submissions: + RedditDownloader._download_submission(downloader_mock, submission) + assert mock_function.call_count == expected_len @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_user', 'limit'), ( - ('danigirl3694', 10), - ('danigirl3694', 50), - ('CapitanHam', None), +@pytest.mark.parametrize('test_submission_id', ( + 'm1hqw6', )) -def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.args.limit = limit - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.sort_filter = RedditTypes.SortType.HOT - downloader_mock.args.submitted = True - downloader_mock.args.user = test_user - downloader_mock.authenticated = False - downloader_mock.reddit_instance = reddit_instance - downloader_mock._create_filtered_listing_generator.return_value = \ - RedditDownloader._create_filtered_listing_generator( - downloader_mock, - reddit_instance.redditor(test_user).submissions, - ) - results = RedditDownloader._get_user_data(downloader_mock) - results = assert_all_results_are_submissions(limit, results) - assert all([res.author.name == test_user for res in results]) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.authenticated -@pytest.mark.parametrize('test_flag', ( - 'upvoted', - 'saved', -)) -def test_get_user_authenticated_lists( - test_flag: str, - downloader_mock: MagicMock, - authenticated_reddit_instance: praw.Reddit, -): - downloader_mock.args.__dict__[test_flag] = True - downloader_mock.reddit_instance = authenticated_reddit_instance - downloader_mock.args.user = 'me' - downloader_mock.args.limit = 10 - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.sort_filter = RedditTypes.SortType.HOT - RedditDownloader._resolve_user_name(downloader_mock) - results = RedditDownloader._get_user_data(downloader_mock) - assert_all_results_are_submissions(10, results) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), ( - ('ljyy27', 4), -)) -def test_download_submission( +def test_mark_hard_link( test_submission_id: str, - expected_files_len: int, downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path): + tmp_path: Path, + reddit_instance: praw.Reddit +): downloader_mock.reddit_instance = reddit_instance - downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) + downloader_mock.args.make_hard_links = True downloader_mock.download_directory = tmp_path + downloader_mock.args.folder_scheme = '' + downloader_mock.args.file_scheme = '{POSTID}' + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) submission = downloader_mock.reddit_instance.submission(id=test_submission_id) + original = Path(tmp_path, f'{test_submission_id}.png') + RedditDownloader._download_submission(downloader_mock, submission) - folder_contents = list(tmp_path.iterdir()) - assert len(folder_contents) == expected_files_len + assert original.exists() + + downloader_mock.args.file_scheme = 'test2_{POSTID}' + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + RedditDownloader._download_submission(downloader_mock, submission) + test_file_1_stats = original.stat() + test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino + + assert test_file_1_stats.st_nlink == 2 + assert test_file_1_stats.st_ino == test_file_2_inode @pytest.mark.online @pytest.mark.reddit -def test_download_submission_file_exists( +@pytest.mark.parametrize(('test_submission_id', 'test_creation_date'), ( + ('ndzz50', 1621204841.0), +)) +def test_file_creation_date( + test_submission_id: str, + test_creation_date: float, downloader_mock: MagicMock, - reddit_instance: praw.Reddit, tmp_path: Path, - capsys: pytest.CaptureFixture + reddit_instance: praw.Reddit ): - setup_logging(3) downloader_mock.reddit_instance = reddit_instance - downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path - submission = downloader_mock.reddit_instance.submission(id='m1hqw6') - Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch() + downloader_mock.args.folder_scheme = '' + downloader_mock.args.file_scheme = '{POSTID}' + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + submission = downloader_mock.reddit_instance.submission(id=test_submission_id) + RedditDownloader._download_submission(downloader_mock, submission) - folder_contents = list(tmp_path.iterdir()) - output = capsys.readouterr() - assert len(folder_contents) == 1 - assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png already exists' in output.out + + for file_path in Path(tmp_path).iterdir(): + file_stats = os.stat(file_path) + assert file_stats.st_mtime == test_creation_date + + +def test_search_existing_files(): + results = RedditDownloader.scan_existing_files(Path('.')) + assert len(results.keys()) != 0 @pytest.mark.online @@ -358,7 +146,7 @@ def test_download_submission_hash_exists( downloader_mock.download_filter.check_url.return_value = True downloader_mock.args.folder_scheme = '' downloader_mock.args.no_dupes = True - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path downloader_mock.master_hash_list = {test_hash: None} submission = downloader_mock.reddit_instance.submission(id=test_submission_id) @@ -369,165 +157,47 @@ def test_download_submission_hash_exists( assert re.search(r'Resource hash .*? downloaded elsewhere', output.out) -@pytest.mark.parametrize(('test_name', 'expected'), ( - ('Mindustry', 'Mindustry'), - ('Futurology', 'Futurology'), - ('r/Mindustry', 'Mindustry'), - ('TrollXChromosomes', 'TrollXChromosomes'), - ('r/TrollXChromosomes', 'TrollXChromosomes'), - ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'), - ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'), - ('https://www.reddit.com/r/Futurology/', 'Futurology'), - ('https://www.reddit.com/r/Futurology', 'Futurology'), -)) -def test_sanitise_subreddit_name(test_name: str, expected: str): - result = RedditDownloader._sanitise_subreddit_name(test_name) - assert result == expected - - -def test_search_existing_files(): - results = RedditDownloader.scan_existing_files(Path('.')) - assert len(results.keys()) >= 40 - - -@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), ( - (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}) -)) -def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]): - results = RedditDownloader._split_args_input(test_subreddit_entries) - assert results == expected - - @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_submission_id', ( - 'm1hqw6', -)) -def test_mark_hard_link( - test_submission_id: str, +def test_download_submission_file_exists( downloader_mock: MagicMock, + reddit_instance: praw.Reddit, tmp_path: Path, - reddit_instance: praw.Reddit + capsys: pytest.CaptureFixture ): + setup_logging(3) downloader_mock.reddit_instance = reddit_instance - downloader_mock.args.make_hard_links = True - downloader_mock.download_directory = tmp_path + downloader_mock.download_filter.check_url.return_value = True downloader_mock.args.folder_scheme = '' - downloader_mock.args.file_scheme = '{POSTID}' - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + downloader_mock.download_directory = tmp_path + submission = downloader_mock.reddit_instance.submission(id='m1hqw6') + Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch() + RedditDownloader._download_submission(downloader_mock, submission) + folder_contents = list(tmp_path.iterdir()) + output = capsys.readouterr() + assert len(folder_contents) == 1 + assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png'\ + ' from submission m1hqw6 already exists' in output.out + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), ( + ('ljyy27', 4), +)) +def test_download_submission( + test_submission_id: str, + expected_files_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path): + downloader_mock.reddit_instance = reddit_instance + downloader_mock.download_filter.check_url.return_value = True + downloader_mock.args.folder_scheme = '' + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + downloader_mock.download_directory = tmp_path submission = downloader_mock.reddit_instance.submission(id=test_submission_id) - original = Path(tmp_path, f'{test_submission_id}.png') - RedditDownloader._download_submission(downloader_mock, submission) - assert original.exists() - - downloader_mock.args.file_scheme = 'test2_{POSTID}' - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) - RedditDownloader._download_submission(downloader_mock, submission) - test_file_1_stats = original.stat() - test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino - - assert test_file_1_stats.st_nlink == 2 - assert test_file_1_stats.st_ino == test_file_2_inode - - -@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), ( - (('aaaaaa',), (), 1), - (('aaaaaa',), ('aaaaaa',), 0), - ((), ('aaaaaa',), 0), - (('aaaaaa', 'bbbbbb'), ('aaaaaa',), 1), -)) -def test_excluded_ids(test_ids: tuple[str], test_excluded: tuple[str], expected_len: int, downloader_mock: MagicMock): - downloader_mock.excluded_submission_ids = test_excluded - test_submissions = [] - for test_id in test_ids: - m = MagicMock() - m.id = test_id - test_submissions.append(m) - downloader_mock.reddit_lists = [test_submissions] - RedditDownloader.download(downloader_mock) - assert downloader_mock._download_submission.call_count == expected_len - - -def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): - test_file = tmp_path / 'test.txt' - test_file.write_text('aaaaaa\nbbbbbb') - downloader_mock.args.exclude_id_file = [test_file] - results = RedditDownloader._read_excluded_ids(downloader_mock) - assert results == {'aaaaaa', 'bbbbbb'} - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'Paracortex', - 'crowdstrike', - 'HannibalGoddamnit', -)) -def test_check_user_existence_good( - test_redditor_name: str, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, -): - downloader_mock.reddit_instance = reddit_instance - RedditDownloader._check_user_existence(downloader_mock, test_redditor_name) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'lhnhfkuhwreolo', - 'adlkfmnhglojh', -)) -def test_check_user_existence_nonexistent( - test_redditor_name: str, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, -): - downloader_mock.reddit_instance = reddit_instance - with pytest.raises(BulkDownloaderException, match='Could not find'): - RedditDownloader._check_user_existence(downloader_mock, test_redditor_name) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'Bree-Boo', -)) -def test_check_user_existence_banned( - test_redditor_name: str, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, -): - downloader_mock.reddit_instance = reddit_instance - with pytest.raises(BulkDownloaderException, match='is banned'): - RedditDownloader._check_user_existence(downloader_mock, test_redditor_name) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), ( - ('donaldtrump', 'cannot be found'), - ('submitters', 'private and cannot be scraped') -)) -def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit): - test_subreddit = reddit_instance.subreddit(test_subreddit_name) - with pytest.raises(BulkDownloaderException, match=expected_message): - RedditDownloader._check_subreddit_status(test_subreddit) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_subreddit_name', ( - 'Python', - 'Mindustry', - 'TrollXChromosomes', - 'all', -)) -def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit): - test_subreddit = reddit_instance.subreddit(test_subreddit_name) - RedditDownloader._check_subreddit_status(test_subreddit) + folder_contents = list(tmp_path.iterdir()) + assert len(folder_contents) == expected_files_len diff --git a/tests/test_file_name_formatter.py b/tests/test_file_name_formatter.py index b1faf86..e4c82ac 100644 --- a/tests/test_file_name_formatter.py +++ b/tests/test_file_name_formatter.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 # coding=utf-8 +import platform +import unittest.mock from datetime import datetime from pathlib import Path from typing import Optional from unittest.mock import MagicMock -import platform import praw.models import pytest @@ -28,10 +29,10 @@ def submission() -> MagicMock: return test -def do_test_string_equality(result: str, expected: str) -> bool: +def do_test_string_equality(result: [Path, str], expected: str) -> bool: if platform.system() == 'Windows': expected = FileNameFormatter._format_for_windows(expected) - return expected == result + return str(result).endswith(expected) def do_test_path_equality(result: Path, expected: str) -> bool: @@ -41,7 +42,7 @@ def do_test_path_equality(result: Path, expected: str) -> bool: expected = Path(*expected) else: expected = Path(expected) - return result == expected + return str(result).endswith(str(expected)) @pytest.fixture(scope='session') @@ -172,8 +173,9 @@ def test_format_multiple_resources(): mocks.append(new_mock) test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') results = test_formatter.format_resource_paths(mocks, Path('.')) - results = set([str(res[0]) for res in results]) - assert results == {'test_1.png', 'test_2.png', 'test_3.png', 'test_4.png'} + results = set([str(res[0].name) for res in results]) + expected = {'test_1.png', 'test_2.png', 'test_3.png', 'test_4.png'} + assert results == expected @pytest.mark.parametrize(('test_filename', 'test_ending'), ( @@ -183,10 +185,11 @@ def test_format_multiple_resources(): ('😍💕✨' * 100, '_1.png'), )) def test_limit_filename_length(test_filename: str, test_ending: str): - result = FileNameFormatter._limit_file_name_length(test_filename, test_ending) - assert len(result) <= 255 - assert len(result.encode('utf-8')) <= 255 - assert isinstance(result, str) + result = FileNameFormatter._limit_file_name_length(test_filename, test_ending, Path('.')) + assert len(result.name) <= 255 + assert len(result.name.encode('utf-8')) <= 255 + assert len(str(result)) <= FileNameFormatter.find_max_path_length() + assert isinstance(result, Path) @pytest.mark.parametrize(('test_filename', 'test_ending', 'expected_end'), ( @@ -201,11 +204,11 @@ def test_limit_filename_length(test_filename: str, test_ending: str): ('😍💕✨' * 100 + '_aaa1aa', '_1.png', '_aaa1aa_1.png'), )) def test_preserve_id_append_when_shortening(test_filename: str, test_ending: str, expected_end: str): - result = FileNameFormatter._limit_file_name_length(test_filename, test_ending) - assert len(result) <= 255 - assert len(result.encode('utf-8')) <= 255 - assert isinstance(result, str) - assert result.endswith(expected_end) + result = FileNameFormatter._limit_file_name_length(test_filename, test_ending, Path('.')) + assert len(result.name) <= 255 + assert len(result.name.encode('utf-8')) <= 255 + assert result.name.endswith(expected_end) + assert len(str(result)) <= FileNameFormatter.find_max_path_length() def test_shorten_filenames(submission: MagicMock, tmp_path: Path): @@ -295,7 +298,7 @@ def test_format_archive_entry_comment( test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme, 'ISO') test_entry = Resource(test_comment, '', '.json') result = test_formatter.format_path(test_entry, tmp_path) - assert do_test_string_equality(result.name, expected_name) + assert do_test_string_equality(result, expected_name) @pytest.mark.parametrize(('test_folder_scheme', 'expected'), ( @@ -364,3 +367,16 @@ def test_time_string_formats(test_time_format: str, expected: str): test_formatter = FileNameFormatter('{TITLE}', '', test_time_format) result = test_formatter._convert_timestamp(test_time.timestamp()) assert result == expected + + +def test_get_max_path_length(): + result = FileNameFormatter.find_max_path_length() + assert result in (4096, 260, 1024) + + +def test_windows_max_path(tmp_path: Path): + with unittest.mock.patch('platform.system', return_value='Windows'): + with unittest.mock.patch('bdfr.file_name_formatter.FileNameFormatter.find_max_path_length', return_value=260): + result = FileNameFormatter._limit_file_name_length('test' * 100, '_1.png', tmp_path) + assert len(str(result)) <= 260 + assert len(result.name) <= (260 - len(str(tmp_path))) diff --git a/tests/test_integration.py b/tests/test_integration.py index 7aec0eb..6a9e52b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -33,6 +33,17 @@ def create_basic_args_for_archive_runner(test_args: list[str], tmp_path: Path): return out +def create_basic_args_for_cloner_runner(test_args: list[str], tmp_path: Path): + out = [ + 'clone', + str(tmp_path), + '-v', + '--config', 'test_config.cfg', + '--log', str(Path(tmp_path, 'test_log.txt')), + ] + test_args + return out + + @pytest.mark.online @pytest.mark.reddit @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @@ -117,6 +128,7 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa @pytest.mark.authenticated @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.parametrize('test_args', ( + ['--user', 'djnish', '--submitted', '--user', 'FriesWithThat', '-L', 10], ['--user', 'me', '--upvoted', '--authenticate', '-L', 10], ['--user', 'me', '--saved', '--authenticate', '-L', 10], ['--user', 'me', '--submitted', '--authenticate', '-L', 10], @@ -231,6 +243,7 @@ def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path): @pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.parametrize('test_args', ( ['--user', 'me', '--authenticate', '--all-comments', '-L', '10'], + ['--user', 'me', '--user', 'djnish', '--authenticate', '--all-comments', '-L', '10'], )) def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path): runner = CliRunner() @@ -265,12 +278,14 @@ def test_cli_archive_long(test_args: list[str], tmp_path: Path): ['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10], ['--subreddit', 'submitters', '-L', 10], # Private subreddit ['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit + ['--user', 'djnish', '--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10], )) def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 + assert 'Downloaded' not in result.output @pytest.mark.online @@ -333,9 +348,42 @@ def test_cli_download_subreddit_exclusion(test_args: list[str], tmp_path: Path): ['--file-scheme', '{TITLE}'], ['--file-scheme', '{TITLE}_test_{SUBREDDIT}'], )) -def test_cli_file_scheme_warning(test_args: list[str], tmp_path: Path): +def test_cli_download_file_scheme_warning(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 assert 'Some files might not be downloaded due to name conflicts' in result.output + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') +@pytest.mark.parametrize('test_args', ( + ['-l', 'm2601g', '--disable-module', 'Direct'], + ['-l', 'nnb9vs', '--disable-module', 'YoutubeDlFallback'], + ['-l', 'nnb9vs', '--disable-module', 'youtubedlfallback'], +)) +def test_cli_download_disable_modules(test_args: list[str], tmp_path: Path): + runner = CliRunner() + test_args = create_basic_args_for_download_runner(test_args, tmp_path) + result = runner.invoke(cli, test_args) + assert result.exit_code == 0 + assert 'skipped due to disabled module' in result.output + assert 'Downloaded submission' not in result.output + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') +@pytest.mark.parametrize('test_args', ( + ['-l', 'm2601g'], + ['-s', 'TrollXChromosomes/', '-L', 1], +)) +def test_cli_scrape_general(test_args: list[str], tmp_path: Path): + runner = CliRunner() + test_args = create_basic_args_for_cloner_runner(test_args, tmp_path) + result = runner.invoke(cli, test_args) + assert result.exit_code == 0 + assert 'Downloaded submission' in result.output + assert 'Record for entry item' in result.output