Refactor out super class RedditConnector
This commit is contained in:
parent
71da1556e5
commit
7016603763
5 changed files with 905 additions and 851 deletions
|
@ -15,13 +15,14 @@ from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
|
||||||
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
|
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
|
||||||
from bdfr.configuration import Configuration
|
from bdfr.configuration import Configuration
|
||||||
from bdfr.downloader import RedditDownloader
|
from bdfr.downloader import RedditDownloader
|
||||||
|
from bdfr.connector import RedditConnector
|
||||||
from bdfr.exceptions import ArchiverError
|
from bdfr.exceptions import ArchiverError
|
||||||
from bdfr.resource import Resource
|
from bdfr.resource import Resource
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Archiver(RedditDownloader):
|
class Archiver(RedditConnector):
|
||||||
def __init__(self, args: Configuration):
|
def __init__(self, args: Configuration):
|
||||||
super(Archiver, self).__init__(args)
|
super(Archiver, self).__init__(args)
|
||||||
|
|
||||||
|
@ -29,9 +30,9 @@ class Archiver(RedditDownloader):
|
||||||
for generator in self.reddit_lists:
|
for generator in self.reddit_lists:
|
||||||
for submission in generator:
|
for submission in generator:
|
||||||
logger.debug(f'Attempting to archive submission {submission.id}')
|
logger.debug(f'Attempting to archive submission {submission.id}')
|
||||||
self._write_entry(submission)
|
self.write_entry(submission)
|
||||||
|
|
||||||
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||||
supplied_submissions = []
|
supplied_submissions = []
|
||||||
for sub_id in self.args.link:
|
for sub_id in self.args.link:
|
||||||
if len(sub_id) == 6:
|
if len(sub_id) == 6:
|
||||||
|
@ -42,10 +43,10 @@ class Archiver(RedditDownloader):
|
||||||
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
|
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
|
||||||
return [supplied_submissions]
|
return [supplied_submissions]
|
||||||
|
|
||||||
def _get_user_data(self) -> list[Iterator]:
|
def get_user_data(self) -> list[Iterator]:
|
||||||
results = super(Archiver, self)._get_user_data()
|
results = super(Archiver, self).get_user_data()
|
||||||
if self.args.user and self.args.all_comments:
|
if self.args.user and self.args.all_comments:
|
||||||
sort = self._determine_sort_function()
|
sort = self.determine_sort_function()
|
||||||
logger.debug(f'Retrieving comments of user {self.args.user}')
|
logger.debug(f'Retrieving comments of user {self.args.user}')
|
||||||
results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit))
|
results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit))
|
||||||
return results
|
return results
|
||||||
|
@ -59,7 +60,7 @@ class Archiver(RedditDownloader):
|
||||||
else:
|
else:
|
||||||
raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}')
|
raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}')
|
||||||
|
|
||||||
def _write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)):
|
def write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)):
|
||||||
archive_entry = self._pull_lever_entry_factory(praw_item)
|
archive_entry = self._pull_lever_entry_factory(praw_item)
|
||||||
if self.args.format == 'json':
|
if self.args.format == 'json':
|
||||||
self._write_entry_json(archive_entry)
|
self._write_entry_json(archive_entry)
|
||||||
|
|
401
bdfr/connector.py
Normal file
401
bdfr/connector.py
Normal file
|
@ -0,0 +1,401 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import configparser
|
||||||
|
import importlib.resources
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import socket
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum, auto
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Iterator
|
||||||
|
|
||||||
|
import appdirs
|
||||||
|
import praw
|
||||||
|
import praw.exceptions
|
||||||
|
import praw.models
|
||||||
|
import prawcore
|
||||||
|
|
||||||
|
from bdfr import exceptions as errors
|
||||||
|
from bdfr.configuration import Configuration
|
||||||
|
from bdfr.download_filter import DownloadFilter
|
||||||
|
from bdfr.file_name_formatter import FileNameFormatter
|
||||||
|
from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
|
||||||
|
from bdfr.site_authenticator import SiteAuthenticator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RedditTypes:
|
||||||
|
class SortType(Enum):
|
||||||
|
CONTROVERSIAL = auto()
|
||||||
|
HOT = auto()
|
||||||
|
NEW = auto()
|
||||||
|
RELEVENCE = auto()
|
||||||
|
RISING = auto()
|
||||||
|
TOP = auto()
|
||||||
|
|
||||||
|
class TimeType(Enum):
|
||||||
|
ALL = 'all'
|
||||||
|
DAY = 'day'
|
||||||
|
HOUR = 'hour'
|
||||||
|
MONTH = 'month'
|
||||||
|
WEEK = 'week'
|
||||||
|
YEAR = 'year'
|
||||||
|
|
||||||
|
|
||||||
|
class RedditConnector(metaclass=ABCMeta):
|
||||||
|
def __init__(self, args: Configuration):
|
||||||
|
self.args = args
|
||||||
|
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR')
|
||||||
|
self.run_time = datetime.now().isoformat()
|
||||||
|
self._setup_internal_objects()
|
||||||
|
|
||||||
|
self.reddit_lists = self.retrieve_reddit_lists()
|
||||||
|
|
||||||
|
def _setup_internal_objects(self):
|
||||||
|
self.determine_directories()
|
||||||
|
self.load_config()
|
||||||
|
self.create_file_logger()
|
||||||
|
|
||||||
|
self.read_config()
|
||||||
|
|
||||||
|
self.download_filter = self.create_download_filter()
|
||||||
|
logger.log(9, 'Created download filter')
|
||||||
|
self.time_filter = self.create_time_filter()
|
||||||
|
logger.log(9, 'Created time filter')
|
||||||
|
self.sort_filter = self.create_sort_filter()
|
||||||
|
logger.log(9, 'Created sort filter')
|
||||||
|
self.file_name_formatter = self.create_file_name_formatter()
|
||||||
|
logger.log(9, 'Create file name formatter')
|
||||||
|
|
||||||
|
self.create_reddit_instance()
|
||||||
|
self.resolve_user_name()
|
||||||
|
|
||||||
|
self.excluded_submission_ids = self.read_excluded_ids()
|
||||||
|
|
||||||
|
self.master_hash_list = {}
|
||||||
|
self.authenticator = self.create_authenticator()
|
||||||
|
logger.log(9, 'Created site authenticator')
|
||||||
|
|
||||||
|
self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit)
|
||||||
|
self.args.skip_subreddit = set([sub.lower() for sub in self.args.skip_subreddit])
|
||||||
|
|
||||||
|
def read_config(self):
|
||||||
|
"""Read any cfg values that need to be processed"""
|
||||||
|
if self.args.max_wait_time is None:
|
||||||
|
if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'):
|
||||||
|
self.cfg_parser.set('DEFAULT', 'max_wait_time', '120')
|
||||||
|
logger.log(9, 'Wrote default download wait time download to config file')
|
||||||
|
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time')
|
||||||
|
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds')
|
||||||
|
if self.args.time_format is None:
|
||||||
|
option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO')
|
||||||
|
if re.match(r'^[ \'\"]*$', option):
|
||||||
|
option = 'ISO'
|
||||||
|
logger.debug(f'Setting datetime format string to {option}')
|
||||||
|
self.args.time_format = option
|
||||||
|
# Update config on disk
|
||||||
|
with open(self.config_location, 'w') as file:
|
||||||
|
self.cfg_parser.write(file)
|
||||||
|
|
||||||
|
def create_reddit_instance(self):
|
||||||
|
if self.args.authenticate:
|
||||||
|
logger.debug('Using authenticated Reddit instance')
|
||||||
|
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
|
||||||
|
logger.log(9, 'Commencing OAuth2 authentication')
|
||||||
|
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
|
||||||
|
scopes = OAuth2Authenticator.split_scopes(scopes)
|
||||||
|
oauth2_authenticator = OAuth2Authenticator(
|
||||||
|
scopes,
|
||||||
|
self.cfg_parser.get('DEFAULT', 'client_id'),
|
||||||
|
self.cfg_parser.get('DEFAULT', 'client_secret'),
|
||||||
|
)
|
||||||
|
token = oauth2_authenticator.retrieve_new_token()
|
||||||
|
self.cfg_parser['DEFAULT']['user_token'] = token
|
||||||
|
with open(self.config_location, 'w') as file:
|
||||||
|
self.cfg_parser.write(file, True)
|
||||||
|
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
|
||||||
|
|
||||||
|
self.authenticated = True
|
||||||
|
self.reddit_instance = praw.Reddit(
|
||||||
|
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
|
||||||
|
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
|
||||||
|
user_agent=socket.gethostname(),
|
||||||
|
token_manager=token_manager,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug('Using unauthenticated Reddit instance')
|
||||||
|
self.authenticated = False
|
||||||
|
self.reddit_instance = praw.Reddit(
|
||||||
|
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
|
||||||
|
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
|
||||||
|
user_agent=socket.gethostname(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
|
||||||
|
master_list = []
|
||||||
|
master_list.extend(self.get_subreddits())
|
||||||
|
logger.log(9, 'Retrieved subreddits')
|
||||||
|
master_list.extend(self.get_multireddits())
|
||||||
|
logger.log(9, 'Retrieved multireddits')
|
||||||
|
master_list.extend(self.get_user_data())
|
||||||
|
logger.log(9, 'Retrieved user data')
|
||||||
|
master_list.extend(self.get_submissions_from_link())
|
||||||
|
logger.log(9, 'Retrieved submissions for given links')
|
||||||
|
return master_list
|
||||||
|
|
||||||
|
def determine_directories(self):
|
||||||
|
self.download_directory = Path(self.args.directory).resolve().expanduser()
|
||||||
|
self.config_directory = Path(self.config_directories.user_config_dir)
|
||||||
|
|
||||||
|
self.download_directory.mkdir(exist_ok=True, parents=True)
|
||||||
|
self.config_directory.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
def load_config(self):
|
||||||
|
self.cfg_parser = configparser.ConfigParser()
|
||||||
|
if self.args.config:
|
||||||
|
if (cfg_path := Path(self.args.config)).exists():
|
||||||
|
self.cfg_parser.read(cfg_path)
|
||||||
|
self.config_location = cfg_path
|
||||||
|
return
|
||||||
|
possible_paths = [
|
||||||
|
Path('./config.cfg'),
|
||||||
|
Path('./default_config.cfg'),
|
||||||
|
Path(self.config_directory, 'config.cfg'),
|
||||||
|
Path(self.config_directory, 'default_config.cfg'),
|
||||||
|
]
|
||||||
|
self.config_location = None
|
||||||
|
for path in possible_paths:
|
||||||
|
if path.resolve().expanduser().exists():
|
||||||
|
self.config_location = path
|
||||||
|
logger.debug(f'Loading configuration from {path}')
|
||||||
|
break
|
||||||
|
if not self.config_location:
|
||||||
|
self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0]
|
||||||
|
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg'))
|
||||||
|
if not self.config_location:
|
||||||
|
raise errors.BulkDownloaderException('Could not find a configuration file to load')
|
||||||
|
self.cfg_parser.read(self.config_location)
|
||||||
|
|
||||||
|
def create_file_logger(self):
|
||||||
|
main_logger = logging.getLogger()
|
||||||
|
if self.args.log is None:
|
||||||
|
log_path = Path(self.config_directory, 'log_output.txt')
|
||||||
|
else:
|
||||||
|
log_path = Path(self.args.log).resolve().expanduser()
|
||||||
|
if not log_path.parent.exists():
|
||||||
|
raise errors.BulkDownloaderException(f'Designated location for logfile does not exist')
|
||||||
|
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3)
|
||||||
|
file_handler = logging.handlers.RotatingFileHandler(
|
||||||
|
log_path,
|
||||||
|
mode='a',
|
||||||
|
backupCount=backup_count,
|
||||||
|
)
|
||||||
|
if log_path.exists():
|
||||||
|
try:
|
||||||
|
file_handler.doRollover()
|
||||||
|
except PermissionError as e:
|
||||||
|
logger.critical(
|
||||||
|
'Cannot rollover logfile, make sure this is the only '
|
||||||
|
'BDFR process or specify alternate logfile location')
|
||||||
|
raise
|
||||||
|
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
file_handler.setLevel(0)
|
||||||
|
|
||||||
|
main_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitise_subreddit_name(subreddit: str) -> str:
|
||||||
|
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$')
|
||||||
|
match = re.match(pattern, subreddit)
|
||||||
|
if not match:
|
||||||
|
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_args_input(entries: list[str]) -> set[str]:
|
||||||
|
all_entries = []
|
||||||
|
split_pattern = re.compile(r'[,;]\s?')
|
||||||
|
for entry in entries:
|
||||||
|
results = re.split(split_pattern, entry)
|
||||||
|
all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results])
|
||||||
|
return set(all_entries)
|
||||||
|
|
||||||
|
def get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
||||||
|
if self.args.subreddit:
|
||||||
|
out = []
|
||||||
|
for reddit in self.split_args_input(self.args.subreddit):
|
||||||
|
try:
|
||||||
|
reddit = self.reddit_instance.subreddit(reddit)
|
||||||
|
try:
|
||||||
|
self.check_subreddit_status(reddit)
|
||||||
|
except errors.BulkDownloaderException as e:
|
||||||
|
logger.error(e)
|
||||||
|
continue
|
||||||
|
if self.args.search:
|
||||||
|
out.append(reddit.search(
|
||||||
|
self.args.search,
|
||||||
|
sort=self.sort_filter.name.lower(),
|
||||||
|
limit=self.args.limit,
|
||||||
|
time_filter=self.time_filter.value,
|
||||||
|
))
|
||||||
|
logger.debug(
|
||||||
|
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
|
||||||
|
else:
|
||||||
|
out.append(self.create_filtered_listing_generator(reddit))
|
||||||
|
logger.debug(f'Added submissions from subreddit {reddit}')
|
||||||
|
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
||||||
|
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def resolve_user_name(self):
|
||||||
|
if self.args.user == 'me':
|
||||||
|
if self.authenticated:
|
||||||
|
self.args.user = self.reddit_instance.user.me().name
|
||||||
|
logger.log(9, f'Resolved user to {self.args.user}')
|
||||||
|
else:
|
||||||
|
self.args.user = None
|
||||||
|
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
|
||||||
|
|
||||||
|
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||||
|
supplied_submissions = []
|
||||||
|
for sub_id in self.args.link:
|
||||||
|
if len(sub_id) == 6:
|
||||||
|
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
|
||||||
|
else:
|
||||||
|
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
|
||||||
|
return [supplied_submissions]
|
||||||
|
|
||||||
|
def determine_sort_function(self) -> Callable:
|
||||||
|
if self.sort_filter is RedditTypes.SortType.NEW:
|
||||||
|
sort_function = praw.models.Subreddit.new
|
||||||
|
elif self.sort_filter is RedditTypes.SortType.RISING:
|
||||||
|
sort_function = praw.models.Subreddit.rising
|
||||||
|
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
|
||||||
|
sort_function = praw.models.Subreddit.controversial
|
||||||
|
elif self.sort_filter is RedditTypes.SortType.TOP:
|
||||||
|
sort_function = praw.models.Subreddit.top
|
||||||
|
else:
|
||||||
|
sort_function = praw.models.Subreddit.hot
|
||||||
|
return sort_function
|
||||||
|
|
||||||
|
def get_multireddits(self) -> list[Iterator]:
|
||||||
|
if self.args.multireddit:
|
||||||
|
out = []
|
||||||
|
for multi in self.split_args_input(self.args.multireddit):
|
||||||
|
try:
|
||||||
|
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
||||||
|
if not multi.subreddits:
|
||||||
|
raise errors.BulkDownloaderException
|
||||||
|
out.append(self.create_filtered_listing_generator(multi))
|
||||||
|
logger.debug(f'Added submissions from multireddit {multi}')
|
||||||
|
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
|
||||||
|
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def create_filtered_listing_generator(self, reddit_source) -> Iterator:
|
||||||
|
sort_function = self.determine_sort_function()
|
||||||
|
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
|
||||||
|
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
|
||||||
|
else:
|
||||||
|
return sort_function(reddit_source, limit=self.args.limit)
|
||||||
|
|
||||||
|
def get_user_data(self) -> list[Iterator]:
|
||||||
|
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
|
||||||
|
if self.args.user:
|
||||||
|
try:
|
||||||
|
self.check_user_existence(self.args.user)
|
||||||
|
except errors.BulkDownloaderException as e:
|
||||||
|
logger.error(e)
|
||||||
|
return []
|
||||||
|
generators = []
|
||||||
|
if self.args.submitted:
|
||||||
|
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
|
||||||
|
generators.append(self.create_filtered_listing_generator(
|
||||||
|
self.reddit_instance.redditor(self.args.user).submissions,
|
||||||
|
))
|
||||||
|
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
||||||
|
logger.warning('Accessing user lists requires authentication')
|
||||||
|
else:
|
||||||
|
if self.args.upvoted:
|
||||||
|
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
|
||||||
|
generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit))
|
||||||
|
if self.args.saved:
|
||||||
|
logger.debug(f'Retrieving saved posts of user {self.args.user}')
|
||||||
|
generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit))
|
||||||
|
return generators
|
||||||
|
else:
|
||||||
|
logger.warning('A user must be supplied to download user data')
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def check_user_existence(self, name: str):
|
||||||
|
user = self.reddit_instance.redditor(name=name)
|
||||||
|
try:
|
||||||
|
if user.id:
|
||||||
|
return
|
||||||
|
except prawcore.exceptions.NotFound:
|
||||||
|
raise errors.BulkDownloaderException(f'Could not find user {name}')
|
||||||
|
except AttributeError:
|
||||||
|
if hasattr(user, 'is_suspended'):
|
||||||
|
raise errors.BulkDownloaderException(f'User {name} is banned')
|
||||||
|
|
||||||
|
def create_file_name_formatter(self) -> FileNameFormatter:
|
||||||
|
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format)
|
||||||
|
|
||||||
|
def create_time_filter(self) -> RedditTypes.TimeType:
|
||||||
|
try:
|
||||||
|
return RedditTypes.TimeType[self.args.time.upper()]
|
||||||
|
except (KeyError, AttributeError):
|
||||||
|
return RedditTypes.TimeType.ALL
|
||||||
|
|
||||||
|
def create_sort_filter(self) -> RedditTypes.SortType:
|
||||||
|
try:
|
||||||
|
return RedditTypes.SortType[self.args.sort.upper()]
|
||||||
|
except (KeyError, AttributeError):
|
||||||
|
return RedditTypes.SortType.HOT
|
||||||
|
|
||||||
|
def create_download_filter(self) -> DownloadFilter:
|
||||||
|
return DownloadFilter(self.args.skip_format, self.args.skip_domain)
|
||||||
|
|
||||||
|
def create_authenticator(self) -> SiteAuthenticator:
|
||||||
|
return SiteAuthenticator(self.cfg_parser)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def download(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_subreddit_status(subreddit: praw.models.Subreddit):
|
||||||
|
if subreddit.display_name == 'all':
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
assert subreddit.id
|
||||||
|
except prawcore.NotFound:
|
||||||
|
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} does not exist or cannot be found')
|
||||||
|
except prawcore.Forbidden:
|
||||||
|
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped')
|
||||||
|
|
||||||
|
def read_excluded_ids(self) -> set[str]:
|
||||||
|
out = []
|
||||||
|
out.extend(self.args.skip_id)
|
||||||
|
for id_file in self.args.skip_id_file:
|
||||||
|
id_file = Path(id_file).resolve().expanduser()
|
||||||
|
if not id_file.exists():
|
||||||
|
logger.warning(f'ID exclusion file at {id_file} does not exist')
|
||||||
|
continue
|
||||||
|
with open(id_file, 'r') as file:
|
||||||
|
for line in file:
|
||||||
|
out.append(line.strip())
|
||||||
|
return set(out)
|
|
@ -1,33 +1,19 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
|
||||||
import configparser
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib.resources
|
|
||||||
import logging
|
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import socket
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum, auto
|
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterator
|
|
||||||
|
|
||||||
import appdirs
|
|
||||||
import praw
|
import praw
|
||||||
import praw.exceptions
|
import praw.exceptions
|
||||||
import praw.models
|
import praw.models
|
||||||
import prawcore
|
|
||||||
|
|
||||||
import bdfr.exceptions as errors
|
from bdfr import exceptions as errors
|
||||||
from bdfr.configuration import Configuration
|
from bdfr.configuration import Configuration
|
||||||
from bdfr.download_filter import DownloadFilter
|
from bdfr.connector import RedditConnector
|
||||||
from bdfr.file_name_formatter import FileNameFormatter
|
|
||||||
from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
|
|
||||||
from bdfr.site_authenticator import SiteAuthenticator
|
|
||||||
from bdfr.site_downloaders.download_factory import DownloadFactory
|
from bdfr.site_downloaders.download_factory import DownloadFactory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -39,350 +25,11 @@ def _calc_hash(existing_file: Path):
|
||||||
return existing_file, file_hash
|
return existing_file, file_hash
|
||||||
|
|
||||||
|
|
||||||
class RedditTypes:
|
class RedditDownloader(RedditConnector):
|
||||||
class SortType(Enum):
|
|
||||||
CONTROVERSIAL = auto()
|
|
||||||
HOT = auto()
|
|
||||||
NEW = auto()
|
|
||||||
RELEVENCE = auto()
|
|
||||||
RISING = auto()
|
|
||||||
TOP = auto()
|
|
||||||
|
|
||||||
class TimeType(Enum):
|
|
||||||
ALL = 'all'
|
|
||||||
DAY = 'day'
|
|
||||||
HOUR = 'hour'
|
|
||||||
MONTH = 'month'
|
|
||||||
WEEK = 'week'
|
|
||||||
YEAR = 'year'
|
|
||||||
|
|
||||||
|
|
||||||
class RedditDownloader:
|
|
||||||
def __init__(self, args: Configuration):
|
def __init__(self, args: Configuration):
|
||||||
self.args = args
|
super(RedditDownloader, self).__init__(args)
|
||||||
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR')
|
|
||||||
self.run_time = datetime.now().isoformat()
|
|
||||||
self._setup_internal_objects()
|
|
||||||
|
|
||||||
self.reddit_lists = self._retrieve_reddit_lists()
|
|
||||||
|
|
||||||
def _setup_internal_objects(self):
|
|
||||||
self._determine_directories()
|
|
||||||
self._load_config()
|
|
||||||
self._create_file_logger()
|
|
||||||
|
|
||||||
self._read_config()
|
|
||||||
|
|
||||||
self.download_filter = self._create_download_filter()
|
|
||||||
logger.log(9, 'Created download filter')
|
|
||||||
self.time_filter = self._create_time_filter()
|
|
||||||
logger.log(9, 'Created time filter')
|
|
||||||
self.sort_filter = self._create_sort_filter()
|
|
||||||
logger.log(9, 'Created sort filter')
|
|
||||||
self.file_name_formatter = self._create_file_name_formatter()
|
|
||||||
logger.log(9, 'Create file name formatter')
|
|
||||||
|
|
||||||
self._create_reddit_instance()
|
|
||||||
self._resolve_user_name()
|
|
||||||
|
|
||||||
self.excluded_submission_ids = self._read_excluded_ids()
|
|
||||||
|
|
||||||
if self.args.search_existing:
|
if self.args.search_existing:
|
||||||
self.master_hash_list = self.scan_existing_files(self.download_directory)
|
self.master_hash_list = self.scan_existing_files(self.download_directory)
|
||||||
else:
|
|
||||||
self.master_hash_list = {}
|
|
||||||
self.authenticator = self._create_authenticator()
|
|
||||||
logger.log(9, 'Created site authenticator')
|
|
||||||
|
|
||||||
self.args.skip_subreddit = self._split_args_input(self.args.skip_subreddit)
|
|
||||||
self.args.skip_subreddit = set([sub.lower() for sub in self.args.skip_subreddit])
|
|
||||||
|
|
||||||
def _read_config(self):
|
|
||||||
"""Read any cfg values that need to be processed"""
|
|
||||||
if self.args.max_wait_time is None:
|
|
||||||
if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'):
|
|
||||||
self.cfg_parser.set('DEFAULT', 'max_wait_time', '120')
|
|
||||||
logger.log(9, 'Wrote default download wait time download to config file')
|
|
||||||
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time')
|
|
||||||
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds')
|
|
||||||
if self.args.time_format is None:
|
|
||||||
option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO')
|
|
||||||
if re.match(r'^[ \'\"]*$', option):
|
|
||||||
option = 'ISO'
|
|
||||||
logger.debug(f'Setting datetime format string to {option}')
|
|
||||||
self.args.time_format = option
|
|
||||||
# Update config on disk
|
|
||||||
with open(self.config_location, 'w') as file:
|
|
||||||
self.cfg_parser.write(file)
|
|
||||||
|
|
||||||
def _create_reddit_instance(self):
|
|
||||||
if self.args.authenticate:
|
|
||||||
logger.debug('Using authenticated Reddit instance')
|
|
||||||
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
|
|
||||||
logger.log(9, 'Commencing OAuth2 authentication')
|
|
||||||
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
|
|
||||||
scopes = OAuth2Authenticator.split_scopes(scopes)
|
|
||||||
oauth2_authenticator = OAuth2Authenticator(
|
|
||||||
scopes,
|
|
||||||
self.cfg_parser.get('DEFAULT', 'client_id'),
|
|
||||||
self.cfg_parser.get('DEFAULT', 'client_secret'),
|
|
||||||
)
|
|
||||||
token = oauth2_authenticator.retrieve_new_token()
|
|
||||||
self.cfg_parser['DEFAULT']['user_token'] = token
|
|
||||||
with open(self.config_location, 'w') as file:
|
|
||||||
self.cfg_parser.write(file, True)
|
|
||||||
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
|
|
||||||
|
|
||||||
self.authenticated = True
|
|
||||||
self.reddit_instance = praw.Reddit(
|
|
||||||
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
|
|
||||||
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
|
|
||||||
user_agent=socket.gethostname(),
|
|
||||||
token_manager=token_manager,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug('Using unauthenticated Reddit instance')
|
|
||||||
self.authenticated = False
|
|
||||||
self.reddit_instance = praw.Reddit(
|
|
||||||
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
|
|
||||||
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
|
|
||||||
user_agent=socket.gethostname(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
|
|
||||||
master_list = []
|
|
||||||
master_list.extend(self._get_subreddits())
|
|
||||||
logger.log(9, 'Retrieved subreddits')
|
|
||||||
master_list.extend(self._get_multireddits())
|
|
||||||
logger.log(9, 'Retrieved multireddits')
|
|
||||||
master_list.extend(self._get_user_data())
|
|
||||||
logger.log(9, 'Retrieved user data')
|
|
||||||
master_list.extend(self._get_submissions_from_link())
|
|
||||||
logger.log(9, 'Retrieved submissions for given links')
|
|
||||||
return master_list
|
|
||||||
|
|
||||||
def _determine_directories(self):
|
|
||||||
self.download_directory = Path(self.args.directory).resolve().expanduser()
|
|
||||||
self.config_directory = Path(self.config_directories.user_config_dir)
|
|
||||||
|
|
||||||
self.download_directory.mkdir(exist_ok=True, parents=True)
|
|
||||||
self.config_directory.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
def _load_config(self):
|
|
||||||
self.cfg_parser = configparser.ConfigParser()
|
|
||||||
if self.args.config:
|
|
||||||
if (cfg_path := Path(self.args.config)).exists():
|
|
||||||
self.cfg_parser.read(cfg_path)
|
|
||||||
self.config_location = cfg_path
|
|
||||||
return
|
|
||||||
possible_paths = [
|
|
||||||
Path('./config.cfg'),
|
|
||||||
Path('./default_config.cfg'),
|
|
||||||
Path(self.config_directory, 'config.cfg'),
|
|
||||||
Path(self.config_directory, 'default_config.cfg'),
|
|
||||||
]
|
|
||||||
self.config_location = None
|
|
||||||
for path in possible_paths:
|
|
||||||
if path.resolve().expanduser().exists():
|
|
||||||
self.config_location = path
|
|
||||||
logger.debug(f'Loading configuration from {path}')
|
|
||||||
break
|
|
||||||
if not self.config_location:
|
|
||||||
self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0]
|
|
||||||
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg'))
|
|
||||||
if not self.config_location:
|
|
||||||
raise errors.BulkDownloaderException('Could not find a configuration file to load')
|
|
||||||
self.cfg_parser.read(self.config_location)
|
|
||||||
|
|
||||||
def _create_file_logger(self):
|
|
||||||
main_logger = logging.getLogger()
|
|
||||||
if self.args.log is None:
|
|
||||||
log_path = Path(self.config_directory, 'log_output.txt')
|
|
||||||
else:
|
|
||||||
log_path = Path(self.args.log).resolve().expanduser()
|
|
||||||
if not log_path.parent.exists():
|
|
||||||
raise errors.BulkDownloaderException(f'Designated location for logfile does not exist')
|
|
||||||
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3)
|
|
||||||
file_handler = logging.handlers.RotatingFileHandler(
|
|
||||||
log_path,
|
|
||||||
mode='a',
|
|
||||||
backupCount=backup_count,
|
|
||||||
)
|
|
||||||
if log_path.exists():
|
|
||||||
try:
|
|
||||||
file_handler.doRollover()
|
|
||||||
except PermissionError as e:
|
|
||||||
logger.critical(
|
|
||||||
'Cannot rollover logfile, make sure this is the only '
|
|
||||||
'BDFR process or specify alternate logfile location')
|
|
||||||
raise
|
|
||||||
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
file_handler.setLevel(0)
|
|
||||||
|
|
||||||
main_logger.addHandler(file_handler)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _sanitise_subreddit_name(subreddit: str) -> str:
|
|
||||||
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$')
|
|
||||||
match = re.match(pattern, subreddit)
|
|
||||||
if not match:
|
|
||||||
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
|
|
||||||
return match.group(1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _split_args_input(entries: list[str]) -> set[str]:
|
|
||||||
all_entries = []
|
|
||||||
split_pattern = re.compile(r'[,;]\s?')
|
|
||||||
for entry in entries:
|
|
||||||
results = re.split(split_pattern, entry)
|
|
||||||
all_entries.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results])
|
|
||||||
return set(all_entries)
|
|
||||||
|
|
||||||
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
|
||||||
if self.args.subreddit:
|
|
||||||
out = []
|
|
||||||
for reddit in self._split_args_input(self.args.subreddit):
|
|
||||||
try:
|
|
||||||
reddit = self.reddit_instance.subreddit(reddit)
|
|
||||||
try:
|
|
||||||
self._check_subreddit_status(reddit)
|
|
||||||
except errors.BulkDownloaderException as e:
|
|
||||||
logger.error(e)
|
|
||||||
continue
|
|
||||||
if self.args.search:
|
|
||||||
out.append(reddit.search(
|
|
||||||
self.args.search,
|
|
||||||
sort=self.sort_filter.name.lower(),
|
|
||||||
limit=self.args.limit,
|
|
||||||
time_filter=self.time_filter.value,
|
|
||||||
))
|
|
||||||
logger.debug(
|
|
||||||
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
|
|
||||||
else:
|
|
||||||
out.append(self._create_filtered_listing_generator(reddit))
|
|
||||||
logger.debug(f'Added submissions from subreddit {reddit}')
|
|
||||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
|
||||||
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _resolve_user_name(self):
|
|
||||||
if self.args.user == 'me':
|
|
||||||
if self.authenticated:
|
|
||||||
self.args.user = self.reddit_instance.user.me().name
|
|
||||||
logger.log(9, f'Resolved user to {self.args.user}')
|
|
||||||
else:
|
|
||||||
self.args.user = None
|
|
||||||
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
|
|
||||||
|
|
||||||
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
|
||||||
supplied_submissions = []
|
|
||||||
for sub_id in self.args.link:
|
|
||||||
if len(sub_id) == 6:
|
|
||||||
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
|
|
||||||
else:
|
|
||||||
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
|
|
||||||
return [supplied_submissions]
|
|
||||||
|
|
||||||
def _determine_sort_function(self) -> Callable:
|
|
||||||
if self.sort_filter is RedditTypes.SortType.NEW:
|
|
||||||
sort_function = praw.models.Subreddit.new
|
|
||||||
elif self.sort_filter is RedditTypes.SortType.RISING:
|
|
||||||
sort_function = praw.models.Subreddit.rising
|
|
||||||
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
|
|
||||||
sort_function = praw.models.Subreddit.controversial
|
|
||||||
elif self.sort_filter is RedditTypes.SortType.TOP:
|
|
||||||
sort_function = praw.models.Subreddit.top
|
|
||||||
else:
|
|
||||||
sort_function = praw.models.Subreddit.hot
|
|
||||||
return sort_function
|
|
||||||
|
|
||||||
def _get_multireddits(self) -> list[Iterator]:
|
|
||||||
if self.args.multireddit:
|
|
||||||
out = []
|
|
||||||
for multi in self._split_args_input(self.args.multireddit):
|
|
||||||
try:
|
|
||||||
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
|
||||||
if not multi.subreddits:
|
|
||||||
raise errors.BulkDownloaderException
|
|
||||||
out.append(self._create_filtered_listing_generator(multi))
|
|
||||||
logger.debug(f'Added submissions from multireddit {multi}')
|
|
||||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
|
|
||||||
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _create_filtered_listing_generator(self, reddit_source) -> Iterator:
|
|
||||||
sort_function = self._determine_sort_function()
|
|
||||||
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
|
|
||||||
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
|
|
||||||
else:
|
|
||||||
return sort_function(reddit_source, limit=self.args.limit)
|
|
||||||
|
|
||||||
def _get_user_data(self) -> list[Iterator]:
|
|
||||||
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
|
|
||||||
if self.args.user:
|
|
||||||
try:
|
|
||||||
self._check_user_existence(self.args.user)
|
|
||||||
except errors.BulkDownloaderException as e:
|
|
||||||
logger.error(e)
|
|
||||||
return []
|
|
||||||
generators = []
|
|
||||||
if self.args.submitted:
|
|
||||||
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
|
|
||||||
generators.append(self._create_filtered_listing_generator(
|
|
||||||
self.reddit_instance.redditor(self.args.user).submissions,
|
|
||||||
))
|
|
||||||
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
|
||||||
logger.warning('Accessing user lists requires authentication')
|
|
||||||
else:
|
|
||||||
if self.args.upvoted:
|
|
||||||
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
|
|
||||||
generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit))
|
|
||||||
if self.args.saved:
|
|
||||||
logger.debug(f'Retrieving saved posts of user {self.args.user}')
|
|
||||||
generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit))
|
|
||||||
return generators
|
|
||||||
else:
|
|
||||||
logger.warning('A user must be supplied to download user data')
|
|
||||||
return []
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _check_user_existence(self, name: str):
|
|
||||||
user = self.reddit_instance.redditor(name=name)
|
|
||||||
try:
|
|
||||||
if user.id:
|
|
||||||
return
|
|
||||||
except prawcore.exceptions.NotFound:
|
|
||||||
raise errors.BulkDownloaderException(f'Could not find user {name}')
|
|
||||||
except AttributeError:
|
|
||||||
if hasattr(user, 'is_suspended'):
|
|
||||||
raise errors.BulkDownloaderException(f'User {name} is banned')
|
|
||||||
|
|
||||||
def _create_file_name_formatter(self) -> FileNameFormatter:
|
|
||||||
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format)
|
|
||||||
|
|
||||||
def _create_time_filter(self) -> RedditTypes.TimeType:
|
|
||||||
try:
|
|
||||||
return RedditTypes.TimeType[self.args.time.upper()]
|
|
||||||
except (KeyError, AttributeError):
|
|
||||||
return RedditTypes.TimeType.ALL
|
|
||||||
|
|
||||||
def _create_sort_filter(self) -> RedditTypes.SortType:
|
|
||||||
try:
|
|
||||||
return RedditTypes.SortType[self.args.sort.upper()]
|
|
||||||
except (KeyError, AttributeError):
|
|
||||||
return RedditTypes.SortType.HOT
|
|
||||||
|
|
||||||
def _create_download_filter(self) -> DownloadFilter:
|
|
||||||
return DownloadFilter(self.args.skip_format, self.args.skip_domain)
|
|
||||||
|
|
||||||
def _create_authenticator(self) -> SiteAuthenticator:
|
|
||||||
return SiteAuthenticator(self.cfg_parser)
|
|
||||||
|
|
||||||
def download(self):
|
def download(self):
|
||||||
for generator in self.reddit_lists:
|
for generator in self.reddit_lists:
|
||||||
|
@ -457,27 +104,3 @@ class RedditDownloader:
|
||||||
|
|
||||||
hash_list = {res[1]: res[0] for res in results}
|
hash_list = {res[1]: res[0] for res in results}
|
||||||
return hash_list
|
return hash_list
|
||||||
|
|
||||||
def _read_excluded_ids(self) -> set[str]:
|
|
||||||
out = []
|
|
||||||
out.extend(self.args.skip_id)
|
|
||||||
for id_file in self.args.skip_id_file:
|
|
||||||
id_file = Path(id_file).resolve().expanduser()
|
|
||||||
if not id_file.exists():
|
|
||||||
logger.warning(f'ID exclusion file at {id_file} does not exist')
|
|
||||||
continue
|
|
||||||
with open(id_file, 'r') as file:
|
|
||||||
for line in file:
|
|
||||||
out.append(line.strip())
|
|
||||||
return set(out)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _check_subreddit_status(subreddit: praw.models.Subreddit):
|
|
||||||
if subreddit.display_name == 'all':
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
assert subreddit.id
|
|
||||||
except prawcore.NotFound:
|
|
||||||
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} does not exist or cannot be found')
|
|
||||||
except prawcore.Forbidden:
|
|
||||||
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped')
|
|
||||||
|
|
401
tests/test_connector.py
Normal file
401
tests/test_connector.py
Normal file
|
@ -0,0 +1,401 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import praw
|
||||||
|
import praw.models
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from bdfr.configuration import Configuration
|
||||||
|
from bdfr.connector import RedditConnector, RedditTypes
|
||||||
|
from bdfr.download_filter import DownloadFilter
|
||||||
|
from bdfr.exceptions import BulkDownloaderException
|
||||||
|
from bdfr.file_name_formatter import FileNameFormatter
|
||||||
|
from bdfr.site_authenticator import SiteAuthenticator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def args() -> Configuration:
|
||||||
|
args = Configuration()
|
||||||
|
args.time_format = 'ISO'
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def downloader_mock(args: Configuration):
|
||||||
|
downloader_mock = MagicMock()
|
||||||
|
downloader_mock.args = args
|
||||||
|
downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name
|
||||||
|
downloader_mock.split_args_input = RedditConnector.split_args_input
|
||||||
|
downloader_mock.master_hash_list = {}
|
||||||
|
return downloader_mock
|
||||||
|
|
||||||
|
|
||||||
|
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
|
||||||
|
results = [sub for res in results for sub in res]
|
||||||
|
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||||
|
if result_limit is not None:
|
||||||
|
assert len(results) == result_limit
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.directory = tmp_path / 'test'
|
||||||
|
downloader_mock.config_directories.user_config_dir = tmp_path
|
||||||
|
RedditConnector.determine_directories(downloader_mock)
|
||||||
|
assert Path(tmp_path / 'test').exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
|
||||||
|
([], []),
|
||||||
|
(['.test'], ['test.com'],),
|
||||||
|
))
|
||||||
|
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.skip_format = skip_extensions
|
||||||
|
downloader_mock.args.skip_domain = skip_domains
|
||||||
|
result = RedditConnector.create_download_filter(downloader_mock)
|
||||||
|
|
||||||
|
assert isinstance(result, DownloadFilter)
|
||||||
|
assert result.excluded_domains == skip_domains
|
||||||
|
assert result.excluded_extensions == skip_extensions
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_time', 'expected'), (
|
||||||
|
('all', 'all'),
|
||||||
|
('hour', 'hour'),
|
||||||
|
('day', 'day'),
|
||||||
|
('week', 'week'),
|
||||||
|
('random', 'all'),
|
||||||
|
('', 'all'),
|
||||||
|
))
|
||||||
|
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.time = test_time
|
||||||
|
result = RedditConnector.create_time_filter(downloader_mock)
|
||||||
|
|
||||||
|
assert isinstance(result, RedditTypes.TimeType)
|
||||||
|
assert result.name.lower() == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_sort', 'expected'), (
|
||||||
|
('', 'hot'),
|
||||||
|
('hot', 'hot'),
|
||||||
|
('controversial', 'controversial'),
|
||||||
|
('new', 'new'),
|
||||||
|
))
|
||||||
|
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.sort = test_sort
|
||||||
|
result = RedditConnector.create_sort_filter(downloader_mock)
|
||||||
|
|
||||||
|
assert isinstance(result, RedditTypes.SortType)
|
||||||
|
assert result.name.lower() == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
|
||||||
|
('{POSTID}', '{SUBREDDIT}'),
|
||||||
|
('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
|
||||||
|
('{POSTID}', 'test'),
|
||||||
|
('{POSTID}', ''),
|
||||||
|
('{POSTID}', '{SUBREDDIT}/{REDDITOR}'),
|
||||||
|
))
|
||||||
|
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.file_scheme = test_file_scheme
|
||||||
|
downloader_mock.args.folder_scheme = test_folder_scheme
|
||||||
|
result = RedditConnector.create_file_name_formatter(downloader_mock)
|
||||||
|
|
||||||
|
assert isinstance(result, FileNameFormatter)
|
||||||
|
assert result.file_format_string == test_file_scheme
|
||||||
|
assert result.directory_format_string == test_folder_scheme.split('/')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
|
||||||
|
('', ''),
|
||||||
|
('', '{SUBREDDIT}'),
|
||||||
|
('test', '{SUBREDDIT}'),
|
||||||
|
))
|
||||||
|
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.file_scheme = test_file_scheme
|
||||||
|
downloader_mock.args.folder_scheme = test_folder_scheme
|
||||||
|
with pytest.raises(BulkDownloaderException):
|
||||||
|
RedditConnector.create_file_name_formatter(downloader_mock)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_authenticator(downloader_mock: MagicMock):
|
||||||
|
result = RedditConnector.create_authenticator(downloader_mock)
|
||||||
|
assert isinstance(result, SiteAuthenticator)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize('test_submission_ids', (
|
||||||
|
('lvpf4l',),
|
||||||
|
('lvpf4l', 'lvqnsn'),
|
||||||
|
('lvpf4l', 'lvqnsn', 'lvl9kd'),
|
||||||
|
))
|
||||||
|
def test_get_submissions_from_link(
|
||||||
|
test_submission_ids: list[str],
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.link = test_submission_ids
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
results = RedditConnector.get_submissions_from_link(downloader_mock)
|
||||||
|
assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
|
||||||
|
assert len(results[0]) == len(test_submission_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
|
||||||
|
(('Futurology',), 10, 'hot', 'all', 10),
|
||||||
|
(('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
|
||||||
|
(('Futurology',), 20, 'hot', 'all', 20),
|
||||||
|
(('Futurology', 'Python'), 10, 'hot', 'all', 20),
|
||||||
|
(('Futurology',), 100, 'hot', 'all', 100),
|
||||||
|
(('Futurology',), 0, 'hot', 'all', 0),
|
||||||
|
(('Futurology',), 10, 'top', 'all', 10),
|
||||||
|
(('Futurology',), 10, 'top', 'week', 10),
|
||||||
|
(('Futurology',), 10, 'hot', 'week', 10),
|
||||||
|
))
|
||||||
|
def test_get_subreddit_normal(
|
||||||
|
test_subreddits: list[str],
|
||||||
|
limit: int,
|
||||||
|
sort_type: str,
|
||||||
|
time_filter: str,
|
||||||
|
max_expected_len: int,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
):
|
||||||
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock.args.sort = sort_type
|
||||||
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
|
||||||
|
results = RedditConnector.get_subreddits(downloader_mock)
|
||||||
|
test_subreddits = downloader_mock._split_args_input(test_subreddits)
|
||||||
|
results = [sub for res1 in results for sub in res1]
|
||||||
|
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
|
||||||
|
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||||
|
assert len(results) <= max_expected_len
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
|
||||||
|
(('Python',), 'scraper', 10, 'all', 10),
|
||||||
|
(('Python',), '', 10, 'all', 10),
|
||||||
|
(('Python',), 'djsdsgewef', 10, 'all', 0),
|
||||||
|
(('Python',), 'scraper', 10, 'year', 10),
|
||||||
|
(('Python',), 'scraper', 10, 'hour', 1),
|
||||||
|
))
|
||||||
|
def test_get_subreddit_search(
|
||||||
|
test_subreddits: list[str],
|
||||||
|
search_term: str,
|
||||||
|
time_filter: str,
|
||||||
|
limit: int,
|
||||||
|
max_expected_len: int,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
):
|
||||||
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock.args.search = search_term
|
||||||
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
downloader_mock.args.time = time_filter
|
||||||
|
downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
|
||||||
|
results = RedditConnector.get_subreddits(downloader_mock)
|
||||||
|
results = [sub for res in results for sub in res]
|
||||||
|
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||||
|
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||||
|
assert len(results) <= max_expected_len
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
|
||||||
|
('helen_darten', ('cuteanimalpics',), 10),
|
||||||
|
('korfor', ('chess',), 100),
|
||||||
|
))
|
||||||
|
# Good sources at https://www.reddit.com/r/multihub/
|
||||||
|
def test_get_multireddits_public(
|
||||||
|
test_user: str,
|
||||||
|
test_multireddits: list[str],
|
||||||
|
limit: int,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
):
|
||||||
|
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock.args.multireddit = test_multireddits
|
||||||
|
downloader_mock.args.user = test_user
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.create_filtered_listing_generator.return_value = \
|
||||||
|
RedditConnector.create_filtered_listing_generator(
|
||||||
|
downloader_mock,
|
||||||
|
reddit_instance.multireddit(test_user, test_multireddits[0]),
|
||||||
|
)
|
||||||
|
results = RedditConnector.get_multireddits(downloader_mock)
|
||||||
|
results = [sub for res in results for sub in res]
|
||||||
|
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||||
|
assert len(results) == limit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_user', 'limit'), (
|
||||||
|
('danigirl3694', 10),
|
||||||
|
('danigirl3694', 50),
|
||||||
|
('CapitanHam', None),
|
||||||
|
))
|
||||||
|
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
downloader_mock.args.submitted = True
|
||||||
|
downloader_mock.args.user = test_user
|
||||||
|
downloader_mock.authenticated = False
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.create_filtered_listing_generator.return_value = \
|
||||||
|
RedditConnector.create_filtered_listing_generator(
|
||||||
|
downloader_mock,
|
||||||
|
reddit_instance.redditor(test_user).submissions,
|
||||||
|
)
|
||||||
|
results = RedditConnector.get_user_data(downloader_mock)
|
||||||
|
results = assert_all_results_are_submissions(limit, results)
|
||||||
|
assert all([res.author.name == test_user for res in results])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.authenticated
|
||||||
|
@pytest.mark.parametrize('test_flag', (
|
||||||
|
'upvoted',
|
||||||
|
'saved',
|
||||||
|
))
|
||||||
|
def test_get_user_authenticated_lists(
|
||||||
|
test_flag: str,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
authenticated_reddit_instance: praw.Reddit,
|
||||||
|
):
|
||||||
|
downloader_mock.args.__dict__[test_flag] = True
|
||||||
|
downloader_mock.reddit_instance = authenticated_reddit_instance
|
||||||
|
downloader_mock.args.user = 'me'
|
||||||
|
downloader_mock.args.limit = 10
|
||||||
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
RedditConnector.resolve_user_name(downloader_mock)
|
||||||
|
results = RedditConnector.get_user_data(downloader_mock)
|
||||||
|
assert_all_results_are_submissions(10, results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_name', 'expected'), (
|
||||||
|
('Mindustry', 'Mindustry'),
|
||||||
|
('Futurology', 'Futurology'),
|
||||||
|
('r/Mindustry', 'Mindustry'),
|
||||||
|
('TrollXChromosomes', 'TrollXChromosomes'),
|
||||||
|
('r/TrollXChromosomes', 'TrollXChromosomes'),
|
||||||
|
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
|
||||||
|
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
|
||||||
|
('https://www.reddit.com/r/Futurology/', 'Futurology'),
|
||||||
|
('https://www.reddit.com/r/Futurology', 'Futurology'),
|
||||||
|
))
|
||||||
|
def test_sanitise_subreddit_name(test_name: str, expected: str):
|
||||||
|
result = RedditConnector.sanitise_subreddit_name(test_name)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
|
||||||
|
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||||
|
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||||
|
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||||
|
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||||
|
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'})
|
||||||
|
))
|
||||||
|
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
|
||||||
|
results = RedditConnector.split_args_input(test_subreddit_entries)
|
||||||
|
assert results == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
|
||||||
|
test_file = tmp_path / 'test.txt'
|
||||||
|
test_file.write_text('aaaaaa\nbbbbbb')
|
||||||
|
downloader_mock.args.skip_id_file = [test_file]
|
||||||
|
results = RedditConnector.read_excluded_ids(downloader_mock)
|
||||||
|
assert results == {'aaaaaa', 'bbbbbb'}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize('test_redditor_name', (
|
||||||
|
'Paracortex',
|
||||||
|
'crowdstrike',
|
||||||
|
'HannibalGoddamnit',
|
||||||
|
))
|
||||||
|
def test_check_user_existence_good(
|
||||||
|
test_redditor_name: str,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
):
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize('test_redditor_name', (
|
||||||
|
'lhnhfkuhwreolo',
|
||||||
|
'adlkfmnhglojh',
|
||||||
|
))
|
||||||
|
def test_check_user_existence_nonexistent(
|
||||||
|
test_redditor_name: str,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
):
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
with pytest.raises(BulkDownloaderException, match='Could not find'):
|
||||||
|
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize('test_redditor_name', (
|
||||||
|
'Bree-Boo',
|
||||||
|
))
|
||||||
|
def test_check_user_existence_banned(
|
||||||
|
test_redditor_name: str,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
):
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
with pytest.raises(BulkDownloaderException, match='is banned'):
|
||||||
|
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), (
|
||||||
|
('donaldtrump', 'cannot be found'),
|
||||||
|
('submitters', 'private and cannot be scraped')
|
||||||
|
))
|
||||||
|
def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit):
|
||||||
|
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
|
||||||
|
with pytest.raises(BulkDownloaderException, match=expected_message):
|
||||||
|
RedditConnector.check_subreddit_status(test_subreddit)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize('test_subreddit_name', (
|
||||||
|
'Python',
|
||||||
|
'Mindustry',
|
||||||
|
'TrollXChromosomes',
|
||||||
|
'all',
|
||||||
|
))
|
||||||
|
def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit):
|
||||||
|
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
|
||||||
|
RedditConnector.check_subreddit_status(test_subreddit)
|
|
@ -3,20 +3,15 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import praw
|
|
||||||
import praw.models
|
import praw.models
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from bdfr.__main__ import setup_logging
|
from bdfr.__main__ import setup_logging
|
||||||
from bdfr.configuration import Configuration
|
from bdfr.configuration import Configuration
|
||||||
from bdfr.download_filter import DownloadFilter
|
from bdfr.connector import RedditConnector
|
||||||
from bdfr.downloader import RedditDownloader, RedditTypes
|
from bdfr.downloader import RedditDownloader
|
||||||
from bdfr.exceptions import BulkDownloaderException
|
|
||||||
from bdfr.file_name_formatter import FileNameFormatter
|
|
||||||
from bdfr.site_authenticator import SiteAuthenticator
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
@ -30,411 +25,12 @@ def args() -> Configuration:
|
||||||
def downloader_mock(args: Configuration):
|
def downloader_mock(args: Configuration):
|
||||||
downloader_mock = MagicMock()
|
downloader_mock = MagicMock()
|
||||||
downloader_mock.args = args
|
downloader_mock.args = args
|
||||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
downloader_mock._sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name
|
||||||
downloader_mock._split_args_input = RedditDownloader._split_args_input
|
downloader_mock._split_args_input = RedditConnector.split_args_input
|
||||||
downloader_mock.master_hash_list = {}
|
downloader_mock.master_hash_list = {}
|
||||||
return downloader_mock
|
return downloader_mock
|
||||||
|
|
||||||
|
|
||||||
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
|
|
||||||
results = [sub for res in results for sub in res]
|
|
||||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
|
||||||
if result_limit is not None:
|
|
||||||
assert len(results) == result_limit
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
|
|
||||||
downloader_mock.args.directory = tmp_path / 'test'
|
|
||||||
downloader_mock.config_directories.user_config_dir = tmp_path
|
|
||||||
RedditDownloader._determine_directories(downloader_mock)
|
|
||||||
assert Path(tmp_path / 'test').exists()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
|
|
||||||
([], []),
|
|
||||||
(['.test'], ['test.com'],),
|
|
||||||
))
|
|
||||||
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
|
|
||||||
downloader_mock.args.skip_format = skip_extensions
|
|
||||||
downloader_mock.args.skip_domain = skip_domains
|
|
||||||
result = RedditDownloader._create_download_filter(downloader_mock)
|
|
||||||
|
|
||||||
assert isinstance(result, DownloadFilter)
|
|
||||||
assert result.excluded_domains == skip_domains
|
|
||||||
assert result.excluded_extensions == skip_extensions
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('test_time', 'expected'), (
|
|
||||||
('all', 'all'),
|
|
||||||
('hour', 'hour'),
|
|
||||||
('day', 'day'),
|
|
||||||
('week', 'week'),
|
|
||||||
('random', 'all'),
|
|
||||||
('', 'all'),
|
|
||||||
))
|
|
||||||
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
|
|
||||||
downloader_mock.args.time = test_time
|
|
||||||
result = RedditDownloader._create_time_filter(downloader_mock)
|
|
||||||
|
|
||||||
assert isinstance(result, RedditTypes.TimeType)
|
|
||||||
assert result.name.lower() == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('test_sort', 'expected'), (
|
|
||||||
('', 'hot'),
|
|
||||||
('hot', 'hot'),
|
|
||||||
('controversial', 'controversial'),
|
|
||||||
('new', 'new'),
|
|
||||||
))
|
|
||||||
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
|
|
||||||
downloader_mock.args.sort = test_sort
|
|
||||||
result = RedditDownloader._create_sort_filter(downloader_mock)
|
|
||||||
|
|
||||||
assert isinstance(result, RedditTypes.SortType)
|
|
||||||
assert result.name.lower() == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
|
|
||||||
('{POSTID}', '{SUBREDDIT}'),
|
|
||||||
('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
|
|
||||||
('{POSTID}', 'test'),
|
|
||||||
('{POSTID}', ''),
|
|
||||||
('{POSTID}', '{SUBREDDIT}/{REDDITOR}'),
|
|
||||||
))
|
|
||||||
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
|
|
||||||
downloader_mock.args.file_scheme = test_file_scheme
|
|
||||||
downloader_mock.args.folder_scheme = test_folder_scheme
|
|
||||||
result = RedditDownloader._create_file_name_formatter(downloader_mock)
|
|
||||||
|
|
||||||
assert isinstance(result, FileNameFormatter)
|
|
||||||
assert result.file_format_string == test_file_scheme
|
|
||||||
assert result.directory_format_string == test_folder_scheme.split('/')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
|
|
||||||
('', ''),
|
|
||||||
('', '{SUBREDDIT}'),
|
|
||||||
('test', '{SUBREDDIT}'),
|
|
||||||
))
|
|
||||||
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
|
|
||||||
downloader_mock.args.file_scheme = test_file_scheme
|
|
||||||
downloader_mock.args.folder_scheme = test_folder_scheme
|
|
||||||
with pytest.raises(BulkDownloaderException):
|
|
||||||
RedditDownloader._create_file_name_formatter(downloader_mock)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_authenticator(downloader_mock: MagicMock):
|
|
||||||
result = RedditDownloader._create_authenticator(downloader_mock)
|
|
||||||
assert isinstance(result, SiteAuthenticator)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize('test_submission_ids', (
|
|
||||||
('lvpf4l',),
|
|
||||||
('lvpf4l', 'lvqnsn'),
|
|
||||||
('lvpf4l', 'lvqnsn', 'lvl9kd'),
|
|
||||||
))
|
|
||||||
def test_get_submissions_from_link(
|
|
||||||
test_submission_ids: list[str],
|
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
downloader_mock: MagicMock):
|
|
||||||
downloader_mock.args.link = test_submission_ids
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
results = RedditDownloader._get_submissions_from_link(downloader_mock)
|
|
||||||
assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
|
|
||||||
assert len(results[0]) == len(test_submission_ids)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
|
|
||||||
(('Futurology',), 10, 'hot', 'all', 10),
|
|
||||||
(('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
|
|
||||||
(('Futurology',), 20, 'hot', 'all', 20),
|
|
||||||
(('Futurology', 'Python'), 10, 'hot', 'all', 20),
|
|
||||||
(('Futurology',), 100, 'hot', 'all', 100),
|
|
||||||
(('Futurology',), 0, 'hot', 'all', 0),
|
|
||||||
(('Futurology',), 10, 'top', 'all', 10),
|
|
||||||
(('Futurology',), 10, 'top', 'week', 10),
|
|
||||||
(('Futurology',), 10, 'hot', 'week', 10),
|
|
||||||
))
|
|
||||||
def test_get_subreddit_normal(
|
|
||||||
test_subreddits: list[str],
|
|
||||||
limit: int,
|
|
||||||
sort_type: str,
|
|
||||||
time_filter: str,
|
|
||||||
max_expected_len: int,
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
):
|
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
|
||||||
downloader_mock.args.limit = limit
|
|
||||||
downloader_mock.args.sort = sort_type
|
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock.sort_filter = RedditDownloader._create_sort_filter(downloader_mock)
|
|
||||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
|
||||||
test_subreddits = downloader_mock._split_args_input(test_subreddits)
|
|
||||||
results = [sub for res1 in results for sub in res1]
|
|
||||||
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
|
|
||||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
|
||||||
assert len(results) <= max_expected_len
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
|
|
||||||
(('Python',), 'scraper', 10, 'all', 10),
|
|
||||||
(('Python',), '', 10, 'all', 10),
|
|
||||||
(('Python',), 'djsdsgewef', 10, 'all', 0),
|
|
||||||
(('Python',), 'scraper', 10, 'year', 10),
|
|
||||||
(('Python',), 'scraper', 10, 'hour', 1),
|
|
||||||
))
|
|
||||||
def test_get_subreddit_search(
|
|
||||||
test_subreddits: list[str],
|
|
||||||
search_term: str,
|
|
||||||
time_filter: str,
|
|
||||||
limit: int,
|
|
||||||
max_expected_len: int,
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
):
|
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
|
||||||
downloader_mock.args.limit = limit
|
|
||||||
downloader_mock.args.search = search_term
|
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
|
||||||
downloader_mock.args.time = time_filter
|
|
||||||
downloader_mock.time_filter = RedditDownloader._create_time_filter(downloader_mock)
|
|
||||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
|
||||||
results = [sub for res in results for sub in res]
|
|
||||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
|
||||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
|
||||||
assert len(results) <= max_expected_len
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
|
|
||||||
('helen_darten', ('cuteanimalpics',), 10),
|
|
||||||
('korfor', ('chess',), 100),
|
|
||||||
))
|
|
||||||
# Good sources at https://www.reddit.com/r/multihub/
|
|
||||||
def test_get_multireddits_public(
|
|
||||||
test_user: str,
|
|
||||||
test_multireddits: list[str],
|
|
||||||
limit: int,
|
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
):
|
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
|
||||||
downloader_mock.args.limit = limit
|
|
||||||
downloader_mock.args.multireddit = test_multireddits
|
|
||||||
downloader_mock.args.user = test_user
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock._create_filtered_listing_generator.return_value = \
|
|
||||||
RedditDownloader._create_filtered_listing_generator(
|
|
||||||
downloader_mock,
|
|
||||||
reddit_instance.multireddit(test_user, test_multireddits[0]),
|
|
||||||
)
|
|
||||||
results = RedditDownloader._get_multireddits(downloader_mock)
|
|
||||||
results = [sub for res in results for sub in res]
|
|
||||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
|
||||||
assert len(results) == limit
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize(('test_user', 'limit'), (
|
|
||||||
('danigirl3694', 10),
|
|
||||||
('danigirl3694', 50),
|
|
||||||
('CapitanHam', None),
|
|
||||||
))
|
|
||||||
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
|
||||||
downloader_mock.args.limit = limit
|
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
|
||||||
downloader_mock.args.submitted = True
|
|
||||||
downloader_mock.args.user = test_user
|
|
||||||
downloader_mock.authenticated = False
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock._create_filtered_listing_generator.return_value = \
|
|
||||||
RedditDownloader._create_filtered_listing_generator(
|
|
||||||
downloader_mock,
|
|
||||||
reddit_instance.redditor(test_user).submissions,
|
|
||||||
)
|
|
||||||
results = RedditDownloader._get_user_data(downloader_mock)
|
|
||||||
results = assert_all_results_are_submissions(limit, results)
|
|
||||||
assert all([res.author.name == test_user for res in results])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.authenticated
|
|
||||||
@pytest.mark.parametrize('test_flag', (
|
|
||||||
'upvoted',
|
|
||||||
'saved',
|
|
||||||
))
|
|
||||||
def test_get_user_authenticated_lists(
|
|
||||||
test_flag: str,
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
authenticated_reddit_instance: praw.Reddit,
|
|
||||||
):
|
|
||||||
downloader_mock.args.__dict__[test_flag] = True
|
|
||||||
downloader_mock.reddit_instance = authenticated_reddit_instance
|
|
||||||
downloader_mock.args.user = 'me'
|
|
||||||
downloader_mock.args.limit = 10
|
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
|
||||||
RedditDownloader._resolve_user_name(downloader_mock)
|
|
||||||
results = RedditDownloader._get_user_data(downloader_mock)
|
|
||||||
assert_all_results_are_submissions(10, results)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), (
|
|
||||||
('ljyy27', 4),
|
|
||||||
))
|
|
||||||
def test_download_submission(
|
|
||||||
test_submission_id: str,
|
|
||||||
expected_files_len: int,
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
tmp_path: Path):
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock.download_filter.check_url.return_value = True
|
|
||||||
downloader_mock.args.folder_scheme = ''
|
|
||||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
|
||||||
downloader_mock.download_directory = tmp_path
|
|
||||||
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
|
||||||
RedditDownloader._download_submission(downloader_mock, submission)
|
|
||||||
folder_contents = list(tmp_path.iterdir())
|
|
||||||
assert len(folder_contents) == expected_files_len
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
def test_download_submission_file_exists(
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
tmp_path: Path,
|
|
||||||
capsys: pytest.CaptureFixture
|
|
||||||
):
|
|
||||||
setup_logging(3)
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock.download_filter.check_url.return_value = True
|
|
||||||
downloader_mock.args.folder_scheme = ''
|
|
||||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
|
||||||
downloader_mock.download_directory = tmp_path
|
|
||||||
submission = downloader_mock.reddit_instance.submission(id='m1hqw6')
|
|
||||||
Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch()
|
|
||||||
RedditDownloader._download_submission(downloader_mock, submission)
|
|
||||||
folder_contents = list(tmp_path.iterdir())
|
|
||||||
output = capsys.readouterr()
|
|
||||||
assert len(folder_contents) == 1
|
|
||||||
assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png already exists' in output.out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize(('test_submission_id', 'test_hash'), (
|
|
||||||
('m1hqw6', 'a912af8905ae468e0121e9940f797ad7'),
|
|
||||||
))
|
|
||||||
def test_download_submission_hash_exists(
|
|
||||||
test_submission_id: str,
|
|
||||||
test_hash: str,
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
tmp_path: Path,
|
|
||||||
capsys: pytest.CaptureFixture
|
|
||||||
):
|
|
||||||
setup_logging(3)
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock.download_filter.check_url.return_value = True
|
|
||||||
downloader_mock.args.folder_scheme = ''
|
|
||||||
downloader_mock.args.no_dupes = True
|
|
||||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
|
||||||
downloader_mock.download_directory = tmp_path
|
|
||||||
downloader_mock.master_hash_list = {test_hash: None}
|
|
||||||
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
|
||||||
RedditDownloader._download_submission(downloader_mock, submission)
|
|
||||||
folder_contents = list(tmp_path.iterdir())
|
|
||||||
output = capsys.readouterr()
|
|
||||||
assert len(folder_contents) == 0
|
|
||||||
assert re.search(r'Resource hash .*? downloaded elsewhere', output.out)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('test_name', 'expected'), (
|
|
||||||
('Mindustry', 'Mindustry'),
|
|
||||||
('Futurology', 'Futurology'),
|
|
||||||
('r/Mindustry', 'Mindustry'),
|
|
||||||
('TrollXChromosomes', 'TrollXChromosomes'),
|
|
||||||
('r/TrollXChromosomes', 'TrollXChromosomes'),
|
|
||||||
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
|
|
||||||
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
|
|
||||||
('https://www.reddit.com/r/Futurology/', 'Futurology'),
|
|
||||||
('https://www.reddit.com/r/Futurology', 'Futurology'),
|
|
||||||
))
|
|
||||||
def test_sanitise_subreddit_name(test_name: str, expected: str):
|
|
||||||
result = RedditDownloader._sanitise_subreddit_name(test_name)
|
|
||||||
assert result == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_search_existing_files():
|
|
||||||
results = RedditDownloader.scan_existing_files(Path('.'))
|
|
||||||
assert len(results.keys()) >= 40
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
|
|
||||||
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
|
|
||||||
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
|
|
||||||
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
|
|
||||||
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
|
|
||||||
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'})
|
|
||||||
))
|
|
||||||
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
|
|
||||||
results = RedditDownloader._split_args_input(test_subreddit_entries)
|
|
||||||
assert results == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize('test_submission_id', (
|
|
||||||
'm1hqw6',
|
|
||||||
))
|
|
||||||
def test_mark_hard_link(
|
|
||||||
test_submission_id: str,
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
tmp_path: Path,
|
|
||||||
reddit_instance: praw.Reddit
|
|
||||||
):
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock.args.make_hard_links = True
|
|
||||||
downloader_mock.download_directory = tmp_path
|
|
||||||
downloader_mock.args.folder_scheme = ''
|
|
||||||
downloader_mock.args.file_scheme = '{POSTID}'
|
|
||||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
|
||||||
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
|
||||||
original = Path(tmp_path, f'{test_submission_id}.png')
|
|
||||||
|
|
||||||
RedditDownloader._download_submission(downloader_mock, submission)
|
|
||||||
assert original.exists()
|
|
||||||
|
|
||||||
downloader_mock.args.file_scheme = 'test2_{POSTID}'
|
|
||||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
|
||||||
RedditDownloader._download_submission(downloader_mock, submission)
|
|
||||||
test_file_1_stats = original.stat()
|
|
||||||
test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino
|
|
||||||
|
|
||||||
assert test_file_1_stats.st_nlink == 2
|
|
||||||
assert test_file_1_stats.st_ino == test_file_2_inode
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), (
|
@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), (
|
||||||
(('aaaaaa',), (), 1),
|
(('aaaaaa',), (), 1),
|
||||||
(('aaaaaa',), ('aaaaaa',), 0),
|
(('aaaaaa',), ('aaaaaa',), 0),
|
||||||
|
@ -453,81 +49,113 @@ def test_excluded_ids(test_ids: tuple[str], test_excluded: tuple[str], expected_
|
||||||
assert downloader_mock._download_submission.call_count == expected_len
|
assert downloader_mock._download_submission.call_count == expected_len
|
||||||
|
|
||||||
|
|
||||||
def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
|
|
||||||
test_file = tmp_path / 'test.txt'
|
|
||||||
test_file.write_text('aaaaaa\nbbbbbb')
|
|
||||||
downloader_mock.args.skip_id_file = [test_file]
|
|
||||||
results = RedditDownloader._read_excluded_ids(downloader_mock)
|
|
||||||
assert results == {'aaaaaa', 'bbbbbb'}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.parametrize('test_redditor_name', (
|
@pytest.mark.parametrize('test_submission_id', (
|
||||||
'Paracortex',
|
'm1hqw6',
|
||||||
'crowdstrike',
|
|
||||||
'HannibalGoddamnit',
|
|
||||||
))
|
))
|
||||||
def test_check_user_existence_good(
|
def test_mark_hard_link(
|
||||||
test_redditor_name: str,
|
test_submission_id: str,
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
downloader_mock: MagicMock,
|
|
||||||
):
|
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
RedditDownloader._check_user_existence(downloader_mock, test_redditor_name)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
|
||||||
@pytest.mark.reddit
|
|
||||||
@pytest.mark.parametrize('test_redditor_name', (
|
|
||||||
'lhnhfkuhwreolo',
|
|
||||||
'adlkfmnhglojh',
|
|
||||||
))
|
|
||||||
def test_check_user_existence_nonexistent(
|
|
||||||
test_redditor_name: str,
|
|
||||||
reddit_instance: praw.Reddit,
|
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
reddit_instance: praw.Reddit
|
||||||
):
|
):
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
with pytest.raises(BulkDownloaderException, match='Could not find'):
|
downloader_mock.args.make_hard_links = True
|
||||||
RedditDownloader._check_user_existence(downloader_mock, test_redditor_name)
|
downloader_mock.download_directory = tmp_path
|
||||||
|
downloader_mock.args.folder_scheme = ''
|
||||||
|
downloader_mock.args.file_scheme = '{POSTID}'
|
||||||
|
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
|
||||||
|
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
||||||
|
original = Path(tmp_path, f'{test_submission_id}.png')
|
||||||
|
|
||||||
|
RedditDownloader._download_submission(downloader_mock, submission)
|
||||||
|
assert original.exists()
|
||||||
|
|
||||||
|
downloader_mock.args.file_scheme = 'test2_{POSTID}'
|
||||||
|
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
|
||||||
|
RedditDownloader._download_submission(downloader_mock, submission)
|
||||||
|
test_file_1_stats = original.stat()
|
||||||
|
test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino
|
||||||
|
|
||||||
|
assert test_file_1_stats.st_nlink == 2
|
||||||
|
assert test_file_1_stats.st_ino == test_file_2_inode
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_existing_files():
|
||||||
|
results = RedditDownloader.scan_existing_files(Path('.'))
|
||||||
|
assert len(results.keys()) >= 40
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.parametrize('test_redditor_name', (
|
@pytest.mark.parametrize(('test_submission_id', 'test_hash'), (
|
||||||
'Bree-Boo',
|
('m1hqw6', 'a912af8905ae468e0121e9940f797ad7'),
|
||||||
))
|
))
|
||||||
def test_check_user_existence_banned(
|
def test_download_submission_hash_exists(
|
||||||
test_redditor_name: str,
|
test_submission_id: str,
|
||||||
reddit_instance: praw.Reddit,
|
test_hash: str,
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
tmp_path: Path,
|
||||||
|
capsys: pytest.CaptureFixture
|
||||||
):
|
):
|
||||||
|
setup_logging(3)
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
with pytest.raises(BulkDownloaderException, match='is banned'):
|
downloader_mock.download_filter.check_url.return_value = True
|
||||||
RedditDownloader._check_user_existence(downloader_mock, test_redditor_name)
|
downloader_mock.args.folder_scheme = ''
|
||||||
|
downloader_mock.args.no_dupes = True
|
||||||
|
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
|
||||||
|
downloader_mock.download_directory = tmp_path
|
||||||
|
downloader_mock.master_hash_list = {test_hash: None}
|
||||||
|
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
||||||
|
RedditDownloader._download_submission(downloader_mock, submission)
|
||||||
|
folder_contents = list(tmp_path.iterdir())
|
||||||
|
output = capsys.readouterr()
|
||||||
|
assert len(folder_contents) == 0
|
||||||
|
assert re.search(r'Resource hash .*? downloaded elsewhere', output.out)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), (
|
def test_download_submission_file_exists(
|
||||||
('donaldtrump', 'cannot be found'),
|
downloader_mock: MagicMock,
|
||||||
('submitters', 'private and cannot be scraped')
|
reddit_instance: praw.Reddit,
|
||||||
))
|
tmp_path: Path,
|
||||||
def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit):
|
capsys: pytest.CaptureFixture
|
||||||
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
|
):
|
||||||
with pytest.raises(BulkDownloaderException, match=expected_message):
|
setup_logging(3)
|
||||||
RedditDownloader._check_subreddit_status(test_subreddit)
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.download_filter.check_url.return_value = True
|
||||||
|
downloader_mock.args.folder_scheme = ''
|
||||||
|
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
|
||||||
|
downloader_mock.download_directory = tmp_path
|
||||||
|
submission = downloader_mock.reddit_instance.submission(id='m1hqw6')
|
||||||
|
Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch()
|
||||||
|
RedditDownloader._download_submission(downloader_mock, submission)
|
||||||
|
folder_contents = list(tmp_path.iterdir())
|
||||||
|
output = capsys.readouterr()
|
||||||
|
assert len(folder_contents) == 1
|
||||||
|
assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png already exists' in output.out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.parametrize('test_subreddit_name', (
|
@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), (
|
||||||
'Python',
|
('ljyy27', 4),
|
||||||
'Mindustry',
|
|
||||||
'TrollXChromosomes',
|
|
||||||
'all',
|
|
||||||
))
|
))
|
||||||
def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit):
|
def test_download_submission(
|
||||||
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
|
test_submission_id: str,
|
||||||
RedditDownloader._check_subreddit_status(test_subreddit)
|
expected_files_len: int,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
tmp_path: Path):
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.download_filter.check_url.return_value = True
|
||||||
|
downloader_mock.args.folder_scheme = ''
|
||||||
|
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
|
||||||
|
downloader_mock.download_directory = tmp_path
|
||||||
|
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
||||||
|
RedditDownloader._download_submission(downloader_mock, submission)
|
||||||
|
folder_contents = list(tmp_path.iterdir())
|
||||||
|
assert len(folder_contents) == expected_files_len
|
||||||
|
|
Loading…
Reference in a new issue