diff --git a/bdfr/archiver.py b/bdfr/archiver.py index 1945dfe..3e0b907 100644 --- a/bdfr/archiver.py +++ b/bdfr/archiver.py @@ -15,13 +15,14 @@ from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry from bdfr.configuration import Configuration from bdfr.downloader import RedditDownloader +from bdfr.connector import RedditConnector from bdfr.exceptions import ArchiverError from bdfr.resource import Resource logger = logging.getLogger(__name__) -class Archiver(RedditDownloader): +class Archiver(RedditConnector): def __init__(self, args: Configuration): super(Archiver, self).__init__(args) @@ -29,9 +30,9 @@ class Archiver(RedditDownloader): for generator in self.reddit_lists: for submission in generator: logger.debug(f'Attempting to archive submission {submission.id}') - self._write_entry(submission) + self.write_entry(submission) - def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]: + def get_submissions_from_link(self) -> list[list[praw.models.Submission]]: supplied_submissions = [] for sub_id in self.args.link: if len(sub_id) == 6: @@ -42,10 +43,10 @@ class Archiver(RedditDownloader): supplied_submissions.append(self.reddit_instance.submission(url=sub_id)) return [supplied_submissions] - def _get_user_data(self) -> list[Iterator]: - results = super(Archiver, self)._get_user_data() + def get_user_data(self) -> list[Iterator]: + results = super(Archiver, self).get_user_data() if self.args.user and self.args.all_comments: - sort = self._determine_sort_function() + sort = self.determine_sort_function() logger.debug(f'Retrieving comments of user {self.args.user}') results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit)) return results @@ -59,7 +60,7 @@ class Archiver(RedditDownloader): else: raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}') - def _write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)): + def write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)): archive_entry = self._pull_lever_entry_factory(praw_item) if self.args.format == 'json': self._write_entry_json(archive_entry) diff --git a/bdfr/connector.py b/bdfr/connector.py new file mode 100644 index 0000000..3dcc118 --- /dev/null +++ b/bdfr/connector.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +import configparser +import importlib.resources +import logging +import logging.handlers +import re +import shutil +import socket +from abc import ABCMeta, abstractmethod +from datetime import datetime +from enum import Enum, auto +from pathlib import Path +from typing import Callable, Iterator + +import appdirs +import praw +import praw.exceptions +import praw.models +import prawcore + +from bdfr import exceptions as errors +from bdfr.configuration import Configuration +from bdfr.download_filter import DownloadFilter +from bdfr.file_name_formatter import FileNameFormatter +from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager +from bdfr.site_authenticator import SiteAuthenticator + +logger = logging.getLogger(__name__) + + +class RedditTypes: + class SortType(Enum): + CONTROVERSIAL = auto() + HOT = auto() + NEW = auto() + RELEVENCE = auto() + RISING = auto() + TOP = auto() + + class TimeType(Enum): + ALL = 'all' + DAY = 'day' + HOUR = 'hour' + MONTH = 'month' + WEEK = 'week' + YEAR = 'year' + + +class RedditConnector(metaclass=ABCMeta): + def __init__(self, args: Configuration): + self.args = args + self.config_directories = appdirs.AppDirs('bdfr', 'BDFR') + self.run_time = datetime.now().isoformat() + self._setup_internal_objects() + + self.reddit_lists = self.retrieve_reddit_lists() + + def _setup_internal_objects(self): + self.determine_directories() + self.load_config() + self.create_file_logger() + + self.read_config() + + self.download_filter = self.create_download_filter() + logger.log(9, 'Created download filter') + self.time_filter = self.create_time_filter() + logger.log(9, 'Created time filter') + self.sort_filter = self.create_sort_filter() + logger.log(9, 'Created sort filter') + self.file_name_formatter = self.create_file_name_formatter() + logger.log(9, 'Create file name formatter') + + self.create_reddit_instance() + self.resolve_user_name() + + self.excluded_submission_ids = self.read_excluded_ids() + + self.master_hash_list = {} + self.authenticator = self.create_authenticator() + logger.log(9, 'Created site authenticator') + + self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit) + self.args.skip_subreddit = set([sub.lower() for sub in self.args.skip_subreddit]) + + def read_config(self): + """Read any cfg values that need to be processed""" + if self.args.max_wait_time is None: + if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'): + self.cfg_parser.set('DEFAULT', 'max_wait_time', '120') + logger.log(9, 'Wrote default download wait time download to config file') + self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time') + logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds') + if self.args.time_format is None: + option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO') + if re.match(r'^[ \'\"]*$', option): + option = 'ISO' + logger.debug(f'Setting datetime format string to {option}') + self.args.time_format = option + # Update config on disk + with open(self.config_location, 'w') as file: + self.cfg_parser.write(file) + + def create_reddit_instance(self): + if self.args.authenticate: + logger.debug('Using authenticated Reddit instance') + if not self.cfg_parser.has_option('DEFAULT', 'user_token'): + logger.log(9, 'Commencing OAuth2 authentication') + scopes = self.cfg_parser.get('DEFAULT', 'scopes') + scopes = OAuth2Authenticator.split_scopes(scopes) + oauth2_authenticator = OAuth2Authenticator( + scopes, + self.cfg_parser.get('DEFAULT', 'client_id'), + self.cfg_parser.get('DEFAULT', 'client_secret'), + ) + token = oauth2_authenticator.retrieve_new_token() + self.cfg_parser['DEFAULT']['user_token'] = token + with open(self.config_location, 'w') as file: + self.cfg_parser.write(file, True) + token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location) + + self.authenticated = True + self.reddit_instance = praw.Reddit( + client_id=self.cfg_parser.get('DEFAULT', 'client_id'), + client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), + user_agent=socket.gethostname(), + token_manager=token_manager, + ) + else: + logger.debug('Using unauthenticated Reddit instance') + self.authenticated = False + self.reddit_instance = praw.Reddit( + client_id=self.cfg_parser.get('DEFAULT', 'client_id'), + client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), + user_agent=socket.gethostname(), + ) + + def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]: + master_list = [] + master_list.extend(self.get_subreddits()) + logger.log(9, 'Retrieved subreddits') + master_list.extend(self.get_multireddits()) + logger.log(9, 'Retrieved multireddits') + master_list.extend(self.get_user_data()) + logger.log(9, 'Retrieved user data') + master_list.extend(self.get_submissions_from_link()) + logger.log(9, 'Retrieved submissions for given links') + return master_list + + def determine_directories(self): + self.download_directory = Path(self.args.directory).resolve().expanduser() + self.config_directory = Path(self.config_directories.user_config_dir) + + self.download_directory.mkdir(exist_ok=True, parents=True) + self.config_directory.mkdir(exist_ok=True, parents=True) + + def load_config(self): + self.cfg_parser = configparser.ConfigParser() + if self.args.config: + if (cfg_path := Path(self.args.config)).exists(): + self.cfg_parser.read(cfg_path) + self.config_location = cfg_path + return + possible_paths = [ + Path('./config.cfg'), + Path('./default_config.cfg'), + Path(self.config_directory, 'config.cfg'), + Path(self.config_directory, 'default_config.cfg'), + ] + self.config_location = None + for path in possible_paths: + if path.resolve().expanduser().exists(): + self.config_location = path + logger.debug(f'Loading configuration from {path}') + break + if not self.config_location: + self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0] + shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg')) + if not self.config_location: + raise errors.BulkDownloaderException('Could not find a configuration file to load') + self.cfg_parser.read(self.config_location) + + def create_file_logger(self): + main_logger = logging.getLogger() + if self.args.log is None: + log_path = Path(self.config_directory, 'log_output.txt') + else: + log_path = Path(self.args.log).resolve().expanduser() + if not log_path.parent.exists(): + raise errors.BulkDownloaderException(f'Designated location for logfile does not exist') + backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3) + file_handler = logging.handlers.RotatingFileHandler( + log_path, + mode='a', + backupCount=backup_count, + ) + if log_path.exists(): + try: + file_handler.doRollover() + except PermissionError as e: + logger.critical( + 'Cannot rollover logfile, make sure this is the only ' + 'BDFR process or specify alternate logfile location') + raise + formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s') + file_handler.setFormatter(formatter) + file_handler.setLevel(0) + + main_logger.addHandler(file_handler) + + @staticmethod + def sanitise_subreddit_name(subreddit: str) -> str: + pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$') + match = re.match(pattern, subreddit) + if not match: + raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}') + return match.group(1) + + @staticmethod + def split_args_input(entries: list[str]) -> set[str]: + all_entries = [] + split_pattern = re.compile(r'[,;]\s?') + for entry in entries: + results = re.split(split_pattern, entry) + all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results]) + return set(all_entries) + + def get_subreddits(self) -> list[praw.models.ListingGenerator]: + if self.args.subreddit: + out = [] + for reddit in self.split_args_input(self.args.subreddit): + try: + reddit = self.reddit_instance.subreddit(reddit) + try: + self.check_subreddit_status(reddit) + except errors.BulkDownloaderException as e: + logger.error(e) + continue + if self.args.search: + out.append(reddit.search( + self.args.search, + sort=self.sort_filter.name.lower(), + limit=self.args.limit, + time_filter=self.time_filter.value, + )) + logger.debug( + f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"') + else: + out.append(self.create_filtered_listing_generator(reddit)) + logger.debug(f'Added submissions from subreddit {reddit}') + except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: + logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') + return out + else: + return [] + + def resolve_user_name(self): + if self.args.user == 'me': + if self.authenticated: + self.args.user = self.reddit_instance.user.me().name + logger.log(9, f'Resolved user to {self.args.user}') + else: + self.args.user = None + logger.warning('To use "me" as a user, an authenticated Reddit instance must be used') + + def get_submissions_from_link(self) -> list[list[praw.models.Submission]]: + supplied_submissions = [] + for sub_id in self.args.link: + if len(sub_id) == 6: + supplied_submissions.append(self.reddit_instance.submission(id=sub_id)) + else: + supplied_submissions.append(self.reddit_instance.submission(url=sub_id)) + return [supplied_submissions] + + def determine_sort_function(self) -> Callable: + if self.sort_filter is RedditTypes.SortType.NEW: + sort_function = praw.models.Subreddit.new + elif self.sort_filter is RedditTypes.SortType.RISING: + sort_function = praw.models.Subreddit.rising + elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL: + sort_function = praw.models.Subreddit.controversial + elif self.sort_filter is RedditTypes.SortType.TOP: + sort_function = praw.models.Subreddit.top + else: + sort_function = praw.models.Subreddit.hot + return sort_function + + def get_multireddits(self) -> list[Iterator]: + if self.args.multireddit: + out = [] + for multi in self.split_args_input(self.args.multireddit): + try: + multi = self.reddit_instance.multireddit(self.args.user, multi) + if not multi.subreddits: + raise errors.BulkDownloaderException + out.append(self.create_filtered_listing_generator(multi)) + logger.debug(f'Added submissions from multireddit {multi}') + except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e: + logger.error(f'Failed to get submissions for multireddit {multi}: {e}') + return out + else: + return [] + + def create_filtered_listing_generator(self, reddit_source) -> Iterator: + sort_function = self.determine_sort_function() + if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL): + return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value) + else: + return sort_function(reddit_source, limit=self.args.limit) + + def get_user_data(self) -> list[Iterator]: + if any([self.args.submitted, self.args.upvoted, self.args.saved]): + if self.args.user: + try: + self.check_user_existence(self.args.user) + except errors.BulkDownloaderException as e: + logger.error(e) + return [] + generators = [] + if self.args.submitted: + logger.debug(f'Retrieving submitted posts of user {self.args.user}') + generators.append(self.create_filtered_listing_generator( + self.reddit_instance.redditor(self.args.user).submissions, + )) + if not self.authenticated and any((self.args.upvoted, self.args.saved)): + logger.warning('Accessing user lists requires authentication') + else: + if self.args.upvoted: + logger.debug(f'Retrieving upvoted posts of user {self.args.user}') + generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit)) + if self.args.saved: + logger.debug(f'Retrieving saved posts of user {self.args.user}') + generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit)) + return generators + else: + logger.warning('A user must be supplied to download user data') + return [] + else: + return [] + + def check_user_existence(self, name: str): + user = self.reddit_instance.redditor(name=name) + try: + if user.id: + return + except prawcore.exceptions.NotFound: + raise errors.BulkDownloaderException(f'Could not find user {name}') + except AttributeError: + if hasattr(user, 'is_suspended'): + raise errors.BulkDownloaderException(f'User {name} is banned') + + def create_file_name_formatter(self) -> FileNameFormatter: + return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format) + + def create_time_filter(self) -> RedditTypes.TimeType: + try: + return RedditTypes.TimeType[self.args.time.upper()] + except (KeyError, AttributeError): + return RedditTypes.TimeType.ALL + + def create_sort_filter(self) -> RedditTypes.SortType: + try: + return RedditTypes.SortType[self.args.sort.upper()] + except (KeyError, AttributeError): + return RedditTypes.SortType.HOT + + def create_download_filter(self) -> DownloadFilter: + return DownloadFilter(self.args.skip_format, self.args.skip_domain) + + def create_authenticator(self) -> SiteAuthenticator: + return SiteAuthenticator(self.cfg_parser) + + @abstractmethod + def download(self): + pass + + @staticmethod + def check_subreddit_status(subreddit: praw.models.Subreddit): + if subreddit.display_name == 'all': + return + try: + assert subreddit.id + except prawcore.NotFound: + raise errors.BulkDownloaderException(f'Source {subreddit.display_name} does not exist or cannot be found') + except prawcore.Forbidden: + raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped') + + def read_excluded_ids(self) -> set[str]: + out = [] + out.extend(self.args.skip_id) + for id_file in self.args.skip_id_file: + id_file = Path(id_file).resolve().expanduser() + if not id_file.exists(): + logger.warning(f'ID exclusion file at {id_file} does not exist') + continue + with open(id_file, 'r') as file: + for line in file: + out.append(line.strip()) + return set(out) diff --git a/bdfr/downloader.py b/bdfr/downloader.py index 6fa37d6..62934a8 100644 --- a/bdfr/downloader.py +++ b/bdfr/downloader.py @@ -1,33 +1,19 @@ #!/usr/bin/env python3 # coding=utf-8 -import configparser import hashlib -import importlib.resources -import logging import logging.handlers import os -import re -import shutil -import socket -from datetime import datetime -from enum import Enum, auto from multiprocessing import Pool from pathlib import Path -from typing import Callable, Iterator -import appdirs import praw import praw.exceptions import praw.models -import prawcore -import bdfr.exceptions as errors +from bdfr import exceptions as errors from bdfr.configuration import Configuration -from bdfr.download_filter import DownloadFilter -from bdfr.file_name_formatter import FileNameFormatter -from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager -from bdfr.site_authenticator import SiteAuthenticator +from bdfr.connector import RedditConnector from bdfr.site_downloaders.download_factory import DownloadFactory logger = logging.getLogger(__name__) @@ -39,350 +25,11 @@ def _calc_hash(existing_file: Path): return existing_file, file_hash -class RedditTypes: - class SortType(Enum): - CONTROVERSIAL = auto() - HOT = auto() - NEW = auto() - RELEVENCE = auto() - RISING = auto() - TOP = auto() - - class TimeType(Enum): - ALL = 'all' - DAY = 'day' - HOUR = 'hour' - MONTH = 'month' - WEEK = 'week' - YEAR = 'year' - - -class RedditDownloader: +class RedditDownloader(RedditConnector): def __init__(self, args: Configuration): - self.args = args - self.config_directories = appdirs.AppDirs('bdfr', 'BDFR') - self.run_time = datetime.now().isoformat() - self._setup_internal_objects() - - self.reddit_lists = self._retrieve_reddit_lists() - - def _setup_internal_objects(self): - self._determine_directories() - self._load_config() - self._create_file_logger() - - self._read_config() - - self.download_filter = self._create_download_filter() - logger.log(9, 'Created download filter') - self.time_filter = self._create_time_filter() - logger.log(9, 'Created time filter') - self.sort_filter = self._create_sort_filter() - logger.log(9, 'Created sort filter') - self.file_name_formatter = self._create_file_name_formatter() - logger.log(9, 'Create file name formatter') - - self._create_reddit_instance() - self._resolve_user_name() - - self.excluded_submission_ids = self._read_excluded_ids() - + super(RedditDownloader, self).__init__(args) if self.args.search_existing: self.master_hash_list = self.scan_existing_files(self.download_directory) - else: - self.master_hash_list = {} - self.authenticator = self._create_authenticator() - logger.log(9, 'Created site authenticator') - - self.args.skip_subreddit = self._split_args_input(self.args.skip_subreddit) - self.args.skip_subreddit = set([sub.lower() for sub in self.args.skip_subreddit]) - - def _read_config(self): - """Read any cfg values that need to be processed""" - if self.args.max_wait_time is None: - if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'): - self.cfg_parser.set('DEFAULT', 'max_wait_time', '120') - logger.log(9, 'Wrote default download wait time download to config file') - self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time') - logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds') - if self.args.time_format is None: - option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO') - if re.match(r'^[ \'\"]*$', option): - option = 'ISO' - logger.debug(f'Setting datetime format string to {option}') - self.args.time_format = option - # Update config on disk - with open(self.config_location, 'w') as file: - self.cfg_parser.write(file) - - def _create_reddit_instance(self): - if self.args.authenticate: - logger.debug('Using authenticated Reddit instance') - if not self.cfg_parser.has_option('DEFAULT', 'user_token'): - logger.log(9, 'Commencing OAuth2 authentication') - scopes = self.cfg_parser.get('DEFAULT', 'scopes') - scopes = OAuth2Authenticator.split_scopes(scopes) - oauth2_authenticator = OAuth2Authenticator( - scopes, - self.cfg_parser.get('DEFAULT', 'client_id'), - self.cfg_parser.get('DEFAULT', 'client_secret'), - ) - token = oauth2_authenticator.retrieve_new_token() - self.cfg_parser['DEFAULT']['user_token'] = token - with open(self.config_location, 'w') as file: - self.cfg_parser.write(file, True) - token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location) - - self.authenticated = True - self.reddit_instance = praw.Reddit( - client_id=self.cfg_parser.get('DEFAULT', 'client_id'), - client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), - user_agent=socket.gethostname(), - token_manager=token_manager, - ) - else: - logger.debug('Using unauthenticated Reddit instance') - self.authenticated = False - self.reddit_instance = praw.Reddit( - client_id=self.cfg_parser.get('DEFAULT', 'client_id'), - client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), - user_agent=socket.gethostname(), - ) - - def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]: - master_list = [] - master_list.extend(self._get_subreddits()) - logger.log(9, 'Retrieved subreddits') - master_list.extend(self._get_multireddits()) - logger.log(9, 'Retrieved multireddits') - master_list.extend(self._get_user_data()) - logger.log(9, 'Retrieved user data') - master_list.extend(self._get_submissions_from_link()) - logger.log(9, 'Retrieved submissions for given links') - return master_list - - def _determine_directories(self): - self.download_directory = Path(self.args.directory).resolve().expanduser() - self.config_directory = Path(self.config_directories.user_config_dir) - - self.download_directory.mkdir(exist_ok=True, parents=True) - self.config_directory.mkdir(exist_ok=True, parents=True) - - def _load_config(self): - self.cfg_parser = configparser.ConfigParser() - if self.args.config: - if (cfg_path := Path(self.args.config)).exists(): - self.cfg_parser.read(cfg_path) - self.config_location = cfg_path - return - possible_paths = [ - Path('./config.cfg'), - Path('./default_config.cfg'), - Path(self.config_directory, 'config.cfg'), - Path(self.config_directory, 'default_config.cfg'), - ] - self.config_location = None - for path in possible_paths: - if path.resolve().expanduser().exists(): - self.config_location = path - logger.debug(f'Loading configuration from {path}') - break - if not self.config_location: - self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0] - shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg')) - if not self.config_location: - raise errors.BulkDownloaderException('Could not find a configuration file to load') - self.cfg_parser.read(self.config_location) - - def _create_file_logger(self): - main_logger = logging.getLogger() - if self.args.log is None: - log_path = Path(self.config_directory, 'log_output.txt') - else: - log_path = Path(self.args.log).resolve().expanduser() - if not log_path.parent.exists(): - raise errors.BulkDownloaderException(f'Designated location for logfile does not exist') - backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3) - file_handler = logging.handlers.RotatingFileHandler( - log_path, - mode='a', - backupCount=backup_count, - ) - if log_path.exists(): - try: - file_handler.doRollover() - except PermissionError as e: - logger.critical( - 'Cannot rollover logfile, make sure this is the only ' - 'BDFR process or specify alternate logfile location') - raise - formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s') - file_handler.setFormatter(formatter) - file_handler.setLevel(0) - - main_logger.addHandler(file_handler) - - @staticmethod - def _sanitise_subreddit_name(subreddit: str) -> str: - pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$') - match = re.match(pattern, subreddit) - if not match: - raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}') - return match.group(1) - - @staticmethod - def _split_args_input(entries: list[str]) -> set[str]: - all_entries = [] - split_pattern = re.compile(r'[,;]\s?') - for entry in entries: - results = re.split(split_pattern, entry) - all_entries.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results]) - return set(all_entries) - - def _get_subreddits(self) -> list[praw.models.ListingGenerator]: - if self.args.subreddit: - out = [] - for reddit in self._split_args_input(self.args.subreddit): - try: - reddit = self.reddit_instance.subreddit(reddit) - try: - self._check_subreddit_status(reddit) - except errors.BulkDownloaderException as e: - logger.error(e) - continue - if self.args.search: - out.append(reddit.search( - self.args.search, - sort=self.sort_filter.name.lower(), - limit=self.args.limit, - time_filter=self.time_filter.value, - )) - logger.debug( - f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"') - else: - out.append(self._create_filtered_listing_generator(reddit)) - logger.debug(f'Added submissions from subreddit {reddit}') - except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: - logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') - return out - else: - return [] - - def _resolve_user_name(self): - if self.args.user == 'me': - if self.authenticated: - self.args.user = self.reddit_instance.user.me().name - logger.log(9, f'Resolved user to {self.args.user}') - else: - self.args.user = None - logger.warning('To use "me" as a user, an authenticated Reddit instance must be used') - - def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]: - supplied_submissions = [] - for sub_id in self.args.link: - if len(sub_id) == 6: - supplied_submissions.append(self.reddit_instance.submission(id=sub_id)) - else: - supplied_submissions.append(self.reddit_instance.submission(url=sub_id)) - return [supplied_submissions] - - def _determine_sort_function(self) -> Callable: - if self.sort_filter is RedditTypes.SortType.NEW: - sort_function = praw.models.Subreddit.new - elif self.sort_filter is RedditTypes.SortType.RISING: - sort_function = praw.models.Subreddit.rising - elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL: - sort_function = praw.models.Subreddit.controversial - elif self.sort_filter is RedditTypes.SortType.TOP: - sort_function = praw.models.Subreddit.top - else: - sort_function = praw.models.Subreddit.hot - return sort_function - - def _get_multireddits(self) -> list[Iterator]: - if self.args.multireddit: - out = [] - for multi in self._split_args_input(self.args.multireddit): - try: - multi = self.reddit_instance.multireddit(self.args.user, multi) - if not multi.subreddits: - raise errors.BulkDownloaderException - out.append(self._create_filtered_listing_generator(multi)) - logger.debug(f'Added submissions from multireddit {multi}') - except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e: - logger.error(f'Failed to get submissions for multireddit {multi}: {e}') - return out - else: - return [] - - def _create_filtered_listing_generator(self, reddit_source) -> Iterator: - sort_function = self._determine_sort_function() - if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL): - return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value) - else: - return sort_function(reddit_source, limit=self.args.limit) - - def _get_user_data(self) -> list[Iterator]: - if any([self.args.submitted, self.args.upvoted, self.args.saved]): - if self.args.user: - try: - self._check_user_existence(self.args.user) - except errors.BulkDownloaderException as e: - logger.error(e) - return [] - generators = [] - if self.args.submitted: - logger.debug(f'Retrieving submitted posts of user {self.args.user}') - generators.append(self._create_filtered_listing_generator( - self.reddit_instance.redditor(self.args.user).submissions, - )) - if not self.authenticated and any((self.args.upvoted, self.args.saved)): - logger.warning('Accessing user lists requires authentication') - else: - if self.args.upvoted: - logger.debug(f'Retrieving upvoted posts of user {self.args.user}') - generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit)) - if self.args.saved: - logger.debug(f'Retrieving saved posts of user {self.args.user}') - generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit)) - return generators - else: - logger.warning('A user must be supplied to download user data') - return [] - else: - return [] - - def _check_user_existence(self, name: str): - user = self.reddit_instance.redditor(name=name) - try: - if user.id: - return - except prawcore.exceptions.NotFound: - raise errors.BulkDownloaderException(f'Could not find user {name}') - except AttributeError: - if hasattr(user, 'is_suspended'): - raise errors.BulkDownloaderException(f'User {name} is banned') - - def _create_file_name_formatter(self) -> FileNameFormatter: - return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format) - - def _create_time_filter(self) -> RedditTypes.TimeType: - try: - return RedditTypes.TimeType[self.args.time.upper()] - except (KeyError, AttributeError): - return RedditTypes.TimeType.ALL - - def _create_sort_filter(self) -> RedditTypes.SortType: - try: - return RedditTypes.SortType[self.args.sort.upper()] - except (KeyError, AttributeError): - return RedditTypes.SortType.HOT - - def _create_download_filter(self) -> DownloadFilter: - return DownloadFilter(self.args.skip_format, self.args.skip_domain) - - def _create_authenticator(self) -> SiteAuthenticator: - return SiteAuthenticator(self.cfg_parser) def download(self): for generator in self.reddit_lists: @@ -457,27 +104,3 @@ class RedditDownloader: hash_list = {res[1]: res[0] for res in results} return hash_list - - def _read_excluded_ids(self) -> set[str]: - out = [] - out.extend(self.args.skip_id) - for id_file in self.args.skip_id_file: - id_file = Path(id_file).resolve().expanduser() - if not id_file.exists(): - logger.warning(f'ID exclusion file at {id_file} does not exist') - continue - with open(id_file, 'r') as file: - for line in file: - out.append(line.strip()) - return set(out) - - @staticmethod - def _check_subreddit_status(subreddit: praw.models.Subreddit): - if subreddit.display_name == 'all': - return - try: - assert subreddit.id - except prawcore.NotFound: - raise errors.BulkDownloaderException(f'Source {subreddit.display_name} does not exist or cannot be found') - except prawcore.Forbidden: - raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped') diff --git a/tests/test_connector.py b/tests/test_connector.py new file mode 100644 index 0000000..41d9115 --- /dev/null +++ b/tests/test_connector.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +from pathlib import Path +from typing import Iterator +from unittest.mock import MagicMock + +import praw +import praw.models +import pytest + +from bdfr.configuration import Configuration +from bdfr.connector import RedditConnector, RedditTypes +from bdfr.download_filter import DownloadFilter +from bdfr.exceptions import BulkDownloaderException +from bdfr.file_name_formatter import FileNameFormatter +from bdfr.site_authenticator import SiteAuthenticator + + +@pytest.fixture() +def args() -> Configuration: + args = Configuration() + args.time_format = 'ISO' + return args + + +@pytest.fixture() +def downloader_mock(args: Configuration): + downloader_mock = MagicMock() + downloader_mock.args = args + downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name + downloader_mock.split_args_input = RedditConnector.split_args_input + downloader_mock.master_hash_list = {} + return downloader_mock + + +def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]): + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + if result_limit is not None: + assert len(results) == result_limit + return results + + +def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): + downloader_mock.args.directory = tmp_path / 'test' + downloader_mock.config_directories.user_config_dir = tmp_path + RedditConnector.determine_directories(downloader_mock) + assert Path(tmp_path / 'test').exists() + + +@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( + ([], []), + (['.test'], ['test.com'],), +)) +def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): + downloader_mock.args.skip_format = skip_extensions + downloader_mock.args.skip_domain = skip_domains + result = RedditConnector.create_download_filter(downloader_mock) + + assert isinstance(result, DownloadFilter) + assert result.excluded_domains == skip_domains + assert result.excluded_extensions == skip_extensions + + +@pytest.mark.parametrize(('test_time', 'expected'), ( + ('all', 'all'), + ('hour', 'hour'), + ('day', 'day'), + ('week', 'week'), + ('random', 'all'), + ('', 'all'), +)) +def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): + downloader_mock.args.time = test_time + result = RedditConnector.create_time_filter(downloader_mock) + + assert isinstance(result, RedditTypes.TimeType) + assert result.name.lower() == expected + + +@pytest.mark.parametrize(('test_sort', 'expected'), ( + ('', 'hot'), + ('hot', 'hot'), + ('controversial', 'controversial'), + ('new', 'new'), +)) +def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): + downloader_mock.args.sort = test_sort + result = RedditConnector.create_sort_filter(downloader_mock) + + assert isinstance(result, RedditTypes.SortType) + assert result.name.lower() == expected + + +@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( + ('{POSTID}', '{SUBREDDIT}'), + ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), + ('{POSTID}', 'test'), + ('{POSTID}', ''), + ('{POSTID}', '{SUBREDDIT}/{REDDITOR}'), +)) +def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): + downloader_mock.args.file_scheme = test_file_scheme + downloader_mock.args.folder_scheme = test_folder_scheme + result = RedditConnector.create_file_name_formatter(downloader_mock) + + assert isinstance(result, FileNameFormatter) + assert result.file_format_string == test_file_scheme + assert result.directory_format_string == test_folder_scheme.split('/') + + +@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( + ('', ''), + ('', '{SUBREDDIT}'), + ('test', '{SUBREDDIT}'), +)) +def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): + downloader_mock.args.file_scheme = test_file_scheme + downloader_mock.args.folder_scheme = test_folder_scheme + with pytest.raises(BulkDownloaderException): + RedditConnector.create_file_name_formatter(downloader_mock) + + +def test_create_authenticator(downloader_mock: MagicMock): + result = RedditConnector.create_authenticator(downloader_mock) + assert isinstance(result, SiteAuthenticator) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_submission_ids', ( + ('lvpf4l',), + ('lvpf4l', 'lvqnsn'), + ('lvpf4l', 'lvqnsn', 'lvl9kd'), +)) +def test_get_submissions_from_link( + test_submission_ids: list[str], + reddit_instance: praw.Reddit, + downloader_mock: MagicMock): + downloader_mock.args.link = test_submission_ids + downloader_mock.reddit_instance = reddit_instance + results = RedditConnector.get_submissions_from_link(downloader_mock) + assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res]) + assert len(results[0]) == len(test_submission_ids) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), ( + (('Futurology',), 10, 'hot', 'all', 10), + (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30), + (('Futurology',), 20, 'hot', 'all', 20), + (('Futurology', 'Python'), 10, 'hot', 'all', 20), + (('Futurology',), 100, 'hot', 'all', 100), + (('Futurology',), 0, 'hot', 'all', 0), + (('Futurology',), 10, 'top', 'all', 10), + (('Futurology',), 10, 'top', 'week', 10), + (('Futurology',), 10, 'hot', 'week', 10), +)) +def test_get_subreddit_normal( + test_subreddits: list[str], + limit: int, + sort_type: str, + time_filter: str, + max_expected_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, +): + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.args.limit = limit + downloader_mock.args.sort = sort_type + downloader_mock.args.subreddit = test_subreddits + downloader_mock.reddit_instance = reddit_instance + downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock) + results = RedditConnector.get_subreddits(downloader_mock) + test_subreddits = downloader_mock._split_args_input(test_subreddits) + results = [sub for res1 in results for sub in res1] + assert all([isinstance(res1, praw.models.Submission) for res1 in results]) + assert all([res.subreddit.display_name in test_subreddits for res in results]) + assert len(results) <= max_expected_len + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), ( + (('Python',), 'scraper', 10, 'all', 10), + (('Python',), '', 10, 'all', 10), + (('Python',), 'djsdsgewef', 10, 'all', 0), + (('Python',), 'scraper', 10, 'year', 10), + (('Python',), 'scraper', 10, 'hour', 1), +)) +def test_get_subreddit_search( + test_subreddits: list[str], + search_term: str, + time_filter: str, + limit: int, + max_expected_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, +): + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.args.limit = limit + downloader_mock.args.search = search_term + downloader_mock.args.subreddit = test_subreddits + downloader_mock.reddit_instance = reddit_instance + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.time = time_filter + downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock) + results = RedditConnector.get_subreddits(downloader_mock) + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + assert all([res.subreddit.display_name in test_subreddits for res in results]) + assert len(results) <= max_expected_len + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), ( + ('helen_darten', ('cuteanimalpics',), 10), + ('korfor', ('chess',), 100), +)) +# Good sources at https://www.reddit.com/r/multihub/ +def test_get_multireddits_public( + test_user: str, + test_multireddits: list[str], + limit: int, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, +): + downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.limit = limit + downloader_mock.args.multireddit = test_multireddits + downloader_mock.args.user = test_user + downloader_mock.reddit_instance = reddit_instance + downloader_mock.create_filtered_listing_generator.return_value = \ + RedditConnector.create_filtered_listing_generator( + downloader_mock, + reddit_instance.multireddit(test_user, test_multireddits[0]), + ) + results = RedditConnector.get_multireddits(downloader_mock) + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + assert len(results) == limit + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_user', 'limit'), ( + ('danigirl3694', 10), + ('danigirl3694', 50), + ('CapitanHam', None), +)) +def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): + downloader_mock.args.limit = limit + downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.submitted = True + downloader_mock.args.user = test_user + downloader_mock.authenticated = False + downloader_mock.reddit_instance = reddit_instance + downloader_mock.create_filtered_listing_generator.return_value = \ + RedditConnector.create_filtered_listing_generator( + downloader_mock, + reddit_instance.redditor(test_user).submissions, + ) + results = RedditConnector.get_user_data(downloader_mock) + results = assert_all_results_are_submissions(limit, results) + assert all([res.author.name == test_user for res in results]) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.authenticated +@pytest.mark.parametrize('test_flag', ( + 'upvoted', + 'saved', +)) +def test_get_user_authenticated_lists( + test_flag: str, + downloader_mock: MagicMock, + authenticated_reddit_instance: praw.Reddit, +): + downloader_mock.args.__dict__[test_flag] = True + downloader_mock.reddit_instance = authenticated_reddit_instance + downloader_mock.args.user = 'me' + downloader_mock.args.limit = 10 + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + RedditConnector.resolve_user_name(downloader_mock) + results = RedditConnector.get_user_data(downloader_mock) + assert_all_results_are_submissions(10, results) + + +@pytest.mark.parametrize(('test_name', 'expected'), ( + ('Mindustry', 'Mindustry'), + ('Futurology', 'Futurology'), + ('r/Mindustry', 'Mindustry'), + ('TrollXChromosomes', 'TrollXChromosomes'), + ('r/TrollXChromosomes', 'TrollXChromosomes'), + ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'), + ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'), + ('https://www.reddit.com/r/Futurology/', 'Futurology'), + ('https://www.reddit.com/r/Futurology', 'Futurology'), +)) +def test_sanitise_subreddit_name(test_name: str, expected: str): + result = RedditConnector.sanitise_subreddit_name(test_name) + assert result == expected + + +@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), ( + (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}), + (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}) +)) +def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]): + results = RedditConnector.split_args_input(test_subreddit_entries) + assert results == expected + + +def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): + test_file = tmp_path / 'test.txt' + test_file.write_text('aaaaaa\nbbbbbb') + downloader_mock.args.skip_id_file = [test_file] + results = RedditConnector.read_excluded_ids(downloader_mock) + assert results == {'aaaaaa', 'bbbbbb'} + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_redditor_name', ( + 'Paracortex', + 'crowdstrike', + 'HannibalGoddamnit', +)) +def test_check_user_existence_good( + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, +): + downloader_mock.reddit_instance = reddit_instance + RedditConnector.check_user_existence(downloader_mock, test_redditor_name) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_redditor_name', ( + 'lhnhfkuhwreolo', + 'adlkfmnhglojh', +)) +def test_check_user_existence_nonexistent( + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, +): + downloader_mock.reddit_instance = reddit_instance + with pytest.raises(BulkDownloaderException, match='Could not find'): + RedditConnector.check_user_existence(downloader_mock, test_redditor_name) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_redditor_name', ( + 'Bree-Boo', +)) +def test_check_user_existence_banned( + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, +): + downloader_mock.reddit_instance = reddit_instance + with pytest.raises(BulkDownloaderException, match='is banned'): + RedditConnector.check_user_existence(downloader_mock, test_redditor_name) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), ( + ('donaldtrump', 'cannot be found'), + ('submitters', 'private and cannot be scraped') +)) +def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit): + test_subreddit = reddit_instance.subreddit(test_subreddit_name) + with pytest.raises(BulkDownloaderException, match=expected_message): + RedditConnector.check_subreddit_status(test_subreddit) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_subreddit_name', ( + 'Python', + 'Mindustry', + 'TrollXChromosomes', + 'all', +)) +def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit): + test_subreddit = reddit_instance.subreddit(test_subreddit_name) + RedditConnector.check_subreddit_status(test_subreddit) diff --git a/tests/test_downloader.py b/tests/test_downloader.py index fd56994..ee43625 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -3,20 +3,15 @@ import re from pathlib import Path -from typing import Iterator from unittest.mock import MagicMock -import praw import praw.models import pytest from bdfr.__main__ import setup_logging from bdfr.configuration import Configuration -from bdfr.download_filter import DownloadFilter -from bdfr.downloader import RedditDownloader, RedditTypes -from bdfr.exceptions import BulkDownloaderException -from bdfr.file_name_formatter import FileNameFormatter -from bdfr.site_authenticator import SiteAuthenticator +from bdfr.connector import RedditConnector +from bdfr.downloader import RedditDownloader @pytest.fixture() @@ -30,411 +25,12 @@ def args() -> Configuration: def downloader_mock(args: Configuration): downloader_mock = MagicMock() downloader_mock.args = args - downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name - downloader_mock._split_args_input = RedditDownloader._split_args_input + downloader_mock._sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name + downloader_mock._split_args_input = RedditConnector.split_args_input downloader_mock.master_hash_list = {} return downloader_mock -def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]): - results = [sub for res in results for sub in res] - assert all([isinstance(res, praw.models.Submission) for res in results]) - if result_limit is not None: - assert len(results) == result_limit - return results - - -def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): - downloader_mock.args.directory = tmp_path / 'test' - downloader_mock.config_directories.user_config_dir = tmp_path - RedditDownloader._determine_directories(downloader_mock) - assert Path(tmp_path / 'test').exists() - - -@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( - ([], []), - (['.test'], ['test.com'],), -)) -def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): - downloader_mock.args.skip_format = skip_extensions - downloader_mock.args.skip_domain = skip_domains - result = RedditDownloader._create_download_filter(downloader_mock) - - assert isinstance(result, DownloadFilter) - assert result.excluded_domains == skip_domains - assert result.excluded_extensions == skip_extensions - - -@pytest.mark.parametrize(('test_time', 'expected'), ( - ('all', 'all'), - ('hour', 'hour'), - ('day', 'day'), - ('week', 'week'), - ('random', 'all'), - ('', 'all'), -)) -def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): - downloader_mock.args.time = test_time - result = RedditDownloader._create_time_filter(downloader_mock) - - assert isinstance(result, RedditTypes.TimeType) - assert result.name.lower() == expected - - -@pytest.mark.parametrize(('test_sort', 'expected'), ( - ('', 'hot'), - ('hot', 'hot'), - ('controversial', 'controversial'), - ('new', 'new'), -)) -def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): - downloader_mock.args.sort = test_sort - result = RedditDownloader._create_sort_filter(downloader_mock) - - assert isinstance(result, RedditTypes.SortType) - assert result.name.lower() == expected - - -@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( - ('{POSTID}', '{SUBREDDIT}'), - ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), - ('{POSTID}', 'test'), - ('{POSTID}', ''), - ('{POSTID}', '{SUBREDDIT}/{REDDITOR}'), -)) -def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): - downloader_mock.args.file_scheme = test_file_scheme - downloader_mock.args.folder_scheme = test_folder_scheme - result = RedditDownloader._create_file_name_formatter(downloader_mock) - - assert isinstance(result, FileNameFormatter) - assert result.file_format_string == test_file_scheme - assert result.directory_format_string == test_folder_scheme.split('/') - - -@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( - ('', ''), - ('', '{SUBREDDIT}'), - ('test', '{SUBREDDIT}'), -)) -def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): - downloader_mock.args.file_scheme = test_file_scheme - downloader_mock.args.folder_scheme = test_folder_scheme - with pytest.raises(BulkDownloaderException): - RedditDownloader._create_file_name_formatter(downloader_mock) - - -def test_create_authenticator(downloader_mock: MagicMock): - result = RedditDownloader._create_authenticator(downloader_mock) - assert isinstance(result, SiteAuthenticator) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_submission_ids', ( - ('lvpf4l',), - ('lvpf4l', 'lvqnsn'), - ('lvpf4l', 'lvqnsn', 'lvl9kd'), -)) -def test_get_submissions_from_link( - test_submission_ids: list[str], - reddit_instance: praw.Reddit, - downloader_mock: MagicMock): - downloader_mock.args.link = test_submission_ids - downloader_mock.reddit_instance = reddit_instance - results = RedditDownloader._get_submissions_from_link(downloader_mock) - assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res]) - assert len(results[0]) == len(test_submission_ids) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), ( - (('Futurology',), 10, 'hot', 'all', 10), - (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30), - (('Futurology',), 20, 'hot', 'all', 20), - (('Futurology', 'Python'), 10, 'hot', 'all', 20), - (('Futurology',), 100, 'hot', 'all', 100), - (('Futurology',), 0, 'hot', 'all', 0), - (('Futurology',), 10, 'top', 'all', 10), - (('Futurology',), 10, 'top', 'week', 10), - (('Futurology',), 10, 'hot', 'week', 10), -)) -def test_get_subreddit_normal( - test_subreddits: list[str], - limit: int, - sort_type: str, - time_filter: str, - max_expected_len: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, -): - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.args.limit = limit - downloader_mock.args.sort = sort_type - downloader_mock.args.subreddit = test_subreddits - downloader_mock.reddit_instance = reddit_instance - downloader_mock.sort_filter = RedditDownloader._create_sort_filter(downloader_mock) - results = RedditDownloader._get_subreddits(downloader_mock) - test_subreddits = downloader_mock._split_args_input(test_subreddits) - results = [sub for res1 in results for sub in res1] - assert all([isinstance(res1, praw.models.Submission) for res1 in results]) - assert all([res.subreddit.display_name in test_subreddits for res in results]) - assert len(results) <= max_expected_len - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), ( - (('Python',), 'scraper', 10, 'all', 10), - (('Python',), '', 10, 'all', 10), - (('Python',), 'djsdsgewef', 10, 'all', 0), - (('Python',), 'scraper', 10, 'year', 10), - (('Python',), 'scraper', 10, 'hour', 1), -)) -def test_get_subreddit_search( - test_subreddits: list[str], - search_term: str, - time_filter: str, - limit: int, - max_expected_len: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, -): - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.args.limit = limit - downloader_mock.args.search = search_term - downloader_mock.args.subreddit = test_subreddits - downloader_mock.reddit_instance = reddit_instance - downloader_mock.sort_filter = RedditTypes.SortType.HOT - downloader_mock.args.time = time_filter - downloader_mock.time_filter = RedditDownloader._create_time_filter(downloader_mock) - results = RedditDownloader._get_subreddits(downloader_mock) - results = [sub for res in results for sub in res] - assert all([isinstance(res, praw.models.Submission) for res in results]) - assert all([res.subreddit.display_name in test_subreddits for res in results]) - assert len(results) <= max_expected_len - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), ( - ('helen_darten', ('cuteanimalpics',), 10), - ('korfor', ('chess',), 100), -)) -# Good sources at https://www.reddit.com/r/multihub/ -def test_get_multireddits_public( - test_user: str, - test_multireddits: list[str], - limit: int, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, -): - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.sort_filter = RedditTypes.SortType.HOT - downloader_mock.args.limit = limit - downloader_mock.args.multireddit = test_multireddits - downloader_mock.args.user = test_user - downloader_mock.reddit_instance = reddit_instance - downloader_mock._create_filtered_listing_generator.return_value = \ - RedditDownloader._create_filtered_listing_generator( - downloader_mock, - reddit_instance.multireddit(test_user, test_multireddits[0]), - ) - results = RedditDownloader._get_multireddits(downloader_mock) - results = [sub for res in results for sub in res] - assert all([isinstance(res, praw.models.Submission) for res in results]) - assert len(results) == limit - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_user', 'limit'), ( - ('danigirl3694', 10), - ('danigirl3694', 50), - ('CapitanHam', None), -)) -def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.args.limit = limit - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.sort_filter = RedditTypes.SortType.HOT - downloader_mock.args.submitted = True - downloader_mock.args.user = test_user - downloader_mock.authenticated = False - downloader_mock.reddit_instance = reddit_instance - downloader_mock._create_filtered_listing_generator.return_value = \ - RedditDownloader._create_filtered_listing_generator( - downloader_mock, - reddit_instance.redditor(test_user).submissions, - ) - results = RedditDownloader._get_user_data(downloader_mock) - results = assert_all_results_are_submissions(limit, results) - assert all([res.author.name == test_user for res in results]) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.authenticated -@pytest.mark.parametrize('test_flag', ( - 'upvoted', - 'saved', -)) -def test_get_user_authenticated_lists( - test_flag: str, - downloader_mock: MagicMock, - authenticated_reddit_instance: praw.Reddit, -): - downloader_mock.args.__dict__[test_flag] = True - downloader_mock.reddit_instance = authenticated_reddit_instance - downloader_mock.args.user = 'me' - downloader_mock.args.limit = 10 - downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot - downloader_mock.sort_filter = RedditTypes.SortType.HOT - RedditDownloader._resolve_user_name(downloader_mock) - results = RedditDownloader._get_user_data(downloader_mock) - assert_all_results_are_submissions(10, results) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), ( - ('ljyy27', 4), -)) -def test_download_submission( - test_submission_id: str, - expected_files_len: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path): - downloader_mock.reddit_instance = reddit_instance - downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) - downloader_mock.download_directory = tmp_path - submission = downloader_mock.reddit_instance.submission(id=test_submission_id) - RedditDownloader._download_submission(downloader_mock, submission) - folder_contents = list(tmp_path.iterdir()) - assert len(folder_contents) == expected_files_len - - -@pytest.mark.online -@pytest.mark.reddit -def test_download_submission_file_exists( - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path, - capsys: pytest.CaptureFixture -): - setup_logging(3) - downloader_mock.reddit_instance = reddit_instance - downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) - downloader_mock.download_directory = tmp_path - submission = downloader_mock.reddit_instance.submission(id='m1hqw6') - Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch() - RedditDownloader._download_submission(downloader_mock, submission) - folder_contents = list(tmp_path.iterdir()) - output = capsys.readouterr() - assert len(folder_contents) == 1 - assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png already exists' in output.out - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'test_hash'), ( - ('m1hqw6', 'a912af8905ae468e0121e9940f797ad7'), -)) -def test_download_submission_hash_exists( - test_submission_id: str, - test_hash: str, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path, - capsys: pytest.CaptureFixture -): - setup_logging(3) - downloader_mock.reddit_instance = reddit_instance - downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' - downloader_mock.args.no_dupes = True - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) - downloader_mock.download_directory = tmp_path - downloader_mock.master_hash_list = {test_hash: None} - submission = downloader_mock.reddit_instance.submission(id=test_submission_id) - RedditDownloader._download_submission(downloader_mock, submission) - folder_contents = list(tmp_path.iterdir()) - output = capsys.readouterr() - assert len(folder_contents) == 0 - assert re.search(r'Resource hash .*? downloaded elsewhere', output.out) - - -@pytest.mark.parametrize(('test_name', 'expected'), ( - ('Mindustry', 'Mindustry'), - ('Futurology', 'Futurology'), - ('r/Mindustry', 'Mindustry'), - ('TrollXChromosomes', 'TrollXChromosomes'), - ('r/TrollXChromosomes', 'TrollXChromosomes'), - ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'), - ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'), - ('https://www.reddit.com/r/Futurology/', 'Futurology'), - ('https://www.reddit.com/r/Futurology', 'Futurology'), -)) -def test_sanitise_subreddit_name(test_name: str, expected: str): - result = RedditDownloader._sanitise_subreddit_name(test_name) - assert result == expected - - -def test_search_existing_files(): - results = RedditDownloader.scan_existing_files(Path('.')) - assert len(results.keys()) >= 40 - - -@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), ( - (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}) -)) -def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]): - results = RedditDownloader._split_args_input(test_subreddit_entries) - assert results == expected - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_submission_id', ( - 'm1hqw6', -)) -def test_mark_hard_link( - test_submission_id: str, - downloader_mock: MagicMock, - tmp_path: Path, - reddit_instance: praw.Reddit -): - downloader_mock.reddit_instance = reddit_instance - downloader_mock.args.make_hard_links = True - downloader_mock.download_directory = tmp_path - downloader_mock.args.folder_scheme = '' - downloader_mock.args.file_scheme = '{POSTID}' - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) - submission = downloader_mock.reddit_instance.submission(id=test_submission_id) - original = Path(tmp_path, f'{test_submission_id}.png') - - RedditDownloader._download_submission(downloader_mock, submission) - assert original.exists() - - downloader_mock.args.file_scheme = 'test2_{POSTID}' - downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock) - RedditDownloader._download_submission(downloader_mock, submission) - test_file_1_stats = original.stat() - test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino - - assert test_file_1_stats.st_nlink == 2 - assert test_file_1_stats.st_ino == test_file_2_inode - - @pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), ( (('aaaaaa',), (), 1), (('aaaaaa',), ('aaaaaa',), 0), @@ -453,81 +49,113 @@ def test_excluded_ids(test_ids: tuple[str], test_excluded: tuple[str], expected_ assert downloader_mock._download_submission.call_count == expected_len -def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): - test_file = tmp_path / 'test.txt' - test_file.write_text('aaaaaa\nbbbbbb') - downloader_mock.args.skip_id_file = [test_file] - results = RedditDownloader._read_excluded_ids(downloader_mock) - assert results == {'aaaaaa', 'bbbbbb'} - - @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'Paracortex', - 'crowdstrike', - 'HannibalGoddamnit', +@pytest.mark.parametrize('test_submission_id', ( + 'm1hqw6', )) -def test_check_user_existence_good( - test_redditor_name: str, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, -): - downloader_mock.reddit_instance = reddit_instance - RedditDownloader._check_user_existence(downloader_mock, test_redditor_name) - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'lhnhfkuhwreolo', - 'adlkfmnhglojh', -)) -def test_check_user_existence_nonexistent( - test_redditor_name: str, - reddit_instance: praw.Reddit, +def test_mark_hard_link( + test_submission_id: str, downloader_mock: MagicMock, + tmp_path: Path, + reddit_instance: praw.Reddit ): downloader_mock.reddit_instance = reddit_instance - with pytest.raises(BulkDownloaderException, match='Could not find'): - RedditDownloader._check_user_existence(downloader_mock, test_redditor_name) + downloader_mock.args.make_hard_links = True + downloader_mock.download_directory = tmp_path + downloader_mock.args.folder_scheme = '' + downloader_mock.args.file_scheme = '{POSTID}' + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + submission = downloader_mock.reddit_instance.submission(id=test_submission_id) + original = Path(tmp_path, f'{test_submission_id}.png') + + RedditDownloader._download_submission(downloader_mock, submission) + assert original.exists() + + downloader_mock.args.file_scheme = 'test2_{POSTID}' + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + RedditDownloader._download_submission(downloader_mock, submission) + test_file_1_stats = original.stat() + test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino + + assert test_file_1_stats.st_nlink == 2 + assert test_file_1_stats.st_ino == test_file_2_inode + + +def test_search_existing_files(): + results = RedditDownloader.scan_existing_files(Path('.')) + assert len(results.keys()) >= 40 @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'Bree-Boo', +@pytest.mark.parametrize(('test_submission_id', 'test_hash'), ( + ('m1hqw6', 'a912af8905ae468e0121e9940f797ad7'), )) -def test_check_user_existence_banned( - test_redditor_name: str, - reddit_instance: praw.Reddit, +def test_download_submission_hash_exists( + test_submission_id: str, + test_hash: str, downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path, + capsys: pytest.CaptureFixture ): + setup_logging(3) downloader_mock.reddit_instance = reddit_instance - with pytest.raises(BulkDownloaderException, match='is banned'): - RedditDownloader._check_user_existence(downloader_mock, test_redditor_name) + downloader_mock.download_filter.check_url.return_value = True + downloader_mock.args.folder_scheme = '' + downloader_mock.args.no_dupes = True + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + downloader_mock.download_directory = tmp_path + downloader_mock.master_hash_list = {test_hash: None} + submission = downloader_mock.reddit_instance.submission(id=test_submission_id) + RedditDownloader._download_submission(downloader_mock, submission) + folder_contents = list(tmp_path.iterdir()) + output = capsys.readouterr() + assert len(folder_contents) == 0 + assert re.search(r'Resource hash .*? downloaded elsewhere', output.out) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), ( - ('donaldtrump', 'cannot be found'), - ('submitters', 'private and cannot be scraped') -)) -def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit): - test_subreddit = reddit_instance.subreddit(test_subreddit_name) - with pytest.raises(BulkDownloaderException, match=expected_message): - RedditDownloader._check_subreddit_status(test_subreddit) +def test_download_submission_file_exists( + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path, + capsys: pytest.CaptureFixture +): + setup_logging(3) + downloader_mock.reddit_instance = reddit_instance + downloader_mock.download_filter.check_url.return_value = True + downloader_mock.args.folder_scheme = '' + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + downloader_mock.download_directory = tmp_path + submission = downloader_mock.reddit_instance.submission(id='m1hqw6') + Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch() + RedditDownloader._download_submission(downloader_mock, submission) + folder_contents = list(tmp_path.iterdir()) + output = capsys.readouterr() + assert len(folder_contents) == 1 + assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png already exists' in output.out @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_subreddit_name', ( - 'Python', - 'Mindustry', - 'TrollXChromosomes', - 'all', +@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), ( + ('ljyy27', 4), )) -def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit): - test_subreddit = reddit_instance.subreddit(test_subreddit_name) - RedditDownloader._check_subreddit_status(test_subreddit) +def test_download_submission( + test_submission_id: str, + expected_files_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path): + downloader_mock.reddit_instance = reddit_instance + downloader_mock.download_filter.check_url.return_value = True + downloader_mock.args.folder_scheme = '' + downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) + downloader_mock.download_directory = tmp_path + submission = downloader_mock.reddit_instance.submission(id=test_submission_id) + RedditDownloader._download_submission(downloader_mock, submission) + folder_contents = list(tmp_path.iterdir()) + assert len(folder_contents) == expected_files_len