1
0
Fork 0
mirror of synced 2024-05-19 11:42:40 +12:00

Merge pull request #440 from aliparlakci/development

v2.2
This commit is contained in:
Ali Parlakçı 2021-06-06 14:25:51 +03:00 committed by GitHub
commit 349abbfb44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 1323 additions and 990 deletions

3
.gitignore vendored
View file

@ -139,3 +139,6 @@ cython_debug/
# Test configuration file
test_config.cfg
.vscode/
.idea/

View file

@ -27,16 +27,24 @@ If you want to use the source code or make contributions, refer to [CONTRIBUTING
The BDFR works by taking submissions from a variety of "sources" from Reddit and then parsing them to download. These sources might be a subreddit, multireddit, a user list, or individual links. These sources are combined and downloaded to disk, according to a naming and organisational scheme defined by the user.
There are two modes to the BDFR: download, and archive. Each one has a command that performs similar but distinct functions. The `download` command will download the resource linked in the Reddit submission, such as the images, video, etc. The `archive` command will download the submission data itself and store it, such as the submission details, upvotes, text, statistics, as and all the comments on that submission. These can then be saved in a data markup language form, such as JSON, XML, or YAML.
There are three modes to the BDFR: download, archive, and clone. Each one has a command that performs similar but distinct functions. The `download` command will download the resource linked in the Reddit submission, such as the images, video, etc. The `archive` command will download the submission data itself and store it, such as the submission details, upvotes, text, statistics, as and all the comments on that submission. These can then be saved in a data markup language form, such as JSON, XML, or YAML. Lastly, the `clone` command will perform both functions of the previous commands at once and is more efficient than running those commands sequentially.
Note that the `clone` command is not a true, failthful clone of Reddit. It simply retrieves much of the raw data that Reddit provides. To get a true clone of Reddit, another tool such as HTTrack should be used.
After installation, run the program from any directory as shown below:
```bash
python3 -m bdfr download
```
```bash
python3 -m bdfr archive
```
```bash
python3 -m bdfr clone
```
However, these commands are not enough. You should chain parameters in [Options](#options) according to your use case. Don't forget that some parameters can be provided multiple times. Some quick reference commands are:
```bash
@ -64,6 +72,10 @@ The following options are common between both the `archive` and `download` comma
- `--config`
- If the path to a configuration file is supplied with this option, the BDFR will use the specified config
- See [Configuration Files](#configuration) for more details
- `--disable-module`
- Can be specified multiple times
- Disables certain modules from being used
- See [Disabling Modules](#disabling-modules) for more information and a list of module names
- `--log`
- This allows one to specify the location of the logfile
- This must be done when running multiple instances of the BDFR, see [Multiple Instances](#multiple-instances) below
@ -124,6 +136,8 @@ The following options are common between both the `archive` and `download` comma
- `-u, --user`
- This specifies the user to scrape in concert with other options
- When using `--authenticate`, `--user me` can be used to refer to the authenticated user
- Can be specified multiple times for multiple users
- If downloading a multireddit, only one user can be specified
- `-v, --verbose`
- Increases the verbosity of the program
- Can be specified multiple times
@ -132,13 +146,6 @@ The following options are common between both the `archive` and `download` comma
The following options apply only to the `download` command. This command downloads the files and resources linked to in the submission, or a text submission itself, to the disk in the specified directory.
- `--exclude-id`
- This will skip the download of any submission with the ID provided
- Can be specified multiple times
- `--exclude-id-file`
- This will skip the download of any submission with any of the IDs in the files provided
- Can be specified multiple times
- Format is one ID per line
- `--make-hard-links`
- This flag will create hard links to an existing file when a duplicate is downloaded
- This will make the file appear in multiple directories while only taking the space of a single instance
@ -159,6 +166,13 @@ The following options apply only to the `download` command. This command downloa
- Sets the scheme for folders
- Default is `{SUBREDDIT}`
- See [Folder and File Name Schemes](#folder-and-file-name-schemes) for more details
- `--exclude-id`
- This will skip the download of any submission with the ID provided
- Can be specified multiple times
- `--exclude-id-file`
- This will skip the download of any submission with any of the IDs in the files provided
- Can be specified multiple times
- Format is one ID per line
- `--skip-domain`
- This adds domains to the download filter i.e. submissions coming from these domains will not be downloaded
- Can be specified multiple times
@ -183,6 +197,10 @@ The following options are for the `archive` command specifically.
- `xml`
- `yaml`
### Cloner Options
The `clone` command can take all the options listed above for both the `archive` and `download` commands since it performs the functions of both.
## Authentication and Security
The BDFR uses OAuth2 authentication to connect to Reddit if authentication is required. This means that it is a secure, token-based system for making requests. This also means that the BDFR only has access to specific parts of the account authenticated, by default only saved posts, upvoted posts, and the identity of the authenticated account. Note that authentication is not required unless accessing private things like upvoted posts, saved posts, and private multireddits.
@ -253,6 +271,7 @@ The following keys are optional, and defaults will be used if they cannot be fou
- `backup_log_count`
- `max_wait_time`
- `time_format`
- `disabled_modules`
All of these should not be modified unless you know what you're doing, as the default values will enable the BDFR to function just fine. A configuration is included in the BDFR when it is installed, and this will be placed in the configuration directory as the default.
@ -264,6 +283,22 @@ The option `time_format` will specify the format of the timestamp that replaces
The format can be specified through the [format codes](https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior) that are standard in the Python `datetime` library.
#### Disabling Modules
The individual modules of the BDFR, used to download submissions from websites, can be disabled. This is helpful especially in the case of the fallback downloaders, since the `--skip-domain` option cannot be effectively used in these cases. For example, the Youtube-DL downloader can retrieve data from hundreds of websites and domains; thus the only way to fully disable it is via the `--disable-module` option.
Modules can be disabled through the command line interface for the BDFR or more permanently in the configuration file via the `disabled_modules` option. The list of downloaders that can be disabled are the following. Note that they are case-insensitive.
- `Direct`
- `Erome`
- `Gallery` (Reddit Image Galleries)
- `Gfycat`
- `Imgur`
- `Redgifs`
- `SelfPost` (Reddit Text Post)
- `Youtube`
- `YoutubeDlFallback`
### Rate Limiting
The option `max_wait_time` has to do with retrying downloads. There are certain HTTP errors that mean that no amount of requests will return the wanted data, but some errors are from rate-limiting. This is when a single client is making so many requests that the remote website cuts the client off to preserve the function of the site. This is a common situation when downloading many resources from the same site. It is polite and best practice to obey the website's wishes in these cases.

View file

@ -8,35 +8,58 @@ import click
from bdfr.archiver import Archiver
from bdfr.configuration import Configuration
from bdfr.downloader import RedditDownloader
from bdfr.cloner import RedditCloner
logger = logging.getLogger()
_common_options = [
click.argument('directory', type=str),
click.option('--config', type=str, default=None),
click.option('-v', '--verbose', default=None, count=True),
click.option('-l', '--link', multiple=True, default=None, type=str),
click.option('-s', '--subreddit', multiple=True, default=None, type=str),
click.option('-m', '--multireddit', multiple=True, default=None, type=str),
click.option('-L', '--limit', default=None, type=int),
click.option('--authenticate', is_flag=True, default=None),
click.option('--config', type=str, default=None),
click.option('--disable-module', multiple=True, default=None, type=str),
click.option('--log', type=str, default=None),
click.option('--submitted', is_flag=True, default=None),
click.option('--upvoted', is_flag=True, default=None),
click.option('--saved', is_flag=True, default=None),
click.option('--search', default=None, type=str),
click.option('--submitted', is_flag=True, default=None),
click.option('--time-format', type=str, default=None),
click.option('-u', '--user', type=str, default=None),
click.option('--upvoted', is_flag=True, default=None),
click.option('-L', '--limit', default=None, type=int),
click.option('-l', '--link', multiple=True, default=None, type=str),
click.option('-m', '--multireddit', multiple=True, default=None, type=str),
click.option('-s', '--subreddit', multiple=True, default=None, type=str),
click.option('-v', '--verbose', default=None, count=True),
click.option('-u', '--user', type=str, multiple=True, default=None),
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new',
'controversial', 'rising', 'relevance')), default=None),
]
_downloader_options = [
click.option('--file-scheme', default=None, type=str),
click.option('--folder-scheme', default=None, type=str),
click.option('--make-hard-links', is_flag=True, default=None),
click.option('--max-wait-time', type=int, default=None),
click.option('--no-dupes', is_flag=True, default=None),
click.option('--search-existing', is_flag=True, default=None),
click.option('--exclude-id', default=None, multiple=True),
click.option('--exclude-id-file', default=None, multiple=True),
click.option('--skip', default=None, multiple=True),
click.option('--skip-domain', default=None, multiple=True),
click.option('--skip-subreddit', default=None, multiple=True),
]
def _add_common_options(func):
for opt in _common_options:
func = opt(func)
return func
_archiver_options = [
click.option('--all-comments', is_flag=True, default=None),
click.option('-f', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None),
]
def _add_options(opts: list):
def wrap(func):
for opt in opts:
func = opt(func)
return func
return wrap
@click.group()
@ -45,18 +68,8 @@ def cli():
@cli.command('download')
@click.option('--exclude-id', default=None, multiple=True)
@click.option('--exclude-id-file', default=None, multiple=True)
@click.option('--file-scheme', default=None, type=str)
@click.option('--folder-scheme', default=None, type=str)
@click.option('--make-hard-links', is_flag=True, default=None)
@click.option('--max-wait-time', type=int, default=None)
@click.option('--no-dupes', is_flag=True, default=None)
@click.option('--search-existing', is_flag=True, default=None)
@click.option('--skip', default=None, multiple=True)
@click.option('--skip-domain', default=None, multiple=True)
@click.option('--skip-subreddit', default=None, multiple=True)
@_add_common_options
@_add_options(_common_options)
@_add_options(_downloader_options)
@click.pass_context
def cli_download(context: click.Context, **_):
config = Configuration()
@ -73,9 +86,8 @@ def cli_download(context: click.Context, **_):
@cli.command('archive')
@_add_common_options
@click.option('--all-comments', is_flag=True, default=None)
@click.option('-f', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None)
@_add_options(_common_options)
@_add_options(_archiver_options)
@click.pass_context
def cli_archive(context: click.Context, **_):
config = Configuration()
@ -85,7 +97,26 @@ def cli_archive(context: click.Context, **_):
reddit_archiver = Archiver(config)
reddit_archiver.download()
except Exception:
logger.exception('Downloader exited unexpectedly')
logger.exception('Archiver exited unexpectedly')
raise
else:
logger.info('Program complete')
@cli.command('clone')
@_add_options(_common_options)
@_add_options(_archiver_options)
@_add_options(_downloader_options)
@click.pass_context
def cli_clone(context: click.Context, **_):
config = Configuration()
config.process_click_arguments(context)
setup_logging(config.verbose)
try:
reddit_scraper = RedditCloner(config)
reddit_scraper.download()
except Exception:
logger.exception('Scraper exited unexpectedly')
raise
else:
logger.info('Program complete')

View file

@ -26,6 +26,7 @@ class BaseArchiveEntry(ABC):
'stickied': in_comment.stickied,
'body': in_comment.body,
'is_submitter': in_comment.is_submitter,
'distinguished': in_comment.distinguished,
'created_utc': in_comment.created_utc,
'parent_id': in_comment.parent_id,
'replies': [],

View file

@ -35,6 +35,10 @@ class SubmissionArchiveEntry(BaseArchiveEntry):
'link_flair_text': self.source.link_flair_text,
'num_comments': self.source.num_comments,
'over_18': self.source.over_18,
'spoiler': self.source.spoiler,
'pinned': self.source.pinned,
'locked': self.source.locked,
'distinguished': self.source.distinguished,
'created_utc': self.source.created_utc,
}

View file

@ -14,14 +14,14 @@ from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
from bdfr.configuration import Configuration
from bdfr.downloader import RedditDownloader
from bdfr.connector import RedditConnector
from bdfr.exceptions import ArchiverError
from bdfr.resource import Resource
logger = logging.getLogger(__name__)
class Archiver(RedditDownloader):
class Archiver(RedditConnector):
def __init__(self, args: Configuration):
super(Archiver, self).__init__(args)
@ -29,9 +29,9 @@ class Archiver(RedditDownloader):
for generator in self.reddit_lists:
for submission in generator:
logger.debug(f'Attempting to archive submission {submission.id}')
self._write_entry(submission)
self.write_entry(submission)
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
if len(sub_id) == 6:
@ -42,12 +42,13 @@ class Archiver(RedditDownloader):
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
return [supplied_submissions]
def _get_user_data(self) -> list[Iterator]:
results = super(Archiver, self)._get_user_data()
def get_user_data(self) -> list[Iterator]:
results = super(Archiver, self).get_user_data()
if self.args.user and self.args.all_comments:
sort = self._determine_sort_function()
logger.debug(f'Retrieving comments of user {self.args.user}')
results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit))
sort = self.determine_sort_function()
for user in self.args.user:
logger.debug(f'Retrieving comments of user {user}')
results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit))
return results
@staticmethod
@ -59,7 +60,7 @@ class Archiver(RedditDownloader):
else:
raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}')
def _write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)):
def write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)):
archive_entry = self._pull_lever_entry_factory(praw_item)
if self.args.format == 'json':
self._write_entry_json(archive_entry)

21
bdfr/cloner.py Normal file
View file

@ -0,0 +1,21 @@
#!/usr/bin/env python3
# coding=utf-8
import logging
from bdfr.archiver import Archiver
from bdfr.configuration import Configuration
from bdfr.downloader import RedditDownloader
logger = logging.getLogger(__name__)
class RedditCloner(RedditDownloader, Archiver):
def __init__(self, args: Configuration):
super(RedditCloner, self).__init__(args)
def download(self):
for generator in self.reddit_lists:
for submission in generator:
self._download_submission(submission)
self.write_entry(submission)

View file

@ -13,19 +13,21 @@ class Configuration(Namespace):
self.authenticate = False
self.config = None
self.directory: str = '.'
self.disable_module: list[str] = []
self.exclude_id = []
self.exclude_id_file = []
self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}'
self.folder_scheme: str = '{SUBREDDIT}'
self.limit: Optional[int] = None
self.link: list[str] = []
self.log: Optional[str] = None
self.make_hard_links = False
self.max_wait_time = None
self.multireddit: list[str] = []
self.no_dupes: bool = False
self.saved: bool = False
self.search: Optional[str] = None
self.search_existing: bool = False
self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}'
self.folder_scheme: str = '{SUBREDDIT}'
self.skip: list[str] = []
self.skip_domain: list[str] = []
self.skip_subreddit: list[str] = []
@ -35,9 +37,8 @@ class Configuration(Namespace):
self.time: str = 'all'
self.time_format = None
self.upvoted: bool = False
self.user: Optional[str] = None
self.user: list[str] = []
self.verbose: int = 0
self.make_hard_links = False
# Archiver-specific options
self.format = 'json'

417
bdfr/connector.py Normal file
View file

@ -0,0 +1,417 @@
#!/usr/bin/env python3
# coding=utf-8
import configparser
import importlib.resources
import logging
import logging.handlers
import re
import shutil
import socket
from abc import ABCMeta, abstractmethod
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import Callable, Iterator
import appdirs
import praw
import praw.exceptions
import praw.models
import prawcore
from bdfr import exceptions as errors
from bdfr.configuration import Configuration
from bdfr.download_filter import DownloadFilter
from bdfr.file_name_formatter import FileNameFormatter
from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
from bdfr.site_authenticator import SiteAuthenticator
logger = logging.getLogger(__name__)
class RedditTypes:
class SortType(Enum):
CONTROVERSIAL = auto()
HOT = auto()
NEW = auto()
RELEVENCE = auto()
RISING = auto()
TOP = auto()
class TimeType(Enum):
ALL = 'all'
DAY = 'day'
HOUR = 'hour'
MONTH = 'month'
WEEK = 'week'
YEAR = 'year'
class RedditConnector(metaclass=ABCMeta):
def __init__(self, args: Configuration):
self.args = args
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR')
self.run_time = datetime.now().isoformat()
self._setup_internal_objects()
self.reddit_lists = self.retrieve_reddit_lists()
def _setup_internal_objects(self):
self.determine_directories()
self.load_config()
self.create_file_logger()
self.read_config()
self.parse_disabled_modules()
self.download_filter = self.create_download_filter()
logger.log(9, 'Created download filter')
self.time_filter = self.create_time_filter()
logger.log(9, 'Created time filter')
self.sort_filter = self.create_sort_filter()
logger.log(9, 'Created sort filter')
self.file_name_formatter = self.create_file_name_formatter()
logger.log(9, 'Create file name formatter')
self.create_reddit_instance()
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
self.excluded_submission_ids = self.read_excluded_ids()
self.master_hash_list = {}
self.authenticator = self.create_authenticator()
logger.log(9, 'Created site authenticator')
self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit)
self.args.skip_subreddit = set([sub.lower() for sub in self.args.skip_subreddit])
def read_config(self):
"""Read any cfg values that need to be processed"""
if self.args.max_wait_time is None:
if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'):
self.cfg_parser.set('DEFAULT', 'max_wait_time', '120')
logger.log(9, 'Wrote default download wait time download to config file')
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time')
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds')
if self.args.time_format is None:
option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO')
if re.match(r'^[ \'\"]*$', option):
option = 'ISO'
logger.debug(f'Setting datetime format string to {option}')
self.args.time_format = option
if not self.args.disable_module:
self.args.disable_module = [self.cfg_parser.get('DEFAULT', 'disabled_modules', fallback='')]
# Update config on disk
with open(self.config_location, 'w') as file:
self.cfg_parser.write(file)
def parse_disabled_modules(self):
disabled_modules = self.args.disable_module
disabled_modules = self.split_args_input(disabled_modules)
disabled_modules = set([name.strip().lower() for name in disabled_modules])
self.args.disable_module = disabled_modules
logger.debug(f'Disabling the following modules: {", ".join(self.args.disable_module)}')
def create_reddit_instance(self):
if self.args.authenticate:
logger.debug('Using authenticated Reddit instance')
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
logger.log(9, 'Commencing OAuth2 authentication')
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
scopes = OAuth2Authenticator.split_scopes(scopes)
oauth2_authenticator = OAuth2Authenticator(
scopes,
self.cfg_parser.get('DEFAULT', 'client_id'),
self.cfg_parser.get('DEFAULT', 'client_secret'),
)
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token
with open(self.config_location, 'w') as file:
self.cfg_parser.write(file, True)
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
self.authenticated = True
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
token_manager=token_manager,
)
else:
logger.debug('Using unauthenticated Reddit instance')
self.authenticated = False
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
)
def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
master_list = []
master_list.extend(self.get_subreddits())
logger.log(9, 'Retrieved subreddits')
master_list.extend(self.get_multireddits())
logger.log(9, 'Retrieved multireddits')
master_list.extend(self.get_user_data())
logger.log(9, 'Retrieved user data')
master_list.extend(self.get_submissions_from_link())
logger.log(9, 'Retrieved submissions for given links')
return master_list
def determine_directories(self):
self.download_directory = Path(self.args.directory).resolve().expanduser()
self.config_directory = Path(self.config_directories.user_config_dir)
self.download_directory.mkdir(exist_ok=True, parents=True)
self.config_directory.mkdir(exist_ok=True, parents=True)
def load_config(self):
self.cfg_parser = configparser.ConfigParser()
if self.args.config:
if (cfg_path := Path(self.args.config)).exists():
self.cfg_parser.read(cfg_path)
self.config_location = cfg_path
return
possible_paths = [
Path('./config.cfg'),
Path('./default_config.cfg'),
Path(self.config_directory, 'config.cfg'),
Path(self.config_directory, 'default_config.cfg'),
]
self.config_location = None
for path in possible_paths:
if path.resolve().expanduser().exists():
self.config_location = path
logger.debug(f'Loading configuration from {path}')
break
if not self.config_location:
self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0]
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg'))
if not self.config_location:
raise errors.BulkDownloaderException('Could not find a configuration file to load')
self.cfg_parser.read(self.config_location)
def create_file_logger(self):
main_logger = logging.getLogger()
if self.args.log is None:
log_path = Path(self.config_directory, 'log_output.txt')
else:
log_path = Path(self.args.log).resolve().expanduser()
if not log_path.parent.exists():
raise errors.BulkDownloaderException(f'Designated location for logfile does not exist')
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3)
file_handler = logging.handlers.RotatingFileHandler(
log_path,
mode='a',
backupCount=backup_count,
)
if log_path.exists():
try:
file_handler.doRollover()
except PermissionError as e:
logger.critical(
'Cannot rollover logfile, make sure this is the only '
'BDFR process or specify alternate logfile location')
raise
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
file_handler.setFormatter(formatter)
file_handler.setLevel(0)
main_logger.addHandler(file_handler)
@staticmethod
def sanitise_subreddit_name(subreddit: str) -> str:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$')
match = re.match(pattern, subreddit)
if not match:
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
return match.group(1)
@staticmethod
def split_args_input(entries: list[str]) -> set[str]:
all_entries = []
split_pattern = re.compile(r'[,;]\s?')
for entry in entries:
results = re.split(split_pattern, entry)
all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results])
return set(all_entries)
def get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit:
out = []
for reddit in self.split_args_input(self.args.subreddit):
try:
reddit = self.reddit_instance.subreddit(reddit)
try:
self.check_subreddit_status(reddit)
except errors.BulkDownloaderException as e:
logger.error(e)
continue
if self.args.search:
out.append(reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
time_filter=self.time_filter.value,
))
logger.debug(
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
else:
out.append(self.create_filtered_listing_generator(reddit))
logger.debug(f'Added submissions from subreddit {reddit}')
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
return out
else:
return []
def resolve_user_name(self, in_name: str) -> str:
if in_name == 'me':
if self.authenticated:
resolved_name = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {resolved_name}')
return resolved_name
else:
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
else:
return in_name
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
if len(sub_id) == 6:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
else:
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
return [supplied_submissions]
def determine_sort_function(self) -> Callable:
if self.sort_filter is RedditTypes.SortType.NEW:
sort_function = praw.models.Subreddit.new
elif self.sort_filter is RedditTypes.SortType.RISING:
sort_function = praw.models.Subreddit.rising
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
sort_function = praw.models.Subreddit.controversial
elif self.sort_filter is RedditTypes.SortType.TOP:
sort_function = praw.models.Subreddit.top
else:
sort_function = praw.models.Subreddit.hot
return sort_function
def get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
if len(self.args.user) != 1:
logger.error(f'Only 1 user can be supplied when retrieving from multireddits')
return []
out = []
for multi in self.split_args_input(self.args.multireddit):
try:
multi = self.reddit_instance.multireddit(self.args.user[0], multi)
if not multi.subreddits:
raise errors.BulkDownloaderException
out.append(self.create_filtered_listing_generator(multi))
logger.debug(f'Added submissions from multireddit {multi}')
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
return out
else:
return []
def create_filtered_listing_generator(self, reddit_source) -> Iterator:
sort_function = self.determine_sort_function()
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
else:
return sort_function(reddit_source, limit=self.args.limit)
def get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if not self.args.user:
logger.warning('At least one user must be supplied to download user data')
return []
generators = []
for user in self.args.user:
try:
self.check_user_existence(user)
except errors.BulkDownloaderException as e:
logger.error(e)
continue
if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
generators.append(self.create_filtered_listing_generator(
self.reddit_instance.redditor(user).submissions,
))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication')
else:
if self.args.upvoted:
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit))
if self.args.saved:
logger.debug(f'Retrieving saved posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit))
return generators
else:
return []
def check_user_existence(self, name: str):
user = self.reddit_instance.redditor(name=name)
try:
if user.id:
return
except prawcore.exceptions.NotFound:
raise errors.BulkDownloaderException(f'Could not find user {name}')
except AttributeError:
if hasattr(user, 'is_suspended'):
raise errors.BulkDownloaderException(f'User {name} is banned')
def create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format)
def create_time_filter(self) -> RedditTypes.TimeType:
try:
return RedditTypes.TimeType[self.args.time.upper()]
except (KeyError, AttributeError):
return RedditTypes.TimeType.ALL
def create_sort_filter(self) -> RedditTypes.SortType:
try:
return RedditTypes.SortType[self.args.sort.upper()]
except (KeyError, AttributeError):
return RedditTypes.SortType.HOT
def create_download_filter(self) -> DownloadFilter:
return DownloadFilter(self.args.skip, self.args.skip_domain)
def create_authenticator(self) -> SiteAuthenticator:
return SiteAuthenticator(self.cfg_parser)
@abstractmethod
def download(self):
pass
@staticmethod
def check_subreddit_status(subreddit: praw.models.Subreddit):
if subreddit.display_name == 'all':
return
try:
assert subreddit.id
except prawcore.NotFound:
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} does not exist or cannot be found')
except prawcore.Forbidden:
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped')
def read_excluded_ids(self) -> set[str]:
out = []
out.extend(self.args.exclude_id)
for id_file in self.args.exclude_id_file:
id_file = Path(id_file).resolve().expanduser()
if not id_file.exists():
logger.warning(f'ID exclusion file at {id_file} does not exist')
continue
with open(id_file, 'r') as file:
for line in file:
out.append(line.strip())
return set(out)

View file

@ -1,405 +1,61 @@
#!/usr/bin/env python3
# coding=utf-8
import configparser
import hashlib
import importlib.resources
import logging
import logging.handlers
import os
import re
import shutil
import socket
import time
from datetime import datetime
from enum import Enum, auto
from multiprocessing import Pool
from pathlib import Path
from typing import Callable, Iterator
import appdirs
import praw
import praw.exceptions
import praw.models
import prawcore
import bdfr.exceptions as errors
from bdfr import exceptions as errors
from bdfr.configuration import Configuration
from bdfr.download_filter import DownloadFilter
from bdfr.file_name_formatter import FileNameFormatter
from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.connector import RedditConnector
from bdfr.site_downloaders.download_factory import DownloadFactory
logger = logging.getLogger(__name__)
def _calc_hash(existing_file: Path):
chunk_size = 1024 * 1024
md5_hash = hashlib.md5()
with open(existing_file, 'rb') as file:
file_hash = hashlib.md5(file.read()).hexdigest()
return existing_file, file_hash
chunk = file.read(chunk_size)
while chunk:
md5_hash.update(chunk)
chunk = file.read(chunk_size)
file_hash = md5_hash.hexdigest()
return existing_file, file_hash
class RedditTypes:
class SortType(Enum):
CONTROVERSIAL = auto()
HOT = auto()
NEW = auto()
RELEVENCE = auto()
RISING = auto()
TOP = auto()
class TimeType(Enum):
ALL = 'all'
DAY = 'day'
HOUR = 'hour'
MONTH = 'month'
WEEK = 'week'
YEAR = 'year'
class RedditDownloader:
class RedditDownloader(RedditConnector):
def __init__(self, args: Configuration):
self.args = args
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR')
self.run_time = datetime.now().isoformat()
self._setup_internal_objects()
self.reddit_lists = self._retrieve_reddit_lists()
def _setup_internal_objects(self):
self._determine_directories()
self._load_config()
self._create_file_logger()
self._read_config()
self.download_filter = self._create_download_filter()
logger.log(9, 'Created download filter')
self.time_filter = self._create_time_filter()
logger.log(9, 'Created time filter')
self.sort_filter = self._create_sort_filter()
logger.log(9, 'Created sort filter')
self.file_name_formatter = self._create_file_name_formatter()
logger.log(9, 'Create file name formatter')
self._create_reddit_instance()
self._resolve_user_name()
self.excluded_submission_ids = self._read_excluded_ids()
super(RedditDownloader, self).__init__(args)
if self.args.search_existing:
self.master_hash_list = self.scan_existing_files(self.download_directory)
else:
self.master_hash_list = {}
self.authenticator = self._create_authenticator()
logger.log(9, 'Created site authenticator')
self.args.skip_subreddit = self._split_args_input(self.args.skip_subreddit)
self.args.skip_subreddit = set([sub.lower() for sub in self.args.skip_subreddit])
def _read_config(self):
"""Read any cfg values that need to be processed"""
if self.args.max_wait_time is None:
if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'):
self.cfg_parser.set('DEFAULT', 'max_wait_time', '120')
logger.log(9, 'Wrote default download wait time download to config file')
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time')
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds')
if self.args.time_format is None:
option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO')
if re.match(r'^[ \'\"]*$', option):
option = 'ISO'
logger.debug(f'Setting datetime format string to {option}')
self.args.time_format = option
# Update config on disk
with open(self.config_location, 'w') as file:
self.cfg_parser.write(file)
def _create_reddit_instance(self):
if self.args.authenticate:
logger.debug('Using authenticated Reddit instance')
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
logger.log(9, 'Commencing OAuth2 authentication')
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
scopes = OAuth2Authenticator.split_scopes(scopes)
oauth2_authenticator = OAuth2Authenticator(
scopes,
self.cfg_parser.get('DEFAULT', 'client_id'),
self.cfg_parser.get('DEFAULT', 'client_secret'),
)
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token
with open(self.config_location, 'w') as file:
self.cfg_parser.write(file, True)
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
self.authenticated = True
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
token_manager=token_manager,
)
else:
logger.debug('Using unauthenticated Reddit instance')
self.authenticated = False
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
)
def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
master_list = []
master_list.extend(self._get_subreddits())
logger.log(9, 'Retrieved subreddits')
master_list.extend(self._get_multireddits())
logger.log(9, 'Retrieved multireddits')
master_list.extend(self._get_user_data())
logger.log(9, 'Retrieved user data')
master_list.extend(self._get_submissions_from_link())
logger.log(9, 'Retrieved submissions for given links')
return master_list
def _determine_directories(self):
self.download_directory = Path(self.args.directory).resolve().expanduser()
self.config_directory = Path(self.config_directories.user_config_dir)
self.download_directory.mkdir(exist_ok=True, parents=True)
self.config_directory.mkdir(exist_ok=True, parents=True)
def _load_config(self):
self.cfg_parser = configparser.ConfigParser()
if self.args.config:
if (cfg_path := Path(self.args.config)).exists():
self.cfg_parser.read(cfg_path)
self.config_location = cfg_path
return
possible_paths = [
Path('./config.cfg'),
Path('./default_config.cfg'),
Path(self.config_directory, 'config.cfg'),
Path(self.config_directory, 'default_config.cfg'),
]
self.config_location = None
for path in possible_paths:
if path.resolve().expanduser().exists():
self.config_location = path
logger.debug(f'Loading configuration from {path}')
break
if not self.config_location:
self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0]
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg'))
if not self.config_location:
raise errors.BulkDownloaderException('Could not find a configuration file to load')
self.cfg_parser.read(self.config_location)
def _create_file_logger(self):
main_logger = logging.getLogger()
if self.args.log is None:
log_path = Path(self.config_directory, 'log_output.txt')
else:
log_path = Path(self.args.log).resolve().expanduser()
if not log_path.parent.exists():
raise errors.BulkDownloaderException(f'Designated location for logfile does not exist')
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3)
file_handler = logging.handlers.RotatingFileHandler(
log_path,
mode='a',
backupCount=backup_count,
)
if log_path.exists():
try:
file_handler.doRollover()
except PermissionError as e:
logger.critical(
'Cannot rollover logfile, make sure this is the only '
'BDFR process or specify alternate logfile location')
raise
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
file_handler.setFormatter(formatter)
file_handler.setLevel(0)
main_logger.addHandler(file_handler)
@staticmethod
def _sanitise_subreddit_name(subreddit: str) -> str:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$')
match = re.match(pattern, subreddit)
if not match:
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
return match.group(1)
@staticmethod
def _split_args_input(entries: list[str]) -> set[str]:
all_entries = []
split_pattern = re.compile(r'[,;]\s?')
for entry in entries:
results = re.split(split_pattern, entry)
all_entries.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results])
return set(all_entries)
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit:
out = []
for reddit in self._split_args_input(self.args.subreddit):
try:
reddit = self.reddit_instance.subreddit(reddit)
try:
self._check_subreddit_status(reddit)
except errors.BulkDownloaderException as e:
logger.error(e)
continue
if self.args.search:
out.append(reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
time_filter=self.time_filter.value,
))
logger.debug(
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
else:
out.append(self._create_filtered_listing_generator(reddit))
logger.debug(f'Added submissions from subreddit {reddit}')
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
return out
else:
return []
def _resolve_user_name(self):
if self.args.user == 'me':
if self.authenticated:
self.args.user = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {self.args.user}')
else:
self.args.user = None
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
if len(sub_id) == 6:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
else:
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
return [supplied_submissions]
def _determine_sort_function(self) -> Callable:
if self.sort_filter is RedditTypes.SortType.NEW:
sort_function = praw.models.Subreddit.new
elif self.sort_filter is RedditTypes.SortType.RISING:
sort_function = praw.models.Subreddit.rising
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
sort_function = praw.models.Subreddit.controversial
elif self.sort_filter is RedditTypes.SortType.TOP:
sort_function = praw.models.Subreddit.top
else:
sort_function = praw.models.Subreddit.hot
return sort_function
def _get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
out = []
for multi in self._split_args_input(self.args.multireddit):
try:
multi = self.reddit_instance.multireddit(self.args.user, multi)
if not multi.subreddits:
raise errors.BulkDownloaderException
out.append(self._create_filtered_listing_generator(multi))
logger.debug(f'Added submissions from multireddit {multi}')
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
return out
else:
return []
def _create_filtered_listing_generator(self, reddit_source) -> Iterator:
sort_function = self._determine_sort_function()
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
else:
return sort_function(reddit_source, limit=self.args.limit)
def _get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if self.args.user:
try:
self._check_user_existence(self.args.user)
except errors.BulkDownloaderException as e:
logger.error(e)
return []
generators = []
if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
generators.append(self._create_filtered_listing_generator(
self.reddit_instance.redditor(self.args.user).submissions,
))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication')
else:
if self.args.upvoted:
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit))
if self.args.saved:
logger.debug(f'Retrieving saved posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit))
return generators
else:
logger.warning('A user must be supplied to download user data')
return []
else:
return []
def _check_user_existence(self, name: str):
user = self.reddit_instance.redditor(name=name)
try:
if user.id:
return
except prawcore.exceptions.NotFound:
raise errors.BulkDownloaderException(f'Could not find user {name}')
except AttributeError:
if hasattr(user, 'is_suspended'):
raise errors.BulkDownloaderException(f'User {name} is banned')
def _create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format)
def _create_time_filter(self) -> RedditTypes.TimeType:
try:
return RedditTypes.TimeType[self.args.time.upper()]
except (KeyError, AttributeError):
return RedditTypes.TimeType.ALL
def _create_sort_filter(self) -> RedditTypes.SortType:
try:
return RedditTypes.SortType[self.args.sort.upper()]
except (KeyError, AttributeError):
return RedditTypes.SortType.HOT
def _create_download_filter(self) -> DownloadFilter:
return DownloadFilter(self.args.skip, self.args.skip_domain)
def _create_authenticator(self) -> SiteAuthenticator:
return SiteAuthenticator(self.cfg_parser)
def download(self):
for generator in self.reddit_lists:
for submission in generator:
if submission.id in self.excluded_submission_ids:
logger.debug(f'Object {submission.id} in exclusion list, skipping')
continue
elif submission.subreddit.display_name.lower() in self.args.skip_subreddit:
logger.debug(f'Submission {submission.id} in {submission.subreddit.display_name} in skip list')
else:
logger.debug(f'Attempting to download submission {submission.id}')
self._download_submission(submission)
self._download_submission(submission)
def _download_submission(self, submission: praw.models.Submission):
if not isinstance(submission, praw.models.Submission):
if submission.id in self.excluded_submission_ids:
logger.debug(f'Object {submission.id} in exclusion list, skipping')
return
elif submission.subreddit.display_name.lower() in self.args.skip_subreddit:
logger.debug(f'Submission {submission.id} in {submission.subreddit.display_name} in skip list')
return
elif not isinstance(submission, praw.models.Submission):
logger.warning(f'{submission.id} is not a submission')
return
logger.debug(f'Attempting to download submission {submission.id}')
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
@ -407,7 +63,9 @@ class RedditDownloader:
except errors.NotADownloadableLinkError as e:
logger.error(f'Could not download submission {submission.id}: {e}')
return
if downloader_class.__name__.lower() in self.args.disable_module:
logger.debug(f'Submission {submission.id} skipped due to disabled module {downloader_class.__name__}')
return
try:
content = downloader.find_resources(self.authenticator)
except errors.SiteDownloaderError as e:
@ -415,34 +73,42 @@ class RedditDownloader:
return
for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory):
if destination.exists():
logger.debug(f'File {destination} already exists, continuing')
logger.debug(f'File {destination} from submission {submission.id} already exists, continuing')
continue
elif not self.download_filter.check_resource(res):
logger.debug(f'Download filter removed {submission.id} with URL {submission.url}')
else:
try:
res.download(self.args.max_wait_time)
except errors.BulkDownloaderException as e:
logger.error(f'Failed to download resource {res.url} in submission {submission.id} '
f'with downloader {downloader_class.__name__}: {e}')
continue
try:
res.download(self.args.max_wait_time)
except errors.BulkDownloaderException as e:
logger.error(f'Failed to download resource {res.url} in submission {submission.id} '
f'with downloader {downloader_class.__name__}: {e}')
return
resource_hash = res.hash.hexdigest()
destination.parent.mkdir(parents=True, exist_ok=True)
if resource_hash in self.master_hash_list:
if self.args.no_dupes:
logger.info(
f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere')
return
resource_hash = res.hash.hexdigest()
destination.parent.mkdir(parents=True, exist_ok=True)
if resource_hash in self.master_hash_list:
if self.args.no_dupes:
logger.info(
f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere')
return
elif self.args.make_hard_links:
self.master_hash_list[resource_hash].link_to(destination)
logger.info(
f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}')
return
elif self.args.make_hard_links:
self.master_hash_list[resource_hash].link_to(destination)
logger.info(
f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}'
f' in submission {submission.id}')
return
try:
with open(destination, 'wb') as file:
file.write(res.content)
logger.debug(f'Written file to {destination}')
self.master_hash_list[resource_hash] = destination
logger.debug(f'Hash added to master list: {resource_hash}')
logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}')
except OSError as e:
logger.exception(e)
logger.error(f'Failed to write file to {destination} in submission {submission.id}: {e}')
creation_time = time.mktime(datetime.fromtimestamp(submission.created_utc).timetuple())
os.utime(destination, (creation_time, creation_time))
self.master_hash_list[resource_hash] = destination
logger.debug(f'Hash added to master list: {resource_hash}')
logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}')
@staticmethod
def scan_existing_files(directory: Path) -> dict[str, Path]:
@ -457,27 +123,3 @@ class RedditDownloader:
hash_list = {res[1]: res[0] for res in results}
return hash_list
def _read_excluded_ids(self) -> set[str]:
out = []
out.extend(self.args.exclude_id)
for id_file in self.args.exclude_id_file:
id_file = Path(id_file).resolve().expanduser()
if not id_file.exists():
logger.warning(f'ID exclusion file at {id_file} does not exist')
continue
with open(id_file, 'r') as file:
for line in file:
out.append(line.strip())
return set(out)
@staticmethod
def _check_subreddit_status(subreddit: praw.models.Subreddit):
if subreddit.display_name == 'all':
return
try:
assert subreddit.id
except prawcore.NotFound:
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} does not exist or cannot be found')
except prawcore.Forbidden:
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped')

View file

@ -4,6 +4,7 @@ import datetime
import logging
import platform
import re
import subprocess
from pathlib import Path
from typing import Optional
@ -104,32 +105,46 @@ class FileNameFormatter:
) -> Path:
subfolder = Path(
destination_directory,
*[self._format_name(resource.source_submission, part) for part in self.directory_format_string]
*[self._format_name(resource.source_submission, part) for part in self.directory_format_string],
)
index = f'_{str(index)}' if index else ''
if not resource.extension:
raise BulkDownloaderException(f'Resource from {resource.url} has no extension')
ending = index + resource.extension
file_name = str(self._format_name(resource.source_submission, self.file_format_string))
file_name = self._limit_file_name_length(file_name, ending)
try:
file_path = Path(subfolder, file_name)
file_path = self._limit_file_name_length(file_name, ending, subfolder)
except TypeError:
raise BulkDownloaderException(f'Could not determine path name: {subfolder}, {index}, {resource.extension}')
return file_path
@staticmethod
def _limit_file_name_length(filename: str, ending: str) -> str:
def _limit_file_name_length(filename: str, ending: str, root: Path) -> Path:
root = root.resolve().expanduser()
possible_id = re.search(r'((?:_\w{6})?$)', filename)
if possible_id:
ending = possible_id.group(1) + ending
filename = filename[:possible_id.start()]
max_path = FileNameFormatter.find_max_path_length()
max_length_chars = 255 - len(ending)
max_length_bytes = 255 - len(ending.encode('utf-8'))
while len(filename) > max_length_chars or len(filename.encode('utf-8')) > max_length_bytes:
max_path_length = max_path - len(ending) - len(str(root)) - 1
while len(filename) > max_length_chars or \
len(filename.encode('utf-8')) > max_length_bytes or \
len(filename) > max_path_length:
filename = filename[:-1]
return filename + ending
return Path(root, filename + ending)
@staticmethod
def find_max_path_length() -> int:
try:
return int(subprocess.check_output(['getconf', 'PATH_MAX', '/']))
except (ValueError, subprocess.CalledProcessError, OSError):
if platform.system() == 'Windows':
return 260
else:
return 4096
def format_resource_paths(
self,

View file

@ -5,8 +5,8 @@ import hashlib
import logging
import re
import time
from typing import Optional
import urllib.parse
from typing import Optional
import _hashlib
import requests
@ -28,8 +28,7 @@ class Resource:
self.extension = self._determine_extension()
@staticmethod
def retry_download(url: str, max_wait_time: int) -> Optional[bytes]:
wait_time = 60
def retry_download(url: str, max_wait_time: int, current_wait_time: int = 60) -> Optional[bytes]:
try:
response = requests.get(url)
if re.match(r'^2\d{2}', str(response.status_code)) and response.content:
@ -39,11 +38,12 @@ class Resource:
else:
raise BulkDownloaderException(
f'Unrecoverable error requesting resource: HTTP Code {response.status_code}')
except requests.exceptions.ConnectionError as e:
logger.warning(f'Error occured downloading from {url}, waiting {wait_time} seconds: {e}')
time.sleep(wait_time)
if wait_time < max_wait_time:
return Resource.retry_download(url, max_wait_time)
except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e:
logger.warning(f'Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}')
time.sleep(current_wait_time)
if current_wait_time < max_wait_time:
current_wait_time += 60
return Resource.retry_download(url, max_wait_time, current_wait_time)
else:
logger.error(f'Max wait time exceeded for resource at url {url}')
raise

View file

@ -21,10 +21,11 @@ from bdfr.site_downloaders.youtube import Youtube
class DownloadFactory:
@staticmethod
def pull_lever(url: str) -> Type[BaseDownloader]:
sanitised_url = DownloadFactory._sanitise_url(url)
sanitised_url = DownloadFactory.sanitise_url(url)
if re.match(r'(i\.)?imgur.*\.gifv$', sanitised_url):
return Imgur
elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url):
elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url) and \
not DownloadFactory.is_web_resource(sanitised_url):
return Direct
elif re.match(r'erome\.com.*', sanitised_url):
return Erome
@ -49,9 +50,29 @@ class DownloadFactory:
f'No downloader module exists for url {url}')
@staticmethod
def _sanitise_url(url: str) -> str:
def sanitise_url(url: str) -> str:
beginning_regex = re.compile(r'\s*(www\.?)?')
split_url = urllib.parse.urlsplit(url)
split_url = split_url.netloc + split_url.path
split_url = re.sub(beginning_regex, '', split_url)
return split_url
@staticmethod
def is_web_resource(url: str) -> bool:
web_extensions = (
'asp',
'aspx',
'cfm',
'cfml',
'css',
'htm',
'html',
'js',
'php',
'php3',
'xhtml',
)
if re.match(rf'(?i).*/.*\.({"|".join(web_extensions)})$', url):
return True
else:
return False

View file

@ -71,6 +71,7 @@ class Imgur(BaseDownloader):
@staticmethod
def _validate_extension(extension_suffix: str) -> str:
extension_suffix = extension_suffix.strip('?1')
possible_extensions = ('.jpg', '.png', '.mp4', '.gif')
selection = [ext for ext in possible_extensions if ext == extension_suffix]
if len(selection) == 1:

View file

@ -14,5 +14,9 @@ else
output="failed.txt"
fi
grep 'Could not download submission' "$file" | awk '{ print $12 }' | rev | cut -c 2- | rev >>"$output"
grep 'Failed to download resource' "$file" | awk '{ print $15 }' >>"$output"
{
grep 'Could not download submission' "$file" | awk '{ print $12 }' | rev | cut -c 2- | rev ;
grep 'Failed to download resource' "$file" | awk '{ print $15 }' ;
grep 'failed to download submission' "$file" | awk '{ print $14 }' | rev | cut -c 2- | rev ;
grep 'Failed to write file' "$file" | awk '{ print $16 }' | rev | cut -c 2- | rev ;
} >>"$output"

View file

@ -14,4 +14,10 @@ else
output="successful.txt"
fi
grep 'Downloaded submission' "$file" | awk '{ print $(NF-2) }' >> "$output"
{
grep 'Downloaded submission' "$file" | awk '{ print $(NF-2) }' ;
grep 'Resource hash' "$file" | awk '{ print $(NF-2) }' ;
grep 'Download filter' "$file" | awk '{ print $(NF-3) }' ;
grep 'already exists, continuing' "$file" | awk '{ print $(NF-3) }' ;
grep 'Hard link made' "$file" | awk '{ print $(NF) }' ;
} >> "$output"

View file

@ -4,7 +4,7 @@ description_file = README.md
description_content_type = text/markdown
home_page = https://github.com/aliparlakci/bulk-downloader-for-reddit
keywords = reddit, download, archive
version = 2.1.1
version = 2.2.0
author = Ali Parlakci
author_email = parlakciali@gmail.com
maintainer = Serene Arc

View file

@ -15,6 +15,7 @@ from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
'subreddit': 'Python',
'submission': 'mgi4op',
'submission_title': '76% Faster CPython',
'distinguished': None,
}),
))
def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_instance: praw.Reddit):

View file

@ -26,6 +26,13 @@ def test_get_comments(test_submission_id: str, min_comments: int, reddit_instanc
'author': 'sinjen-tos',
'id': 'm3reby',
'link_flair_text': 'image',
'pinned': False,
'spoiler': False,
'over_18': False,
'locked': False,
'distinguished': None,
'created_utc': 1615583837,
'permalink': '/r/australia/comments/m3reby/this_little_guy_fell_out_of_a_tree_and_in_front/'
}),
('m3kua3', {'author': 'DELETED'}),
))

View file

@ -69,6 +69,19 @@ def test_factory_lever_bad(test_url: str):
('https://youtube.com/watch?v=Gv8Wz74FjVA', 'youtube.com/watch'),
('https://i.imgur.com/BuzvZwb.gifv', 'i.imgur.com/BuzvZwb.gifv'),
))
def test_sanitise_urll(test_url: str, expected: str):
result = DownloadFactory._sanitise_url(test_url)
def test_sanitise_url(test_url: str, expected: str):
result = DownloadFactory.sanitise_url(test_url)
assert result == expected
@pytest.mark.parametrize(('test_url', 'expected'), (
('www.example.com/test.asp', True),
('www.example.com/test.html', True),
('www.example.com/test.js', True),
('www.example.com/test.xhtml', True),
('www.example.com/test.mp4', False),
('www.example.com/test.png', False),
))
def test_is_web_resource(test_url: str, expected: bool):
result = DownloadFactory.is_web_resource(test_url)
assert result == expected

View file

@ -34,9 +34,6 @@ def test_get_link(test_url: str, expected_urls: tuple[str]):
('https://www.erome.com/a/vqtPuLXh', {
'5da2a8d60d87bed279431fdec8e7d72f'
}),
('https://www.erome.com/i/ItASD33e', {
'b0d73fedc9ce6995c2f2c4fdb6f11eff'
}),
('https://www.erome.com/a/lGrcFxmb', {
'0e98f9f527a911dcedde4f846bb5b69f',
'25696ae364750a5303fc7d7dc78b35c1',

View file

@ -122,6 +122,14 @@ def test_imgur_extension_validation_bad(test_extension: str):
'029c475ce01b58fdf1269d8771d33913',
),
),
(
'https://imgur.com/a/eemHCCK',
(
'9cb757fd8f055e7ef7aa88addc9d9fa5',
'b6cb6c918e2544e96fb7c07d828774b5',
'fb6c913d721c0bbb96aa65d7f560d385',
),
),
))
def test_find_resources(test_url: str, expected_hashes: list[str]):
mock_download = Mock()
@ -131,5 +139,4 @@ def test_find_resources(test_url: str, expected_hashes: list[str]):
assert all([isinstance(res, Resource) for res in results])
[res.download(120) for res in results]
hashes = set([res.hash.hexdigest() for res in results])
assert len(results) == len(expected_hashes)
assert hashes == set(expected_hashes)

View file

@ -7,51 +7,20 @@ from unittest.mock import MagicMock
import praw
import pytest
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
from bdfr.archiver import Archiver
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', (
'm3reby',
@pytest.mark.parametrize(('test_submission_id', 'test_format'), (
('m3reby', 'xml'),
('m3reby', 'json'),
('m3reby', 'yaml'),
))
def test_write_submission_json(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
def test_write_submission_json(test_submission_id: str, tmp_path: Path, test_format: str, reddit_instance: praw.Reddit):
archiver_mock = MagicMock()
test_path = Path(tmp_path, 'test.json')
archiver_mock.args.format = test_format
test_path = Path(tmp_path, 'test')
test_submission = reddit_instance.submission(id=test_submission_id)
archiver_mock.file_name_formatter.format_path.return_value = test_path
test_entry = SubmissionArchiveEntry(test_submission)
Archiver._write_entry_json(archiver_mock, test_entry)
archiver_mock._write_content_to_disk.assert_called_once()
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', (
'm3reby',
))
def test_write_submission_xml(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
archiver_mock = MagicMock()
test_path = Path(tmp_path, 'test.xml')
test_submission = reddit_instance.submission(id=test_submission_id)
archiver_mock.file_name_formatter.format_path.return_value = test_path
test_entry = SubmissionArchiveEntry(test_submission)
Archiver._write_entry_xml(archiver_mock, test_entry)
archiver_mock._write_content_to_disk.assert_called_once()
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', (
'm3reby',
))
def test_write_submission_yaml(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
archiver_mock = MagicMock()
archiver_mock.download_directory = tmp_path
test_path = Path(tmp_path, 'test.yaml')
test_submission = reddit_instance.submission(id=test_submission_id)
archiver_mock.file_name_formatter.format_path.return_value = test_path
test_entry = SubmissionArchiveEntry(test_submission)
Archiver._write_entry_yaml(archiver_mock, test_entry)
archiver_mock._write_content_to_disk.assert_called_once()
Archiver.write_entry(archiver_mock, test_submission)

402
tests/test_connector.py Normal file
View file

@ -0,0 +1,402 @@
#!/usr/bin/env python3
# coding=utf-8
from pathlib import Path
from typing import Iterator
from unittest.mock import MagicMock
import praw
import praw.models
import pytest
from bdfr.configuration import Configuration
from bdfr.connector import RedditConnector, RedditTypes
from bdfr.download_filter import DownloadFilter
from bdfr.exceptions import BulkDownloaderException
from bdfr.file_name_formatter import FileNameFormatter
from bdfr.site_authenticator import SiteAuthenticator
@pytest.fixture()
def args() -> Configuration:
args = Configuration()
args.time_format = 'ISO'
return args
@pytest.fixture()
def downloader_mock(args: Configuration):
downloader_mock = MagicMock()
downloader_mock.args = args
downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name
downloader_mock.split_args_input = RedditConnector.split_args_input
downloader_mock.master_hash_list = {}
return downloader_mock
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list:
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
if result_limit is not None:
assert len(results) == result_limit
return results
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
downloader_mock.args.directory = tmp_path / 'test'
downloader_mock.config_directories.user_config_dir = tmp_path
RedditConnector.determine_directories(downloader_mock)
assert Path(tmp_path / 'test').exists()
@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
([], []),
(['.test'], ['test.com'],),
))
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
downloader_mock.args.skip = skip_extensions
downloader_mock.args.skip_domain = skip_domains
result = RedditConnector.create_download_filter(downloader_mock)
assert isinstance(result, DownloadFilter)
assert result.excluded_domains == skip_domains
assert result.excluded_extensions == skip_extensions
@pytest.mark.parametrize(('test_time', 'expected'), (
('all', 'all'),
('hour', 'hour'),
('day', 'day'),
('week', 'week'),
('random', 'all'),
('', 'all'),
))
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.time = test_time
result = RedditConnector.create_time_filter(downloader_mock)
assert isinstance(result, RedditTypes.TimeType)
assert result.name.lower() == expected
@pytest.mark.parametrize(('test_sort', 'expected'), (
('', 'hot'),
('hot', 'hot'),
('controversial', 'controversial'),
('new', 'new'),
))
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.sort = test_sort
result = RedditConnector.create_sort_filter(downloader_mock)
assert isinstance(result, RedditTypes.SortType)
assert result.name.lower() == expected
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
('{POSTID}', '{SUBREDDIT}'),
('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
('{POSTID}', 'test'),
('{POSTID}', ''),
('{POSTID}', '{SUBREDDIT}/{REDDITOR}'),
))
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme
result = RedditConnector.create_file_name_formatter(downloader_mock)
assert isinstance(result, FileNameFormatter)
assert result.file_format_string == test_file_scheme
assert result.directory_format_string == test_folder_scheme.split('/')
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
('', ''),
('', '{SUBREDDIT}'),
('test', '{SUBREDDIT}'),
))
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme
with pytest.raises(BulkDownloaderException):
RedditConnector.create_file_name_formatter(downloader_mock)
def test_create_authenticator(downloader_mock: MagicMock):
result = RedditConnector.create_authenticator(downloader_mock)
assert isinstance(result, SiteAuthenticator)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_ids', (
('lvpf4l',),
('lvpf4l', 'lvqnsn'),
('lvpf4l', 'lvqnsn', 'lvl9kd'),
))
def test_get_submissions_from_link(
test_submission_ids: list[str],
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock.args.link = test_submission_ids
downloader_mock.reddit_instance = reddit_instance
results = RedditConnector.get_submissions_from_link(downloader_mock)
assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
assert len(results[0]) == len(test_submission_ids)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
(('Futurology',), 10, 'hot', 'all', 10),
(('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
(('Futurology',), 20, 'hot', 'all', 20),
(('Futurology', 'Python'), 10, 'hot', 'all', 20),
(('Futurology',), 100, 'hot', 'all', 100),
(('Futurology',), 0, 'hot', 'all', 0),
(('Futurology',), 10, 'top', 'all', 10),
(('Futurology',), 10, 'top', 'week', 10),
(('Futurology',), 10, 'hot', 'week', 10),
))
def test_get_subreddit_normal(
test_subreddits: list[str],
limit: int,
sort_type: str,
time_filter: str,
max_expected_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.sort = sort_type
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
results = RedditConnector.get_subreddits(downloader_mock)
test_subreddits = downloader_mock._split_args_input(test_subreddits)
results = [sub for res1 in results for sub in res1]
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
assert all([res.subreddit.display_name in test_subreddits for res in results])
assert len(results) <= max_expected_len
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
(('Python',), 'scraper', 10, 'all', 10),
(('Python',), '', 10, 'all', 10),
(('Python',), 'djsdsgewef', 10, 'all', 0),
(('Python',), 'scraper', 10, 'year', 10),
(('Python',), 'scraper', 10, 'hour', 1),
))
def test_get_subreddit_search(
test_subreddits: list[str],
search_term: str,
time_filter: str,
limit: int,
max_expected_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.time = time_filter
downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
results = RedditConnector.get_subreddits(downloader_mock)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert all([res.subreddit.display_name in test_subreddits for res in results])
assert len(results) <= max_expected_len
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
('helen_darten', ('cuteanimalpics',), 10),
('korfor', ('chess',), 100),
))
# Good sources at https://www.reddit.com/r/multihub/
def test_get_multireddits_public(
test_user: str,
test_multireddits: list[str],
limit: int,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = [test_user]
downloader_mock.reddit_instance = reddit_instance
downloader_mock.create_filtered_listing_generator.return_value = \
RedditConnector.create_filtered_listing_generator(
downloader_mock,
reddit_instance.multireddit(test_user, test_multireddits[0]),
)
results = RedditConnector.get_multireddits(downloader_mock)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert len(results) == limit
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'limit'), (
('danigirl3694', 10),
('danigirl3694', 50),
('CapitanHam', None),
))
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.limit = limit
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.submitted = True
downloader_mock.args.user = [test_user]
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
downloader_mock.create_filtered_listing_generator.return_value = \
RedditConnector.create_filtered_listing_generator(
downloader_mock,
reddit_instance.redditor(test_user).submissions,
)
results = RedditConnector.get_user_data(downloader_mock)
results = assert_all_results_are_submissions(limit, results)
assert all([res.author.name == test_user for res in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
@pytest.mark.parametrize('test_flag', (
'upvoted',
'saved',
))
def test_get_user_authenticated_lists(
test_flag: str,
downloader_mock: MagicMock,
authenticated_reddit_instance: praw.Reddit,
):
downloader_mock.args.__dict__[test_flag] = True
downloader_mock.reddit_instance = authenticated_reddit_instance
downloader_mock.args.limit = 10
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')]
results = RedditConnector.get_user_data(downloader_mock)
assert_all_results_are_submissions(10, results)
@pytest.mark.parametrize(('test_name', 'expected'), (
('Mindustry', 'Mindustry'),
('Futurology', 'Futurology'),
('r/Mindustry', 'Mindustry'),
('TrollXChromosomes', 'TrollXChromosomes'),
('r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/Futurology/', 'Futurology'),
('https://www.reddit.com/r/Futurology', 'Futurology'),
))
def test_sanitise_subreddit_name(test_name: str, expected: str):
result = RedditConnector.sanitise_subreddit_name(test_name)
assert result == expected
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}),
([''], {''}),
(['test'], {'test'}),
))
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
results = RedditConnector.split_args_input(test_subreddit_entries)
assert results == expected
def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
test_file = tmp_path / 'test.txt'
test_file.write_text('aaaaaa\nbbbbbb')
downloader_mock.args.exclude_id_file = [test_file]
results = RedditConnector.read_excluded_ids(downloader_mock)
assert results == {'aaaaaa', 'bbbbbb'}
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', (
'Paracortex',
'crowdstrike',
'HannibalGoddamnit',
))
def test_check_user_existence_good(
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', (
'lhnhfkuhwreolo',
'adlkfmnhglojh',
))
def test_check_user_existence_nonexistent(
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
with pytest.raises(BulkDownloaderException, match='Could not find'):
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', (
'Bree-Boo',
))
def test_check_user_existence_banned(
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
with pytest.raises(BulkDownloaderException, match='is banned'):
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), (
('donaldtrump', 'cannot be found'),
('submitters', 'private and cannot be scraped')
))
def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit):
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
with pytest.raises(BulkDownloaderException, match=expected_message):
RedditConnector.check_subreddit_status(test_subreddit)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_subreddit_name', (
'Python',
'Mindustry',
'TrollXChromosomes',
'all',
))
def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit):
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
RedditConnector.check_subreddit_status(test_subreddit)

View file

@ -1,22 +1,19 @@
#!/usr/bin/env python3
# coding=utf-8
import os
import re
from pathlib import Path
from typing import Iterator
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import praw
import praw.models
import pytest
import bdfr.site_downloaders.download_factory
from bdfr.__main__ import setup_logging
from bdfr.configuration import Configuration
from bdfr.download_filter import DownloadFilter
from bdfr.downloader import RedditDownloader, RedditTypes
from bdfr.exceptions import BulkDownloaderException
from bdfr.file_name_formatter import FileNameFormatter
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.connector import RedditConnector
from bdfr.downloader import RedditDownloader
@pytest.fixture()
@ -30,314 +27,105 @@ def args() -> Configuration:
def downloader_mock(args: Configuration):
downloader_mock = MagicMock()
downloader_mock.args = args
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock._split_args_input = RedditDownloader._split_args_input
downloader_mock._sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name
downloader_mock._split_args_input = RedditConnector.split_args_input
downloader_mock.master_hash_list = {}
return downloader_mock
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
if result_limit is not None:
assert len(results) == result_limit
return results
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
downloader_mock.args.directory = tmp_path / 'test'
downloader_mock.config_directories.user_config_dir = tmp_path
RedditDownloader._determine_directories(downloader_mock)
assert Path(tmp_path / 'test').exists()
@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
([], []),
(['.test'], ['test.com'],),
@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), (
(('aaaaaa',), (), 1),
(('aaaaaa',), ('aaaaaa',), 0),
((), ('aaaaaa',), 0),
(('aaaaaa', 'bbbbbb'), ('aaaaaa',), 1),
(('aaaaaa', 'bbbbbb', 'cccccc'), ('aaaaaa',), 2),
))
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
downloader_mock.args.skip = skip_extensions
downloader_mock.args.skip_domain = skip_domains
result = RedditDownloader._create_download_filter(downloader_mock)
assert isinstance(result, DownloadFilter)
assert result.excluded_domains == skip_domains
assert result.excluded_extensions == skip_extensions
@pytest.mark.parametrize(('test_time', 'expected'), (
('all', 'all'),
('hour', 'hour'),
('day', 'day'),
('week', 'week'),
('random', 'all'),
('', 'all'),
))
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.time = test_time
result = RedditDownloader._create_time_filter(downloader_mock)
assert isinstance(result, RedditTypes.TimeType)
assert result.name.lower() == expected
@pytest.mark.parametrize(('test_sort', 'expected'), (
('', 'hot'),
('hot', 'hot'),
('controversial', 'controversial'),
('new', 'new'),
))
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.sort = test_sort
result = RedditDownloader._create_sort_filter(downloader_mock)
assert isinstance(result, RedditTypes.SortType)
assert result.name.lower() == expected
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
('{POSTID}', '{SUBREDDIT}'),
('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
('{POSTID}', 'test'),
('{POSTID}', ''),
('{POSTID}', '{SUBREDDIT}/{REDDITOR}'),
))
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme
result = RedditDownloader._create_file_name_formatter(downloader_mock)
assert isinstance(result, FileNameFormatter)
assert result.file_format_string == test_file_scheme
assert result.directory_format_string == test_folder_scheme.split('/')
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
('', ''),
('', '{SUBREDDIT}'),
('test', '{SUBREDDIT}'),
))
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme
with pytest.raises(BulkDownloaderException):
RedditDownloader._create_file_name_formatter(downloader_mock)
def test_create_authenticator(downloader_mock: MagicMock):
result = RedditDownloader._create_authenticator(downloader_mock)
assert isinstance(result, SiteAuthenticator)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_ids', (
('lvpf4l',),
('lvpf4l', 'lvqnsn'),
('lvpf4l', 'lvqnsn', 'lvl9kd'),
))
def test_get_submissions_from_link(
test_submission_ids: list[str],
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock.args.link = test_submission_ids
downloader_mock.reddit_instance = reddit_instance
results = RedditDownloader._get_submissions_from_link(downloader_mock)
assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
assert len(results[0]) == len(test_submission_ids)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
(('Futurology',), 10, 'hot', 'all', 10),
(('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
(('Futurology',), 20, 'hot', 'all', 20),
(('Futurology', 'Python'), 10, 'hot', 'all', 20),
(('Futurology',), 100, 'hot', 'all', 100),
(('Futurology',), 0, 'hot', 'all', 0),
(('Futurology',), 10, 'top', 'all', 10),
(('Futurology',), 10, 'top', 'week', 10),
(('Futurology',), 10, 'hot', 'week', 10),
))
def test_get_subreddit_normal(
test_subreddits: list[str],
limit: int,
sort_type: str,
time_filter: str,
max_expected_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.sort = sort_type
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditDownloader._create_sort_filter(downloader_mock)
results = RedditDownloader._get_subreddits(downloader_mock)
test_subreddits = downloader_mock._split_args_input(test_subreddits)
results = [sub for res1 in results for sub in res1]
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
assert all([res.subreddit.display_name in test_subreddits for res in results])
assert len(results) <= max_expected_len
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
(('Python',), 'scraper', 10, 'all', 10),
(('Python',), '', 10, 'all', 10),
(('Python',), 'djsdsgewef', 10, 'all', 0),
(('Python',), 'scraper', 10, 'year', 10),
(('Python',), 'scraper', 10, 'hour', 1),
))
def test_get_subreddit_search(
test_subreddits: list[str],
search_term: str,
time_filter: str,
limit: int,
max_expected_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.time = time_filter
downloader_mock.time_filter = RedditDownloader._create_time_filter(downloader_mock)
results = RedditDownloader._get_subreddits(downloader_mock)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert all([res.subreddit.display_name in test_subreddits for res in results])
assert len(results) <= max_expected_len
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
('helen_darten', ('cuteanimalpics',), 10),
('korfor', ('chess',), 100),
))
# Good sources at https://www.reddit.com/r/multihub/
def test_get_multireddits_public(
test_user: str,
test_multireddits: list[str],
limit: int,
reddit_instance: praw.Reddit,
@patch('bdfr.site_downloaders.download_factory.DownloadFactory.pull_lever')
def test_excluded_ids(
mock_function: MagicMock,
test_ids: tuple[str],
test_excluded: tuple[str],
expected_len: int,
downloader_mock: MagicMock,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = test_user
downloader_mock.reddit_instance = reddit_instance
downloader_mock._create_filtered_listing_generator.return_value = \
RedditDownloader._create_filtered_listing_generator(
downloader_mock,
reddit_instance.multireddit(test_user, test_multireddits[0]),
)
results = RedditDownloader._get_multireddits(downloader_mock)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert len(results) == limit
downloader_mock.excluded_submission_ids = test_excluded
mock_function.return_value = MagicMock()
mock_function.return_value.__name__ = 'test'
test_submissions = []
for test_id in test_ids:
m = MagicMock()
m.id = test_id
m.subreddit.display_name.return_value = 'https://www.example.com/'
m.__class__ = praw.models.Submission
test_submissions.append(m)
downloader_mock.reddit_lists = [test_submissions]
for submission in test_submissions:
RedditDownloader._download_submission(downloader_mock, submission)
assert mock_function.call_count == expected_len
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'limit'), (
('danigirl3694', 10),
('danigirl3694', 50),
('CapitanHam', None),
@pytest.mark.parametrize('test_submission_id', (
'm1hqw6',
))
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.limit = limit
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.submitted = True
downloader_mock.args.user = test_user
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
downloader_mock._create_filtered_listing_generator.return_value = \
RedditDownloader._create_filtered_listing_generator(
downloader_mock,
reddit_instance.redditor(test_user).submissions,
)
results = RedditDownloader._get_user_data(downloader_mock)
results = assert_all_results_are_submissions(limit, results)
assert all([res.author.name == test_user for res in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
@pytest.mark.parametrize('test_flag', (
'upvoted',
'saved',
))
def test_get_user_authenticated_lists(
test_flag: str,
downloader_mock: MagicMock,
authenticated_reddit_instance: praw.Reddit,
):
downloader_mock.args.__dict__[test_flag] = True
downloader_mock.reddit_instance = authenticated_reddit_instance
downloader_mock.args.user = 'me'
downloader_mock.args.limit = 10
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
RedditDownloader._resolve_user_name(downloader_mock)
results = RedditDownloader._get_user_data(downloader_mock)
assert_all_results_are_submissions(10, results)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), (
('ljyy27', 4),
))
def test_download_submission(
def test_mark_hard_link(
test_submission_id: str,
expected_files_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
tmp_path: Path):
tmp_path: Path,
reddit_instance: praw.Reddit
):
downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = ''
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
downloader_mock.args.make_hard_links = True
downloader_mock.download_directory = tmp_path
downloader_mock.args.folder_scheme = ''
downloader_mock.args.file_scheme = '{POSTID}'
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
original = Path(tmp_path, f'{test_submission_id}.png')
RedditDownloader._download_submission(downloader_mock, submission)
folder_contents = list(tmp_path.iterdir())
assert len(folder_contents) == expected_files_len
assert original.exists()
downloader_mock.args.file_scheme = 'test2_{POSTID}'
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
RedditDownloader._download_submission(downloader_mock, submission)
test_file_1_stats = original.stat()
test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino
assert test_file_1_stats.st_nlink == 2
assert test_file_1_stats.st_ino == test_file_2_inode
@pytest.mark.online
@pytest.mark.reddit
def test_download_submission_file_exists(
@pytest.mark.parametrize(('test_submission_id', 'test_creation_date'), (
('ndzz50', 1621204841.0),
))
def test_file_creation_date(
test_submission_id: str,
test_creation_date: float,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
tmp_path: Path,
capsys: pytest.CaptureFixture
reddit_instance: praw.Reddit
):
setup_logging(3)
downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = ''
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id='m1hqw6')
Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch()
downloader_mock.args.folder_scheme = ''
downloader_mock.args.file_scheme = '{POSTID}'
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
RedditDownloader._download_submission(downloader_mock, submission)
folder_contents = list(tmp_path.iterdir())
output = capsys.readouterr()
assert len(folder_contents) == 1
assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png already exists' in output.out
for file_path in Path(tmp_path).iterdir():
file_stats = os.stat(file_path)
assert file_stats.st_mtime == test_creation_date
def test_search_existing_files():
results = RedditDownloader.scan_existing_files(Path('.'))
assert len(results.keys()) != 0
@pytest.mark.online
@ -358,7 +146,7 @@ def test_download_submission_hash_exists(
downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = ''
downloader_mock.args.no_dupes = True
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path
downloader_mock.master_hash_list = {test_hash: None}
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
@ -369,165 +157,47 @@ def test_download_submission_hash_exists(
assert re.search(r'Resource hash .*? downloaded elsewhere', output.out)
@pytest.mark.parametrize(('test_name', 'expected'), (
('Mindustry', 'Mindustry'),
('Futurology', 'Futurology'),
('r/Mindustry', 'Mindustry'),
('TrollXChromosomes', 'TrollXChromosomes'),
('r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/Futurology/', 'Futurology'),
('https://www.reddit.com/r/Futurology', 'Futurology'),
))
def test_sanitise_subreddit_name(test_name: str, expected: str):
result = RedditDownloader._sanitise_subreddit_name(test_name)
assert result == expected
def test_search_existing_files():
results = RedditDownloader.scan_existing_files(Path('.'))
assert len(results.keys()) >= 40
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'})
))
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
results = RedditDownloader._split_args_input(test_subreddit_entries)
assert results == expected
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', (
'm1hqw6',
))
def test_mark_hard_link(
test_submission_id: str,
def test_download_submission_file_exists(
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
tmp_path: Path,
reddit_instance: praw.Reddit
capsys: pytest.CaptureFixture
):
setup_logging(3)
downloader_mock.reddit_instance = reddit_instance
downloader_mock.args.make_hard_links = True
downloader_mock.download_directory = tmp_path
downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = ''
downloader_mock.args.file_scheme = '{POSTID}'
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id='m1hqw6')
Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch()
RedditDownloader._download_submission(downloader_mock, submission)
folder_contents = list(tmp_path.iterdir())
output = capsys.readouterr()
assert len(folder_contents) == 1
assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png'\
' from submission m1hqw6 already exists' in output.out
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), (
('ljyy27', 4),
))
def test_download_submission(
test_submission_id: str,
expected_files_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
tmp_path: Path):
downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = ''
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
original = Path(tmp_path, f'{test_submission_id}.png')
RedditDownloader._download_submission(downloader_mock, submission)
assert original.exists()
downloader_mock.args.file_scheme = 'test2_{POSTID}'
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
RedditDownloader._download_submission(downloader_mock, submission)
test_file_1_stats = original.stat()
test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino
assert test_file_1_stats.st_nlink == 2
assert test_file_1_stats.st_ino == test_file_2_inode
@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), (
(('aaaaaa',), (), 1),
(('aaaaaa',), ('aaaaaa',), 0),
((), ('aaaaaa',), 0),
(('aaaaaa', 'bbbbbb'), ('aaaaaa',), 1),
))
def test_excluded_ids(test_ids: tuple[str], test_excluded: tuple[str], expected_len: int, downloader_mock: MagicMock):
downloader_mock.excluded_submission_ids = test_excluded
test_submissions = []
for test_id in test_ids:
m = MagicMock()
m.id = test_id
test_submissions.append(m)
downloader_mock.reddit_lists = [test_submissions]
RedditDownloader.download(downloader_mock)
assert downloader_mock._download_submission.call_count == expected_len
def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
test_file = tmp_path / 'test.txt'
test_file.write_text('aaaaaa\nbbbbbb')
downloader_mock.args.exclude_id_file = [test_file]
results = RedditDownloader._read_excluded_ids(downloader_mock)
assert results == {'aaaaaa', 'bbbbbb'}
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', (
'Paracortex',
'crowdstrike',
'HannibalGoddamnit',
))
def test_check_user_existence_good(
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
RedditDownloader._check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', (
'lhnhfkuhwreolo',
'adlkfmnhglojh',
))
def test_check_user_existence_nonexistent(
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
with pytest.raises(BulkDownloaderException, match='Could not find'):
RedditDownloader._check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', (
'Bree-Boo',
))
def test_check_user_existence_banned(
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
with pytest.raises(BulkDownloaderException, match='is banned'):
RedditDownloader._check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), (
('donaldtrump', 'cannot be found'),
('submitters', 'private and cannot be scraped')
))
def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit):
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
with pytest.raises(BulkDownloaderException, match=expected_message):
RedditDownloader._check_subreddit_status(test_subreddit)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_subreddit_name', (
'Python',
'Mindustry',
'TrollXChromosomes',
'all',
))
def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit):
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
RedditDownloader._check_subreddit_status(test_subreddit)
folder_contents = list(tmp_path.iterdir())
assert len(folder_contents) == expected_files_len

View file

@ -1,11 +1,12 @@
#!/usr/bin/env python3
# coding=utf-8
import platform
import unittest.mock
from datetime import datetime
from pathlib import Path
from typing import Optional
from unittest.mock import MagicMock
import platform
import praw.models
import pytest
@ -28,10 +29,10 @@ def submission() -> MagicMock:
return test
def do_test_string_equality(result: str, expected: str) -> bool:
def do_test_string_equality(result: [Path, str], expected: str) -> bool:
if platform.system() == 'Windows':
expected = FileNameFormatter._format_for_windows(expected)
return expected == result
return str(result).endswith(expected)
def do_test_path_equality(result: Path, expected: str) -> bool:
@ -41,7 +42,7 @@ def do_test_path_equality(result: Path, expected: str) -> bool:
expected = Path(*expected)
else:
expected = Path(expected)
return result == expected
return str(result).endswith(str(expected))
@pytest.fixture(scope='session')
@ -172,8 +173,9 @@ def test_format_multiple_resources():
mocks.append(new_mock)
test_formatter = FileNameFormatter('{TITLE}', '', 'ISO')
results = test_formatter.format_resource_paths(mocks, Path('.'))
results = set([str(res[0]) for res in results])
assert results == {'test_1.png', 'test_2.png', 'test_3.png', 'test_4.png'}
results = set([str(res[0].name) for res in results])
expected = {'test_1.png', 'test_2.png', 'test_3.png', 'test_4.png'}
assert results == expected
@pytest.mark.parametrize(('test_filename', 'test_ending'), (
@ -183,10 +185,11 @@ def test_format_multiple_resources():
('😍💕✨' * 100, '_1.png'),
))
def test_limit_filename_length(test_filename: str, test_ending: str):
result = FileNameFormatter._limit_file_name_length(test_filename, test_ending)
assert len(result) <= 255
assert len(result.encode('utf-8')) <= 255
assert isinstance(result, str)
result = FileNameFormatter._limit_file_name_length(test_filename, test_ending, Path('.'))
assert len(result.name) <= 255
assert len(result.name.encode('utf-8')) <= 255
assert len(str(result)) <= FileNameFormatter.find_max_path_length()
assert isinstance(result, Path)
@pytest.mark.parametrize(('test_filename', 'test_ending', 'expected_end'), (
@ -201,11 +204,11 @@ def test_limit_filename_length(test_filename: str, test_ending: str):
('😍💕✨' * 100 + '_aaa1aa', '_1.png', '_aaa1aa_1.png'),
))
def test_preserve_id_append_when_shortening(test_filename: str, test_ending: str, expected_end: str):
result = FileNameFormatter._limit_file_name_length(test_filename, test_ending)
assert len(result) <= 255
assert len(result.encode('utf-8')) <= 255
assert isinstance(result, str)
assert result.endswith(expected_end)
result = FileNameFormatter._limit_file_name_length(test_filename, test_ending, Path('.'))
assert len(result.name) <= 255
assert len(result.name.encode('utf-8')) <= 255
assert result.name.endswith(expected_end)
assert len(str(result)) <= FileNameFormatter.find_max_path_length()
def test_shorten_filenames(submission: MagicMock, tmp_path: Path):
@ -295,7 +298,7 @@ def test_format_archive_entry_comment(
test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme, 'ISO')
test_entry = Resource(test_comment, '', '.json')
result = test_formatter.format_path(test_entry, tmp_path)
assert do_test_string_equality(result.name, expected_name)
assert do_test_string_equality(result, expected_name)
@pytest.mark.parametrize(('test_folder_scheme', 'expected'), (
@ -364,3 +367,16 @@ def test_time_string_formats(test_time_format: str, expected: str):
test_formatter = FileNameFormatter('{TITLE}', '', test_time_format)
result = test_formatter._convert_timestamp(test_time.timestamp())
assert result == expected
def test_get_max_path_length():
result = FileNameFormatter.find_max_path_length()
assert result in (4096, 260, 1024)
def test_windows_max_path(tmp_path: Path):
with unittest.mock.patch('platform.system', return_value='Windows'):
with unittest.mock.patch('bdfr.file_name_formatter.FileNameFormatter.find_max_path_length', return_value=260):
result = FileNameFormatter._limit_file_name_length('test' * 100, '_1.png', tmp_path)
assert len(str(result)) <= 260
assert len(result.name) <= (260 - len(str(tmp_path)))

View file

@ -33,6 +33,17 @@ def create_basic_args_for_archive_runner(test_args: list[str], tmp_path: Path):
return out
def create_basic_args_for_cloner_runner(test_args: list[str], tmp_path: Path):
out = [
'clone',
str(tmp_path),
'-v',
'--config', 'test_config.cfg',
'--log', str(Path(tmp_path, 'test_log.txt')),
] + test_args
return out
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@ -117,6 +128,7 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa
@pytest.mark.authenticated
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'djnish', '--submitted', '--user', 'FriesWithThat', '-L', 10],
['--user', 'me', '--upvoted', '--authenticate', '-L', 10],
['--user', 'me', '--saved', '--authenticate', '-L', 10],
['--user', 'me', '--submitted', '--authenticate', '-L', 10],
@ -231,6 +243,7 @@ def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path):
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'me', '--authenticate', '--all-comments', '-L', '10'],
['--user', 'me', '--user', 'djnish', '--authenticate', '--all-comments', '-L', '10'],
))
def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path):
runner = CliRunner()
@ -265,12 +278,14 @@ def test_cli_archive_long(test_args: list[str], tmp_path: Path):
['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10],
['--subreddit', 'submitters', '-L', 10], # Private subreddit
['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit
['--user', 'djnish', '--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10],
))
def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Downloaded' not in result.output
@pytest.mark.online
@ -333,9 +348,42 @@ def test_cli_download_subreddit_exclusion(test_args: list[str], tmp_path: Path):
['--file-scheme', '{TITLE}'],
['--file-scheme', '{TITLE}_test_{SUBREDDIT}'],
))
def test_cli_file_scheme_warning(test_args: list[str], tmp_path: Path):
def test_cli_download_file_scheme_warning(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Some files might not be downloaded due to name conflicts' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['-l', 'm2601g', '--disable-module', 'Direct'],
['-l', 'nnb9vs', '--disable-module', 'YoutubeDlFallback'],
['-l', 'nnb9vs', '--disable-module', 'youtubedlfallback'],
))
def test_cli_download_disable_modules(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'skipped due to disabled module' in result.output
assert 'Downloaded submission' not in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['-l', 'm2601g'],
['-s', 'TrollXChromosomes/', '-L', 1],
))
def test_cli_scrape_general(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = create_basic_args_for_cloner_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Downloaded submission' in result.output
assert 'Record for entry item' in result.output