1
0
Fork 0
mirror of synced 2024-05-22 13:12:33 +12:00

Format according to the black standard

This commit is contained in:
Serene-Arc 2022-12-03 15:11:17 +10:00
parent 96cd7d7147
commit 0873a4a2b2
60 changed files with 2160 additions and 1790 deletions

View file

@ -13,53 +13,54 @@ from bdfr.downloader import RedditDownloader
logger = logging.getLogger() logger = logging.getLogger()
_common_options = [ _common_options = [
click.argument('directory', type=str), click.argument("directory", type=str),
click.option('--authenticate', is_flag=True, default=None), click.option("--authenticate", is_flag=True, default=None),
click.option('--config', type=str, default=None), click.option("--config", type=str, default=None),
click.option('--opts', type=str, default=None), click.option("--opts", type=str, default=None),
click.option('--disable-module', multiple=True, default=None, type=str), click.option("--disable-module", multiple=True, default=None, type=str),
click.option('--exclude-id', default=None, multiple=True), click.option("--exclude-id", default=None, multiple=True),
click.option('--exclude-id-file', default=None, multiple=True), click.option("--exclude-id-file", default=None, multiple=True),
click.option('--file-scheme', default=None, type=str), click.option("--file-scheme", default=None, type=str),
click.option('--folder-scheme', default=None, type=str), click.option("--folder-scheme", default=None, type=str),
click.option('--ignore-user', type=str, multiple=True, default=None), click.option("--ignore-user", type=str, multiple=True, default=None),
click.option('--include-id-file', multiple=True, default=None), click.option("--include-id-file", multiple=True, default=None),
click.option('--log', type=str, default=None), click.option("--log", type=str, default=None),
click.option('--saved', is_flag=True, default=None), click.option("--saved", is_flag=True, default=None),
click.option('--search', default=None, type=str), click.option("--search", default=None, type=str),
click.option('--submitted', is_flag=True, default=None), click.option("--submitted", is_flag=True, default=None),
click.option('--subscribed', is_flag=True, default=None), click.option("--subscribed", is_flag=True, default=None),
click.option('--time-format', type=str, default=None), click.option("--time-format", type=str, default=None),
click.option('--upvoted', is_flag=True, default=None), click.option("--upvoted", is_flag=True, default=None),
click.option('-L', '--limit', default=None, type=int), click.option("-L", "--limit", default=None, type=int),
click.option('-l', '--link', multiple=True, default=None, type=str), click.option("-l", "--link", multiple=True, default=None, type=str),
click.option('-m', '--multireddit', multiple=True, default=None, type=str), click.option("-m", "--multireddit", multiple=True, default=None, type=str),
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', 'controversial', 'rising', 'relevance')), click.option(
default=None), "-S", "--sort", type=click.Choice(("hot", "top", "new", "controversial", "rising", "relevance")), default=None
click.option('-s', '--subreddit', multiple=True, default=None, type=str), ),
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None), click.option("-s", "--subreddit", multiple=True, default=None, type=str),
click.option('-u', '--user', type=str, multiple=True, default=None), click.option("-t", "--time", type=click.Choice(("all", "hour", "day", "week", "month", "year")), default=None),
click.option('-v', '--verbose', default=None, count=True), click.option("-u", "--user", type=str, multiple=True, default=None),
click.option("-v", "--verbose", default=None, count=True),
] ]
_downloader_options = [ _downloader_options = [
click.option('--make-hard-links', is_flag=True, default=None), click.option("--make-hard-links", is_flag=True, default=None),
click.option('--max-wait-time', type=int, default=None), click.option("--max-wait-time", type=int, default=None),
click.option('--no-dupes', is_flag=True, default=None), click.option("--no-dupes", is_flag=True, default=None),
click.option('--search-existing', is_flag=True, default=None), click.option("--search-existing", is_flag=True, default=None),
click.option('--skip', default=None, multiple=True), click.option("--skip", default=None, multiple=True),
click.option('--skip-domain', default=None, multiple=True), click.option("--skip-domain", default=None, multiple=True),
click.option('--skip-subreddit', default=None, multiple=True), click.option("--skip-subreddit", default=None, multiple=True),
click.option('--min-score', type=int, default=None), click.option("--min-score", type=int, default=None),
click.option('--max-score', type=int, default=None), click.option("--max-score", type=int, default=None),
click.option('--min-score-ratio', type=float, default=None), click.option("--min-score-ratio", type=float, default=None),
click.option('--max-score-ratio', type=float, default=None), click.option("--max-score-ratio", type=float, default=None),
] ]
_archiver_options = [ _archiver_options = [
click.option('--all-comments', is_flag=True, default=None), click.option("--all-comments", is_flag=True, default=None),
click.option('--comment-context', is_flag=True, default=None), click.option("--comment-context", is_flag=True, default=None),
click.option('-f', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None), click.option("-f", "--format", type=click.Choice(("xml", "json", "yaml")), default=None),
] ]
@ -68,6 +69,7 @@ def _add_options(opts: list):
for opt in opts: for opt in opts:
func = opt(func) func = opt(func)
return func return func
return wrap return wrap
@ -76,7 +78,7 @@ def cli():
pass pass
@cli.command('download') @cli.command("download")
@_add_options(_common_options) @_add_options(_common_options)
@_add_options(_downloader_options) @_add_options(_downloader_options)
@click.pass_context @click.pass_context
@ -88,13 +90,13 @@ def cli_download(context: click.Context, **_):
reddit_downloader = RedditDownloader(config) reddit_downloader = RedditDownloader(config)
reddit_downloader.download() reddit_downloader.download()
except Exception: except Exception:
logger.exception('Downloader exited unexpectedly') logger.exception("Downloader exited unexpectedly")
raise raise
else: else:
logger.info('Program complete') logger.info("Program complete")
@cli.command('archive') @cli.command("archive")
@_add_options(_common_options) @_add_options(_common_options)
@_add_options(_archiver_options) @_add_options(_archiver_options)
@click.pass_context @click.pass_context
@ -106,13 +108,13 @@ def cli_archive(context: click.Context, **_):
reddit_archiver = Archiver(config) reddit_archiver = Archiver(config)
reddit_archiver.download() reddit_archiver.download()
except Exception: except Exception:
logger.exception('Archiver exited unexpectedly') logger.exception("Archiver exited unexpectedly")
raise raise
else: else:
logger.info('Program complete') logger.info("Program complete")
@cli.command('clone') @cli.command("clone")
@_add_options(_common_options) @_add_options(_common_options)
@_add_options(_archiver_options) @_add_options(_archiver_options)
@_add_options(_downloader_options) @_add_options(_downloader_options)
@ -125,10 +127,10 @@ def cli_clone(context: click.Context, **_):
reddit_scraper = RedditCloner(config) reddit_scraper = RedditCloner(config)
reddit_scraper.download() reddit_scraper.download()
except Exception: except Exception:
logger.exception('Scraper exited unexpectedly') logger.exception("Scraper exited unexpectedly")
raise raise
else: else:
logger.info('Program complete') logger.info("Program complete")
def setup_logging(verbosity: int): def setup_logging(verbosity: int):
@ -141,7 +143,7 @@ def setup_logging(verbosity: int):
stream = logging.StreamHandler(sys.stdout) stream = logging.StreamHandler(sys.stdout)
stream.addFilter(StreamExceptionFilter()) stream.addFilter(StreamExceptionFilter())
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s') formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s")
stream.setFormatter(formatter) stream.setFormatter(formatter)
logger.addHandler(stream) logger.addHandler(stream)
@ -151,10 +153,10 @@ def setup_logging(verbosity: int):
stream.setLevel(logging.DEBUG) stream.setLevel(logging.DEBUG)
else: else:
stream.setLevel(9) stream.setLevel(9)
logging.getLogger('praw').setLevel(logging.CRITICAL) logging.getLogger("praw").setLevel(logging.CRITICAL)
logging.getLogger('prawcore').setLevel(logging.CRITICAL) logging.getLogger("prawcore").setLevel(logging.CRITICAL)
logging.getLogger('urllib3').setLevel(logging.CRITICAL) logging.getLogger("urllib3").setLevel(logging.CRITICAL)
if __name__ == '__main__': if __name__ == "__main__":
cli() cli()

View file

@ -19,21 +19,21 @@ class BaseArchiveEntry(ABC):
@staticmethod @staticmethod
def _convert_comment_to_dict(in_comment: Comment) -> dict: def _convert_comment_to_dict(in_comment: Comment) -> dict:
out_dict = { out_dict = {
'author': in_comment.author.name if in_comment.author else 'DELETED', "author": in_comment.author.name if in_comment.author else "DELETED",
'id': in_comment.id, "id": in_comment.id,
'score': in_comment.score, "score": in_comment.score,
'subreddit': in_comment.subreddit.display_name, "subreddit": in_comment.subreddit.display_name,
'author_flair': in_comment.author_flair_text, "author_flair": in_comment.author_flair_text,
'submission': in_comment.submission.id, "submission": in_comment.submission.id,
'stickied': in_comment.stickied, "stickied": in_comment.stickied,
'body': in_comment.body, "body": in_comment.body,
'is_submitter': in_comment.is_submitter, "is_submitter": in_comment.is_submitter,
'distinguished': in_comment.distinguished, "distinguished": in_comment.distinguished,
'created_utc': in_comment.created_utc, "created_utc": in_comment.created_utc,
'parent_id': in_comment.parent_id, "parent_id": in_comment.parent_id,
'replies': [], "replies": [],
} }
in_comment.replies.replace_more(limit=None) in_comment.replies.replace_more(limit=None)
for reply in in_comment.replies: for reply in in_comment.replies:
out_dict['replies'].append(BaseArchiveEntry._convert_comment_to_dict(reply)) out_dict["replies"].append(BaseArchiveEntry._convert_comment_to_dict(reply))
return out_dict return out_dict

View file

@ -17,5 +17,5 @@ class CommentArchiveEntry(BaseArchiveEntry):
def compile(self) -> dict: def compile(self) -> dict:
self.source.refresh() self.source.refresh()
self.post_details = self._convert_comment_to_dict(self.source) self.post_details = self._convert_comment_to_dict(self.source)
self.post_details['submission_title'] = self.source.submission.title self.post_details["submission_title"] = self.source.submission.title
return self.post_details return self.post_details

View file

@ -18,32 +18,32 @@ class SubmissionArchiveEntry(BaseArchiveEntry):
comments = self._get_comments() comments = self._get_comments()
self._get_post_details() self._get_post_details()
out = self.post_details out = self.post_details
out['comments'] = comments out["comments"] = comments
return out return out
def _get_post_details(self): def _get_post_details(self):
self.post_details = { self.post_details = {
'title': self.source.title, "title": self.source.title,
'name': self.source.name, "name": self.source.name,
'url': self.source.url, "url": self.source.url,
'selftext': self.source.selftext, "selftext": self.source.selftext,
'score': self.source.score, "score": self.source.score,
'upvote_ratio': self.source.upvote_ratio, "upvote_ratio": self.source.upvote_ratio,
'permalink': self.source.permalink, "permalink": self.source.permalink,
'id': self.source.id, "id": self.source.id,
'author': self.source.author.name if self.source.author else 'DELETED', "author": self.source.author.name if self.source.author else "DELETED",
'link_flair_text': self.source.link_flair_text, "link_flair_text": self.source.link_flair_text,
'num_comments': self.source.num_comments, "num_comments": self.source.num_comments,
'over_18': self.source.over_18, "over_18": self.source.over_18,
'spoiler': self.source.spoiler, "spoiler": self.source.spoiler,
'pinned': self.source.pinned, "pinned": self.source.pinned,
'locked': self.source.locked, "locked": self.source.locked,
'distinguished': self.source.distinguished, "distinguished": self.source.distinguished,
'created_utc': self.source.created_utc, "created_utc": self.source.created_utc,
} }
def _get_comments(self) -> list[dict]: def _get_comments(self) -> list[dict]:
logger.debug(f'Retrieving full comment tree for submission {self.source.id}') logger.debug(f"Retrieving full comment tree for submission {self.source.id}")
comments = [] comments = []
self.source.comments.replace_more(limit=None) self.source.comments.replace_more(limit=None)
for top_level_comment in self.source.comments: for top_level_comment in self.source.comments:

View file

@ -30,26 +30,28 @@ class Archiver(RedditConnector):
for generator in self.reddit_lists: for generator in self.reddit_lists:
for submission in generator: for submission in generator:
try: try:
if (submission.author and submission.author.name in self.args.ignore_user) or \ if (submission.author and submission.author.name in self.args.ignore_user) or (
(submission.author is None and 'DELETED' in self.args.ignore_user): submission.author is None and "DELETED" in self.args.ignore_user
):
logger.debug( logger.debug(
f'Submission {submission.id} in {submission.subreddit.display_name} skipped' f"Submission {submission.id} in {submission.subreddit.display_name} skipped"
f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user') f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user'
)
continue continue
if submission.id in self.excluded_submission_ids: if submission.id in self.excluded_submission_ids:
logger.debug(f'Object {submission.id} in exclusion list, skipping') logger.debug(f"Object {submission.id} in exclusion list, skipping")
continue continue
logger.debug(f'Attempting to archive submission {submission.id}') logger.debug(f"Attempting to archive submission {submission.id}")
self.write_entry(submission) self.write_entry(submission)
except prawcore.PrawcoreException as e: except prawcore.PrawcoreException as e:
logger.error(f'Submission {submission.id} failed to be archived due to a PRAW exception: {e}') logger.error(f"Submission {submission.id} failed to be archived due to a PRAW exception: {e}")
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]: def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = [] supplied_submissions = []
for sub_id in self.args.link: for sub_id in self.args.link:
if len(sub_id) == 6: if len(sub_id) == 6:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id)) supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
elif re.match(r'^\w{7}$', sub_id): elif re.match(r"^\w{7}$", sub_id):
supplied_submissions.append(self.reddit_instance.comment(id=sub_id)) supplied_submissions.append(self.reddit_instance.comment(id=sub_id))
else: else:
supplied_submissions.append(self.reddit_instance.submission(url=sub_id)) supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
@ -60,7 +62,7 @@ class Archiver(RedditConnector):
if self.args.user and self.args.all_comments: if self.args.user and self.args.all_comments:
sort = self.determine_sort_function() sort = self.determine_sort_function()
for user in self.args.user: for user in self.args.user:
logger.debug(f'Retrieving comments of user {user}') logger.debug(f"Retrieving comments of user {user}")
results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit)) results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit))
return results return results
@ -71,43 +73,44 @@ class Archiver(RedditConnector):
elif isinstance(praw_item, praw.models.Comment): elif isinstance(praw_item, praw.models.Comment):
return CommentArchiveEntry(praw_item) return CommentArchiveEntry(praw_item)
else: else:
raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}') raise ArchiverError(f"Factory failed to classify item of type {type(praw_item).__name__}")
def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]): def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]):
if self.args.comment_context and isinstance(praw_item, praw.models.Comment): if self.args.comment_context and isinstance(praw_item, praw.models.Comment):
logger.debug(f'Converting comment {praw_item.id} to submission {praw_item.submission.id}') logger.debug(f"Converting comment {praw_item.id} to submission {praw_item.submission.id}")
praw_item = praw_item.submission praw_item = praw_item.submission
archive_entry = self._pull_lever_entry_factory(praw_item) archive_entry = self._pull_lever_entry_factory(praw_item)
if self.args.format == 'json': if self.args.format == "json":
self._write_entry_json(archive_entry) self._write_entry_json(archive_entry)
elif self.args.format == 'xml': elif self.args.format == "xml":
self._write_entry_xml(archive_entry) self._write_entry_xml(archive_entry)
elif self.args.format == 'yaml': elif self.args.format == "yaml":
self._write_entry_yaml(archive_entry) self._write_entry_yaml(archive_entry)
else: else:
raise ArchiverError(f'Unknown format {self.args.format} given') raise ArchiverError(f"Unknown format {self.args.format} given")
logger.info(f'Record for entry item {praw_item.id} written to disk') logger.info(f"Record for entry item {praw_item.id} written to disk")
def _write_entry_json(self, entry: BaseArchiveEntry): def _write_entry_json(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', lambda: None, '.json') resource = Resource(entry.source, "", lambda: None, ".json")
content = json.dumps(entry.compile()) content = json.dumps(entry.compile())
self._write_content_to_disk(resource, content) self._write_content_to_disk(resource, content)
def _write_entry_xml(self, entry: BaseArchiveEntry): def _write_entry_xml(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', lambda: None, '.xml') resource = Resource(entry.source, "", lambda: None, ".xml")
content = dict2xml.dict2xml(entry.compile(), wrap='root') content = dict2xml.dict2xml(entry.compile(), wrap="root")
self._write_content_to_disk(resource, content) self._write_content_to_disk(resource, content)
def _write_entry_yaml(self, entry: BaseArchiveEntry): def _write_entry_yaml(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', lambda: None, '.yaml') resource = Resource(entry.source, "", lambda: None, ".yaml")
content = yaml.dump(entry.compile()) content = yaml.dump(entry.compile())
self._write_content_to_disk(resource, content) self._write_content_to_disk(resource, content)
def _write_content_to_disk(self, resource: Resource, content: str): def _write_content_to_disk(self, resource: Resource, content: str):
file_path = self.file_name_formatter.format_path(resource, self.download_directory) file_path = self.file_name_formatter.format_path(resource, self.download_directory)
file_path.parent.mkdir(exist_ok=True, parents=True) file_path.parent.mkdir(exist_ok=True, parents=True)
with open(file_path, 'w', encoding="utf-8") as file: with open(file_path, "w", encoding="utf-8") as file:
logger.debug( logger.debug(
f'Writing entry {resource.source_submission.id} to file in {resource.extension[1:].upper()}' f"Writing entry {resource.source_submission.id} to file in {resource.extension[1:].upper()}"
f' format at {file_path}') f" format at {file_path}"
)
file.write(content) file.write(content)

View file

@ -23,4 +23,4 @@ class RedditCloner(RedditDownloader, Archiver):
self._download_submission(submission) self._download_submission(submission)
self.write_entry(submission) self.write_entry(submission)
except prawcore.PrawcoreException as e: except prawcore.PrawcoreException as e:
logger.error(f'Submission {submission.id} failed to be cloned due to a PRAW exception: {e}') logger.error(f"Submission {submission.id} failed to be cloned due to a PRAW exception: {e}")

View file

@ -1,28 +1,29 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# coding=utf-8 # coding=utf-8
import logging
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import logging
import click import click
import yaml import yaml
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Configuration(Namespace): class Configuration(Namespace):
def __init__(self): def __init__(self):
super(Configuration, self).__init__() super(Configuration, self).__init__()
self.authenticate = False self.authenticate = False
self.config = None self.config = None
self.opts: Optional[str] = None self.opts: Optional[str] = None
self.directory: str = '.' self.directory: str = "."
self.disable_module: list[str] = [] self.disable_module: list[str] = []
self.exclude_id = [] self.exclude_id = []
self.exclude_id_file = [] self.exclude_id_file = []
self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}' self.file_scheme: str = "{REDDITOR}_{TITLE}_{POSTID}"
self.folder_scheme: str = '{SUBREDDIT}' self.folder_scheme: str = "{SUBREDDIT}"
self.ignore_user = [] self.ignore_user = []
self.include_id_file = [] self.include_id_file = []
self.limit: Optional[int] = None self.limit: Optional[int] = None
@ -42,11 +43,11 @@ class Configuration(Namespace):
self.max_score = None self.max_score = None
self.min_score_ratio = None self.min_score_ratio = None
self.max_score_ratio = None self.max_score_ratio = None
self.sort: str = 'hot' self.sort: str = "hot"
self.submitted: bool = False self.submitted: bool = False
self.subscribed: bool = False self.subscribed: bool = False
self.subreddit: list[str] = [] self.subreddit: list[str] = []
self.time: str = 'all' self.time: str = "all"
self.time_format = None self.time_format = None
self.upvoted: bool = False self.upvoted: bool = False
self.user: list[str] = [] self.user: list[str] = []
@ -54,15 +55,15 @@ class Configuration(Namespace):
# Archiver-specific options # Archiver-specific options
self.all_comments = False self.all_comments = False
self.format = 'json' self.format = "json"
self.comment_context: bool = False self.comment_context: bool = False
def process_click_arguments(self, context: click.Context): def process_click_arguments(self, context: click.Context):
if context.params.get('opts') is not None: if context.params.get("opts") is not None:
self.parse_yaml_options(context.params['opts']) self.parse_yaml_options(context.params["opts"])
for arg_key in context.params.keys(): for arg_key in context.params.keys():
if not hasattr(self, arg_key): if not hasattr(self, arg_key):
logger.warning(f'Ignoring an unknown CLI argument: {arg_key}') logger.warning(f"Ignoring an unknown CLI argument: {arg_key}")
continue continue
val = context.params[arg_key] val = context.params[arg_key]
if val is None or val == (): if val is None or val == ():
@ -73,16 +74,16 @@ class Configuration(Namespace):
def parse_yaml_options(self, file_path: str): def parse_yaml_options(self, file_path: str):
yaml_file_loc = Path(file_path) yaml_file_loc = Path(file_path)
if not yaml_file_loc.exists(): if not yaml_file_loc.exists():
logger.error(f'No YAML file found at {yaml_file_loc}') logger.error(f"No YAML file found at {yaml_file_loc}")
return return
with yaml_file_loc.open() as file: with yaml_file_loc.open() as file:
try: try:
opts = yaml.load(file, Loader=yaml.FullLoader) opts = yaml.load(file, Loader=yaml.FullLoader)
except yaml.YAMLError as e: except yaml.YAMLError as e:
logger.error(f'Could not parse YAML options file: {e}') logger.error(f"Could not parse YAML options file: {e}")
return return
for arg_key, val in opts.items(): for arg_key, val in opts.items():
if not hasattr(self, arg_key): if not hasattr(self, arg_key):
logger.warning(f'Ignoring an unknown YAML argument: {arg_key}') logger.warning(f"Ignoring an unknown YAML argument: {arg_key}")
continue continue
setattr(self, arg_key, val) setattr(self, arg_key, val)

View file

@ -41,18 +41,18 @@ class RedditTypes:
TOP = auto() TOP = auto()
class TimeType(Enum): class TimeType(Enum):
ALL = 'all' ALL = "all"
DAY = 'day' DAY = "day"
HOUR = 'hour' HOUR = "hour"
MONTH = 'month' MONTH = "month"
WEEK = 'week' WEEK = "week"
YEAR = 'year' YEAR = "year"
class RedditConnector(metaclass=ABCMeta): class RedditConnector(metaclass=ABCMeta):
def __init__(self, args: Configuration): def __init__(self, args: Configuration):
self.args = args self.args = args
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR') self.config_directories = appdirs.AppDirs("bdfr", "BDFR")
self.run_time = datetime.now().isoformat() self.run_time = datetime.now().isoformat()
self._setup_internal_objects() self._setup_internal_objects()
@ -68,13 +68,13 @@ class RedditConnector(metaclass=ABCMeta):
self.parse_disabled_modules() self.parse_disabled_modules()
self.download_filter = self.create_download_filter() self.download_filter = self.create_download_filter()
logger.log(9, 'Created download filter') logger.log(9, "Created download filter")
self.time_filter = self.create_time_filter() self.time_filter = self.create_time_filter()
logger.log(9, 'Created time filter') logger.log(9, "Created time filter")
self.sort_filter = self.create_sort_filter() self.sort_filter = self.create_sort_filter()
logger.log(9, 'Created sort filter') logger.log(9, "Created sort filter")
self.file_name_formatter = self.create_file_name_formatter() self.file_name_formatter = self.create_file_name_formatter()
logger.log(9, 'Create file name formatter') logger.log(9, "Create file name formatter")
self.create_reddit_instance() self.create_reddit_instance()
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user])) self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
@ -88,7 +88,7 @@ class RedditConnector(metaclass=ABCMeta):
self.master_hash_list = {} self.master_hash_list = {}
self.authenticator = self.create_authenticator() self.authenticator = self.create_authenticator()
logger.log(9, 'Created site authenticator') logger.log(9, "Created site authenticator")
self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit) self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit)
self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit} self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit}
@ -96,18 +96,18 @@ class RedditConnector(metaclass=ABCMeta):
def read_config(self): def read_config(self):
"""Read any cfg values that need to be processed""" """Read any cfg values that need to be processed"""
if self.args.max_wait_time is None: if self.args.max_wait_time is None:
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time', fallback=120) self.args.max_wait_time = self.cfg_parser.getint("DEFAULT", "max_wait_time", fallback=120)
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds') logger.debug(f"Setting maximum download wait time to {self.args.max_wait_time} seconds")
if self.args.time_format is None: if self.args.time_format is None:
option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO') option = self.cfg_parser.get("DEFAULT", "time_format", fallback="ISO")
if re.match(r'^[\s\'\"]*$', option): if re.match(r"^[\s\'\"]*$", option):
option = 'ISO' option = "ISO"
logger.debug(f'Setting datetime format string to {option}') logger.debug(f"Setting datetime format string to {option}")
self.args.time_format = option self.args.time_format = option
if not self.args.disable_module: if not self.args.disable_module:
self.args.disable_module = [self.cfg_parser.get('DEFAULT', 'disabled_modules', fallback='')] self.args.disable_module = [self.cfg_parser.get("DEFAULT", "disabled_modules", fallback="")]
# Update config on disk # Update config on disk
with open(self.config_location, 'w') as file: with open(self.config_location, "w") as file:
self.cfg_parser.write(file) self.cfg_parser.write(file)
def parse_disabled_modules(self): def parse_disabled_modules(self):
@ -119,48 +119,48 @@ class RedditConnector(metaclass=ABCMeta):
def create_reddit_instance(self): def create_reddit_instance(self):
if self.args.authenticate: if self.args.authenticate:
logger.debug('Using authenticated Reddit instance') logger.debug("Using authenticated Reddit instance")
if not self.cfg_parser.has_option('DEFAULT', 'user_token'): if not self.cfg_parser.has_option("DEFAULT", "user_token"):
logger.log(9, 'Commencing OAuth2 authentication') logger.log(9, "Commencing OAuth2 authentication")
scopes = self.cfg_parser.get('DEFAULT', 'scopes', fallback='identity, history, read, save') scopes = self.cfg_parser.get("DEFAULT", "scopes", fallback="identity, history, read, save")
scopes = OAuth2Authenticator.split_scopes(scopes) scopes = OAuth2Authenticator.split_scopes(scopes)
oauth2_authenticator = OAuth2Authenticator( oauth2_authenticator = OAuth2Authenticator(
scopes, scopes,
self.cfg_parser.get('DEFAULT', 'client_id'), self.cfg_parser.get("DEFAULT", "client_id"),
self.cfg_parser.get('DEFAULT', 'client_secret'), self.cfg_parser.get("DEFAULT", "client_secret"),
) )
token = oauth2_authenticator.retrieve_new_token() token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token self.cfg_parser["DEFAULT"]["user_token"] = token
with open(self.config_location, 'w') as file: with open(self.config_location, "w") as file:
self.cfg_parser.write(file, True) self.cfg_parser.write(file, True)
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location) token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
self.authenticated = True self.authenticated = True
self.reddit_instance = praw.Reddit( self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'), client_id=self.cfg_parser.get("DEFAULT", "client_id"),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), client_secret=self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(), user_agent=socket.gethostname(),
token_manager=token_manager, token_manager=token_manager,
) )
else: else:
logger.debug('Using unauthenticated Reddit instance') logger.debug("Using unauthenticated Reddit instance")
self.authenticated = False self.authenticated = False
self.reddit_instance = praw.Reddit( self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'), client_id=self.cfg_parser.get("DEFAULT", "client_id"),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), client_secret=self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(), user_agent=socket.gethostname(),
) )
def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]: def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
master_list = [] master_list = []
master_list.extend(self.get_subreddits()) master_list.extend(self.get_subreddits())
logger.log(9, 'Retrieved subreddits') logger.log(9, "Retrieved subreddits")
master_list.extend(self.get_multireddits()) master_list.extend(self.get_multireddits())
logger.log(9, 'Retrieved multireddits') logger.log(9, "Retrieved multireddits")
master_list.extend(self.get_user_data()) master_list.extend(self.get_user_data())
logger.log(9, 'Retrieved user data') logger.log(9, "Retrieved user data")
master_list.extend(self.get_submissions_from_link()) master_list.extend(self.get_submissions_from_link())
logger.log(9, 'Retrieved submissions for given links') logger.log(9, "Retrieved submissions for given links")
return master_list return master_list
def determine_directories(self): def determine_directories(self):
@ -178,37 +178,37 @@ class RedditConnector(metaclass=ABCMeta):
self.config_location = cfg_path self.config_location = cfg_path
return return
possible_paths = [ possible_paths = [
Path('./config.cfg'), Path("./config.cfg"),
Path('./default_config.cfg'), Path("./default_config.cfg"),
Path(self.config_directory, 'config.cfg'), Path(self.config_directory, "config.cfg"),
Path(self.config_directory, 'default_config.cfg'), Path(self.config_directory, "default_config.cfg"),
] ]
self.config_location = None self.config_location = None
for path in possible_paths: for path in possible_paths:
if path.resolve().expanduser().exists(): if path.resolve().expanduser().exists():
self.config_location = path self.config_location = path
logger.debug(f'Loading configuration from {path}') logger.debug(f"Loading configuration from {path}")
break break
if not self.config_location: if not self.config_location:
with importlib.resources.path('bdfr', 'default_config.cfg') as path: with importlib.resources.path("bdfr", "default_config.cfg") as path:
self.config_location = path self.config_location = path
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg')) shutil.copy(self.config_location, Path(self.config_directory, "default_config.cfg"))
if not self.config_location: if not self.config_location:
raise errors.BulkDownloaderException('Could not find a configuration file to load') raise errors.BulkDownloaderException("Could not find a configuration file to load")
self.cfg_parser.read(self.config_location) self.cfg_parser.read(self.config_location)
def create_file_logger(self): def create_file_logger(self):
main_logger = logging.getLogger() main_logger = logging.getLogger()
if self.args.log is None: if self.args.log is None:
log_path = Path(self.config_directory, 'log_output.txt') log_path = Path(self.config_directory, "log_output.txt")
else: else:
log_path = Path(self.args.log).resolve().expanduser() log_path = Path(self.args.log).resolve().expanduser()
if not log_path.parent.exists(): if not log_path.parent.exists():
raise errors.BulkDownloaderException(f'Designated location for logfile does not exist') raise errors.BulkDownloaderException(f"Designated location for logfile does not exist")
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3) backup_count = self.cfg_parser.getint("DEFAULT", "backup_log_count", fallback=3)
file_handler = logging.handlers.RotatingFileHandler( file_handler = logging.handlers.RotatingFileHandler(
log_path, log_path,
mode='a', mode="a",
backupCount=backup_count, backupCount=backup_count,
) )
if log_path.exists(): if log_path.exists():
@ -216,10 +216,11 @@ class RedditConnector(metaclass=ABCMeta):
file_handler.doRollover() file_handler.doRollover()
except PermissionError: except PermissionError:
logger.critical( logger.critical(
'Cannot rollover logfile, make sure this is the only ' "Cannot rollover logfile, make sure this is the only "
'BDFR process or specify alternate logfile location') "BDFR process or specify alternate logfile location"
)
raise raise
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s') formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s")
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
file_handler.setLevel(0) file_handler.setLevel(0)
@ -227,16 +228,16 @@ class RedditConnector(metaclass=ABCMeta):
@staticmethod @staticmethod
def sanitise_subreddit_name(subreddit: str) -> str: def sanitise_subreddit_name(subreddit: str) -> str:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$') pattern = re.compile(r"^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$")
match = re.match(pattern, subreddit) match = re.match(pattern, subreddit)
if not match: if not match:
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}') raise errors.BulkDownloaderException(f"Could not find subreddit name in string {subreddit}")
return match.group(1) return match.group(1)
@staticmethod @staticmethod
def split_args_input(entries: list[str]) -> set[str]: def split_args_input(entries: list[str]) -> set[str]:
all_entries = [] all_entries = []
split_pattern = re.compile(r'[,;]\s?') split_pattern = re.compile(r"[,;]\s?")
for entry in entries: for entry in entries:
results = re.split(split_pattern, entry) results = re.split(split_pattern, entry)
all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results]) all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results])
@ -251,13 +252,13 @@ class RedditConnector(metaclass=ABCMeta):
subscribed_subreddits = list(self.reddit_instance.user.subreddits(limit=None)) subscribed_subreddits = list(self.reddit_instance.user.subreddits(limit=None))
subscribed_subreddits = {s.display_name for s in subscribed_subreddits} subscribed_subreddits = {s.display_name for s in subscribed_subreddits}
except prawcore.InsufficientScope: except prawcore.InsufficientScope:
logger.error('BDFR has insufficient scope to access subreddit lists') logger.error("BDFR has insufficient scope to access subreddit lists")
else: else:
logger.error('Cannot find subscribed subreddits without an authenticated instance') logger.error("Cannot find subscribed subreddits without an authenticated instance")
if self.args.subreddit or subscribed_subreddits: if self.args.subreddit or subscribed_subreddits:
for reddit in self.split_args_input(self.args.subreddit) | subscribed_subreddits: for reddit in self.split_args_input(self.args.subreddit) | subscribed_subreddits:
if reddit == 'friends' and self.authenticated is False: if reddit == "friends" and self.authenticated is False:
logger.error('Cannot read friends subreddit without an authenticated instance') logger.error("Cannot read friends subreddit without an authenticated instance")
continue continue
try: try:
reddit = self.reddit_instance.subreddit(reddit) reddit = self.reddit_instance.subreddit(reddit)
@ -267,26 +268,29 @@ class RedditConnector(metaclass=ABCMeta):
logger.error(e) logger.error(e)
continue continue
if self.args.search: if self.args.search:
out.append(reddit.search( out.append(
self.args.search, reddit.search(
sort=self.sort_filter.name.lower(), self.args.search,
limit=self.args.limit, sort=self.sort_filter.name.lower(),
time_filter=self.time_filter.value, limit=self.args.limit,
)) time_filter=self.time_filter.value,
)
)
logger.debug( logger.debug(
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"') f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"'
)
else: else:
out.append(self.create_filtered_listing_generator(reddit)) out.append(self.create_filtered_listing_generator(reddit))
logger.debug(f'Added submissions from subreddit {reddit}') logger.debug(f"Added submissions from subreddit {reddit}")
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') logger.error(f"Failed to get submissions for subreddit {reddit}: {e}")
return out return out
def resolve_user_name(self, in_name: str) -> str: def resolve_user_name(self, in_name: str) -> str:
if in_name == 'me': if in_name == "me":
if self.authenticated: if self.authenticated:
resolved_name = self.reddit_instance.user.me().name resolved_name = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {resolved_name}') logger.log(9, f"Resolved user to {resolved_name}")
return resolved_name return resolved_name
else: else:
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used') logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
@ -318,7 +322,7 @@ class RedditConnector(metaclass=ABCMeta):
def get_multireddits(self) -> list[Iterator]: def get_multireddits(self) -> list[Iterator]:
if self.args.multireddit: if self.args.multireddit:
if len(self.args.user) != 1: if len(self.args.user) != 1:
logger.error(f'Only 1 user can be supplied when retrieving from multireddits') logger.error(f"Only 1 user can be supplied when retrieving from multireddits")
return [] return []
out = [] out = []
for multi in self.split_args_input(self.args.multireddit): for multi in self.split_args_input(self.args.multireddit):
@ -327,9 +331,9 @@ class RedditConnector(metaclass=ABCMeta):
if not multi.subreddits: if not multi.subreddits:
raise errors.BulkDownloaderException raise errors.BulkDownloaderException
out.append(self.create_filtered_listing_generator(multi)) out.append(self.create_filtered_listing_generator(multi))
logger.debug(f'Added submissions from multireddit {multi}') logger.debug(f"Added submissions from multireddit {multi}")
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e: except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
logger.error(f'Failed to get submissions for multireddit {multi}: {e}') logger.error(f"Failed to get submissions for multireddit {multi}: {e}")
return out return out
else: else:
return [] return []
@ -344,7 +348,7 @@ class RedditConnector(metaclass=ABCMeta):
def get_user_data(self) -> list[Iterator]: def get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]): if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if not self.args.user: if not self.args.user:
logger.warning('At least one user must be supplied to download user data') logger.warning("At least one user must be supplied to download user data")
return [] return []
generators = [] generators = []
for user in self.args.user: for user in self.args.user:
@ -354,18 +358,20 @@ class RedditConnector(metaclass=ABCMeta):
logger.error(e) logger.error(e)
continue continue
if self.args.submitted: if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}') logger.debug(f"Retrieving submitted posts of user {self.args.user}")
generators.append(self.create_filtered_listing_generator( generators.append(
self.reddit_instance.redditor(user).submissions, self.create_filtered_listing_generator(
)) self.reddit_instance.redditor(user).submissions,
)
)
if not self.authenticated and any((self.args.upvoted, self.args.saved)): if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication') logger.warning("Accessing user lists requires authentication")
else: else:
if self.args.upvoted: if self.args.upvoted:
logger.debug(f'Retrieving upvoted posts of user {self.args.user}') logger.debug(f"Retrieving upvoted posts of user {self.args.user}")
generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit)) generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit))
if self.args.saved: if self.args.saved:
logger.debug(f'Retrieving saved posts of user {self.args.user}') logger.debug(f"Retrieving saved posts of user {self.args.user}")
generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit)) generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit))
return generators return generators
else: else:
@ -377,10 +383,10 @@ class RedditConnector(metaclass=ABCMeta):
if user.id: if user.id:
return return
except prawcore.exceptions.NotFound: except prawcore.exceptions.NotFound:
raise errors.BulkDownloaderException(f'Could not find user {name}') raise errors.BulkDownloaderException(f"Could not find user {name}")
except AttributeError: except AttributeError:
if hasattr(user, 'is_suspended'): if hasattr(user, "is_suspended"):
raise errors.BulkDownloaderException(f'User {name} is banned') raise errors.BulkDownloaderException(f"User {name} is banned")
def create_file_name_formatter(self) -> FileNameFormatter: def create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format) return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format)
@ -409,7 +415,7 @@ class RedditConnector(metaclass=ABCMeta):
@staticmethod @staticmethod
def check_subreddit_status(subreddit: praw.models.Subreddit): def check_subreddit_status(subreddit: praw.models.Subreddit):
if subreddit.display_name in ('all', 'friends'): if subreddit.display_name in ("all", "friends"):
return return
try: try:
assert subreddit.id assert subreddit.id
@ -418,7 +424,7 @@ class RedditConnector(metaclass=ABCMeta):
except prawcore.Redirect: except prawcore.Redirect:
raise errors.BulkDownloaderException(f"Source {subreddit.display_name} does not exist") raise errors.BulkDownloaderException(f"Source {subreddit.display_name} does not exist")
except prawcore.Forbidden: except prawcore.Forbidden:
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped') raise errors.BulkDownloaderException(f"Source {subreddit.display_name} is private and cannot be scraped")
@staticmethod @staticmethod
def read_id_files(file_locations: list[str]) -> set[str]: def read_id_files(file_locations: list[str]) -> set[str]:
@ -426,9 +432,9 @@ class RedditConnector(metaclass=ABCMeta):
for id_file in file_locations: for id_file in file_locations:
id_file = Path(id_file).resolve().expanduser() id_file = Path(id_file).resolve().expanduser()
if not id_file.exists(): if not id_file.exists():
logger.warning(f'ID file at {id_file} does not exist') logger.warning(f"ID file at {id_file} does not exist")
continue continue
with id_file.open('r') as file: with id_file.open("r") as file:
for line in file: for line in file:
out.append(line.strip()) out.append(line.strip())
return set(out) return set(out)

View file

@ -33,8 +33,8 @@ class DownloadFilter:
def _check_extension(self, resource_extension: str) -> bool: def _check_extension(self, resource_extension: str) -> bool:
if not self.excluded_extensions: if not self.excluded_extensions:
return True return True
combined_extensions = '|'.join(self.excluded_extensions) combined_extensions = "|".join(self.excluded_extensions)
pattern = re.compile(r'.*({})$'.format(combined_extensions)) pattern = re.compile(r".*({})$".format(combined_extensions))
if re.match(pattern, resource_extension): if re.match(pattern, resource_extension):
logger.log(9, f'Url "{resource_extension}" matched with "{pattern}"') logger.log(9, f'Url "{resource_extension}" matched with "{pattern}"')
return False return False
@ -44,8 +44,8 @@ class DownloadFilter:
def _check_domain(self, url: str) -> bool: def _check_domain(self, url: str) -> bool:
if not self.excluded_domains: if not self.excluded_domains:
return True return True
combined_domains = '|'.join(self.excluded_domains) combined_domains = "|".join(self.excluded_domains)
pattern = re.compile(r'https?://.*({}).*'.format(combined_domains)) pattern = re.compile(r"https?://.*({}).*".format(combined_domains))
if re.match(pattern, url): if re.match(pattern, url):
logger.log(9, f'Url "{url}" matched with "{pattern}"') logger.log(9, f'Url "{url}" matched with "{pattern}"')
return False return False

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
def _calc_hash(existing_file: Path): def _calc_hash(existing_file: Path):
chunk_size = 1024 * 1024 chunk_size = 1024 * 1024
md5_hash = hashlib.md5() md5_hash = hashlib.md5()
with existing_file.open('rb') as file: with existing_file.open("rb") as file:
chunk = file.read(chunk_size) chunk = file.read(chunk_size)
while chunk: while chunk:
md5_hash.update(chunk) md5_hash.update(chunk)
@ -46,28 +46,32 @@ class RedditDownloader(RedditConnector):
try: try:
self._download_submission(submission) self._download_submission(submission)
except prawcore.PrawcoreException as e: except prawcore.PrawcoreException as e:
logger.error(f'Submission {submission.id} failed to download due to a PRAW exception: {e}') logger.error(f"Submission {submission.id} failed to download due to a PRAW exception: {e}")
def _download_submission(self, submission: praw.models.Submission): def _download_submission(self, submission: praw.models.Submission):
if submission.id in self.excluded_submission_ids: if submission.id in self.excluded_submission_ids:
logger.debug(f'Object {submission.id} in exclusion list, skipping') logger.debug(f"Object {submission.id} in exclusion list, skipping")
return return
elif submission.subreddit.display_name.lower() in self.args.skip_subreddit: elif submission.subreddit.display_name.lower() in self.args.skip_subreddit:
logger.debug(f'Submission {submission.id} in {submission.subreddit.display_name} in skip list') logger.debug(f"Submission {submission.id} in {submission.subreddit.display_name} in skip list")
return return
elif (submission.author and submission.author.name in self.args.ignore_user) or \ elif (submission.author and submission.author.name in self.args.ignore_user) or (
(submission.author is None and 'DELETED' in self.args.ignore_user): submission.author is None and "DELETED" in self.args.ignore_user
):
logger.debug( logger.debug(
f'Submission {submission.id} in {submission.subreddit.display_name} skipped' f"Submission {submission.id} in {submission.subreddit.display_name} skipped"
f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user') f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user'
)
return return
elif self.args.min_score and submission.score < self.args.min_score: elif self.args.min_score and submission.score < self.args.min_score:
logger.debug( logger.debug(
f"Submission {submission.id} filtered due to score {submission.score} < [{self.args.min_score}]") f"Submission {submission.id} filtered due to score {submission.score} < [{self.args.min_score}]"
)
return return
elif self.args.max_score and self.args.max_score < submission.score: elif self.args.max_score and self.args.max_score < submission.score:
logger.debug( logger.debug(
f"Submission {submission.id} filtered due to score {submission.score} > [{self.args.max_score}]") f"Submission {submission.id} filtered due to score {submission.score} > [{self.args.max_score}]"
)
return return
elif (self.args.min_score_ratio and submission.upvote_ratio < self.args.min_score_ratio) or ( elif (self.args.min_score_ratio and submission.upvote_ratio < self.args.min_score_ratio) or (
self.args.max_score_ratio and self.args.max_score_ratio < submission.upvote_ratio self.args.max_score_ratio and self.args.max_score_ratio < submission.upvote_ratio
@ -75,47 +79,48 @@ class RedditDownloader(RedditConnector):
logger.debug(f"Submission {submission.id} filtered due to score ratio ({submission.upvote_ratio})") logger.debug(f"Submission {submission.id} filtered due to score ratio ({submission.upvote_ratio})")
return return
elif not isinstance(submission, praw.models.Submission): elif not isinstance(submission, praw.models.Submission):
logger.warning(f'{submission.id} is not a submission') logger.warning(f"{submission.id} is not a submission")
return return
elif not self.download_filter.check_url(submission.url): elif not self.download_filter.check_url(submission.url):
logger.debug(f'Submission {submission.id} filtered due to URL {submission.url}') logger.debug(f"Submission {submission.id} filtered due to URL {submission.url}")
return return
logger.debug(f'Attempting to download submission {submission.id}') logger.debug(f"Attempting to download submission {submission.id}")
try: try:
downloader_class = DownloadFactory.pull_lever(submission.url) downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission) downloader = downloader_class(submission)
logger.debug(f'Using {downloader_class.__name__} with url {submission.url}') logger.debug(f"Using {downloader_class.__name__} with url {submission.url}")
except errors.NotADownloadableLinkError as e: except errors.NotADownloadableLinkError as e:
logger.error(f'Could not download submission {submission.id}: {e}') logger.error(f"Could not download submission {submission.id}: {e}")
return return
if downloader_class.__name__.lower() in self.args.disable_module: if downloader_class.__name__.lower() in self.args.disable_module:
logger.debug(f'Submission {submission.id} skipped due to disabled module {downloader_class.__name__}') logger.debug(f"Submission {submission.id} skipped due to disabled module {downloader_class.__name__}")
return return
try: try:
content = downloader.find_resources(self.authenticator) content = downloader.find_resources(self.authenticator)
except errors.SiteDownloaderError as e: except errors.SiteDownloaderError as e:
logger.error(f'Site {downloader_class.__name__} failed to download submission {submission.id}: {e}') logger.error(f"Site {downloader_class.__name__} failed to download submission {submission.id}: {e}")
return return
for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory): for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory):
if destination.exists(): if destination.exists():
logger.debug(f'File {destination} from submission {submission.id} already exists, continuing') logger.debug(f"File {destination} from submission {submission.id} already exists, continuing")
continue continue
elif not self.download_filter.check_resource(res): elif not self.download_filter.check_resource(res):
logger.debug(f'Download filter removed {submission.id} file with URL {submission.url}') logger.debug(f"Download filter removed {submission.id} file with URL {submission.url}")
continue continue
try: try:
res.download({'max_wait_time': self.args.max_wait_time}) res.download({"max_wait_time": self.args.max_wait_time})
except errors.BulkDownloaderException as e: except errors.BulkDownloaderException as e:
logger.error(f'Failed to download resource {res.url} in submission {submission.id} ' logger.error(
f'with downloader {downloader_class.__name__}: {e}') f"Failed to download resource {res.url} in submission {submission.id} "
f"with downloader {downloader_class.__name__}: {e}"
)
return return
resource_hash = res.hash.hexdigest() resource_hash = res.hash.hexdigest()
destination.parent.mkdir(parents=True, exist_ok=True) destination.parent.mkdir(parents=True, exist_ok=True)
if resource_hash in self.master_hash_list: if resource_hash in self.master_hash_list:
if self.args.no_dupes: if self.args.no_dupes:
logger.info( logger.info(f"Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere")
f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere')
return return
elif self.args.make_hard_links: elif self.args.make_hard_links:
try: try:
@ -123,29 +128,30 @@ class RedditDownloader(RedditConnector):
except AttributeError: except AttributeError:
self.master_hash_list[resource_hash].link_to(destination) self.master_hash_list[resource_hash].link_to(destination)
logger.info( logger.info(
f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}' f"Hard link made linking {destination} to {self.master_hash_list[resource_hash]}"
f' in submission {submission.id}') f" in submission {submission.id}"
)
return return
try: try:
with destination.open('wb') as file: with destination.open("wb") as file:
file.write(res.content) file.write(res.content)
logger.debug(f'Written file to {destination}') logger.debug(f"Written file to {destination}")
except OSError as e: except OSError as e:
logger.exception(e) logger.exception(e)
logger.error(f'Failed to write file in submission {submission.id} to {destination}: {e}') logger.error(f"Failed to write file in submission {submission.id} to {destination}: {e}")
return return
creation_time = time.mktime(datetime.fromtimestamp(submission.created_utc).timetuple()) creation_time = time.mktime(datetime.fromtimestamp(submission.created_utc).timetuple())
os.utime(destination, (creation_time, creation_time)) os.utime(destination, (creation_time, creation_time))
self.master_hash_list[resource_hash] = destination self.master_hash_list[resource_hash] = destination
logger.debug(f'Hash added to master list: {resource_hash}') logger.debug(f"Hash added to master list: {resource_hash}")
logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}') logger.info(f"Downloaded submission {submission.id} from {submission.subreddit.display_name}")
@staticmethod @staticmethod
def scan_existing_files(directory: Path) -> dict[str, Path]: def scan_existing_files(directory: Path) -> dict[str, Path]:
files = [] files = []
for (dirpath, dirnames, filenames) in os.walk(directory): for (dirpath, dirnames, filenames) in os.walk(directory):
files.extend([Path(dirpath, file) for file in filenames]) files.extend([Path(dirpath, file) for file in filenames])
logger.info(f'Calculating hashes for {len(files)} files') logger.info(f"Calculating hashes for {len(files)} files")
pool = Pool(15) pool = Pool(15)
results = pool.map(_calc_hash, files) results = pool.map(_calc_hash, files)

View file

@ -1,5 +1,6 @@
#!/usr/bin/env #!/usr/bin/env
class BulkDownloaderException(Exception): class BulkDownloaderException(Exception):
pass pass

View file

@ -18,20 +18,20 @@ logger = logging.getLogger(__name__)
class FileNameFormatter: class FileNameFormatter:
key_terms = ( key_terms = (
'date', "date",
'flair', "flair",
'postid', "postid",
'redditor', "redditor",
'subreddit', "subreddit",
'title', "title",
'upvotes', "upvotes",
) )
def __init__(self, file_format_string: str, directory_format_string: str, time_format_string: str): def __init__(self, file_format_string: str, directory_format_string: str, time_format_string: str):
if not self.validate_string(file_format_string): if not self.validate_string(file_format_string):
raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string') raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string')
self.file_format_string = file_format_string self.file_format_string = file_format_string
self.directory_format_string: list[str] = directory_format_string.split('/') self.directory_format_string: list[str] = directory_format_string.split("/")
self.time_format_string = time_format_string self.time_format_string = time_format_string
def _format_name(self, submission: Union[Comment, Submission], format_string: str) -> str: def _format_name(self, submission: Union[Comment, Submission], format_string: str) -> str:
@ -40,108 +40,111 @@ class FileNameFormatter:
elif isinstance(submission, Comment): elif isinstance(submission, Comment):
attributes = self._generate_name_dict_from_comment(submission) attributes = self._generate_name_dict_from_comment(submission)
else: else:
raise BulkDownloaderException(f'Cannot name object {type(submission).__name__}') raise BulkDownloaderException(f"Cannot name object {type(submission).__name__}")
result = format_string result = format_string
for key in attributes.keys(): for key in attributes.keys():
if re.search(fr'(?i).*{{{key}}}.*', result): if re.search(rf"(?i).*{{{key}}}.*", result):
key_value = str(attributes.get(key, 'unknown')) key_value = str(attributes.get(key, "unknown"))
key_value = FileNameFormatter._convert_unicode_escapes(key_value) key_value = FileNameFormatter._convert_unicode_escapes(key_value)
key_value = key_value.replace('\\', '\\\\') key_value = key_value.replace("\\", "\\\\")
result = re.sub(fr'(?i){{{key}}}', key_value, result) result = re.sub(rf"(?i){{{key}}}", key_value, result)
result = result.replace('/', '') result = result.replace("/", "")
if platform.system() == 'Windows': if platform.system() == "Windows":
result = FileNameFormatter._format_for_windows(result) result = FileNameFormatter._format_for_windows(result)
return result return result
@staticmethod @staticmethod
def _convert_unicode_escapes(in_string: str) -> str: def _convert_unicode_escapes(in_string: str) -> str:
pattern = re.compile(r'(\\u\d{4})') pattern = re.compile(r"(\\u\d{4})")
matches = re.search(pattern, in_string) matches = re.search(pattern, in_string)
if matches: if matches:
for match in matches.groups(): for match in matches.groups():
converted_match = bytes(match, 'utf-8').decode('unicode-escape') converted_match = bytes(match, "utf-8").decode("unicode-escape")
in_string = in_string.replace(match, converted_match) in_string = in_string.replace(match, converted_match)
return in_string return in_string
def _generate_name_dict_from_submission(self, submission: Submission) -> dict: def _generate_name_dict_from_submission(self, submission: Submission) -> dict:
submission_attributes = { submission_attributes = {
'title': submission.title, "title": submission.title,
'subreddit': submission.subreddit.display_name, "subreddit": submission.subreddit.display_name,
'redditor': submission.author.name if submission.author else 'DELETED', "redditor": submission.author.name if submission.author else "DELETED",
'postid': submission.id, "postid": submission.id,
'upvotes': submission.score, "upvotes": submission.score,
'flair': submission.link_flair_text, "flair": submission.link_flair_text,
'date': self._convert_timestamp(submission.created_utc), "date": self._convert_timestamp(submission.created_utc),
} }
return submission_attributes return submission_attributes
def _convert_timestamp(self, timestamp: float) -> str: def _convert_timestamp(self, timestamp: float) -> str:
input_time = datetime.datetime.fromtimestamp(timestamp) input_time = datetime.datetime.fromtimestamp(timestamp)
if self.time_format_string.upper().strip() == 'ISO': if self.time_format_string.upper().strip() == "ISO":
return input_time.isoformat() return input_time.isoformat()
else: else:
return input_time.strftime(self.time_format_string) return input_time.strftime(self.time_format_string)
def _generate_name_dict_from_comment(self, comment: Comment) -> dict: def _generate_name_dict_from_comment(self, comment: Comment) -> dict:
comment_attributes = { comment_attributes = {
'title': comment.submission.title, "title": comment.submission.title,
'subreddit': comment.subreddit.display_name, "subreddit": comment.subreddit.display_name,
'redditor': comment.author.name if comment.author else 'DELETED', "redditor": comment.author.name if comment.author else "DELETED",
'postid': comment.id, "postid": comment.id,
'upvotes': comment.score, "upvotes": comment.score,
'flair': '', "flair": "",
'date': self._convert_timestamp(comment.created_utc), "date": self._convert_timestamp(comment.created_utc),
} }
return comment_attributes return comment_attributes
def format_path( def format_path(
self, self,
resource: Resource, resource: Resource,
destination_directory: Path, destination_directory: Path,
index: Optional[int] = None, index: Optional[int] = None,
) -> Path: ) -> Path:
subfolder = Path( subfolder = Path(
destination_directory, destination_directory,
*[self._format_name(resource.source_submission, part) for part in self.directory_format_string], *[self._format_name(resource.source_submission, part) for part in self.directory_format_string],
) )
index = f'_{index}' if index else '' index = f"_{index}" if index else ""
if not resource.extension: if not resource.extension:
raise BulkDownloaderException(f'Resource from {resource.url} has no extension') raise BulkDownloaderException(f"Resource from {resource.url} has no extension")
file_name = str(self._format_name(resource.source_submission, self.file_format_string)) file_name = str(self._format_name(resource.source_submission, self.file_format_string))
file_name = re.sub(r'\n', ' ', file_name) file_name = re.sub(r"\n", " ", file_name)
if not re.match(r'.*\.$', file_name) and not re.match(r'^\..*', resource.extension): if not re.match(r".*\.$", file_name) and not re.match(r"^\..*", resource.extension):
ending = index + '.' + resource.extension ending = index + "." + resource.extension
else: else:
ending = index + resource.extension ending = index + resource.extension
try: try:
file_path = self.limit_file_name_length(file_name, ending, subfolder) file_path = self.limit_file_name_length(file_name, ending, subfolder)
except TypeError: except TypeError:
raise BulkDownloaderException(f'Could not determine path name: {subfolder}, {index}, {resource.extension}') raise BulkDownloaderException(f"Could not determine path name: {subfolder}, {index}, {resource.extension}")
return file_path return file_path
@staticmethod @staticmethod
def limit_file_name_length(filename: str, ending: str, root: Path) -> Path: def limit_file_name_length(filename: str, ending: str, root: Path) -> Path:
root = root.resolve().expanduser() root = root.resolve().expanduser()
possible_id = re.search(r'((?:_\w{6})?$)', filename) possible_id = re.search(r"((?:_\w{6})?$)", filename)
if possible_id: if possible_id:
ending = possible_id.group(1) + ending ending = possible_id.group(1) + ending
filename = filename[:possible_id.start()] filename = filename[: possible_id.start()]
max_path = FileNameFormatter.find_max_path_length() max_path = FileNameFormatter.find_max_path_length()
max_file_part_length_chars = 255 - len(ending) max_file_part_length_chars = 255 - len(ending)
max_file_part_length_bytes = 255 - len(ending.encode('utf-8')) max_file_part_length_bytes = 255 - len(ending.encode("utf-8"))
max_path_length = max_path - len(ending) - len(str(root)) - 1 max_path_length = max_path - len(ending) - len(str(root)) - 1
out = Path(root, filename + ending) out = Path(root, filename + ending)
while any([len(filename) > max_file_part_length_chars, while any(
len(filename.encode('utf-8')) > max_file_part_length_bytes, [
len(str(out)) > max_path_length, len(filename) > max_file_part_length_chars,
]): len(filename.encode("utf-8")) > max_file_part_length_bytes,
len(str(out)) > max_path_length,
]
):
filename = filename[:-1] filename = filename[:-1]
out = Path(root, filename + ending) out = Path(root, filename + ending)
@ -150,44 +153,46 @@ class FileNameFormatter:
@staticmethod @staticmethod
def find_max_path_length() -> int: def find_max_path_length() -> int:
try: try:
return int(subprocess.check_output(['getconf', 'PATH_MAX', '/'])) return int(subprocess.check_output(["getconf", "PATH_MAX", "/"]))
except (ValueError, subprocess.CalledProcessError, OSError): except (ValueError, subprocess.CalledProcessError, OSError):
if platform.system() == 'Windows': if platform.system() == "Windows":
return 260 return 260
else: else:
return 4096 return 4096
def format_resource_paths( def format_resource_paths(
self, self,
resources: list[Resource], resources: list[Resource],
destination_directory: Path, destination_directory: Path,
) -> list[tuple[Path, Resource]]: ) -> list[tuple[Path, Resource]]:
out = [] out = []
if len(resources) == 1: if len(resources) == 1:
try: try:
out.append((self.format_path(resources[0], destination_directory, None), resources[0])) out.append((self.format_path(resources[0], destination_directory, None), resources[0]))
except BulkDownloaderException as e: except BulkDownloaderException as e:
logger.error(f'Could not generate file path for resource {resources[0].url}: {e}') logger.error(f"Could not generate file path for resource {resources[0].url}: {e}")
logger.exception('Could not generate file path') logger.exception("Could not generate file path")
else: else:
for i, res in enumerate(resources, start=1): for i, res in enumerate(resources, start=1):
logger.log(9, f'Formatting filename with index {i}') logger.log(9, f"Formatting filename with index {i}")
try: try:
out.append((self.format_path(res, destination_directory, i), res)) out.append((self.format_path(res, destination_directory, i), res))
except BulkDownloaderException as e: except BulkDownloaderException as e:
logger.error(f'Could not generate file path for resource {res.url}: {e}') logger.error(f"Could not generate file path for resource {res.url}: {e}")
logger.exception('Could not generate file path') logger.exception("Could not generate file path")
return out return out
@staticmethod @staticmethod
def validate_string(test_string: str) -> bool: def validate_string(test_string: str) -> bool:
if not test_string: if not test_string:
return False return False
result = any([f'{{{key}}}' in test_string.lower() for key in FileNameFormatter.key_terms]) result = any([f"{{{key}}}" in test_string.lower() for key in FileNameFormatter.key_terms])
if result: if result:
if 'POSTID' not in test_string: if "POSTID" not in test_string:
logger.warning('Some files might not be downloaded due to name conflicts as filenames are' logger.warning(
' not guaranteed to be be unique without {POSTID}') "Some files might not be downloaded due to name conflicts as filenames are"
" not guaranteed to be be unique without {POSTID}"
)
return True return True
else: else:
return False return False
@ -196,11 +201,11 @@ class FileNameFormatter:
def _format_for_windows(input_string: str) -> str: def _format_for_windows(input_string: str) -> str:
invalid_characters = r'<>:"\/|?*' invalid_characters = r'<>:"\/|?*'
for char in invalid_characters: for char in invalid_characters:
input_string = input_string.replace(char, '') input_string = input_string.replace(char, "")
input_string = FileNameFormatter._strip_emojis(input_string) input_string = FileNameFormatter._strip_emojis(input_string)
return input_string return input_string
@staticmethod @staticmethod
def _strip_emojis(input_string: str) -> str: def _strip_emojis(input_string: str) -> str:
result = input_string.encode('ascii', errors='ignore').decode('utf-8') result = input_string.encode("ascii", errors="ignore").decode("utf-8")
return result return result

View file

@ -17,7 +17,6 @@ logger = logging.getLogger(__name__)
class OAuth2Authenticator: class OAuth2Authenticator:
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str): def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
self._check_scopes(wanted_scopes) self._check_scopes(wanted_scopes)
self.scopes = wanted_scopes self.scopes = wanted_scopes
@ -26,39 +25,41 @@ class OAuth2Authenticator:
@staticmethod @staticmethod
def _check_scopes(wanted_scopes: set[str]): def _check_scopes(wanted_scopes: set[str]):
response = requests.get('https://www.reddit.com/api/v1/scopes.json', response = requests.get(
headers={'User-Agent': 'fetch-scopes test'}) "https://www.reddit.com/api/v1/scopes.json", headers={"User-Agent": "fetch-scopes test"}
)
known_scopes = [scope for scope, data in response.json().items()] known_scopes = [scope for scope, data in response.json().items()]
known_scopes.append('*') known_scopes.append("*")
for scope in wanted_scopes: for scope in wanted_scopes:
if scope not in known_scopes: if scope not in known_scopes:
raise BulkDownloaderException(f'Scope {scope} is not known to reddit') raise BulkDownloaderException(f"Scope {scope} is not known to reddit")
@staticmethod @staticmethod
def split_scopes(scopes: str) -> set[str]: def split_scopes(scopes: str) -> set[str]:
scopes = re.split(r'[,: ]+', scopes) scopes = re.split(r"[,: ]+", scopes)
return set(scopes) return set(scopes)
def retrieve_new_token(self) -> str: def retrieve_new_token(self) -> str:
reddit = praw.Reddit( reddit = praw.Reddit(
redirect_uri='http://localhost:7634', redirect_uri="http://localhost:7634",
user_agent='obtain_refresh_token for BDFR', user_agent="obtain_refresh_token for BDFR",
client_id=self.client_id, client_id=self.client_id,
client_secret=self.client_secret) client_secret=self.client_secret,
)
state = str(random.randint(0, 65000)) state = str(random.randint(0, 65000))
url = reddit.auth.url(self.scopes, state, 'permanent') url = reddit.auth.url(self.scopes, state, "permanent")
logger.warning('Authentication action required before the program can proceed') logger.warning("Authentication action required before the program can proceed")
logger.warning(f'Authenticate at {url}') logger.warning(f"Authenticate at {url}")
client = self.receive_connection() client = self.receive_connection()
data = client.recv(1024).decode('utf-8') data = client.recv(1024).decode("utf-8")
param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&') param_tokens = data.split(" ", 2)[1].split("?", 1)[1].split("&")
params = {key: value for (key, value) in [token.split('=') for token in param_tokens]} params = {key: value for (key, value) in [token.split("=") for token in param_tokens]}
if state != params['state']: if state != params["state"]:
self.send_message(client) self.send_message(client)
raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}') raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
elif 'error' in params: elif "error" in params:
self.send_message(client) self.send_message(client)
raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}') raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
@ -70,19 +71,19 @@ class OAuth2Authenticator:
def receive_connection() -> socket.socket: def receive_connection() -> socket.socket:
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(('0.0.0.0', 7634)) server.bind(("0.0.0.0", 7634))
logger.log(9, 'Server listening on 0.0.0.0:7634') logger.log(9, "Server listening on 0.0.0.0:7634")
server.listen(1) server.listen(1)
client = server.accept()[0] client = server.accept()[0]
server.close() server.close()
logger.log(9, 'Server closed') logger.log(9, "Server closed")
return client return client
@staticmethod @staticmethod
def send_message(client: socket.socket, message: str = ''): def send_message(client: socket.socket, message: str = ""):
client.send(f'HTTP/1.1 200 OK\r\n\r\n{message}'.encode('utf-8')) client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode("utf-8"))
client.close() client.close()
@ -94,14 +95,14 @@ class OAuth2TokenManager(praw.reddit.BaseTokenManager):
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer): def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
if authorizer.refresh_token is None: if authorizer.refresh_token is None:
if self.config.has_option('DEFAULT', 'user_token'): if self.config.has_option("DEFAULT", "user_token"):
authorizer.refresh_token = self.config.get('DEFAULT', 'user_token') authorizer.refresh_token = self.config.get("DEFAULT", "user_token")
logger.log(9, 'Loaded OAuth2 token for authoriser') logger.log(9, "Loaded OAuth2 token for authoriser")
else: else:
raise RedditAuthenticationError('No auth token loaded in configuration') raise RedditAuthenticationError("No auth token loaded in configuration")
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer): def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
self.config.set('DEFAULT', 'user_token', authorizer.refresh_token) self.config.set("DEFAULT", "user_token", authorizer.refresh_token)
with open(self.config_location, 'w') as file: with open(self.config_location, "w") as file:
self.config.write(file, True) self.config.write(file, True)
logger.log(9, f'Written OAuth2 token from authoriser to {self.config_location}') logger.log(9, f"Written OAuth2 token from authoriser to {self.config_location}")

View file

@ -39,7 +39,7 @@ class Resource:
try: try:
content = self.download_function(download_parameters) content = self.download_function(download_parameters)
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
raise BulkDownloaderException(f'Could not download resource: {e}') raise BulkDownloaderException(f"Could not download resource: {e}")
except BulkDownloaderException: except BulkDownloaderException:
raise raise
if content: if content:
@ -51,7 +51,7 @@ class Resource:
self.hash = hashlib.md5(self.content) self.hash = hashlib.md5(self.content)
def _determine_extension(self) -> Optional[str]: def _determine_extension(self) -> Optional[str]:
extension_pattern = re.compile(r'.*(\..{3,5})$') extension_pattern = re.compile(r".*(\..{3,5})$")
stripped_url = urllib.parse.urlsplit(self.url).path stripped_url = urllib.parse.urlsplit(self.url).path
match = re.search(extension_pattern, stripped_url) match = re.search(extension_pattern, stripped_url)
if match: if match:
@ -59,27 +59,28 @@ class Resource:
@staticmethod @staticmethod
def http_download(url: str, download_parameters: dict) -> Optional[bytes]: def http_download(url: str, download_parameters: dict) -> Optional[bytes]:
headers = download_parameters.get('headers') headers = download_parameters.get("headers")
current_wait_time = 60 current_wait_time = 60
if 'max_wait_time' in download_parameters: if "max_wait_time" in download_parameters:
max_wait_time = download_parameters['max_wait_time'] max_wait_time = download_parameters["max_wait_time"]
else: else:
max_wait_time = 300 max_wait_time = 300
while True: while True:
try: try:
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
if re.match(r'^2\d{2}', str(response.status_code)) and response.content: if re.match(r"^2\d{2}", str(response.status_code)) and response.content:
return response.content return response.content
elif response.status_code in (408, 429): elif response.status_code in (408, 429):
raise requests.exceptions.ConnectionError(f'Response code {response.status_code}') raise requests.exceptions.ConnectionError(f"Response code {response.status_code}")
else: else:
raise BulkDownloaderException( raise BulkDownloaderException(
f'Unrecoverable error requesting resource: HTTP Code {response.status_code}') f"Unrecoverable error requesting resource: HTTP Code {response.status_code}"
)
except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e: except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e:
logger.warning(f'Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}') logger.warning(f"Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}")
time.sleep(current_wait_time) time.sleep(current_wait_time)
if current_wait_time < max_wait_time: if current_wait_time < max_wait_time:
current_wait_time += 60 current_wait_time += 60
else: else:
logger.error(f'Max wait time exceeded for resource at url {url}') logger.error(f"Max wait time exceeded for resource at url {url}")
raise raise

View file

@ -31,7 +31,7 @@ class BaseDownloader(ABC):
res = requests.get(url, cookies=cookies, headers=headers) res = requests.get(url, cookies=cookies, headers=headers)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.exception(e) logger.exception(e)
raise SiteDownloaderError(f'Failed to get page {url}') raise SiteDownloaderError(f"Failed to get page {url}")
if res.status_code != 200: if res.status_code != 200:
raise ResourceNotFound(f'Server responded with {res.status_code} to {url}') raise ResourceNotFound(f"Server responded with {res.status_code} to {url}")
return res return res

View file

@ -5,8 +5,8 @@ from typing import Optional
from praw.models import Submission from praw.models import Submission
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.resource import Resource from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader from bdfr.site_downloaders.base_downloader import BaseDownloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -4,8 +4,8 @@ from typing import Optional
from praw.models import Submission from praw.models import Submission
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.resource import Resource from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader from bdfr.site_downloaders.base_downloader import BaseDownloader

View file

@ -26,62 +26,63 @@ class DownloadFactory:
@staticmethod @staticmethod
def pull_lever(url: str) -> Type[BaseDownloader]: def pull_lever(url: str) -> Type[BaseDownloader]:
sanitised_url = DownloadFactory.sanitise_url(url) sanitised_url = DownloadFactory.sanitise_url(url)
if re.match(r'(i\.|m\.)?imgur', sanitised_url): if re.match(r"(i\.|m\.)?imgur", sanitised_url):
return Imgur return Imgur
elif re.match(r'(i\.)?(redgifs|gifdeliverynetwork)', sanitised_url): elif re.match(r"(i\.)?(redgifs|gifdeliverynetwork)", sanitised_url):
return Redgifs return Redgifs
elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url) and \ elif re.match(r".*/.*\.\w{3,4}(\?[\w;&=]*)?$", sanitised_url) and not DownloadFactory.is_web_resource(
not DownloadFactory.is_web_resource(sanitised_url): sanitised_url
):
return Direct return Direct
elif re.match(r'erome\.com.*', sanitised_url): elif re.match(r"erome\.com.*", sanitised_url):
return Erome return Erome
elif re.match(r'delayforreddit\.com', sanitised_url): elif re.match(r"delayforreddit\.com", sanitised_url):
return DelayForReddit return DelayForReddit
elif re.match(r'reddit\.com/gallery/.*', sanitised_url): elif re.match(r"reddit\.com/gallery/.*", sanitised_url):
return Gallery return Gallery
elif re.match(r'patreon\.com.*', sanitised_url): elif re.match(r"patreon\.com.*", sanitised_url):
return Gallery return Gallery
elif re.match(r'gfycat\.', sanitised_url): elif re.match(r"gfycat\.", sanitised_url):
return Gfycat return Gfycat
elif re.match(r'reddit\.com/r/', sanitised_url): elif re.match(r"reddit\.com/r/", sanitised_url):
return SelfPost return SelfPost
elif re.match(r'(m\.)?youtu\.?be', sanitised_url): elif re.match(r"(m\.)?youtu\.?be", sanitised_url):
return Youtube return Youtube
elif re.match(r'i\.redd\.it.*', sanitised_url): elif re.match(r"i\.redd\.it.*", sanitised_url):
return Direct return Direct
elif re.match(r'v\.redd\.it.*', sanitised_url): elif re.match(r"v\.redd\.it.*", sanitised_url):
return VReddit return VReddit
elif re.match(r'pornhub\.com.*', sanitised_url): elif re.match(r"pornhub\.com.*", sanitised_url):
return PornHub return PornHub
elif re.match(r'vidble\.com', sanitised_url): elif re.match(r"vidble\.com", sanitised_url):
return Vidble return Vidble
elif YtdlpFallback.can_handle_link(sanitised_url): elif YtdlpFallback.can_handle_link(sanitised_url):
return YtdlpFallback return YtdlpFallback
else: else:
raise NotADownloadableLinkError(f'No downloader module exists for url {url}') raise NotADownloadableLinkError(f"No downloader module exists for url {url}")
@staticmethod @staticmethod
def sanitise_url(url: str) -> str: def sanitise_url(url: str) -> str:
beginning_regex = re.compile(r'\s*(www\.?)?') beginning_regex = re.compile(r"\s*(www\.?)?")
split_url = urllib.parse.urlsplit(url) split_url = urllib.parse.urlsplit(url)
split_url = split_url.netloc + split_url.path split_url = split_url.netloc + split_url.path
split_url = re.sub(beginning_regex, '', split_url) split_url = re.sub(beginning_regex, "", split_url)
return split_url return split_url
@staticmethod @staticmethod
def is_web_resource(url: str) -> bool: def is_web_resource(url: str) -> bool:
web_extensions = ( web_extensions = (
'asp', "asp",
'aspx', "aspx",
'cfm', "cfm",
'cfml', "cfml",
'css', "css",
'htm', "htm",
'html', "html",
'js', "js",
'php', "php",
'php3', "php3",
'xhtml', "xhtml",
) )
if re.match(rf'(?i).*/.*\.({"|".join(web_extensions)})$', url): if re.match(rf'(?i).*/.*\.({"|".join(web_extensions)})$', url):
return True return True

View file

@ -23,34 +23,34 @@ class Erome(BaseDownloader):
links = self._get_links(self.post.url) links = self._get_links(self.post.url)
if not links: if not links:
raise SiteDownloaderError('Erome parser could not find any links') raise SiteDownloaderError("Erome parser could not find any links")
out = [] out = []
for link in links: for link in links:
if not re.match(r'https?://.*', link): if not re.match(r"https?://.*", link):
link = 'https://' + link link = "https://" + link
out.append(Resource(self.post, link, self.erome_download(link))) out.append(Resource(self.post, link, self.erome_download(link)))
return out return out
@staticmethod @staticmethod
def _get_links(url: str) -> set[str]: def _get_links(url: str) -> set[str]:
page = Erome.retrieve_url(url) page = Erome.retrieve_url(url)
soup = bs4.BeautifulSoup(page.text, 'html.parser') soup = bs4.BeautifulSoup(page.text, "html.parser")
front_images = soup.find_all('img', attrs={'class': 'lasyload'}) front_images = soup.find_all("img", attrs={"class": "lasyload"})
out = [im.get('data-src') for im in front_images] out = [im.get("data-src") for im in front_images]
videos = soup.find_all('source') videos = soup.find_all("source")
out.extend([vid.get('src') for vid in videos]) out.extend([vid.get("src") for vid in videos])
return set(out) return set(out)
@staticmethod @staticmethod
def erome_download(url: str) -> Callable: def erome_download(url: str) -> Callable:
download_parameters = { download_parameters = {
'headers': { "headers": {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)' "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
' Chrome/88.0.4324.104 Safari/537.36', " Chrome/88.0.4324.104 Safari/537.36",
'Referer': 'https://www.erome.com/', "Referer": "https://www.erome.com/",
}, },
} }
return lambda global_params: Resource.http_download(url, global_params | download_parameters) return lambda global_params: Resource.http_download(url, global_params | download_parameters)

View file

@ -7,7 +7,6 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class BaseFallbackDownloader(BaseDownloader, ABC): class BaseFallbackDownloader(BaseDownloader, ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def can_handle_link(url: str) -> bool: def can_handle_link(url: str) -> bool:

View file

@ -9,7 +9,9 @@ from praw.models import Submission
from bdfr.exceptions import NotADownloadableLinkError from bdfr.exceptions import NotADownloadableLinkError
from bdfr.resource import Resource from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.fallback_downloaders.fallback_downloader import BaseFallbackDownloader from bdfr.site_downloaders.fallback_downloaders.fallback_downloader import (
BaseFallbackDownloader,
)
from bdfr.site_downloaders.youtube import Youtube from bdfr.site_downloaders.youtube import Youtube
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,7 +26,7 @@ class YtdlpFallback(BaseFallbackDownloader, Youtube):
self.post, self.post,
self.post.url, self.post.url,
super()._download_video({}), super()._download_video({}),
super().get_video_attributes(self.post.url)['ext'], super().get_video_attributes(self.post.url)["ext"],
) )
return [out] return [out]

View file

@ -20,27 +20,27 @@ class Gallery(BaseDownloader):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
try: try:
image_urls = self._get_links(self.post.gallery_data['items']) image_urls = self._get_links(self.post.gallery_data["items"])
except (AttributeError, TypeError): except (AttributeError, TypeError):
try: try:
image_urls = self._get_links(self.post.crosspost_parent_list[0]['gallery_data']['items']) image_urls = self._get_links(self.post.crosspost_parent_list[0]["gallery_data"]["items"])
except (AttributeError, IndexError, TypeError, KeyError): except (AttributeError, IndexError, TypeError, KeyError):
logger.error(f'Could not find gallery data in submission {self.post.id}') logger.error(f"Could not find gallery data in submission {self.post.id}")
logger.exception('Gallery image find failure') logger.exception("Gallery image find failure")
raise SiteDownloaderError('No images found in Reddit gallery') raise SiteDownloaderError("No images found in Reddit gallery")
if not image_urls: if not image_urls:
raise SiteDownloaderError('No images found in Reddit gallery') raise SiteDownloaderError("No images found in Reddit gallery")
return [Resource(self.post, url, Resource.retry_download(url)) for url in image_urls] return [Resource(self.post, url, Resource.retry_download(url)) for url in image_urls]
@ staticmethod @staticmethod
def _get_links(id_dict: list[dict]) -> list[str]: def _get_links(id_dict: list[dict]) -> list[str]:
out = [] out = []
for item in id_dict: for item in id_dict:
image_id = item['media_id'] image_id = item["media_id"]
possible_extensions = ('.jpg', '.png', '.gif', '.gifv', '.jpeg') possible_extensions = (".jpg", ".png", ".gif", ".gifv", ".jpeg")
for extension in possible_extensions: for extension in possible_extensions:
test_url = f'https://i.redd.it/{image_id}{extension}' test_url = f"https://i.redd.it/{image_id}{extension}"
response = requests.head(test_url) response = requests.head(test_url)
if response.status_code == 200: if response.status_code == 200:
out.append(test_url) out.append(test_url)

View file

@ -22,21 +22,23 @@ class Gfycat(Redgifs):
@staticmethod @staticmethod
def _get_link(url: str) -> set[str]: def _get_link(url: str) -> set[str]:
gfycat_id = re.match(r'.*/(.*?)/?$', url).group(1) gfycat_id = re.match(r".*/(.*?)/?$", url).group(1)
url = 'https://gfycat.com/' + gfycat_id url = "https://gfycat.com/" + gfycat_id
response = Gfycat.retrieve_url(url) response = Gfycat.retrieve_url(url)
if re.search(r'(redgifs|gifdeliverynetwork)', response.url): if re.search(r"(redgifs|gifdeliverynetwork)", response.url):
url = url.lower() # Fixes error with old gfycat/redgifs links url = url.lower() # Fixes error with old gfycat/redgifs links
return Redgifs._get_link(url) return Redgifs._get_link(url)
soup = BeautifulSoup(response.text, 'html.parser') soup = BeautifulSoup(response.text, "html.parser")
content = soup.find('script', attrs={'data-react-helmet': 'true', 'type': 'application/ld+json'}) content = soup.find("script", attrs={"data-react-helmet": "true", "type": "application/ld+json"})
try: try:
out = json.loads(content.contents[0])['video']['contentUrl'] out = json.loads(content.contents[0])["video"]["contentUrl"]
except (IndexError, KeyError, AttributeError) as e: except (IndexError, KeyError, AttributeError) as e:
raise SiteDownloaderError(f'Failed to download Gfycat link {url}: {e}') raise SiteDownloaderError(f"Failed to download Gfycat link {url}: {e}")
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Did not receive valid JSON data: {e}') raise SiteDownloaderError(f"Did not receive valid JSON data: {e}")
return {out,} return {
out,
}

View file

@ -14,7 +14,6 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class Imgur(BaseDownloader): class Imgur(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission):
super().__init__(post) super().__init__(post)
self.raw_data = {} self.raw_data = {}
@ -23,63 +22,63 @@ class Imgur(BaseDownloader):
self.raw_data = self._get_data(self.post.url) self.raw_data = self._get_data(self.post.url)
out = [] out = []
if 'album_images' in self.raw_data: if "album_images" in self.raw_data:
images = self.raw_data['album_images'] images = self.raw_data["album_images"]
for image in images['images']: for image in images["images"]:
out.append(self._compute_image_url(image)) out.append(self._compute_image_url(image))
else: else:
out.append(self._compute_image_url(self.raw_data)) out.append(self._compute_image_url(self.raw_data))
return out return out
def _compute_image_url(self, image: dict) -> Resource: def _compute_image_url(self, image: dict) -> Resource:
ext = self._validate_extension(image['ext']) ext = self._validate_extension(image["ext"])
if image.get('prefer_video', False): if image.get("prefer_video", False):
ext = '.mp4' ext = ".mp4"
image_url = 'https://i.imgur.com/' + image['hash'] + ext image_url = "https://i.imgur.com/" + image["hash"] + ext
return Resource(self.post, image_url, Resource.retry_download(image_url)) return Resource(self.post, image_url, Resource.retry_download(image_url))
@staticmethod @staticmethod
def _get_data(link: str) -> dict: def _get_data(link: str) -> dict:
try: try:
imgur_id = re.match(r'.*/(.*?)(\..{0,})?$', link).group(1) imgur_id = re.match(r".*/(.*?)(\..{0,})?$", link).group(1)
gallery = 'a/' if re.search(r'.*/(.*?)(gallery/|a/)', link) else '' gallery = "a/" if re.search(r".*/(.*?)(gallery/|a/)", link) else ""
link = f'https://imgur.com/{gallery}{imgur_id}' link = f"https://imgur.com/{gallery}{imgur_id}"
except AttributeError: except AttributeError:
raise SiteDownloaderError(f'Could not extract Imgur ID from {link}') raise SiteDownloaderError(f"Could not extract Imgur ID from {link}")
res = Imgur.retrieve_url(link, cookies={'over18': '1', 'postpagebeta': '0'}) res = Imgur.retrieve_url(link, cookies={"over18": "1", "postpagebeta": "0"})
soup = bs4.BeautifulSoup(res.text, 'html.parser') soup = bs4.BeautifulSoup(res.text, "html.parser")
scripts = soup.find_all('script', attrs={'type': 'text/javascript'}) scripts = soup.find_all("script", attrs={"type": "text/javascript"})
scripts = [script.string.replace('\n', '') for script in scripts if script.string] scripts = [script.string.replace("\n", "") for script in scripts if script.string]
script_regex = re.compile(r'\s*\(function\(widgetFactory\)\s*{\s*widgetFactory\.mergeConfig\(\'gallery\'') script_regex = re.compile(r"\s*\(function\(widgetFactory\)\s*{\s*widgetFactory\.mergeConfig\(\'gallery\'")
chosen_script = list(filter(lambda s: re.search(script_regex, s), scripts)) chosen_script = list(filter(lambda s: re.search(script_regex, s), scripts))
if len(chosen_script) != 1: if len(chosen_script) != 1:
raise SiteDownloaderError(f'Could not read page source from {link}') raise SiteDownloaderError(f"Could not read page source from {link}")
chosen_script = chosen_script[0] chosen_script = chosen_script[0]
outer_regex = re.compile(r'widgetFactory\.mergeConfig\(\'gallery\', ({.*})\);') outer_regex = re.compile(r"widgetFactory\.mergeConfig\(\'gallery\', ({.*})\);")
inner_regex = re.compile(r'image\s*:(.*),\s*group') inner_regex = re.compile(r"image\s*:(.*),\s*group")
try: try:
image_dict = re.search(outer_regex, chosen_script).group(1) image_dict = re.search(outer_regex, chosen_script).group(1)
image_dict = re.search(inner_regex, image_dict).group(1) image_dict = re.search(inner_regex, image_dict).group(1)
except AttributeError: except AttributeError:
raise SiteDownloaderError(f'Could not find image dictionary in page source') raise SiteDownloaderError(f"Could not find image dictionary in page source")
try: try:
image_dict = json.loads(image_dict) image_dict = json.loads(image_dict)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Could not parse received dict as JSON: {e}') raise SiteDownloaderError(f"Could not parse received dict as JSON: {e}")
return image_dict return image_dict
@staticmethod @staticmethod
def _validate_extension(extension_suffix: str) -> str: def _validate_extension(extension_suffix: str) -> str:
extension_suffix = re.sub(r'\?.*', '', extension_suffix) extension_suffix = re.sub(r"\?.*", "", extension_suffix)
possible_extensions = ('.jpg', '.png', '.mp4', '.gif') possible_extensions = (".jpg", ".png", ".mp4", ".gif")
selection = [ext for ext in possible_extensions if ext == extension_suffix] selection = [ext for ext in possible_extensions if ext == extension_suffix]
if len(selection) == 1: if len(selection) == 1:
return selection[0] return selection[0]

View file

@ -20,11 +20,11 @@ class PornHub(Youtube):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
ytdl_options = { ytdl_options = {
'format': 'best', "format": "best",
'nooverwrites': True, "nooverwrites": True,
} }
if video_attributes := super().get_video_attributes(self.post.url): if video_attributes := super().get_video_attributes(self.post.url):
extension = video_attributes['ext'] extension = video_attributes["ext"]
else: else:
raise SiteDownloaderError() raise SiteDownloaderError()

View file

@ -2,9 +2,9 @@
import json import json
import re import re
import requests
from typing import Optional from typing import Optional
import requests
from praw.models import Submission from praw.models import Submission
from bdfr.exceptions import SiteDownloaderError from bdfr.exceptions import SiteDownloaderError
@ -24,52 +24,53 @@ class Redgifs(BaseDownloader):
@staticmethod @staticmethod
def _get_link(url: str) -> set[str]: def _get_link(url: str) -> set[str]:
try: try:
redgif_id = re.match(r'.*/(.*?)(\..{0,})?$', url).group(1) redgif_id = re.match(r".*/(.*?)(\..{0,})?$", url).group(1)
except AttributeError: except AttributeError:
raise SiteDownloaderError(f'Could not extract Redgifs ID from {url}') raise SiteDownloaderError(f"Could not extract Redgifs ID from {url}")
auth_token = json.loads(Redgifs.retrieve_url('https://api.redgifs.com/v2/auth/temporary').text)['token'] auth_token = json.loads(Redgifs.retrieve_url("https://api.redgifs.com/v2/auth/temporary").text)["token"]
if not auth_token: if not auth_token:
raise SiteDownloaderError('Unable to retrieve Redgifs API token') raise SiteDownloaderError("Unable to retrieve Redgifs API token")
headers = { headers = {
'referer': 'https://www.redgifs.com/', "referer": "https://www.redgifs.com/",
'origin': 'https://www.redgifs.com', "origin": "https://www.redgifs.com",
'content-type': 'application/json', "content-type": "application/json",
'Authorization': f'Bearer {auth_token}', "Authorization": f"Bearer {auth_token}",
} }
content = Redgifs.retrieve_url(f'https://api.redgifs.com/v2/gifs/{redgif_id}', headers=headers) content = Redgifs.retrieve_url(f"https://api.redgifs.com/v2/gifs/{redgif_id}", headers=headers)
if content is None: if content is None:
raise SiteDownloaderError('Could not read the page source') raise SiteDownloaderError("Could not read the page source")
try: try:
response_json = json.loads(content.text) response_json = json.loads(content.text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Received data was not valid JSON: {e}') raise SiteDownloaderError(f"Received data was not valid JSON: {e}")
out = set() out = set()
try: try:
if response_json['gif']['type'] == 1: # type 1 is a video if response_json["gif"]["type"] == 1: # type 1 is a video
if requests.get(response_json['gif']['urls']['hd'], headers=headers).ok: if requests.get(response_json["gif"]["urls"]["hd"], headers=headers).ok:
out.add(response_json['gif']['urls']['hd']) out.add(response_json["gif"]["urls"]["hd"])
else: else:
out.add(response_json['gif']['urls']['sd']) out.add(response_json["gif"]["urls"]["sd"])
elif response_json['gif']['type'] == 2: # type 2 is an image elif response_json["gif"]["type"] == 2: # type 2 is an image
if response_json['gif']['gallery']: if response_json["gif"]["gallery"]:
content = Redgifs.retrieve_url( content = Redgifs.retrieve_url(
f'https://api.redgifs.com/v2/gallery/{response_json["gif"]["gallery"]}') f'https://api.redgifs.com/v2/gallery/{response_json["gif"]["gallery"]}'
)
response_json = json.loads(content.text) response_json = json.loads(content.text)
out = {p['urls']['hd'] for p in response_json['gifs']} out = {p["urls"]["hd"] for p in response_json["gifs"]}
else: else:
out.add(response_json['gif']['urls']['hd']) out.add(response_json["gif"]["urls"]["hd"])
else: else:
raise KeyError raise KeyError
except (KeyError, AttributeError): except (KeyError, AttributeError):
raise SiteDownloaderError('Failed to find JSON data in page') raise SiteDownloaderError("Failed to find JSON data in page")
# Update subdomain if old one is returned # Update subdomain if old one is returned
out = {re.sub('thumbs2', 'thumbs3', link) for link in out} out = {re.sub("thumbs2", "thumbs3", link) for link in out}
out = {re.sub('thumbs3', 'thumbs4', link) for link in out} out = {re.sub("thumbs3", "thumbs4", link) for link in out}
return out return out

View file

@ -17,27 +17,29 @@ class SelfPost(BaseDownloader):
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
out = Resource(self.post, self.post.url, lambda: None, '.txt') out = Resource(self.post, self.post.url, lambda: None, ".txt")
out.content = self.export_to_string().encode('utf-8') out.content = self.export_to_string().encode("utf-8")
out.create_hash() out.create_hash()
return [out] return [out]
def export_to_string(self) -> str: def export_to_string(self) -> str:
"""Self posts are formatted here""" """Self posts are formatted here"""
content = ("## [" content = (
+ self.post.fullname "## ["
+ "](" + self.post.fullname
+ self.post.url + "]("
+ ")\n" + self.post.url
+ self.post.selftext + ")\n"
+ "\n\n---\n\n" + self.post.selftext
+ "submitted to [r/" + "\n\n---\n\n"
+ self.post.subreddit.title + "submitted to [r/"
+ "](https://www.reddit.com/r/" + self.post.subreddit.title
+ self.post.subreddit.title + "](https://www.reddit.com/r/"
+ ") by [u/" + self.post.subreddit.title
+ (self.post.author.name if self.post.author else "DELETED") + ") by [u/"
+ "](https://www.reddit.com/user/" + (self.post.author.name if self.post.author else "DELETED")
+ (self.post.author.name if self.post.author else "DELETED") + "](https://www.reddit.com/user/"
+ ")") + (self.post.author.name if self.post.author else "DELETED")
+ ")"
)
return content return content

View file

@ -25,30 +25,30 @@ class Vidble(BaseDownloader):
try: try:
res = self.get_links(self.post.url) res = self.get_links(self.post.url)
except AttributeError: except AttributeError:
raise SiteDownloaderError(f'Could not read page at {self.post.url}') raise SiteDownloaderError(f"Could not read page at {self.post.url}")
if not res: if not res:
raise SiteDownloaderError(rf'No resources found at {self.post.url}') raise SiteDownloaderError(rf"No resources found at {self.post.url}")
res = [Resource(self.post, r, Resource.retry_download(r)) for r in res] res = [Resource(self.post, r, Resource.retry_download(r)) for r in res]
return res return res
@staticmethod @staticmethod
def get_links(url: str) -> set[str]: def get_links(url: str) -> set[str]:
if not re.search(r'vidble.com/(show/|album/|watch\?v)', url): if not re.search(r"vidble.com/(show/|album/|watch\?v)", url):
url = re.sub(r'/(\w*?)$', r'/show/\1', url) url = re.sub(r"/(\w*?)$", r"/show/\1", url)
page = requests.get(url) page = requests.get(url)
soup = bs4.BeautifulSoup(page.text, 'html.parser') soup = bs4.BeautifulSoup(page.text, "html.parser")
content_div = soup.find('div', attrs={'id': 'ContentPlaceHolder1_divContent'}) content_div = soup.find("div", attrs={"id": "ContentPlaceHolder1_divContent"})
images = content_div.find_all('img') images = content_div.find_all("img")
images = [i.get('src') for i in images] images = [i.get("src") for i in images]
videos = content_div.find_all('source', attrs={'type': 'video/mp4'}) videos = content_div.find_all("source", attrs={"type": "video/mp4"})
videos = [v.get('src') for v in videos] videos = [v.get("src") for v in videos]
resources = filter(None, itertools.chain(images, videos)) resources = filter(None, itertools.chain(images, videos))
resources = ['https://www.vidble.com' + r for r in resources] resources = ["https://www.vidble.com" + r for r in resources]
resources = [Vidble.change_med_url(r) for r in resources] resources = [Vidble.change_med_url(r) for r in resources]
return set(resources) return set(resources)
@staticmethod @staticmethod
def change_med_url(url: str) -> str: def change_med_url(url: str) -> str:
out = re.sub(r'_med(\..{3,4})$', r'\1', url) out = re.sub(r"_med(\..{3,4})$", r"\1", url)
return out return out

View file

@ -22,18 +22,18 @@ class VReddit(Youtube):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
ytdl_options = { ytdl_options = {
'playlistend': 1, "playlistend": 1,
'nooverwrites': True, "nooverwrites": True,
} }
download_function = self._download_video(ytdl_options) download_function = self._download_video(ytdl_options)
extension = self.get_video_attributes(self.post.url)['ext'] extension = self.get_video_attributes(self.post.url)["ext"]
res = Resource(self.post, self.post.url, download_function, extension) res = Resource(self.post, self.post.url, download_function, extension)
return [res] return [res]
@staticmethod @staticmethod
def get_video_attributes(url: str) -> dict: def get_video_attributes(url: str) -> dict:
result = VReddit.get_video_data(url) result = VReddit.get_video_data(url)
if 'ext' in result: if "ext" in result:
return result return result
else: else:
try: try:
@ -41,4 +41,4 @@ class VReddit(Youtube):
return result return result
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
raise NotADownloadableLinkError(f'Video info extraction failed for {url}') raise NotADownloadableLinkError(f"Video info extraction failed for {url}")

View file

@ -22,57 +22,62 @@ class Youtube(BaseDownloader):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
ytdl_options = { ytdl_options = {
'format': 'best', "format": "best",
'playlistend': 1, "playlistend": 1,
'nooverwrites': True, "nooverwrites": True,
} }
download_function = self._download_video(ytdl_options) download_function = self._download_video(ytdl_options)
extension = self.get_video_attributes(self.post.url)['ext'] extension = self.get_video_attributes(self.post.url)["ext"]
res = Resource(self.post, self.post.url, download_function, extension) res = Resource(self.post, self.post.url, download_function, extension)
return [res] return [res]
def _download_video(self, ytdl_options: dict) -> Callable: def _download_video(self, ytdl_options: dict) -> Callable:
yt_logger = logging.getLogger('youtube-dl') yt_logger = logging.getLogger("youtube-dl")
yt_logger.setLevel(logging.CRITICAL) yt_logger.setLevel(logging.CRITICAL)
ytdl_options['quiet'] = True ytdl_options["quiet"] = True
ytdl_options['logger'] = yt_logger ytdl_options["logger"] = yt_logger
def download(_: dict) -> bytes: def download(_: dict) -> bytes:
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
download_path = Path(temp_dir).resolve() download_path = Path(temp_dir).resolve()
ytdl_options['outtmpl'] = str(download_path) + '/' + 'test.%(ext)s' ytdl_options["outtmpl"] = str(download_path) + "/" + "test.%(ext)s"
try: try:
with yt_dlp.YoutubeDL(ytdl_options) as ydl: with yt_dlp.YoutubeDL(ytdl_options) as ydl:
ydl.download([self.post.url]) ydl.download([self.post.url])
except yt_dlp.DownloadError as e: except yt_dlp.DownloadError as e:
raise SiteDownloaderError(f'Youtube download failed: {e}') raise SiteDownloaderError(f"Youtube download failed: {e}")
downloaded_files = list(download_path.iterdir()) downloaded_files = list(download_path.iterdir())
if downloaded_files: if downloaded_files:
downloaded_file = downloaded_files[0] downloaded_file = downloaded_files[0]
else: else:
raise NotADownloadableLinkError(f"No media exists in the URL {self.post.url}") raise NotADownloadableLinkError(f"No media exists in the URL {self.post.url}")
with downloaded_file.open('rb') as file: with downloaded_file.open("rb") as file:
content = file.read() content = file.read()
return content return content
return download return download
@staticmethod @staticmethod
def get_video_data(url: str) -> dict: def get_video_data(url: str) -> dict:
yt_logger = logging.getLogger('youtube-dl') yt_logger = logging.getLogger("youtube-dl")
yt_logger.setLevel(logging.CRITICAL) yt_logger.setLevel(logging.CRITICAL)
with yt_dlp.YoutubeDL({'logger': yt_logger, }) as ydl: with yt_dlp.YoutubeDL(
{
"logger": yt_logger,
}
) as ydl:
try: try:
result = ydl.extract_info(url, download=False) result = ydl.extract_info(url, download=False)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
raise NotADownloadableLinkError(f'Video info extraction failed for {url}') raise NotADownloadableLinkError(f"Video info extraction failed for {url}")
return result return result
@staticmethod @staticmethod
def get_video_attributes(url: str) -> dict: def get_video_attributes(url: str) -> dict:
result = Youtube.get_video_data(url) result = Youtube.get_video_data(url)
if 'ext' in result: if "ext" in result:
return result return result
else: else:
raise NotADownloadableLinkError(f'Video info extraction failed for {url}') raise NotADownloadableLinkError(f"Video info extraction failed for {url}")

@ -1 +1 @@
Subproject commit e8c840b58f0833e23461c682655fe540aa923f85 Subproject commit ce5ca2802fabe5dc38393240cd40e20f8928d3b0

@ -1 +1 @@
Subproject commit 78fa631d1370562d2cd4a1390989e706158e7bf0 Subproject commit e0de84e9c011223e7f88b7ccf1c929f4327097ba

View file

@ -9,15 +9,21 @@ from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_comment_id', 'expected_dict'), ( @pytest.mark.parametrize(
('gstd4hk', { ("test_comment_id", "expected_dict"),
'author': 'james_pic', (
'subreddit': 'Python', (
'submission': 'mgi4op', "gstd4hk",
'submission_title': '76% Faster CPython', {
'distinguished': None, "author": "james_pic",
}), "subreddit": "Python",
)) "submission": "mgi4op",
"submission_title": "76% Faster CPython",
"distinguished": None,
},
),
),
)
def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_instance: praw.Reddit): def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_instance: praw.Reddit):
comment = reddit_instance.comment(id=test_comment_id) comment = reddit_instance.comment(id=test_comment_id)
test_entry = CommentArchiveEntry(comment) test_entry = CommentArchiveEntry(comment)
@ -27,13 +33,16 @@ def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_i
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_comment_id', 'expected_min_comments'), ( @pytest.mark.parametrize(
('gstd4hk', 4), ("test_comment_id", "expected_min_comments"),
('gsvyste', 3), (
('gsxnvvb', 5), ("gstd4hk", 4),
)) ("gsvyste", 3),
("gsxnvvb", 5),
),
)
def test_get_comment_replies(test_comment_id: str, expected_min_comments: int, reddit_instance: praw.Reddit): def test_get_comment_replies(test_comment_id: str, expected_min_comments: int, reddit_instance: praw.Reddit):
comment = reddit_instance.comment(id=test_comment_id) comment = reddit_instance.comment(id=test_comment_id)
test_entry = CommentArchiveEntry(comment) test_entry = CommentArchiveEntry(comment)
result = test_entry.compile() result = test_entry.compile()
assert len(result.get('replies')) >= expected_min_comments assert len(result.get("replies")) >= expected_min_comments

View file

@ -9,9 +9,7 @@ from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'min_comments'), ( @pytest.mark.parametrize(("test_submission_id", "min_comments"), (("m3reby", 27),))
('m3reby', 27),
))
def test_get_comments(test_submission_id: str, min_comments: int, reddit_instance: praw.Reddit): def test_get_comments(test_submission_id: str, min_comments: int, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id) test_submission = reddit_instance.submission(id=test_submission_id)
test_archive_entry = SubmissionArchiveEntry(test_submission) test_archive_entry = SubmissionArchiveEntry(test_submission)
@ -21,21 +19,27 @@ def test_get_comments(test_submission_id: str, min_comments: int, reddit_instanc
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_dict'), ( @pytest.mark.parametrize(
('m3reby', { ("test_submission_id", "expected_dict"),
'author': 'sinjen-tos', (
'id': 'm3reby', (
'link_flair_text': 'image', "m3reby",
'pinned': False, {
'spoiler': False, "author": "sinjen-tos",
'over_18': False, "id": "m3reby",
'locked': False, "link_flair_text": "image",
'distinguished': None, "pinned": False,
'created_utc': 1615583837, "spoiler": False,
'permalink': '/r/australia/comments/m3reby/this_little_guy_fell_out_of_a_tree_and_in_front/' "over_18": False,
}), "locked": False,
# TODO: add deleted user test case "distinguished": None,
)) "created_utc": 1615583837,
"permalink": "/r/australia/comments/m3reby/this_little_guy_fell_out_of_a_tree_and_in_front/",
},
),
# TODO: add deleted user test case
),
)
def test_get_post_details(test_submission_id: str, expected_dict: dict, reddit_instance: praw.Reddit): def test_get_post_details(test_submission_id: str, expected_dict: dict, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id) test_submission = reddit_instance.submission(id=test_submission_id)
test_archive_entry = SubmissionArchiveEntry(test_submission) test_archive_entry = SubmissionArchiveEntry(test_submission)

View file

@ -11,29 +11,29 @@ import pytest
from bdfr.oauth2 import OAuth2TokenManager from bdfr.oauth2 import OAuth2TokenManager
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def reddit_instance(): def reddit_instance():
rd = praw.Reddit( rd = praw.Reddit(
client_id='U-6gk4ZCh3IeNQ', client_id="U-6gk4ZCh3IeNQ",
client_secret='7CZHY6AmKweZME5s50SfDGylaPg', client_secret="7CZHY6AmKweZME5s50SfDGylaPg",
user_agent='test', user_agent="test",
) )
return rd return rd
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def authenticated_reddit_instance(): def authenticated_reddit_instance():
test_config_path = Path('./tests/test_config.cfg') test_config_path = Path("./tests/test_config.cfg")
if not test_config_path.exists(): if not test_config_path.exists():
pytest.skip('Refresh token must be provided to authenticate with OAuth2') pytest.skip("Refresh token must be provided to authenticate with OAuth2")
cfg_parser = configparser.ConfigParser() cfg_parser = configparser.ConfigParser()
cfg_parser.read(test_config_path) cfg_parser.read(test_config_path)
if not cfg_parser.has_option('DEFAULT', 'user_token'): if not cfg_parser.has_option("DEFAULT", "user_token"):
pytest.skip('Refresh token must be provided to authenticate with OAuth2') pytest.skip("Refresh token must be provided to authenticate with OAuth2")
token_manager = OAuth2TokenManager(cfg_parser, test_config_path) token_manager = OAuth2TokenManager(cfg_parser, test_config_path)
reddit_instance = praw.Reddit( reddit_instance = praw.Reddit(
client_id=cfg_parser.get('DEFAULT', 'client_id'), client_id=cfg_parser.get("DEFAULT", "client_id"),
client_secret=cfg_parser.get('DEFAULT', 'client_secret'), client_secret=cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(), user_agent=socket.gethostname(),
token_manager=token_manager, token_manager=token_manager,
) )

View file

@ -10,67 +10,78 @@ from click.testing import CliRunner
from bdfr.__main__ import cli from bdfr.__main__ import cli
does_test_config_exist = Path('./tests/test_config.cfg').exists() does_test_config_exist = Path("./tests/test_config.cfg").exists()
def copy_test_config(run_path: Path): def copy_test_config(run_path: Path):
shutil.copy(Path('./tests/test_config.cfg'), Path(run_path, 'test_config.cfg')) shutil.copy(Path("./tests/test_config.cfg"), Path(run_path, "test_config.cfg"))
def create_basic_args_for_archive_runner(test_args: list[str], run_path: Path): def create_basic_args_for_archive_runner(test_args: list[str], run_path: Path):
copy_test_config(run_path) copy_test_config(run_path)
out = [ out = [
'archive', "archive",
str(run_path), str(run_path),
'-v', "-v",
'--config', str(Path(run_path, 'test_config.cfg')), "--config",
'--log', str(Path(run_path, 'test_log.txt')), str(Path(run_path, "test_config.cfg")),
"--log",
str(Path(run_path, "test_log.txt")),
] + test_args ] + test_args
return out return out
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-l', 'gstd4hk'], "test_args",
['-l', 'm2601g', '-f', 'yaml'], (
['-l', 'n60t4c', '-f', 'xml'], ["-l", "gstd4hk"],
)) ["-l", "m2601g", "-f", "yaml"],
["-l", "n60t4c", "-f", "xml"],
),
)
def test_cli_archive_single(test_args: list[str], tmp_path: Path): def test_cli_archive_single(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert re.search(r'Writing entry .*? to file in .*? format', result.output) assert re.search(r"Writing entry .*? to file in .*? format", result.output)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--subreddit', 'Mindustry', '-L', 25], "test_args",
['--subreddit', 'Mindustry', '-L', 25, '--format', 'xml'], (
['--subreddit', 'Mindustry', '-L', 25, '--format', 'yaml'], ["--subreddit", "Mindustry", "-L", 25],
['--subreddit', 'Mindustry', '-L', 25, '--sort', 'new'], ["--subreddit", "Mindustry", "-L", 25, "--format", "xml"],
['--subreddit', 'Mindustry', '-L', 25, '--time', 'day'], ["--subreddit", "Mindustry", "-L", 25, "--format", "yaml"],
['--subreddit', 'Mindustry', '-L', 25, '--time', 'day', '--sort', 'new'], ["--subreddit", "Mindustry", "-L", 25, "--sort", "new"],
)) ["--subreddit", "Mindustry", "-L", 25, "--time", "day"],
["--subreddit", "Mindustry", "-L", 25, "--time", "day", "--sort", "new"],
),
)
def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path): def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert re.search(r'Writing entry .*? to file in .*? format', result.output) assert re.search(r"Writing entry .*? to file in .*? format", result.output)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--user', 'me', '--authenticate', '--all-comments', '-L', '10'], "test_args",
['--user', 'me', '--user', 'djnish', '--authenticate', '--all-comments', '-L', '10'], (
)) ["--user", "me", "--authenticate", "--all-comments", "-L", "10"],
["--user", "me", "--user", "djnish", "--authenticate", "--all-comments", "-L", "10"],
),
)
def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path): def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
@ -80,89 +91,88 @@ def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path):
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--comment-context", "--link", "gxqapql"],))
['--comment-context', '--link', 'gxqapql'],
))
def test_cli_archive_full_context(test_args: list[str], tmp_path: Path): def test_cli_archive_full_context(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Converting comment' in result.output assert "Converting comment" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--subreddit', 'all', '-L', 100], "test_args",
['--subreddit', 'all', '-L', 100, '--sort', 'new'], (
)) ["--subreddit", "all", "-L", 100],
["--subreddit", "all", "-L", 100, "--sort", "new"],
),
)
def test_cli_archive_long(test_args: list[str], tmp_path: Path): def test_cli_archive_long(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert re.search(r'Writing entry .*? to file in .*? format', result.output) assert re.search(r"Writing entry .*? to file in .*? format", result.output)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--ignore-user", "ArjanEgges", "-l", "m3hxzd"],))
['--ignore-user', 'ArjanEgges', '-l', 'm3hxzd'],
))
def test_cli_archive_ignore_user(test_args: list[str], tmp_path: Path): def test_cli_archive_ignore_user(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'being an ignored user' in result.output assert "being an ignored user" in result.output
assert 'Attempting to archive submission' not in result.output assert "Attempting to archive submission" not in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--file-scheme", "{TITLE}", "-l", "suy011"],))
['--file-scheme', '{TITLE}', '-l', 'suy011'],
))
def test_cli_archive_file_format(test_args: list[str], tmp_path: Path): def test_cli_archive_file_format(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Attempting to archive submission' in result.output assert "Attempting to archive submission" in result.output
assert re.search('format at /.+?/Judge says Trump and two adult', result.output) assert re.search("format at /.+?/Judge says Trump and two adult", result.output)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["-l", "m2601g", "--exclude-id", "m2601g"],))
['-l', 'm2601g', '--exclude-id', 'm2601g'],
))
def test_cli_archive_links_exclusion(test_args: list[str], tmp_path: Path): def test_cli_archive_links_exclusion(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'in exclusion list' in result.output assert "in exclusion list" in result.output
assert 'Attempting to archive' not in result.output assert "Attempting to archive" not in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-l', 'ijy4ch'], # user deleted post "test_args",
['-l', 'kw4wjm'], # post from banned subreddit (
)) ["-l", "ijy4ch"], # user deleted post
["-l", "kw4wjm"], # post from banned subreddit
),
)
def test_cli_archive_soft_fail(test_args: list[str], tmp_path: Path): def test_cli_archive_soft_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_archive_runner(test_args, tmp_path) test_args = create_basic_args_for_archive_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'failed to be archived due to a PRAW exception' in result.output assert "failed to be archived due to a PRAW exception" in result.output
assert 'Attempting to archive' not in result.output assert "Attempting to archive" not in result.output

View file

@ -9,54 +9,62 @@ from click.testing import CliRunner
from bdfr.__main__ import cli from bdfr.__main__ import cli
does_test_config_exist = Path('./tests/test_config.cfg').exists() does_test_config_exist = Path("./tests/test_config.cfg").exists()
def copy_test_config(run_path: Path): def copy_test_config(run_path: Path):
shutil.copy(Path('./tests/test_config.cfg'), Path(run_path, 'test_config.cfg')) shutil.copy(Path("./tests/test_config.cfg"), Path(run_path, "test_config.cfg"))
def create_basic_args_for_cloner_runner(test_args: list[str], tmp_path: Path): def create_basic_args_for_cloner_runner(test_args: list[str], tmp_path: Path):
copy_test_config(tmp_path) copy_test_config(tmp_path)
out = [ out = [
'clone', "clone",
str(tmp_path), str(tmp_path),
'-v', "-v",
'--config', str(Path(tmp_path, 'test_config.cfg')), "--config",
'--log', str(Path(tmp_path, 'test_log.txt')), str(Path(tmp_path, "test_config.cfg")),
"--log",
str(Path(tmp_path, "test_log.txt")),
] + test_args ] + test_args
return out return out
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-l', '6l7778'], "test_args",
['-s', 'TrollXChromosomes/', '-L', 1], (
['-l', 'eiajjw'], ["-l", "6l7778"],
['-l', 'xl0lhi'], ["-s", "TrollXChromosomes/", "-L", 1],
)) ["-l", "eiajjw"],
["-l", "xl0lhi"],
),
)
def test_cli_scrape_general(test_args: list[str], tmp_path: Path): def test_cli_scrape_general(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_cloner_runner(test_args, tmp_path) test_args = create_basic_args_for_cloner_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Downloaded submission' in result.output assert "Downloaded submission" in result.output
assert 'Record for entry item' in result.output assert "Record for entry item" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-l', 'ijy4ch'], # user deleted post "test_args",
['-l', 'kw4wjm'], # post from banned subreddit (
)) ["-l", "ijy4ch"], # user deleted post
["-l", "kw4wjm"], # post from banned subreddit
),
)
def test_cli_scrape_soft_fail(test_args: list[str], tmp_path: Path): def test_cli_scrape_soft_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_cloner_runner(test_args, tmp_path) test_args = create_basic_args_for_cloner_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Downloaded submission' not in result.output assert "Downloaded submission" not in result.output
assert 'Record for entry item' not in result.output assert "Record for entry item" not in result.output

View file

@ -9,97 +9,107 @@ from click.testing import CliRunner
from bdfr.__main__ import cli from bdfr.__main__ import cli
does_test_config_exist = Path('./tests/test_config.cfg').exists() does_test_config_exist = Path("./tests/test_config.cfg").exists()
def copy_test_config(run_path: Path): def copy_test_config(run_path: Path):
shutil.copy(Path('./tests/test_config.cfg'), Path(run_path, './test_config.cfg')) shutil.copy(Path("./tests/test_config.cfg"), Path(run_path, "./test_config.cfg"))
def create_basic_args_for_download_runner(test_args: list[str], run_path: Path): def create_basic_args_for_download_runner(test_args: list[str], run_path: Path):
copy_test_config(run_path) copy_test_config(run_path)
out = [ out = [
'download', str(run_path), "download",
'-v', str(run_path),
'--config', str(Path(run_path, './test_config.cfg')), "-v",
'--log', str(Path(run_path, 'test_log.txt')), "--config",
str(Path(run_path, "./test_config.cfg")),
"--log",
str(Path(run_path, "test_log.txt")),
] + test_args ] + test_args
return out return out
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-s', 'Mindustry', '-L', 3], "test_args",
['-s', 'r/Mindustry', '-L', 3], (
['-s', 'r/mindustry', '-L', 3], ["-s", "Mindustry", "-L", 3],
['-s', 'mindustry', '-L', 3], ["-s", "r/Mindustry", "-L", 3],
['-s', 'https://www.reddit.com/r/TrollXChromosomes/', '-L', 3], ["-s", "r/mindustry", "-L", 3],
['-s', 'r/TrollXChromosomes/', '-L', 3], ["-s", "mindustry", "-L", 3],
['-s', 'TrollXChromosomes/', '-L', 3], ["-s", "https://www.reddit.com/r/TrollXChromosomes/", "-L", 3],
['-s', 'trollxchromosomes', '-L', 3], ["-s", "r/TrollXChromosomes/", "-L", 3],
['-s', 'trollxchromosomes,mindustry,python', '-L', 3], ["-s", "TrollXChromosomes/", "-L", 3],
['-s', 'trollxchromosomes, mindustry, python', '-L', 3], ["-s", "trollxchromosomes", "-L", 3],
['-s', 'trollxchromosomes', '-L', 3, '--time', 'day'], ["-s", "trollxchromosomes,mindustry,python", "-L", 3],
['-s', 'trollxchromosomes', '-L', 3, '--sort', 'new'], ["-s", "trollxchromosomes, mindustry, python", "-L", 3],
['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--sort', 'new'], ["-s", "trollxchromosomes", "-L", 3, "--time", "day"],
['-s', 'trollxchromosomes', '-L', 3, '--search', 'women'], ["-s", "trollxchromosomes", "-L", 3, "--sort", "new"],
['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--search', 'women'], ["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--sort", "new"],
['-s', 'trollxchromosomes', '-L', 3, '--sort', 'new', '--search', 'women'], ["-s", "trollxchromosomes", "-L", 3, "--search", "women"],
['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--sort', 'new', '--search', 'women'], ["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--search", "women"],
)) ["-s", "trollxchromosomes", "-L", 3, "--sort", "new", "--search", "women"],
["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--sort", "new", "--search", "women"],
),
)
def test_cli_download_subreddits(test_args: list[str], tmp_path: Path): def test_cli_download_subreddits(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Added submissions from subreddit ' in result.output assert "Added submissions from subreddit " in result.output
assert 'Downloaded submission' in result.output assert "Downloaded submission" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.authenticated @pytest.mark.authenticated
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-s', 'hentai', '-L', 10, '--search', 'red', '--authenticate'], "test_args",
['--authenticate', '--subscribed', '-L', 10], (
)) ["-s", "hentai", "-L", 10, "--search", "red", "--authenticate"],
["--authenticate", "--subscribed", "-L", 10],
),
)
def test_cli_download_search_subreddits_authenticated(test_args: list[str], tmp_path: Path): def test_cli_download_search_subreddits_authenticated(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Added submissions from subreddit ' in result.output assert "Added submissions from subreddit " in result.output
assert 'Downloaded submission' in result.output assert "Downloaded submission" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.authenticated @pytest.mark.authenticated
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--subreddit", "friends", "-L", 10, "--authenticate"],))
['--subreddit', 'friends', '-L', 10, '--authenticate'],
))
def test_cli_download_user_specific_subreddits(test_args: list[str], tmp_path: Path): def test_cli_download_user_specific_subreddits(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Added submissions from subreddit ' in result.output assert "Added submissions from subreddit " in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-l', '6l7778'], "test_args",
['-l', 'https://reddit.com/r/EmpireDidNothingWrong/comments/6l7778/technically_true/'], (
['-l', 'm3hxzd'], # Really long title used to overflow filename limit ["-l", "6l7778"],
['-l', 'm5bqkf'], # Resource leading to a 404 ["-l", "https://reddit.com/r/EmpireDidNothingWrong/comments/6l7778/technically_true/"],
)) ["-l", "m3hxzd"], # Really long title used to overflow filename limit
["-l", "m5bqkf"], # Resource leading to a 404
),
)
def test_cli_download_links(test_args: list[str], tmp_path: Path): def test_cli_download_links(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
@ -109,64 +119,66 @@ def test_cli_download_links(test_args: list[str], tmp_path: Path):
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10], "test_args",
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--sort', 'rising'], (
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--time', 'week'], ["--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10],
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--time', 'week', '--sort', 'rising'], ["--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10, "--sort", "rising"],
)) ["--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10, "--time", "week"],
["--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10, "--time", "week", "--sort", "rising"],
),
)
def test_cli_download_multireddit(test_args: list[str], tmp_path: Path): def test_cli_download_multireddit(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Added submissions from multireddit ' in result.output assert "Added submissions from multireddit " in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--user", "helen_darten", "-m", "xxyyzzqwerty", "-L", 10],))
['--user', 'helen_darten', '-m', 'xxyyzzqwerty', '-L', 10],
))
def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Path): def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Failed to get submissions for multireddit' in result.output assert "Failed to get submissions for multireddit" in result.output
assert 'received 404 HTTP response' in result.output assert "received 404 HTTP response" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.authenticated @pytest.mark.authenticated
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--user', 'djnish', '--submitted', '--user', 'FriesWithThat', '-L', 10], "test_args",
['--user', 'me', '--upvoted', '--authenticate', '-L', 10], (
['--user', 'me', '--saved', '--authenticate', '-L', 10], ["--user", "djnish", "--submitted", "--user", "FriesWithThat", "-L", 10],
['--user', 'me', '--submitted', '--authenticate', '-L', 10], ["--user", "me", "--upvoted", "--authenticate", "-L", 10],
['--user', 'djnish', '--submitted', '-L', 10], ["--user", "me", "--saved", "--authenticate", "-L", 10],
['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'], ["--user", "me", "--submitted", "--authenticate", "-L", 10],
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'], ["--user", "djnish", "--submitted", "-L", 10],
)) ["--user", "djnish", "--submitted", "-L", 10, "--time", "month"],
["--user", "djnish", "--submitted", "-L", 10, "--sort", "controversial"],
),
)
def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path): def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Downloaded submission ' in result.output assert "Downloaded submission " in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.authenticated @pytest.mark.authenticated
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--user", "me", "-L", 10, "--folder-scheme", ""],))
['--user', 'me', '-L', 10, '--folder-scheme', ''],
))
def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp_path: Path): def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
@ -177,42 +189,41 @@ def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--subreddit", "python", "-L", 1, "--search-existing"],))
['--subreddit', 'python', '-L', 1, '--search-existing'],
))
def test_cli_download_search_existing(test_args: list[str], tmp_path: Path): def test_cli_download_search_existing(test_args: list[str], tmp_path: Path):
Path(tmp_path, 'test.txt').touch() Path(tmp_path, "test.txt").touch()
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Calculating hashes for' in result.output assert "Calculating hashes for" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--subreddit', 'tumblr', '-L', '25', '--skip', 'png', '--skip', 'jpg'], "test_args",
['--subreddit', 'MaliciousCompliance', '-L', '25', '--skip', 'txt'], (
['--subreddit', 'tumblr', '-L', '10', '--skip-domain', 'i.redd.it'], ["--subreddit", "tumblr", "-L", "25", "--skip", "png", "--skip", "jpg"],
)) ["--subreddit", "MaliciousCompliance", "-L", "25", "--skip", "txt"],
["--subreddit", "tumblr", "-L", "10", "--skip-domain", "i.redd.it"],
),
)
def test_cli_download_download_filters(test_args: list[str], tmp_path: Path): def test_cli_download_download_filters(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert any((string in result.output for string in ('Download filter removed ', 'filtered due to URL'))) assert any((string in result.output for string in ("Download filter removed ", "filtered due to URL")))
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--subreddit", "all", "-L", "100", "--sort", "new"],))
['--subreddit', 'all', '-L', '100', '--sort', 'new'],
))
def test_cli_download_long(test_args: list[str], tmp_path: Path): def test_cli_download_long(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
@ -223,34 +234,40 @@ def test_cli_download_long(test_args: list[str], tmp_path: Path):
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--user', 'sdclhgsolgjeroij', '--submitted', '-L', 10], "test_args",
['--user', 'me', '--upvoted', '-L', 10], (
['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10], ["--user", "sdclhgsolgjeroij", "--submitted", "-L", 10],
['--subreddit', 'submitters', '-L', 10], # Private subreddit ["--user", "me", "--upvoted", "-L", 10],
['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit ["--user", "sdclhgsolgjeroij", "--upvoted", "-L", 10],
['--user', 'djnish', '--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10], ["--subreddit", "submitters", "-L", 10], # Private subreddit
['--subreddit', 'friends', '-L', 10], ["--subreddit", "donaldtrump", "-L", 10], # Banned subreddit
['-l', 'ijy4ch'], # user deleted post ["--user", "djnish", "--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10],
['-l', 'kw4wjm'], # post from banned subreddit ["--subreddit", "friends", "-L", 10],
)) ["-l", "ijy4ch"], # user deleted post
["-l", "kw4wjm"], # post from banned subreddit
),
)
def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path): def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Downloaded' not in result.output assert "Downloaded" not in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--time', 'random'], "test_args",
['--sort', 'random'], (
)) ["--time", "random"],
["--sort", "random"],
),
)
def test_cli_download_hard_fail(test_args: list[str], tmp_path: Path): def test_cli_download_hard_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
@ -260,114 +277,122 @@ def test_cli_download_hard_fail(test_args: list[str], tmp_path: Path):
def test_cli_download_use_default_config(tmp_path: Path): def test_cli_download_use_default_config(tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = ['download', '-vv', str(tmp_path)] test_args = ["download", "-vv", str(tmp_path)]
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["-l", "6l7778", "--exclude-id", "6l7778"],))
['-l', '6l7778', '--exclude-id', '6l7778'],
))
def test_cli_download_links_exclusion(test_args: list[str], tmp_path: Path): def test_cli_download_links_exclusion(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'in exclusion list' in result.output assert "in exclusion list" in result.output
assert 'Downloaded submission ' not in result.output assert "Downloaded submission " not in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-l', '6l7778', '--skip-subreddit', 'EmpireDidNothingWrong'], "test_args",
['-s', 'trollxchromosomes', '--skip-subreddit', 'trollxchromosomes', '-L', '3'], (
)) ["-l", "6l7778", "--skip-subreddit", "EmpireDidNothingWrong"],
["-s", "trollxchromosomes", "--skip-subreddit", "trollxchromosomes", "-L", "3"],
),
)
def test_cli_download_subreddit_exclusion(test_args: list[str], tmp_path: Path): def test_cli_download_subreddit_exclusion(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'in skip list' in result.output assert "in skip list" in result.output
assert 'Downloaded submission ' not in result.output assert "Downloaded submission " not in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['--file-scheme', '{TITLE}'], "test_args",
['--file-scheme', '{TITLE}_test_{SUBREDDIT}'], (
)) ["--file-scheme", "{TITLE}"],
["--file-scheme", "{TITLE}_test_{SUBREDDIT}"],
),
)
def test_cli_download_file_scheme_warning(test_args: list[str], tmp_path: Path): def test_cli_download_file_scheme_warning(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Some files might not be downloaded due to name conflicts' in result.output assert "Some files might not be downloaded due to name conflicts" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize(
['-l', 'n9w9fo', '--disable-module', 'SelfPost'], "test_args",
['-l', 'nnb9vs', '--disable-module', 'VReddit'], (
)) ["-l", "n9w9fo", "--disable-module", "SelfPost"],
["-l", "nnb9vs", "--disable-module", "VReddit"],
),
)
def test_cli_download_disable_modules(test_args: list[str], tmp_path: Path): def test_cli_download_disable_modules(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'skipped due to disabled module' in result.output assert "skipped due to disabled module" in result.output
assert 'Downloaded submission' not in result.output assert "Downloaded submission" not in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
def test_cli_download_include_id_file(tmp_path: Path): def test_cli_download_include_id_file(tmp_path: Path):
test_file = Path(tmp_path, 'include.txt') test_file = Path(tmp_path, "include.txt")
test_args = ['--include-id-file', str(test_file)] test_args = ["--include-id-file", str(test_file)]
test_file.write_text('odr9wg\nody576') test_file.write_text("odr9wg\nody576")
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Downloaded submission' in result.output assert "Downloaded submission" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize('test_args', ( @pytest.mark.parametrize("test_args", (["--ignore-user", "ArjanEgges", "-l", "m3hxzd"],))
['--ignore-user', 'ArjanEgges', '-l', 'm3hxzd'],
))
def test_cli_download_ignore_user(test_args: list[str], tmp_path: Path): def test_cli_download_ignore_user(test_args: list[str], tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Downloaded submission' not in result.output assert "Downloaded submission" not in result.output
assert 'being an ignored user' in result.output assert "being an ignored user" in result.output
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') @pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests")
@pytest.mark.parametrize(('test_args', 'was_filtered'), ( @pytest.mark.parametrize(
(['-l', 'ljyy27', '--min-score', '50'], True), ("test_args", "was_filtered"),
(['-l', 'ljyy27', '--min-score', '1'], False), (
(['-l', 'ljyy27', '--max-score', '1'], True), (["-l", "ljyy27", "--min-score", "50"], True),
(['-l', 'ljyy27', '--max-score', '100'], False), (["-l", "ljyy27", "--min-score", "1"], False),
)) (["-l", "ljyy27", "--max-score", "1"], True),
(["-l", "ljyy27", "--max-score", "100"], False),
),
)
def test_cli_download_score_filter(test_args: list[str], was_filtered: bool, tmp_path: Path): def test_cli_download_score_filter(test_args: list[str], was_filtered: bool, tmp_path: Path):
runner = CliRunner() runner = CliRunner()
test_args = create_basic_args_for_download_runner(test_args, tmp_path) test_args = create_basic_args_for_download_runner(test_args, tmp_path)
result = runner.invoke(cli, test_args) result = runner.invoke(cli, test_args)
assert result.exit_code == 0 assert result.exit_code == 0
assert ('filtered due to score' in result.output) == was_filtered assert ("filtered due to score" in result.output) == was_filtered

View file

@ -10,22 +10,23 @@ from bdfr.site_downloaders.fallback_downloaders.ytdlp_fallback import YtdlpFallb
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(
('https://www.reddit.com/r/specializedtools/comments/n2nw5m/bamboo_splitter/', True), ("test_url", "expected"),
('https://www.youtube.com/watch?v=P19nvJOmqCc', True), (
('https://www.example.com/test', False), ("https://www.reddit.com/r/specializedtools/comments/n2nw5m/bamboo_splitter/", True),
('https://milesmatrix.bandcamp.com/album/la-boum/', False), ("https://www.youtube.com/watch?v=P19nvJOmqCc", True),
('https://v.redd.it/dlr54z8p182a1', True), ("https://www.example.com/test", False),
)) ("https://milesmatrix.bandcamp.com/album/la-boum/", False),
("https://v.redd.it/dlr54z8p182a1", True),
),
)
def test_can_handle_link(test_url: str, expected: bool): def test_can_handle_link(test_url: str, expected: bool):
result = YtdlpFallback.can_handle_link(test_url) result = YtdlpFallback.can_handle_link(test_url)
assert result == expected assert result == expected
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize('test_url', ( @pytest.mark.parametrize("test_url", ("https://milesmatrix.bandcamp.com/album/la-boum/",))
'https://milesmatrix.bandcamp.com/album/la-boum/',
))
def test_info_extraction_bad(test_url: str): def test_info_extraction_bad(test_url: str):
with pytest.raises(NotADownloadableLinkError): with pytest.raises(NotADownloadableLinkError):
YtdlpFallback.get_video_attributes(test_url) YtdlpFallback.get_video_attributes(test_url)
@ -33,12 +34,18 @@ def test_info_extraction_bad(test_url: str):
@pytest.mark.online @pytest.mark.online
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize(('test_url', 'expected_hash'), ( @pytest.mark.parametrize(
('https://streamable.com/dt46y', 'b7e465adaade5f2b6d8c2b4b7d0a2878'), ("test_url", "expected_hash"),
('https://streamable.com/t8sem', '49b2d1220c485455548f1edbc05d4ecf'), (
('https://www.reddit.com/r/specializedtools/comments/n2nw5m/bamboo_splitter/', '6c6ff46e04b4e33a755ae2a9b5a45ac5'), ("https://streamable.com/dt46y", "b7e465adaade5f2b6d8c2b4b7d0a2878"),
('https://v.redd.it/9z1dnk3xr5k61', '226cee353421c7aefb05c92424cc8cdd'), ("https://streamable.com/t8sem", "49b2d1220c485455548f1edbc05d4ecf"),
)) (
"https://www.reddit.com/r/specializedtools/comments/n2nw5m/bamboo_splitter/",
"6c6ff46e04b4e33a755ae2a9b5a45ac5",
),
("https://v.redd.it/9z1dnk3xr5k61", "226cee353421c7aefb05c92424cc8cdd"),
),
)
def test_find_resources(test_url: str, expected_hash: str): def test_find_resources(test_url: str, expected_hash: str):
test_submission = MagicMock() test_submission = MagicMock()
test_submission.url = test_url test_submission.url = test_url

View file

@ -10,10 +10,13 @@ from bdfr.site_downloaders.delay_for_reddit import DelayForReddit
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), ( @pytest.mark.parametrize(
('https://www.delayforreddit.com/dfr/calvin6123/MjU1Njc5NQ==', '3300f28c2f9358d05667985c9c04210d'), ("test_url", "expected_hash"),
('https://www.delayforreddit.com/dfr/RoXs_26/NDAwMzAyOQ==', '09b7b01719dff45ab197bdc08b90f78a'), (
)) ("https://www.delayforreddit.com/dfr/calvin6123/MjU1Njc5NQ==", "3300f28c2f9358d05667985c9c04210d"),
("https://www.delayforreddit.com/dfr/RoXs_26/NDAwMzAyOQ==", "09b7b01719dff45ab197bdc08b90f78a"),
),
)
def test_download_resource(test_url: str, expected_hash: str): def test_download_resource(test_url: str, expected_hash: str):
mock_submission = Mock() mock_submission = Mock()
mock_submission.url = test_url mock_submission.url = test_url

View file

@ -10,10 +10,13 @@ from bdfr.site_downloaders.direct import Direct
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), ( @pytest.mark.parametrize(
('https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4', '48f9bd4dbec1556d7838885612b13b39'), ("test_url", "expected_hash"),
('https://giant.gfycat.com/DazzlingSilkyIguana.mp4', '808941b48fc1e28713d36dd7ed9dc648'), (
)) ("https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4", "48f9bd4dbec1556d7838885612b13b39"),
("https://giant.gfycat.com/DazzlingSilkyIguana.mp4", "808941b48fc1e28713d36dd7ed9dc648"),
),
)
def test_download_resource(test_url: str, expected_hash: str): def test_download_resource(test_url: str, expected_hash: str):
mock_submission = Mock() mock_submission = Mock()
mock_submission.url = test_url mock_submission.url = test_url

View file

@ -21,67 +21,82 @@ from bdfr.site_downloaders.youtube import Youtube
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_submission_url', 'expected_class'), ( @pytest.mark.parametrize(
('https://www.reddit.com/r/TwoXChromosomes/comments/lu29zn/i_refuse_to_live_my_life' ("test_submission_url", "expected_class"),
'_in_anything_but_comfort/', SelfPost), (
('https://i.redd.it/affyv0axd5k61.png', Direct), (
('https://i.imgur.com/bZx1SJQ.jpg', Imgur), "https://www.reddit.com/r/TwoXChromosomes/comments/lu29zn/i_refuse_to_live_my_life"
('https://imgur.com/BuzvZwb.gifv', Imgur), "_in_anything_but_comfort/",
('https://imgur.com/a/MkxAzeg', Imgur), SelfPost,
('https://m.imgur.com/a/py3RW0j', Imgur), ),
('https://www.reddit.com/gallery/lu93m7', Gallery), ("https://i.redd.it/affyv0axd5k61.png", Direct),
('https://gfycat.com/concretecheerfulfinwhale', Gfycat), ("https://i.imgur.com/bZx1SJQ.jpg", Imgur),
('https://www.erome.com/a/NWGw0F09', Erome), ("https://imgur.com/BuzvZwb.gifv", Imgur),
('https://youtube.com/watch?v=Gv8Wz74FjVA', Youtube), ("https://imgur.com/a/MkxAzeg", Imgur),
('https://redgifs.com/watch/courageousimpeccablecanvasback', Redgifs), ("https://m.imgur.com/a/py3RW0j", Imgur),
('https://www.gifdeliverynetwork.com/repulsivefinishedandalusianhorse', Redgifs), ("https://www.reddit.com/gallery/lu93m7", Gallery),
('https://youtu.be/DevfjHOhuFc', Youtube), ("https://gfycat.com/concretecheerfulfinwhale", Gfycat),
('https://m.youtube.com/watch?v=kr-FeojxzUM', Youtube), ("https://www.erome.com/a/NWGw0F09", Erome),
('https://dynasty-scans.com/system/images_images/000/017/819/original/80215103_p0.png?1612232781', Direct), ("https://youtube.com/watch?v=Gv8Wz74FjVA", Youtube),
('https://v.redd.it/9z1dnk3xr5k61', VReddit), ("https://redgifs.com/watch/courageousimpeccablecanvasback", Redgifs),
('https://streamable.com/dt46y', YtdlpFallback), ("https://www.gifdeliverynetwork.com/repulsivefinishedandalusianhorse", Redgifs),
('https://vimeo.com/channels/31259/53576664', YtdlpFallback), ("https://youtu.be/DevfjHOhuFc", Youtube),
('http://video.pbs.org/viralplayer/2365173446/', YtdlpFallback), ("https://m.youtube.com/watch?v=kr-FeojxzUM", Youtube),
('https://www.pornhub.com/view_video.php?viewkey=ph5a2ee0461a8d0', PornHub), ("https://dynasty-scans.com/system/images_images/000/017/819/original/80215103_p0.png?1612232781", Direct),
('https://www.patreon.com/posts/minecart-track-59346560', Gallery), ("https://v.redd.it/9z1dnk3xr5k61", VReddit),
)) ("https://streamable.com/dt46y", YtdlpFallback),
("https://vimeo.com/channels/31259/53576664", YtdlpFallback),
("http://video.pbs.org/viralplayer/2365173446/", YtdlpFallback),
("https://www.pornhub.com/view_video.php?viewkey=ph5a2ee0461a8d0", PornHub),
("https://www.patreon.com/posts/minecart-track-59346560", Gallery),
),
)
def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownloader, reddit_instance: praw.Reddit): def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownloader, reddit_instance: praw.Reddit):
result = DownloadFactory.pull_lever(test_submission_url) result = DownloadFactory.pull_lever(test_submission_url)
assert result is expected_class assert result is expected_class
@pytest.mark.parametrize('test_url', ( @pytest.mark.parametrize(
'random.com', "test_url",
'bad', (
'https://www.google.com/', "random.com",
'https://www.google.com', "bad",
'https://www.google.com/test', "https://www.google.com/",
'https://www.google.com/test/', "https://www.google.com",
)) "https://www.google.com/test",
"https://www.google.com/test/",
),
)
def test_factory_lever_bad(test_url: str): def test_factory_lever_bad(test_url: str):
with pytest.raises(NotADownloadableLinkError): with pytest.raises(NotADownloadableLinkError):
DownloadFactory.pull_lever(test_url) DownloadFactory.pull_lever(test_url)
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(
('www.test.com/test.png', 'test.com/test.png'), ("test_url", "expected"),
('www.test.com/test.png?test_value=random', 'test.com/test.png'), (
('https://youtube.com/watch?v=Gv8Wz74FjVA', 'youtube.com/watch'), ("www.test.com/test.png", "test.com/test.png"),
('https://i.imgur.com/BuzvZwb.gifv', 'i.imgur.com/BuzvZwb.gifv'), ("www.test.com/test.png?test_value=random", "test.com/test.png"),
)) ("https://youtube.com/watch?v=Gv8Wz74FjVA", "youtube.com/watch"),
("https://i.imgur.com/BuzvZwb.gifv", "i.imgur.com/BuzvZwb.gifv"),
),
)
def test_sanitise_url(test_url: str, expected: str): def test_sanitise_url(test_url: str, expected: str):
result = DownloadFactory.sanitise_url(test_url) result = DownloadFactory.sanitise_url(test_url)
assert result == expected assert result == expected
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(
('www.example.com/test.asp', True), ("test_url", "expected"),
('www.example.com/test.html', True), (
('www.example.com/test.js', True), ("www.example.com/test.asp", True),
('www.example.com/test.xhtml', True), ("www.example.com/test.html", True),
('www.example.com/test.mp4', False), ("www.example.com/test.js", True),
('www.example.com/test.png', False), ("www.example.com/test.xhtml", True),
)) ("www.example.com/test.mp4", False),
("www.example.com/test.png", False),
),
)
def test_is_web_resource(test_url: str, expected: bool): def test_is_web_resource(test_url: str, expected: bool):
result = DownloadFactory.is_web_resource(test_url) result = DownloadFactory.is_web_resource(test_url)
assert result == expected assert result == expected

View file

@ -9,31 +9,38 @@ from bdfr.site_downloaders.erome import Erome
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_urls'), ( @pytest.mark.parametrize(
('https://www.erome.com/a/vqtPuLXh', ( ("test_url", "expected_urls"),
r'https://[a-z]\d+.erome.com/\d{3}/vqtPuLXh/KH2qBT99_480p.mp4', (
)), ("https://www.erome.com/a/vqtPuLXh", (r"https://[a-z]\d+.erome.com/\d{3}/vqtPuLXh/KH2qBT99_480p.mp4",)),
('https://www.erome.com/a/ORhX0FZz', ( (
r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/9IYQocM9_480p.mp4', "https://www.erome.com/a/ORhX0FZz",
r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/9eEDc8xm_480p.mp4', (
r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/EvApC7Rp_480p.mp4', r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/9IYQocM9_480p.mp4",
r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/LruobtMs_480p.mp4', r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/9eEDc8xm_480p.mp4",
r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/TJNmSUU5_480p.mp4', r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/EvApC7Rp_480p.mp4",
r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/X11Skh6Z_480p.mp4', r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/LruobtMs_480p.mp4",
r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/bjlTkpn7_480p.mp4' r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/TJNmSUU5_480p.mp4",
)), r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/X11Skh6Z_480p.mp4",
)) r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/bjlTkpn7_480p.mp4",
),
),
),
)
def test_get_link(test_url: str, expected_urls: tuple[str]): def test_get_link(test_url: str, expected_urls: tuple[str]):
result = Erome. _get_links(test_url) result = Erome._get_links(test_url)
assert all([any([re.match(p, r) for r in result]) for p in expected_urls]) assert all([any([re.match(p, r) for r in result]) for p in expected_urls])
@pytest.mark.online @pytest.mark.online
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize(('test_url', 'expected_hashes_len'), ( @pytest.mark.parametrize(
('https://www.erome.com/a/vqtPuLXh', 1), ("test_url", "expected_hashes_len"),
('https://www.erome.com/a/4tP3KI6F', 1), (
)) ("https://www.erome.com/a/vqtPuLXh", 1),
("https://www.erome.com/a/4tP3KI6F", 1),
),
)
def test_download_resource(test_url: str, expected_hashes_len: int): def test_download_resource(test_url: str, expected_hashes_len: int):
# Can't compare hashes for this test, Erome doesn't return the exact same file from request to request so the hash # Can't compare hashes for this test, Erome doesn't return the exact same file from request to request so the hash
# will change back and forth randomly # will change back and forth randomly

View file

@ -9,30 +9,39 @@ from bdfr.site_downloaders.gallery import Gallery
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_ids', 'expected'), ( @pytest.mark.parametrize(
([ ("test_ids", "expected"),
{'media_id': '18nzv9ch0hn61'}, (
{'media_id': 'jqkizcch0hn61'}, (
{'media_id': 'k0fnqzbh0hn61'}, [
{'media_id': 'm3gamzbh0hn61'}, {"media_id": "18nzv9ch0hn61"},
], { {"media_id": "jqkizcch0hn61"},
'https://i.redd.it/18nzv9ch0hn61.jpg', {"media_id": "k0fnqzbh0hn61"},
'https://i.redd.it/jqkizcch0hn61.jpg', {"media_id": "m3gamzbh0hn61"},
'https://i.redd.it/k0fnqzbh0hn61.jpg', ],
'https://i.redd.it/m3gamzbh0hn61.jpg' {
}), "https://i.redd.it/18nzv9ch0hn61.jpg",
([ "https://i.redd.it/jqkizcch0hn61.jpg",
{'media_id': '04vxj25uqih61'}, "https://i.redd.it/k0fnqzbh0hn61.jpg",
{'media_id': '0fnx83kpqih61'}, "https://i.redd.it/m3gamzbh0hn61.jpg",
{'media_id': '7zkmr1wqqih61'}, },
{'media_id': 'u37k5gxrqih61'}, ),
], { (
'https://i.redd.it/04vxj25uqih61.png', [
'https://i.redd.it/0fnx83kpqih61.png', {"media_id": "04vxj25uqih61"},
'https://i.redd.it/7zkmr1wqqih61.png', {"media_id": "0fnx83kpqih61"},
'https://i.redd.it/u37k5gxrqih61.png' {"media_id": "7zkmr1wqqih61"},
}), {"media_id": "u37k5gxrqih61"},
)) ],
{
"https://i.redd.it/04vxj25uqih61.png",
"https://i.redd.it/0fnx83kpqih61.png",
"https://i.redd.it/7zkmr1wqqih61.png",
"https://i.redd.it/u37k5gxrqih61.png",
},
),
),
)
def test_gallery_get_links(test_ids: list[dict], expected: set[str]): def test_gallery_get_links(test_ids: list[dict], expected: set[str]):
results = Gallery._get_links(test_ids) results = Gallery._get_links(test_ids)
assert set(results) == expected assert set(results) == expected
@ -40,32 +49,47 @@ def test_gallery_get_links(test_ids: list[dict], expected: set[str]):
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_hashes'), ( @pytest.mark.parametrize(
('m6lvrh', { ("test_submission_id", "expected_hashes"),
'5c42b8341dd56eebef792e86f3981c6a', (
'8f38d76da46f4057bf2773a778e725ca', (
'f5776f8f90491c8b770b8e0a6bfa49b3', "m6lvrh",
'fa1a43c94da30026ad19a9813a0ed2c2', {
}), "5c42b8341dd56eebef792e86f3981c6a",
('ljyy27', { "8f38d76da46f4057bf2773a778e725ca",
'359c203ec81d0bc00e675f1023673238', "f5776f8f90491c8b770b8e0a6bfa49b3",
'79262fd46bce5bfa550d878a3b898be4', "fa1a43c94da30026ad19a9813a0ed2c2",
'808c35267f44acb523ce03bfa5687404', },
'ec8b65bdb7f1279c4b3af0ea2bbb30c3', ),
}), (
('obkflw', { "ljyy27",
'65163f685fb28c5b776e0e77122718be', {
'2a337eb5b13c34d3ca3f51b5db7c13e9', "359c203ec81d0bc00e675f1023673238",
}), "79262fd46bce5bfa550d878a3b898be4",
('rb3ub6', { # patreon post "808c35267f44acb523ce03bfa5687404",
'748a976c6cedf7ea85b6f90e7cb685c7', "ec8b65bdb7f1279c4b3af0ea2bbb30c3",
'839796d7745e88ced6355504e1f74508', },
'bcdb740367d0f19f97a77e614b48a42d', ),
'0f230b8c4e5d103d35a773fab9814ec3', (
'e5192d6cb4f84c4f4a658355310bf0f9', "obkflw",
'91cbe172cd8ccbcf049fcea4204eb979', {
}) "65163f685fb28c5b776e0e77122718be",
)) "2a337eb5b13c34d3ca3f51b5db7c13e9",
},
),
(
"rb3ub6",
{ # patreon post
"748a976c6cedf7ea85b6f90e7cb685c7",
"839796d7745e88ced6355504e1f74508",
"bcdb740367d0f19f97a77e614b48a42d",
"0f230b8c4e5d103d35a773fab9814ec3",
"e5192d6cb4f84c4f4a658355310bf0f9",
"91cbe172cd8ccbcf049fcea4204eb979",
},
),
),
)
def test_gallery_download(test_submission_id: str, expected_hashes: set[str], reddit_instance: praw.Reddit): def test_gallery_download(test_submission_id: str, expected_hashes: set[str], reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id) test_submission = reddit_instance.submission(id=test_submission_id)
gallery = Gallery(test_submission) gallery = Gallery(test_submission)
@ -75,10 +99,13 @@ def test_gallery_download(test_submission_id: str, expected_hashes: set[str], re
assert set(hashes) == expected_hashes assert set(hashes) == expected_hashes
@pytest.mark.parametrize('test_id', ( @pytest.mark.parametrize(
'n0pyzp', "test_id",
'nxyahw', (
)) "n0pyzp",
"nxyahw",
),
)
def test_gallery_download_raises_right_error(test_id: str, reddit_instance: praw.Reddit): def test_gallery_download_raises_right_error(test_id: str, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_id) test_submission = reddit_instance.submission(id=test_id)
gallery = Gallery(test_submission) gallery = Gallery(test_submission)

View file

@ -10,20 +10,26 @@ from bdfr.site_downloaders.gfycat import Gfycat
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_url'), ( @pytest.mark.parametrize(
('https://gfycat.com/definitivecaninecrayfish', 'https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4'), ("test_url", "expected_url"),
('https://gfycat.com/dazzlingsilkyiguana', 'https://giant.gfycat.com/DazzlingSilkyIguana.mp4'), (
)) ("https://gfycat.com/definitivecaninecrayfish", "https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4"),
("https://gfycat.com/dazzlingsilkyiguana", "https://giant.gfycat.com/DazzlingSilkyIguana.mp4"),
),
)
def test_get_link(test_url: str, expected_url: str): def test_get_link(test_url: str, expected_url: str):
result = Gfycat._get_link(test_url) result = Gfycat._get_link(test_url)
assert result.pop() == expected_url assert result.pop() == expected_url
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), ( @pytest.mark.parametrize(
('https://gfycat.com/definitivecaninecrayfish', '48f9bd4dbec1556d7838885612b13b39'), ("test_url", "expected_hash"),
('https://gfycat.com/dazzlingsilkyiguana', '808941b48fc1e28713d36dd7ed9dc648'), (
)) ("https://gfycat.com/definitivecaninecrayfish", "48f9bd4dbec1556d7838885612b13b39"),
("https://gfycat.com/dazzlingsilkyiguana", "808941b48fc1e28713d36dd7ed9dc648"),
),
)
def test_download_resource(test_url: str, expected_hash: str): def test_download_resource(test_url: str, expected_hash: str):
mock_submission = Mock() mock_submission = Mock()
mock_submission.url = test_url mock_submission.url = test_url

View file

@ -11,166 +11,167 @@ from bdfr.site_downloaders.imgur import Imgur
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_gen_dict', 'expected_image_dict'), ( @pytest.mark.parametrize(
("test_url", "expected_gen_dict", "expected_image_dict"),
( (
'https://imgur.com/a/xWZsDDP', (
{'num_images': '1', 'id': 'xWZsDDP', 'hash': 'xWZsDDP'}, "https://imgur.com/a/xWZsDDP",
[ {"num_images": "1", "id": "xWZsDDP", "hash": "xWZsDDP"},
{'hash': 'ypa8YfS', 'title': '', 'ext': '.png', 'animated': False} [{"hash": "ypa8YfS", "title": "", "ext": ".png", "animated": False}],
] ),
(
"https://imgur.com/gallery/IjJJdlC",
{"num_images": 1, "id": 384898055, "hash": "IjJJdlC"},
[
{
"hash": "CbbScDt",
"description": "watch when he gets it",
"ext": ".gif",
"animated": True,
"has_sound": False,
}
],
),
(
"https://imgur.com/a/dcc84Gt",
{"num_images": "4", "id": "dcc84Gt", "hash": "dcc84Gt"},
[
{"hash": "ylx0Kle", "ext": ".jpg", "title": ""},
{"hash": "TdYfKbK", "ext": ".jpg", "title": ""},
{"hash": "pCxGbe8", "ext": ".jpg", "title": ""},
{"hash": "TSAkikk", "ext": ".jpg", "title": ""},
],
),
(
"https://m.imgur.com/a/py3RW0j",
{
"num_images": "1",
"id": "py3RW0j",
"hash": "py3RW0j",
},
[{"hash": "K24eQmK", "has_sound": False, "ext": ".jpg"}],
),
), ),
( )
'https://imgur.com/gallery/IjJJdlC',
{'num_images': 1, 'id': 384898055, 'hash': 'IjJJdlC'},
[
{'hash': 'CbbScDt',
'description': 'watch when he gets it',
'ext': '.gif',
'animated': True,
'has_sound': False
}
],
),
(
'https://imgur.com/a/dcc84Gt',
{'num_images': '4', 'id': 'dcc84Gt', 'hash': 'dcc84Gt'},
[
{'hash': 'ylx0Kle', 'ext': '.jpg', 'title': ''},
{'hash': 'TdYfKbK', 'ext': '.jpg', 'title': ''},
{'hash': 'pCxGbe8', 'ext': '.jpg', 'title': ''},
{'hash': 'TSAkikk', 'ext': '.jpg', 'title': ''},
]
),
(
'https://m.imgur.com/a/py3RW0j',
{'num_images': '1', 'id': 'py3RW0j', 'hash': 'py3RW0j', },
[
{'hash': 'K24eQmK', 'has_sound': False, 'ext': '.jpg'}
],
),
))
def test_get_data_album(test_url: str, expected_gen_dict: dict, expected_image_dict: list[dict]): def test_get_data_album(test_url: str, expected_gen_dict: dict, expected_image_dict: list[dict]):
result = Imgur._get_data(test_url) result = Imgur._get_data(test_url)
assert all([result.get(key) == expected_gen_dict[key] for key in expected_gen_dict.keys()]) assert all([result.get(key) == expected_gen_dict[key] for key in expected_gen_dict.keys()])
# Check if all the keys from the test dict are correct in at least one of the album entries # Check if all the keys from the test dict are correct in at least one of the album entries
assert any([all([image.get(key) == image_dict[key] for key in image_dict.keys()]) assert any(
for image_dict in expected_image_dict for image in result['album_images']['images']]) [
all([image.get(key) == image_dict[key] for key in image_dict.keys()])
for image_dict in expected_image_dict
for image in result["album_images"]["images"]
]
)
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_image_dict'), ( @pytest.mark.parametrize(
("test_url", "expected_image_dict"),
( (
'https://i.imgur.com/dLk3FGY.gifv', ("https://i.imgur.com/dLk3FGY.gifv", {"hash": "dLk3FGY", "title": "", "ext": ".mp4", "animated": True}),
{'hash': 'dLk3FGY', 'title': '', 'ext': '.mp4', 'animated': True} (
"https://imgur.com/65FqTpT.gifv",
{"hash": "65FqTpT", "title": "", "description": "", "animated": True, "mimetype": "video/mp4"},
),
), ),
( )
'https://imgur.com/65FqTpT.gifv',
{
'hash': '65FqTpT',
'title': '',
'description': '',
'animated': True,
'mimetype': 'video/mp4'
},
),
))
def test_get_data_gif(test_url: str, expected_image_dict: dict): def test_get_data_gif(test_url: str, expected_image_dict: dict):
result = Imgur._get_data(test_url) result = Imgur._get_data(test_url)
assert all([result.get(key) == expected_image_dict[key] for key in expected_image_dict.keys()]) assert all([result.get(key) == expected_image_dict[key] for key in expected_image_dict.keys()])
@pytest.mark.parametrize('test_extension', ( @pytest.mark.parametrize("test_extension", (".gif", ".png", ".jpg", ".mp4"))
'.gif',
'.png',
'.jpg',
'.mp4'
))
def test_imgur_extension_validation_good(test_extension: str): def test_imgur_extension_validation_good(test_extension: str):
result = Imgur._validate_extension(test_extension) result = Imgur._validate_extension(test_extension)
assert result == test_extension assert result == test_extension
@pytest.mark.parametrize('test_extension', ( @pytest.mark.parametrize(
'.jpeg', "test_extension",
'bad', (
'.avi', ".jpeg",
'.test', "bad",
'.flac', ".avi",
)) ".test",
".flac",
),
)
def test_imgur_extension_validation_bad(test_extension: str): def test_imgur_extension_validation_bad(test_extension: str):
with pytest.raises(SiteDownloaderError): with pytest.raises(SiteDownloaderError):
Imgur._validate_extension(test_extension) Imgur._validate_extension(test_extension)
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hashes'), ( @pytest.mark.parametrize(
("test_url", "expected_hashes"),
( (
'https://imgur.com/a/xWZsDDP', ("https://imgur.com/a/xWZsDDP", ("f551d6e6b0fef2ce909767338612e31b",)),
('f551d6e6b0fef2ce909767338612e31b',)
),
(
'https://imgur.com/gallery/IjJJdlC',
('740b006cf9ec9d6f734b6e8f5130bdab',),
),
(
'https://imgur.com/a/dcc84Gt',
( (
'cf1158e1de5c3c8993461383b96610cf', "https://imgur.com/gallery/IjJJdlC",
'28d6b791a2daef8aa363bf5a3198535d', ("740b006cf9ec9d6f734b6e8f5130bdab",),
'248ef8f2a6d03eeb2a80d0123dbaf9b6', ),
'029c475ce01b58fdf1269d8771d33913', (
"https://imgur.com/a/dcc84Gt",
(
"cf1158e1de5c3c8993461383b96610cf",
"28d6b791a2daef8aa363bf5a3198535d",
"248ef8f2a6d03eeb2a80d0123dbaf9b6",
"029c475ce01b58fdf1269d8771d33913",
),
),
(
"https://imgur.com/a/eemHCCK",
(
"9cb757fd8f055e7ef7aa88addc9d9fa5",
"b6cb6c918e2544e96fb7c07d828774b5",
"fb6c913d721c0bbb96aa65d7f560d385",
),
),
(
"https://i.imgur.com/lFJai6i.gifv",
("01a6e79a30bec0e644e5da12365d5071",),
),
(
"https://i.imgur.com/ywSyILa.gifv?",
("56d4afc32d2966017c38d98568709b45",),
),
(
"https://imgur.com/ubYwpbk.GIFV",
("d4a774aac1667783f9ed3a1bd02fac0c",),
),
(
"https://i.imgur.com/j1CNCZY.gifv",
("58e7e6d972058c18b7ecde910ca147e3",),
),
(
"https://i.imgur.com/uTvtQsw.gifv",
("46c86533aa60fc0e09f2a758513e3ac2",),
),
(
"https://i.imgur.com/OGeVuAe.giff",
("77389679084d381336f168538793f218",),
),
(
"https://i.imgur.com/OGeVuAe.gift",
("77389679084d381336f168538793f218",),
),
(
"https://i.imgur.com/3SKrQfK.jpg?1",
("aa299e181b268578979cad176d1bd1d0",),
),
(
"https://i.imgur.com/cbivYRW.jpg?3",
("7ec6ceef5380cb163a1d498c359c51fd",),
),
(
"http://i.imgur.com/s9uXxlq.jpg?5.jpg",
("338de3c23ee21af056b3a7c154e2478f",),
), ),
), ),
( )
'https://imgur.com/a/eemHCCK',
(
'9cb757fd8f055e7ef7aa88addc9d9fa5',
'b6cb6c918e2544e96fb7c07d828774b5',
'fb6c913d721c0bbb96aa65d7f560d385',
),
),
(
'https://i.imgur.com/lFJai6i.gifv',
('01a6e79a30bec0e644e5da12365d5071',),
),
(
'https://i.imgur.com/ywSyILa.gifv?',
('56d4afc32d2966017c38d98568709b45',),
),
(
'https://imgur.com/ubYwpbk.GIFV',
('d4a774aac1667783f9ed3a1bd02fac0c',),
),
(
'https://i.imgur.com/j1CNCZY.gifv',
('58e7e6d972058c18b7ecde910ca147e3',),
),
(
'https://i.imgur.com/uTvtQsw.gifv',
('46c86533aa60fc0e09f2a758513e3ac2',),
),
(
'https://i.imgur.com/OGeVuAe.giff',
('77389679084d381336f168538793f218',),
),
(
'https://i.imgur.com/OGeVuAe.gift',
('77389679084d381336f168538793f218',),
),
(
'https://i.imgur.com/3SKrQfK.jpg?1',
('aa299e181b268578979cad176d1bd1d0',),
),
(
'https://i.imgur.com/cbivYRW.jpg?3',
('7ec6ceef5380cb163a1d498c359c51fd',),
),
(
'http://i.imgur.com/s9uXxlq.jpg?5.jpg',
('338de3c23ee21af056b3a7c154e2478f',),
),
))
def test_find_resources(test_url: str, expected_hashes: list[str]): def test_find_resources(test_url: str, expected_hashes: list[str]):
mock_download = Mock() mock_download = Mock()
mock_download.url = test_url mock_download.url = test_url

View file

@ -12,9 +12,10 @@ from bdfr.site_downloaders.pornhub import PornHub
@pytest.mark.online @pytest.mark.online
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize(('test_url', 'expected_hash'), ( @pytest.mark.parametrize(
('https://www.pornhub.com/view_video.php?viewkey=ph6074c59798497', 'ad52a0f4fce8f99df0abed17de1d04c7'), ("test_url", "expected_hash"),
)) (("https://www.pornhub.com/view_video.php?viewkey=ph6074c59798497", "ad52a0f4fce8f99df0abed17de1d04c7"),),
)
def test_hash_resources_good(test_url: str, expected_hash: str): def test_hash_resources_good(test_url: str, expected_hash: str):
test_submission = MagicMock() test_submission = MagicMock()
test_submission.url = test_url test_submission.url = test_url
@ -27,9 +28,7 @@ def test_hash_resources_good(test_url: str, expected_hash: str):
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize('test_url', ( @pytest.mark.parametrize("test_url", ("https://www.pornhub.com/view_video.php?viewkey=ph5ede121f0d3f8",))
'https://www.pornhub.com/view_video.php?viewkey=ph5ede121f0d3f8',
))
def test_find_resources_good(test_url: str): def test_find_resources_good(test_url: str):
test_submission = MagicMock() test_submission = MagicMock()
test_submission.url = test_url test_submission.url = test_url

View file

@ -1,8 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# coding=utf-8 # coding=utf-8
from unittest.mock import Mock
import re import re
from unittest.mock import Mock
import pytest import pytest
@ -11,45 +11,55 @@ from bdfr.site_downloaders.redgifs import Redgifs
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(
('https://redgifs.com/watch/frighteningvictorioussalamander', ("test_url", "expected"),
{'FrighteningVictoriousSalamander.mp4'}), (
('https://redgifs.com/watch/springgreendecisivetaruca', ("https://redgifs.com/watch/frighteningvictorioussalamander", {"FrighteningVictoriousSalamander.mp4"}),
{'SpringgreenDecisiveTaruca.mp4'}), ("https://redgifs.com/watch/springgreendecisivetaruca", {"SpringgreenDecisiveTaruca.mp4"}),
('https://www.redgifs.com/watch/palegoldenrodrawhalibut', ("https://www.redgifs.com/watch/palegoldenrodrawhalibut", {"PalegoldenrodRawHalibut.mp4"}),
{'PalegoldenrodRawHalibut.mp4'}), ("https://redgifs.com/watch/hollowintentsnowyowl", {"HollowIntentSnowyowl-large.jpg"}),
('https://redgifs.com/watch/hollowintentsnowyowl', (
{'HollowIntentSnowyowl-large.jpg'}), "https://www.redgifs.com/watch/lustrousstickywaxwing",
('https://www.redgifs.com/watch/lustrousstickywaxwing', {
{'EntireEnchantingHypsilophodon-large.jpg', "EntireEnchantingHypsilophodon-large.jpg",
'FancyMagnificentAdamsstaghornedbeetle-large.jpg', "FancyMagnificentAdamsstaghornedbeetle-large.jpg",
'LustrousStickyWaxwing-large.jpg', "LustrousStickyWaxwing-large.jpg",
'ParchedWindyArmyworm-large.jpg', "ParchedWindyArmyworm-large.jpg",
'ThunderousColorlessErmine-large.jpg', "ThunderousColorlessErmine-large.jpg",
'UnripeUnkemptWoodpecker-large.jpg'}), "UnripeUnkemptWoodpecker-large.jpg",
)) },
),
),
)
def test_get_link(test_url: str, expected: set[str]): def test_get_link(test_url: str, expected: set[str]):
result = Redgifs._get_link(test_url) result = Redgifs._get_link(test_url)
result = list(result) result = list(result)
patterns = [r'https://thumbs\d\.redgifs\.com/' + e + r'.*' for e in expected] patterns = [r"https://thumbs\d\.redgifs\.com/" + e + r".*" for e in expected]
assert all([re.match(p, r) for p in patterns] for r in result) assert all([re.match(p, r) for p in patterns] for r in result)
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hashes'), ( @pytest.mark.parametrize(
('https://redgifs.com/watch/frighteningvictorioussalamander', {'4007c35d9e1f4b67091b5f12cffda00a'}), ("test_url", "expected_hashes"),
('https://redgifs.com/watch/springgreendecisivetaruca', {'8dac487ac49a1f18cc1b4dabe23f0869'}), (
('https://redgifs.com/watch/leafysaltydungbeetle', {'076792c660b9c024c0471ef4759af8bd'}), ("https://redgifs.com/watch/frighteningvictorioussalamander", {"4007c35d9e1f4b67091b5f12cffda00a"}),
('https://www.redgifs.com/watch/palegoldenrodrawhalibut', {'46d5aa77fe80c6407de1ecc92801c10e'}), ("https://redgifs.com/watch/springgreendecisivetaruca", {"8dac487ac49a1f18cc1b4dabe23f0869"}),
('https://redgifs.com/watch/hollowintentsnowyowl', {'5ee51fa15e0a58e98f11dea6a6cca771'}), ("https://redgifs.com/watch/leafysaltydungbeetle", {"076792c660b9c024c0471ef4759af8bd"}),
('https://www.redgifs.com/watch/lustrousstickywaxwing', ("https://www.redgifs.com/watch/palegoldenrodrawhalibut", {"46d5aa77fe80c6407de1ecc92801c10e"}),
{'b461e55664f07bed8d2f41d8586728fa', ("https://redgifs.com/watch/hollowintentsnowyowl", {"5ee51fa15e0a58e98f11dea6a6cca771"}),
'30ba079a8ed7d7adf17929dc3064c10f', (
'0d4f149d170d29fc2f015c1121bab18b', "https://www.redgifs.com/watch/lustrousstickywaxwing",
'53987d99cfd77fd65b5fdade3718f9f1', {
'fb2e7d972846b83bf4016447d3060d60', "b461e55664f07bed8d2f41d8586728fa",
'44fb28f72ec9a5cca63fa4369ab4f672'}), "30ba079a8ed7d7adf17929dc3064c10f",
)) "0d4f149d170d29fc2f015c1121bab18b",
"53987d99cfd77fd65b5fdade3718f9f1",
"fb2e7d972846b83bf4016447d3060d60",
"44fb28f72ec9a5cca63fa4369ab4f672",
},
),
),
)
def test_download_resource(test_url: str, expected_hashes: set[str]): def test_download_resource(test_url: str, expected_hashes: set[str]):
mock_submission = Mock() mock_submission = Mock()
mock_submission.url = test_url mock_submission.url = test_url
@ -62,18 +72,30 @@ def test_download_resource(test_url: str, expected_hashes: set[str]):
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_link', 'expected_hash'), ( @pytest.mark.parametrize(
('https://redgifs.com/watch/flippantmemorablebaiji', {'FlippantMemorableBaiji-mobile.mp4'}, ("test_url", "expected_link", "expected_hash"),
{'41a5fb4865367ede9f65fc78736f497a'}), (
('https://redgifs.com/watch/thirstyunfortunatewaterdragons', {'thirstyunfortunatewaterdragons-mobile.mp4'}, (
{'1a51dad8fedb594bdd84f027b3cbe8af'}), "https://redgifs.com/watch/flippantmemorablebaiji",
('https://redgifs.com/watch/conventionalplainxenopterygii', {'conventionalplainxenopterygii-mobile.mp4'}, {"FlippantMemorableBaiji-mobile.mp4"},
{'2e1786b3337da85b80b050e2c289daa4'}) {"41a5fb4865367ede9f65fc78736f497a"},
)) ),
(
"https://redgifs.com/watch/thirstyunfortunatewaterdragons",
{"thirstyunfortunatewaterdragons-mobile.mp4"},
{"1a51dad8fedb594bdd84f027b3cbe8af"},
),
(
"https://redgifs.com/watch/conventionalplainxenopterygii",
{"conventionalplainxenopterygii-mobile.mp4"},
{"2e1786b3337da85b80b050e2c289daa4"},
),
),
)
def test_hd_soft_fail(test_url: str, expected_link: set[str], expected_hash: set[str]): def test_hd_soft_fail(test_url: str, expected_link: set[str], expected_hash: set[str]):
link = Redgifs._get_link(test_url) link = Redgifs._get_link(test_url)
link = list(link) link = list(link)
patterns = [r'https://thumbs\d\.redgifs\.com/' + e + r'.*' for e in expected_link] patterns = [r"https://thumbs\d\.redgifs\.com/" + e + r".*" for e in expected_link]
assert all([re.match(p, r) for p in patterns] for r in link) assert all([re.match(p, r) for p in patterns] for r in link)
mock_submission = Mock() mock_submission = Mock()
mock_submission.url = test_url mock_submission.url = test_url

View file

@ -10,11 +10,14 @@ from bdfr.site_downloaders.self_post import SelfPost
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_hash'), ( @pytest.mark.parametrize(
('ltmivt', '7d2c9e4e989e5cf2dca2e55a06b1c4f6'), ("test_submission_id", "expected_hash"),
('ltoaan', '221606386b614d6780c2585a59bd333f'), (
('d3sc8o', 'c1ff2b6bd3f6b91381dcd18dfc4ca35f'), ("ltmivt", "7d2c9e4e989e5cf2dca2e55a06b1c4f6"),
)) ("ltoaan", "221606386b614d6780c2585a59bd333f"),
("d3sc8o", "c1ff2b6bd3f6b91381dcd18dfc4ca35f"),
),
)
def test_find_resource(test_submission_id: str, expected_hash: str, reddit_instance: praw.Reddit): def test_find_resource(test_submission_id: str, expected_hash: str, reddit_instance: praw.Reddit):
submission = reddit_instance.submission(id=test_submission_id) submission = reddit_instance.submission(id=test_submission_id)
downloader = SelfPost(submission) downloader = SelfPost(submission)

View file

@ -8,55 +8,83 @@ from bdfr.resource import Resource
from bdfr.site_downloaders.vidble import Vidble from bdfr.site_downloaders.vidble import Vidble
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(("test_url", "expected"), (("/RDFbznUvcN_med.jpg", "/RDFbznUvcN.jpg"),))
('/RDFbznUvcN_med.jpg', '/RDFbznUvcN.jpg'),
))
def test_change_med_url(test_url: str, expected: str): def test_change_med_url(test_url: str, expected: str):
result = Vidble.change_med_url(test_url) result = Vidble.change_med_url(test_url)
assert result == expected assert result == expected
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(
('https://www.vidble.com/show/UxsvAssYe5', { ("test_url", "expected"),
'https://www.vidble.com/UxsvAssYe5.gif', (
}), (
('https://vidble.com/show/RDFbznUvcN', { "https://www.vidble.com/show/UxsvAssYe5",
'https://www.vidble.com/RDFbznUvcN.jpg', {
}), "https://www.vidble.com/UxsvAssYe5.gif",
('https://vidble.com/album/h0jTLs6B', { },
'https://www.vidble.com/XG4eAoJ5JZ.jpg', ),
'https://www.vidble.com/IqF5UdH6Uq.jpg', (
'https://www.vidble.com/VWuNsnLJMD.jpg', "https://vidble.com/show/RDFbznUvcN",
'https://www.vidble.com/sMmM8O650W.jpg', {
}), "https://www.vidble.com/RDFbznUvcN.jpg",
('https://www.vidble.com/pHuwWkOcEb', { },
'https://www.vidble.com/pHuwWkOcEb.jpg', ),
}), (
)) "https://vidble.com/album/h0jTLs6B",
{
"https://www.vidble.com/XG4eAoJ5JZ.jpg",
"https://www.vidble.com/IqF5UdH6Uq.jpg",
"https://www.vidble.com/VWuNsnLJMD.jpg",
"https://www.vidble.com/sMmM8O650W.jpg",
},
),
(
"https://www.vidble.com/pHuwWkOcEb",
{
"https://www.vidble.com/pHuwWkOcEb.jpg",
},
),
),
)
def test_get_links(test_url: str, expected: set[str]): def test_get_links(test_url: str, expected: set[str]):
results = Vidble.get_links(test_url) results = Vidble.get_links(test_url)
assert results == expected assert results == expected
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hashes'), ( @pytest.mark.parametrize(
('https://www.vidble.com/show/UxsvAssYe5', { ("test_url", "expected_hashes"),
'0ef2f8e0e0b45936d2fb3e6fbdf67e28', (
}), (
('https://vidble.com/show/RDFbznUvcN', { "https://www.vidble.com/show/UxsvAssYe5",
'c2dd30a71e32369c50eed86f86efff58', {
}), "0ef2f8e0e0b45936d2fb3e6fbdf67e28",
('https://vidble.com/album/h0jTLs6B', { },
'3b3cba02e01c91f9858a95240b942c71', ),
'dd6ecf5fc9e936f9fb614eb6a0537f99', (
'b31a942cd8cdda218ed547bbc04c3a27', "https://vidble.com/show/RDFbznUvcN",
'6f77c570b451eef4222804bd52267481', {
}), "c2dd30a71e32369c50eed86f86efff58",
('https://www.vidble.com/pHuwWkOcEb', { },
'585f486dd0b2f23a57bddbd5bf185bc7', ),
}), (
)) "https://vidble.com/album/h0jTLs6B",
{
"3b3cba02e01c91f9858a95240b942c71",
"dd6ecf5fc9e936f9fb614eb6a0537f99",
"b31a942cd8cdda218ed547bbc04c3a27",
"6f77c570b451eef4222804bd52267481",
},
),
(
"https://www.vidble.com/pHuwWkOcEb",
{
"585f486dd0b2f23a57bddbd5bf185bc7",
},
),
),
)
def test_find_resources(test_url: str, expected_hashes: set[str]): def test_find_resources(test_url: str, expected_hashes: set[str]):
mock_download = Mock() mock_download = Mock()
mock_download.url = test_url mock_download.url = test_url

View file

@ -12,9 +12,10 @@ from bdfr.site_downloaders.vreddit import VReddit
@pytest.mark.online @pytest.mark.online
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize(('test_url', 'expected_hash'), ( @pytest.mark.parametrize(
('https://reddit.com/r/Unexpected/comments/z4xsuj/omg_thats_so_cute/', '1ffab5e5c0cc96db18108e4f37e8ca7f'), ("test_url", "expected_hash"),
)) (("https://reddit.com/r/Unexpected/comments/z4xsuj/omg_thats_so_cute/", "1ffab5e5c0cc96db18108e4f37e8ca7f"),),
)
def test_find_resources_good(test_url: str, expected_hash: str): def test_find_resources_good(test_url: str, expected_hash: str):
test_submission = MagicMock() test_submission = MagicMock()
test_submission.url = test_url test_submission.url = test_url
@ -27,10 +28,13 @@ def test_find_resources_good(test_url: str, expected_hash: str):
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize('test_url', ( @pytest.mark.parametrize(
'https://www.polygon.com/disney-plus/2020/5/14/21249881/gargoyles-animated-series-disney-plus-greg-weisman' "test_url",
'-interview-oj-simpson-goliath-chronicles', (
)) "https://www.polygon.com/disney-plus/2020/5/14/21249881/gargoyles-animated-series-disney-plus-greg-weisman"
"-interview-oj-simpson-goliath-chronicles",
),
)
def test_find_resources_bad(test_url: str): def test_find_resources_bad(test_url: str):
test_submission = MagicMock() test_submission = MagicMock()
test_submission.url = test_url test_submission.url = test_url

View file

@ -12,10 +12,13 @@ from bdfr.site_downloaders.youtube import Youtube
@pytest.mark.online @pytest.mark.online
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize(('test_url', 'expected_hash'), ( @pytest.mark.parametrize(
('https://www.youtube.com/watch?v=uSm2VDgRIUs', '2d60b54582df5b95ec72bb00b580d2ff'), ("test_url", "expected_hash"),
('https://www.youtube.com/watch?v=GcI7nxQj7HA', '5db0fc92a0a7fb9ac91e63505eea9cf0'), (
)) ("https://www.youtube.com/watch?v=uSm2VDgRIUs", "2d60b54582df5b95ec72bb00b580d2ff"),
("https://www.youtube.com/watch?v=GcI7nxQj7HA", "5db0fc92a0a7fb9ac91e63505eea9cf0"),
),
)
def test_find_resources_good(test_url: str, expected_hash: str): def test_find_resources_good(test_url: str, expected_hash: str):
test_submission = MagicMock() test_submission = MagicMock()
test_submission.url = test_url test_submission.url = test_url
@ -28,10 +31,13 @@ def test_find_resources_good(test_url: str, expected_hash: str):
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize('test_url', ( @pytest.mark.parametrize(
'https://www.polygon.com/disney-plus/2020/5/14/21249881/gargoyles-animated-series-disney-plus-greg-weisman' "test_url",
'-interview-oj-simpson-goliath-chronicles', (
)) "https://www.polygon.com/disney-plus/2020/5/14/21249881/gargoyles-animated-series-disney-plus-greg-weisman"
"-interview-oj-simpson-goliath-chronicles",
),
)
def test_find_resources_bad(test_url: str): def test_find_resources_bad(test_url: str):
test_submission = MagicMock() test_submission = MagicMock()
test_submission.url = test_url test_submission.url = test_url

View file

@ -12,15 +12,18 @@ from bdfr.archiver import Archiver
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'test_format'), ( @pytest.mark.parametrize(
('m3reby', 'xml'), ("test_submission_id", "test_format"),
('m3reby', 'json'), (
('m3reby', 'yaml'), ("m3reby", "xml"),
)) ("m3reby", "json"),
("m3reby", "yaml"),
),
)
def test_write_submission_json(test_submission_id: str, tmp_path: Path, test_format: str, reddit_instance: praw.Reddit): def test_write_submission_json(test_submission_id: str, tmp_path: Path, test_format: str, reddit_instance: praw.Reddit):
archiver_mock = MagicMock() archiver_mock = MagicMock()
archiver_mock.args.format = test_format archiver_mock.args.format = test_format
test_path = Path(tmp_path, 'test') test_path = Path(tmp_path, "test")
test_submission = reddit_instance.submission(id=test_submission_id) test_submission = reddit_instance.submission(id=test_submission_id)
archiver_mock.file_name_formatter.format_path.return_value = test_path archiver_mock.file_name_formatter.format_path.return_value = test_path
Archiver.write_entry(archiver_mock, test_submission) Archiver.write_entry(archiver_mock, test_submission)

View file

@ -8,13 +8,16 @@ import pytest
from bdfr.configuration import Configuration from bdfr.configuration import Configuration
@pytest.mark.parametrize('arg_dict', ( @pytest.mark.parametrize(
{'directory': 'test_dir'}, "arg_dict",
{ (
'directory': 'test_dir', {"directory": "test_dir"},
'no_dupes': True, {
}, "directory": "test_dir",
)) "no_dupes": True,
},
),
)
def test_process_click_context(arg_dict: dict): def test_process_click_context(arg_dict: dict):
test_config = Configuration() test_config = Configuration()
test_context = MagicMock() test_context = MagicMock()
@ -25,9 +28,9 @@ def test_process_click_context(arg_dict: dict):
def test_yaml_file_read(): def test_yaml_file_read():
file = './tests/yaml_test_configuration.yaml' file = "./tests/yaml_test_configuration.yaml"
test_config = Configuration() test_config = Configuration()
test_config.parse_yaml_options(file) test_config.parse_yaml_options(file)
assert test_config.subreddit == ['EarthPorn', 'TwoXChromosomes', 'Mindustry'] assert test_config.subreddit == ["EarthPorn", "TwoXChromosomes", "Mindustry"]
assert test_config.sort == 'new' assert test_config.sort == "new"
assert test_config.limit == 10 assert test_config.limit == 10

View file

@ -20,7 +20,7 @@ from bdfr.site_authenticator import SiteAuthenticator
@pytest.fixture() @pytest.fixture()
def args() -> Configuration: def args() -> Configuration:
args = Configuration() args = Configuration()
args.time_format = 'ISO' args.time_format = "ISO"
return args return args
@ -30,7 +30,8 @@ def downloader_mock(args: Configuration):
downloader_mock.args = args downloader_mock.args = args
downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name
downloader_mock.create_filtered_listing_generator = lambda x: RedditConnector.create_filtered_listing_generator( downloader_mock.create_filtered_listing_generator = lambda x: RedditConnector.create_filtered_listing_generator(
downloader_mock, x) downloader_mock, x
)
downloader_mock.split_args_input = RedditConnector.split_args_input downloader_mock.split_args_input = RedditConnector.split_args_input
downloader_mock.master_hash_list = {} downloader_mock.master_hash_list = {}
return downloader_mock return downloader_mock
@ -55,16 +56,22 @@ def assert_all_results_are_submissions_or_comments(result_limit: int, results: l
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
downloader_mock.args.directory = tmp_path / 'test' downloader_mock.args.directory = tmp_path / "test"
downloader_mock.config_directories.user_config_dir = tmp_path downloader_mock.config_directories.user_config_dir = tmp_path
RedditConnector.determine_directories(downloader_mock) RedditConnector.determine_directories(downloader_mock)
assert Path(tmp_path / 'test').exists() assert Path(tmp_path / "test").exists()
@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( @pytest.mark.parametrize(
([], []), ("skip_extensions", "skip_domains"),
(['.test'], ['test.com'],), (
)) ([], []),
(
[".test"],
["test.com"],
),
),
)
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
downloader_mock.args.skip = skip_extensions downloader_mock.args.skip = skip_extensions
downloader_mock.args.skip_domain = skip_domains downloader_mock.args.skip_domain = skip_domains
@ -75,14 +82,17 @@ def test_create_download_filter(skip_extensions: list[str], skip_domains: list[s
assert result.excluded_extensions == skip_extensions assert result.excluded_extensions == skip_extensions
@pytest.mark.parametrize(('test_time', 'expected'), ( @pytest.mark.parametrize(
('all', 'all'), ("test_time", "expected"),
('hour', 'hour'), (
('day', 'day'), ("all", "all"),
('week', 'week'), ("hour", "hour"),
('random', 'all'), ("day", "day"),
('', 'all'), ("week", "week"),
)) ("random", "all"),
("", "all"),
),
)
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.time = test_time downloader_mock.args.time = test_time
result = RedditConnector.create_time_filter(downloader_mock) result = RedditConnector.create_time_filter(downloader_mock)
@ -91,12 +101,15 @@ def test_create_time_filter(test_time: str, expected: str, downloader_mock: Magi
assert result.name.lower() == expected assert result.name.lower() == expected
@pytest.mark.parametrize(('test_sort', 'expected'), ( @pytest.mark.parametrize(
('', 'hot'), ("test_sort", "expected"),
('hot', 'hot'), (
('controversial', 'controversial'), ("", "hot"),
('new', 'new'), ("hot", "hot"),
)) ("controversial", "controversial"),
("new", "new"),
),
)
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.sort = test_sort downloader_mock.args.sort = test_sort
result = RedditConnector.create_sort_filter(downloader_mock) result = RedditConnector.create_sort_filter(downloader_mock)
@ -105,13 +118,16 @@ def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: Magi
assert result.name.lower() == expected assert result.name.lower() == expected
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( @pytest.mark.parametrize(
('{POSTID}', '{SUBREDDIT}'), ("test_file_scheme", "test_folder_scheme"),
('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), (
('{POSTID}', 'test'), ("{POSTID}", "{SUBREDDIT}"),
('{POSTID}', ''), ("{REDDITOR}_{TITLE}_{POSTID}", "{SUBREDDIT}"),
('{POSTID}', '{SUBREDDIT}/{REDDITOR}'), ("{POSTID}", "test"),
)) ("{POSTID}", ""),
("{POSTID}", "{SUBREDDIT}/{REDDITOR}"),
),
)
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme downloader_mock.args.folder_scheme = test_folder_scheme
@ -119,14 +135,17 @@ def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: s
assert isinstance(result, FileNameFormatter) assert isinstance(result, FileNameFormatter)
assert result.file_format_string == test_file_scheme assert result.file_format_string == test_file_scheme
assert result.directory_format_string == test_folder_scheme.split('/') assert result.directory_format_string == test_folder_scheme.split("/")
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( @pytest.mark.parametrize(
('', ''), ("test_file_scheme", "test_folder_scheme"),
('', '{SUBREDDIT}'), (
('test', '{SUBREDDIT}'), ("", ""),
)) ("", "{SUBREDDIT}"),
("test", "{SUBREDDIT}"),
),
)
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme downloader_mock.args.folder_scheme = test_folder_scheme
@ -141,15 +160,17 @@ def test_create_authenticator(downloader_mock: MagicMock):
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize('test_submission_ids', ( @pytest.mark.parametrize(
('lvpf4l',), "test_submission_ids",
('lvpf4l', 'lvqnsn'), (
('lvpf4l', 'lvqnsn', 'lvl9kd'), ("lvpf4l",),
)) ("lvpf4l", "lvqnsn"),
("lvpf4l", "lvqnsn", "lvl9kd"),
),
)
def test_get_submissions_from_link( def test_get_submissions_from_link(
test_submission_ids: list[str], test_submission_ids: list[str], reddit_instance: praw.Reddit, downloader_mock: MagicMock
reddit_instance: praw.Reddit, ):
downloader_mock: MagicMock):
downloader_mock.args.link = test_submission_ids downloader_mock.args.link = test_submission_ids
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
results = RedditConnector.get_submissions_from_link(downloader_mock) results = RedditConnector.get_submissions_from_link(downloader_mock)
@ -159,25 +180,28 @@ def test_get_submissions_from_link(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), ( @pytest.mark.parametrize(
(('Futurology',), 10, 'hot', 'all', 10), ("test_subreddits", "limit", "sort_type", "time_filter", "max_expected_len"),
(('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30), (
(('Futurology',), 20, 'hot', 'all', 20), (("Futurology",), 10, "hot", "all", 10),
(('Futurology', 'Python'), 10, 'hot', 'all', 20), (("Futurology", "Mindustry, Python"), 10, "hot", "all", 30),
(('Futurology',), 100, 'hot', 'all', 100), (("Futurology",), 20, "hot", "all", 20),
(('Futurology',), 0, 'hot', 'all', 0), (("Futurology", "Python"), 10, "hot", "all", 20),
(('Futurology',), 10, 'top', 'all', 10), (("Futurology",), 100, "hot", "all", 100),
(('Futurology',), 10, 'top', 'week', 10), (("Futurology",), 0, "hot", "all", 0),
(('Futurology',), 10, 'hot', 'week', 10), (("Futurology",), 10, "top", "all", 10),
)) (("Futurology",), 10, "top", "week", 10),
(("Futurology",), 10, "hot", "week", 10),
),
)
def test_get_subreddit_normal( def test_get_subreddit_normal(
test_subreddits: list[str], test_subreddits: list[str],
limit: int, limit: int,
sort_type: str, sort_type: str,
time_filter: str, time_filter: str,
max_expected_len: int, max_expected_len: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
): ):
downloader_mock.args.limit = limit downloader_mock.args.limit = limit
downloader_mock.args.sort = sort_type downloader_mock.args.sort = sort_type
@ -197,26 +221,29 @@ def test_get_subreddit_normal(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_time', 'test_delta'), ( @pytest.mark.parametrize(
('hour', timedelta(hours=1)), ("test_time", "test_delta"),
('day', timedelta(days=1)), (
('week', timedelta(days=7)), ("hour", timedelta(hours=1)),
('month', timedelta(days=31)), ("day", timedelta(days=1)),
('year', timedelta(days=365)), ("week", timedelta(days=7)),
)) ("month", timedelta(days=31)),
("year", timedelta(days=365)),
),
)
def test_get_subreddit_time_verification( def test_get_subreddit_time_verification(
test_time: str, test_time: str,
test_delta: timedelta, test_delta: timedelta,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
): ):
downloader_mock.args.limit = 10 downloader_mock.args.limit = 10
downloader_mock.args.sort = 'top' downloader_mock.args.sort = "top"
downloader_mock.args.time = test_time downloader_mock.args.time = test_time
downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock) downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock) downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock) downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock)
downloader_mock.args.subreddit = ['all'] downloader_mock.args.subreddit = ["all"]
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
results = RedditConnector.get_subreddits(downloader_mock) results = RedditConnector.get_subreddits(downloader_mock)
results = [sub for res1 in results for sub in res1] results = [sub for res1 in results for sub in res1]
@ -230,20 +257,23 @@ def test_get_subreddit_time_verification(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), ( @pytest.mark.parametrize(
(('Python',), 'scraper', 10, 'all', 10), ("test_subreddits", "search_term", "limit", "time_filter", "max_expected_len"),
(('Python',), '', 10, 'all', 0), (
(('Python',), 'djsdsgewef', 10, 'all', 0), (("Python",), "scraper", 10, "all", 10),
(('Python',), 'scraper', 10, 'year', 10), (("Python",), "", 10, "all", 0),
)) (("Python",), "djsdsgewef", 10, "all", 0),
(("Python",), "scraper", 10, "year", 10),
),
)
def test_get_subreddit_search( def test_get_subreddit_search(
test_subreddits: list[str], test_subreddits: list[str],
search_term: str, search_term: str,
time_filter: str, time_filter: str,
limit: int, limit: int,
max_expected_len: int, max_expected_len: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
): ):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit downloader_mock.args.limit = limit
@ -265,17 +295,20 @@ def test_get_subreddit_search(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), ( @pytest.mark.parametrize(
('helen_darten', ('cuteanimalpics',), 10), ("test_user", "test_multireddits", "limit"),
('korfor', ('chess',), 100), (
)) ("helen_darten", ("cuteanimalpics",), 10),
("korfor", ("chess",), 100),
),
)
# Good sources at https://www.reddit.com/r/multihub/ # Good sources at https://www.reddit.com/r/multihub/
def test_get_multireddits_public( def test_get_multireddits_public(
test_user: str, test_user: str,
test_multireddits: list[str], test_multireddits: list[str],
limit: int, limit: int,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
downloader_mock: MagicMock, downloader_mock: MagicMock,
): ):
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
@ -283,11 +316,10 @@ def test_get_multireddits_public(
downloader_mock.args.multireddit = test_multireddits downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = [test_user] downloader_mock.args.user = [test_user]
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.create_filtered_listing_generator.return_value = \ downloader_mock.create_filtered_listing_generator.return_value = RedditConnector.create_filtered_listing_generator(
RedditConnector.create_filtered_listing_generator( downloader_mock,
downloader_mock, reddit_instance.multireddit(redditor=test_user, name=test_multireddits[0]),
reddit_instance.multireddit(redditor=test_user, name=test_multireddits[0]), )
)
results = RedditConnector.get_multireddits(downloader_mock) results = RedditConnector.get_multireddits(downloader_mock)
results = [sub for res in results for sub in res] results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results]) assert all([isinstance(res, praw.models.Submission) for res in results])
@ -297,11 +329,14 @@ def test_get_multireddits_public(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'limit'), ( @pytest.mark.parametrize(
('danigirl3694', 10), ("test_user", "limit"),
('danigirl3694', 50), (
('CapitanHam', None), ("danigirl3694", 10),
)) ("danigirl3694", 50),
("CapitanHam", None),
),
)
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.limit = limit downloader_mock.args.limit = limit
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
@ -310,11 +345,10 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
downloader_mock.args.user = [test_user] downloader_mock.args.user = [test_user]
downloader_mock.authenticated = False downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.create_filtered_listing_generator.return_value = \ downloader_mock.create_filtered_listing_generator.return_value = RedditConnector.create_filtered_listing_generator(
RedditConnector.create_filtered_listing_generator( downloader_mock,
downloader_mock, reddit_instance.redditor(test_user).submissions,
reddit_instance.redditor(test_user).submissions, )
)
results = RedditConnector.get_user_data(downloader_mock) results = RedditConnector.get_user_data(downloader_mock)
results = assert_all_results_are_submissions(limit, results) results = assert_all_results_are_submissions(limit, results)
assert all([res.author.name == test_user for res in results]) assert all([res.author.name == test_user for res in results])
@ -324,21 +358,24 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.authenticated @pytest.mark.authenticated
@pytest.mark.parametrize('test_flag', ( @pytest.mark.parametrize(
'upvoted', "test_flag",
'saved', (
)) "upvoted",
"saved",
),
)
def test_get_user_authenticated_lists( def test_get_user_authenticated_lists(
test_flag: str, test_flag: str,
downloader_mock: MagicMock, downloader_mock: MagicMock,
authenticated_reddit_instance: praw.Reddit, authenticated_reddit_instance: praw.Reddit,
): ):
downloader_mock.args.__dict__[test_flag] = True downloader_mock.args.__dict__[test_flag] = True
downloader_mock.reddit_instance = authenticated_reddit_instance downloader_mock.reddit_instance = authenticated_reddit_instance
downloader_mock.args.limit = 10 downloader_mock.args.limit = 10
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')] downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, "me")]
results = RedditConnector.get_user_data(downloader_mock) results = RedditConnector.get_user_data(downloader_mock)
assert_all_results_are_submissions_or_comments(10, results) assert_all_results_are_submissions_or_comments(10, results)
@ -359,54 +396,63 @@ def test_get_subscribed_subreddits(downloader_mock: MagicMock, authenticated_red
assert results assert results
@pytest.mark.parametrize(('test_name', 'expected'), ( @pytest.mark.parametrize(
('Mindustry', 'Mindustry'), ("test_name", "expected"),
('Futurology', 'Futurology'), (
('r/Mindustry', 'Mindustry'), ("Mindustry", "Mindustry"),
('TrollXChromosomes', 'TrollXChromosomes'), ("Futurology", "Futurology"),
('r/TrollXChromosomes', 'TrollXChromosomes'), ("r/Mindustry", "Mindustry"),
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'), ("TrollXChromosomes", "TrollXChromosomes"),
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'), ("r/TrollXChromosomes", "TrollXChromosomes"),
('https://www.reddit.com/r/Futurology/', 'Futurology'), ("https://www.reddit.com/r/TrollXChromosomes/", "TrollXChromosomes"),
('https://www.reddit.com/r/Futurology', 'Futurology'), ("https://www.reddit.com/r/TrollXChromosomes", "TrollXChromosomes"),
)) ("https://www.reddit.com/r/Futurology/", "Futurology"),
("https://www.reddit.com/r/Futurology", "Futurology"),
),
)
def test_sanitise_subreddit_name(test_name: str, expected: str): def test_sanitise_subreddit_name(test_name: str, expected: str):
result = RedditConnector.sanitise_subreddit_name(test_name) result = RedditConnector.sanitise_subreddit_name(test_name)
assert result == expected assert result == expected
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), ( @pytest.mark.parametrize(
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}), ("test_subreddit_entries", "expected"),
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}), (
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}), (["test1", "test2", "test3"], {"test1", "test2", "test3"}),
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}), (["test1,test2", "test3"], {"test1", "test2", "test3"}),
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}), (["test1, test2", "test3"], {"test1", "test2", "test3"}),
([''], {''}), (["test1; test2", "test3"], {"test1", "test2", "test3"}),
(['test'], {'test'}), (["test1, test2", "test1,test2,test3", "test4"], {"test1", "test2", "test3", "test4"}),
)) ([""], {""}),
(["test"], {"test"}),
),
)
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]): def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
results = RedditConnector.split_args_input(test_subreddit_entries) results = RedditConnector.split_args_input(test_subreddit_entries)
assert results == expected assert results == expected
def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
test_file = tmp_path / 'test.txt' test_file = tmp_path / "test.txt"
test_file.write_text('aaaaaa\nbbbbbb') test_file.write_text("aaaaaa\nbbbbbb")
results = RedditConnector.read_id_files([str(test_file)]) results = RedditConnector.read_id_files([str(test_file)])
assert results == {'aaaaaa', 'bbbbbb'} assert results == {"aaaaaa", "bbbbbb"}
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', ( @pytest.mark.parametrize(
'nasa', "test_redditor_name",
'crowdstrike', (
'HannibalGoddamnit', "nasa",
)) "crowdstrike",
"HannibalGoddamnit",
),
)
def test_check_user_existence_good( def test_check_user_existence_good(
test_redditor_name: str, test_redditor_name: str,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
downloader_mock: MagicMock, downloader_mock: MagicMock,
): ):
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
RedditConnector.check_user_existence(downloader_mock, test_redditor_name) RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@ -414,42 +460,46 @@ def test_check_user_existence_good(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', ( @pytest.mark.parametrize(
'lhnhfkuhwreolo', "test_redditor_name",
'adlkfmnhglojh', (
)) "lhnhfkuhwreolo",
"adlkfmnhglojh",
),
)
def test_check_user_existence_nonexistent( def test_check_user_existence_nonexistent(
test_redditor_name: str, test_redditor_name: str,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
downloader_mock: MagicMock, downloader_mock: MagicMock,
): ):
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
with pytest.raises(BulkDownloaderException, match='Could not find'): with pytest.raises(BulkDownloaderException, match="Could not find"):
RedditConnector.check_user_existence(downloader_mock, test_redditor_name) RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize('test_redditor_name', ( @pytest.mark.parametrize("test_redditor_name", ("Bree-Boo",))
'Bree-Boo',
))
def test_check_user_existence_banned( def test_check_user_existence_banned(
test_redditor_name: str, test_redditor_name: str,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
downloader_mock: MagicMock, downloader_mock: MagicMock,
): ):
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
with pytest.raises(BulkDownloaderException, match='is banned'): with pytest.raises(BulkDownloaderException, match="is banned"):
RedditConnector.check_user_existence(downloader_mock, test_redditor_name) RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), ( @pytest.mark.parametrize(
('donaldtrump', 'cannot be found'), ("test_subreddit_name", "expected_message"),
('submitters', 'private and cannot be scraped'), (
('lhnhfkuhwreolo', 'does not exist') ("donaldtrump", "cannot be found"),
)) ("submitters", "private and cannot be scraped"),
("lhnhfkuhwreolo", "does not exist"),
),
)
def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit): def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit):
test_subreddit = reddit_instance.subreddit(test_subreddit_name) test_subreddit = reddit_instance.subreddit(test_subreddit_name)
with pytest.raises(BulkDownloaderException, match=expected_message): with pytest.raises(BulkDownloaderException, match=expected_message):
@ -458,12 +508,15 @@ def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message:
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize('test_subreddit_name', ( @pytest.mark.parametrize(
'Python', "test_subreddit_name",
'Mindustry', (
'TrollXChromosomes', "Python",
'all', "Mindustry",
)) "TrollXChromosomes",
"all",
),
)
def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit): def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit):
test_subreddit = reddit_instance.subreddit(test_subreddit_name) test_subreddit = reddit_instance.subreddit(test_subreddit_name)
RedditConnector.check_subreddit_status(test_subreddit) RedditConnector.check_subreddit_status(test_subreddit)

View file

@ -11,55 +11,67 @@ from bdfr.resource import Resource
@pytest.fixture() @pytest.fixture()
def download_filter() -> DownloadFilter: def download_filter() -> DownloadFilter:
return DownloadFilter(['mp4', 'mp3'], ['test.com', 'reddit.com', 'img.example.com']) return DownloadFilter(["mp4", "mp3"], ["test.com", "reddit.com", "img.example.com"])
@pytest.mark.parametrize(('test_extension', 'expected'), ( @pytest.mark.parametrize(
('.mp4', False), ("test_extension", "expected"),
('.avi', True), (
('.random.mp3', False), (".mp4", False),
('mp4', False), (".avi", True),
)) (".random.mp3", False),
("mp4", False),
),
)
def test_filter_extension(test_extension: str, expected: bool, download_filter: DownloadFilter): def test_filter_extension(test_extension: str, expected: bool, download_filter: DownloadFilter):
result = download_filter._check_extension(test_extension) result = download_filter._check_extension(test_extension)
assert result == expected assert result == expected
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(
('test.mp4', True), ("test_url", "expected"),
('http://reddit.com/test.mp4', False), (
('http://reddit.com/test.gif', False), ("test.mp4", True),
('https://www.example.com/test.mp4', True), ("http://reddit.com/test.mp4", False),
('https://www.example.com/test.png', True), ("http://reddit.com/test.gif", False),
('https://i.example.com/test.png', True), ("https://www.example.com/test.mp4", True),
('https://img.example.com/test.png', False), ("https://www.example.com/test.png", True),
('https://i.test.com/test.png', False), ("https://i.example.com/test.png", True),
)) ("https://img.example.com/test.png", False),
("https://i.test.com/test.png", False),
),
)
def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadFilter): def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadFilter):
result = download_filter._check_domain(test_url) result = download_filter._check_domain(test_url)
assert result == expected assert result == expected
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(
('test.mp4', False), ("test_url", "expected"),
('test.gif', True), (
('https://www.example.com/test.mp4', False), ("test.mp4", False),
('https://www.example.com/test.png', True), ("test.gif", True),
('http://reddit.com/test.mp4', False), ("https://www.example.com/test.mp4", False),
('http://reddit.com/test.gif', False), ("https://www.example.com/test.png", True),
)) ("http://reddit.com/test.mp4", False),
("http://reddit.com/test.gif", False),
),
)
def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter): def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter):
test_resource = Resource(MagicMock(), test_url, lambda: None) test_resource = Resource(MagicMock(), test_url, lambda: None)
result = download_filter.check_resource(test_resource) result = download_filter.check_resource(test_resource)
assert result == expected assert result == expected
@pytest.mark.parametrize('test_url', ( @pytest.mark.parametrize(
'test.mp3', "test_url",
'test.mp4', (
'http://reddit.com/test.mp4', "test.mp3",
't', "test.mp4",
)) "http://reddit.com/test.mp4",
"t",
),
)
def test_filter_empty_filter(test_url: str): def test_filter_empty_filter(test_url: str):
download_filter = DownloadFilter() download_filter = DownloadFilter()
test_resource = Resource(MagicMock(), test_url, lambda: None) test_resource = Resource(MagicMock(), test_url, lambda: None)

View file

@ -18,7 +18,7 @@ from bdfr.downloader import RedditDownloader
@pytest.fixture() @pytest.fixture()
def args() -> Configuration: def args() -> Configuration:
args = Configuration() args = Configuration()
args.time_format = 'ISO' args.time_format = "ISO"
return args return args
@ -32,29 +32,32 @@ def downloader_mock(args: Configuration):
return downloader_mock return downloader_mock
@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), ( @pytest.mark.parametrize(
(('aaaaaa',), (), 1), ("test_ids", "test_excluded", "expected_len"),
(('aaaaaa',), ('aaaaaa',), 0), (
((), ('aaaaaa',), 0), (("aaaaaa",), (), 1),
(('aaaaaa', 'bbbbbb'), ('aaaaaa',), 1), (("aaaaaa",), ("aaaaaa",), 0),
(('aaaaaa', 'bbbbbb', 'cccccc'), ('aaaaaa',), 2), ((), ("aaaaaa",), 0),
)) (("aaaaaa", "bbbbbb"), ("aaaaaa",), 1),
@patch('bdfr.site_downloaders.download_factory.DownloadFactory.pull_lever') (("aaaaaa", "bbbbbb", "cccccc"), ("aaaaaa",), 2),
),
)
@patch("bdfr.site_downloaders.download_factory.DownloadFactory.pull_lever")
def test_excluded_ids( def test_excluded_ids(
mock_function: MagicMock, mock_function: MagicMock,
test_ids: tuple[str], test_ids: tuple[str],
test_excluded: tuple[str], test_excluded: tuple[str],
expected_len: int, expected_len: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
): ):
downloader_mock.excluded_submission_ids = test_excluded downloader_mock.excluded_submission_ids = test_excluded
mock_function.return_value = MagicMock() mock_function.return_value = MagicMock()
mock_function.return_value.__name__ = 'test' mock_function.return_value.__name__ = "test"
test_submissions = [] test_submissions = []
for test_id in test_ids: for test_id in test_ids:
m = MagicMock() m = MagicMock()
m.id = test_id m.id = test_id
m.subreddit.display_name.return_value = 'https://www.example.com/' m.subreddit.display_name.return_value = "https://www.example.com/"
m.__class__ = praw.models.Submission m.__class__ = praw.models.Submission
test_submissions.append(m) test_submissions.append(m)
downloader_mock.reddit_lists = [test_submissions] downloader_mock.reddit_lists = [test_submissions]
@ -65,32 +68,27 @@ def test_excluded_ids(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', ( @pytest.mark.parametrize("test_submission_id", ("m1hqw6",))
'm1hqw6',
))
def test_mark_hard_link( def test_mark_hard_link(
test_submission_id: str, test_submission_id: str, downloader_mock: MagicMock, tmp_path: Path, reddit_instance: praw.Reddit
downloader_mock: MagicMock,
tmp_path: Path,
reddit_instance: praw.Reddit
): ):
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.args.make_hard_links = True downloader_mock.args.make_hard_links = True
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.args.file_scheme = '{POSTID}' downloader_mock.args.file_scheme = "{POSTID}"
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
submission = downloader_mock.reddit_instance.submission(id=test_submission_id) submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
original = Path(tmp_path, f'{test_submission_id}.png') original = Path(tmp_path, f"{test_submission_id}.png")
RedditDownloader._download_submission(downloader_mock, submission) RedditDownloader._download_submission(downloader_mock, submission)
assert original.exists() assert original.exists()
downloader_mock.args.file_scheme = 'test2_{POSTID}' downloader_mock.args.file_scheme = "test2_{POSTID}"
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
RedditDownloader._download_submission(downloader_mock, submission) RedditDownloader._download_submission(downloader_mock, submission)
test_file_1_stats = original.stat() test_file_1_stats = original.stat()
test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino test_file_2_inode = Path(tmp_path, f"test2_{test_submission_id}.png").stat().st_ino
assert test_file_1_stats.st_nlink == 2 assert test_file_1_stats.st_nlink == 2
assert test_file_1_stats.st_ino == test_file_2_inode assert test_file_1_stats.st_ino == test_file_2_inode
@ -98,20 +96,18 @@ def test_mark_hard_link(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'test_creation_date'), ( @pytest.mark.parametrize(("test_submission_id", "test_creation_date"), (("ndzz50", 1621204841.0),))
('ndzz50', 1621204841.0),
))
def test_file_creation_date( def test_file_creation_date(
test_submission_id: str, test_submission_id: str,
test_creation_date: float, test_creation_date: float,
downloader_mock: MagicMock, downloader_mock: MagicMock,
tmp_path: Path, tmp_path: Path,
reddit_instance: praw.Reddit reddit_instance: praw.Reddit,
): ):
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.args.file_scheme = '{POSTID}' downloader_mock.args.file_scheme = "{POSTID}"
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
submission = downloader_mock.reddit_instance.submission(id=test_submission_id) submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
@ -123,27 +119,25 @@ def test_file_creation_date(
def test_search_existing_files(): def test_search_existing_files():
results = RedditDownloader.scan_existing_files(Path('.')) results = RedditDownloader.scan_existing_files(Path("."))
assert len(results.keys()) != 0 assert len(results.keys()) != 0
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'test_hash'), ( @pytest.mark.parametrize(("test_submission_id", "test_hash"), (("m1hqw6", "a912af8905ae468e0121e9940f797ad7"),))
('m1hqw6', 'a912af8905ae468e0121e9940f797ad7'),
))
def test_download_submission_hash_exists( def test_download_submission_hash_exists(
test_submission_id: str, test_submission_id: str,
test_hash: str, test_hash: str,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
tmp_path: Path, tmp_path: Path,
capsys: pytest.CaptureFixture capsys: pytest.CaptureFixture,
): ):
setup_logging(3) setup_logging(3)
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.args.no_dupes = True downloader_mock.args.no_dupes = True
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
@ -153,47 +147,44 @@ def test_download_submission_hash_exists(
folder_contents = list(tmp_path.iterdir()) folder_contents = list(tmp_path.iterdir())
output = capsys.readouterr() output = capsys.readouterr()
assert not folder_contents assert not folder_contents
assert re.search(r'Resource hash .*? downloaded elsewhere', output.out) assert re.search(r"Resource hash .*? downloaded elsewhere", output.out)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
def test_download_submission_file_exists( def test_download_submission_file_exists(
downloader_mock: MagicMock, downloader_mock: MagicMock, reddit_instance: praw.Reddit, tmp_path: Path, capsys: pytest.CaptureFixture
reddit_instance: praw.Reddit,
tmp_path: Path,
capsys: pytest.CaptureFixture
): ):
setup_logging(3) setup_logging(3)
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id='m1hqw6') submission = downloader_mock.reddit_instance.submission(id="m1hqw6")
Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch() Path(tmp_path, "Arneeman_Metagaming isn't always a bad thing_m1hqw6.png").touch()
RedditDownloader._download_submission(downloader_mock, submission) RedditDownloader._download_submission(downloader_mock, submission)
folder_contents = list(tmp_path.iterdir()) folder_contents = list(tmp_path.iterdir())
output = capsys.readouterr() output = capsys.readouterr()
assert len(folder_contents) == 1 assert len(folder_contents) == 1
assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png'\ assert (
' from submission m1hqw6 already exists' in output.out "Arneeman_Metagaming isn't always a bad thing_m1hqw6.png" " from submission m1hqw6 already exists" in output.out
)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), ( @pytest.mark.parametrize(("test_submission_id", "expected_files_len"), (("ljyy27", 4),))
('ljyy27', 4),
))
def test_download_submission( def test_download_submission(
test_submission_id: str, test_submission_id: str,
expected_files_len: int, expected_files_len: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
tmp_path: Path): tmp_path: Path,
):
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id=test_submission_id) submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
@ -204,103 +195,95 @@ def test_download_submission(
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'min_score'), ( @pytest.mark.parametrize(("test_submission_id", "min_score"), (("ljyy27", 1),))
('ljyy27', 1),
))
def test_download_submission_min_score_above( def test_download_submission_min_score_above(
test_submission_id: str, test_submission_id: str,
min_score: int, min_score: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
tmp_path: Path, tmp_path: Path,
capsys: pytest.CaptureFixture, capsys: pytest.CaptureFixture,
): ):
setup_logging(3) setup_logging(3)
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.args.min_score = min_score downloader_mock.args.min_score = min_score
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id=test_submission_id) submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
RedditDownloader._download_submission(downloader_mock, submission) RedditDownloader._download_submission(downloader_mock, submission)
output = capsys.readouterr() output = capsys.readouterr()
assert 'filtered due to score' not in output.out assert "filtered due to score" not in output.out
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'min_score'), ( @pytest.mark.parametrize(("test_submission_id", "min_score"), (("ljyy27", 25),))
('ljyy27', 25),
))
def test_download_submission_min_score_below( def test_download_submission_min_score_below(
test_submission_id: str, test_submission_id: str,
min_score: int, min_score: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
tmp_path: Path, tmp_path: Path,
capsys: pytest.CaptureFixture, capsys: pytest.CaptureFixture,
): ):
setup_logging(3) setup_logging(3)
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.args.min_score = min_score downloader_mock.args.min_score = min_score
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id=test_submission_id) submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
RedditDownloader._download_submission(downloader_mock, submission) RedditDownloader._download_submission(downloader_mock, submission)
output = capsys.readouterr() output = capsys.readouterr()
assert 'filtered due to score' in output.out assert "filtered due to score" in output.out
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'max_score'), ( @pytest.mark.parametrize(("test_submission_id", "max_score"), (("ljyy27", 25),))
('ljyy27', 25),
))
def test_download_submission_max_score_below( def test_download_submission_max_score_below(
test_submission_id: str, test_submission_id: str,
max_score: int, max_score: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
tmp_path: Path, tmp_path: Path,
capsys: pytest.CaptureFixture, capsys: pytest.CaptureFixture,
): ):
setup_logging(3) setup_logging(3)
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.args.max_score = max_score downloader_mock.args.max_score = max_score
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id=test_submission_id) submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
RedditDownloader._download_submission(downloader_mock, submission) RedditDownloader._download_submission(downloader_mock, submission)
output = capsys.readouterr() output = capsys.readouterr()
assert 'filtered due to score' not in output.out assert "filtered due to score" not in output.out
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'max_score'), ( @pytest.mark.parametrize(("test_submission_id", "max_score"), (("ljyy27", 1),))
('ljyy27', 1),
))
def test_download_submission_max_score_above( def test_download_submission_max_score_above(
test_submission_id: str, test_submission_id: str,
max_score: int, max_score: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
tmp_path: Path, tmp_path: Path,
capsys: pytest.CaptureFixture, capsys: pytest.CaptureFixture,
): ):
setup_logging(3) setup_logging(3)
downloader_mock.reddit_instance = reddit_instance downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = '' downloader_mock.args.folder_scheme = ""
downloader_mock.args.max_score = max_score downloader_mock.args.max_score = max_score
downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id=test_submission_id) submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
RedditDownloader._download_submission(downloader_mock, submission) RedditDownloader._download_submission(downloader_mock, submission)
output = capsys.readouterr() output = capsys.readouterr()
assert 'filtered due to score' in output.out assert "filtered due to score" in output.out

View file

@ -22,26 +22,26 @@ from bdfr.site_downloaders.self_post import SelfPost
@pytest.fixture() @pytest.fixture()
def submission() -> MagicMock: def submission() -> MagicMock:
test = MagicMock() test = MagicMock()
test.title = 'name' test.title = "name"
test.subreddit.display_name = 'randomreddit' test.subreddit.display_name = "randomreddit"
test.author.name = 'person' test.author.name = "person"
test.id = '12345' test.id = "12345"
test.score = 1000 test.score = 1000
test.link_flair_text = 'test_flair' test.link_flair_text = "test_flair"
test.created_utc = datetime(2021, 4, 21, 9, 30, 0).timestamp() test.created_utc = datetime(2021, 4, 21, 9, 30, 0).timestamp()
test.__class__ = praw.models.Submission test.__class__ = praw.models.Submission
return test return test
def do_test_string_equality(result: Union[Path, str], expected: str) -> bool: def do_test_string_equality(result: Union[Path, str], expected: str) -> bool:
if platform.system() == 'Windows': if platform.system() == "Windows":
expected = FileNameFormatter._format_for_windows(expected) expected = FileNameFormatter._format_for_windows(expected)
return str(result).endswith(expected) return str(result).endswith(expected)
def do_test_path_equality(result: Path, expected: str) -> bool: def do_test_path_equality(result: Path, expected: str) -> bool:
if platform.system() == 'Windows': if platform.system() == "Windows":
expected = expected.split('/') expected = expected.split("/")
expected = [FileNameFormatter._format_for_windows(part) for part in expected] expected = [FileNameFormatter._format_for_windows(part) for part in expected]
expected = Path(*expected) expected = Path(*expected)
else: else:
@ -49,35 +49,41 @@ def do_test_path_equality(result: Path, expected: str) -> bool:
return str(result).endswith(str(expected)) return str(result).endswith(str(expected))
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def reddit_submission(reddit_instance: praw.Reddit) -> praw.models.Submission: def reddit_submission(reddit_instance: praw.Reddit) -> praw.models.Submission:
return reddit_instance.submission(id='w22m5l') return reddit_instance.submission(id="w22m5l")
@pytest.mark.parametrize(('test_format_string', 'expected'), ( @pytest.mark.parametrize(
('{SUBREDDIT}', 'randomreddit'), ("test_format_string", "expected"),
('{REDDITOR}', 'person'), (
('{POSTID}', '12345'), ("{SUBREDDIT}", "randomreddit"),
('{UPVOTES}', '1000'), ("{REDDITOR}", "person"),
('{FLAIR}', 'test_flair'), ("{POSTID}", "12345"),
('{DATE}', '2021-04-21T09:30:00'), ("{UPVOTES}", "1000"),
('{REDDITOR}_{TITLE}_{POSTID}', 'person_name_12345'), ("{FLAIR}", "test_flair"),
)) ("{DATE}", "2021-04-21T09:30:00"),
("{REDDITOR}_{TITLE}_{POSTID}", "person_name_12345"),
),
)
def test_format_name_mock(test_format_string: str, expected: str, submission: MagicMock): def test_format_name_mock(test_format_string: str, expected: str, submission: MagicMock):
test_formatter = FileNameFormatter(test_format_string, '', 'ISO') test_formatter = FileNameFormatter(test_format_string, "", "ISO")
result = test_formatter._format_name(submission, test_format_string) result = test_formatter._format_name(submission, test_format_string)
assert do_test_string_equality(result, expected) assert do_test_string_equality(result, expected)
@pytest.mark.parametrize(('test_string', 'expected'), ( @pytest.mark.parametrize(
('', False), ("test_string", "expected"),
('test', False), (
('{POSTID}', True), ("", False),
('POSTID', False), ("test", False),
('{POSTID}_test', True), ("{POSTID}", True),
('test_{TITLE}', True), ("POSTID", False),
('TITLE_POSTID', False), ("{POSTID}_test", True),
)) ("test_{TITLE}", True),
("TITLE_POSTID", False),
),
)
def test_check_format_string_validity(test_string: str, expected: bool): def test_check_format_string_validity(test_string: str, expected: bool):
result = FileNameFormatter.validate_string(test_string) result = FileNameFormatter.validate_string(test_string)
assert result == expected assert result == expected
@ -85,84 +91,98 @@ def test_check_format_string_validity(test_string: str, expected: bool):
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_format_string', 'expected'), ( @pytest.mark.parametrize(
('{SUBREDDIT}', 'formula1'), ("test_format_string", "expected"),
('{REDDITOR}', 'Kirsty-Blue'), (
('{POSTID}', 'w22m5l'), ("{SUBREDDIT}", "formula1"),
('{FLAIR}', 'Social Media rall'), ("{REDDITOR}", "Kirsty-Blue"),
('{SUBREDDIT}_{TITLE}', 'formula1_George Russel acknowledges the Twitter trend about him'), ("{POSTID}", "w22m5l"),
('{REDDITOR}_{TITLE}_{POSTID}', 'Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l') ("{FLAIR}", "Social Media rall"),
)) ("{SUBREDDIT}_{TITLE}", "formula1_George Russel acknowledges the Twitter trend about him"),
("{REDDITOR}_{TITLE}_{POSTID}", "Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l"),
),
)
def test_format_name_real(test_format_string: str, expected: str, reddit_submission: praw.models.Submission): def test_format_name_real(test_format_string: str, expected: str, reddit_submission: praw.models.Submission):
test_formatter = FileNameFormatter(test_format_string, '', '') test_formatter = FileNameFormatter(test_format_string, "", "")
result = test_formatter._format_name(reddit_submission, test_format_string) result = test_formatter._format_name(reddit_submission, test_format_string)
assert do_test_string_equality(result, expected) assert do_test_string_equality(result, expected)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('format_string_directory', 'format_string_file', 'expected'), ( @pytest.mark.parametrize(
("format_string_directory", "format_string_file", "expected"),
( (
'{SUBREDDIT}', (
'{POSTID}', "{SUBREDDIT}",
'test/formula1/w22m5l.png', "{POSTID}",
"test/formula1/w22m5l.png",
),
(
"{SUBREDDIT}",
"{TITLE}_{POSTID}",
"test/formula1/George Russel acknowledges the Twitter trend about him_w22m5l.png",
),
(
"{SUBREDDIT}",
"{REDDITOR}_{TITLE}_{POSTID}",
"test/formula1/Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l.png",
),
), ),
( )
'{SUBREDDIT}',
'{TITLE}_{POSTID}',
'test/formula1/George Russel acknowledges the Twitter trend about him_w22m5l.png',
),
(
'{SUBREDDIT}',
'{REDDITOR}_{TITLE}_{POSTID}',
'test/formula1/Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l.png',
),
))
def test_format_full( def test_format_full(
format_string_directory: str, format_string_directory: str, format_string_file: str, expected: str, reddit_submission: praw.models.Submission
format_string_file: str, ):
expected: str, test_resource = Resource(reddit_submission, "i.reddit.com/blabla.png", lambda: None)
reddit_submission: praw.models.Submission): test_formatter = FileNameFormatter(format_string_file, format_string_directory, "ISO")
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None) result = test_formatter.format_path(test_resource, Path("test"))
test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO')
result = test_formatter.format_path(test_resource, Path('test'))
assert do_test_path_equality(result, expected) assert do_test_path_equality(result, expected)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('format_string_directory', 'format_string_file'), ( @pytest.mark.parametrize(
('{SUBREDDIT}', '{POSTID}'), ("format_string_directory", "format_string_file"),
('{SUBREDDIT}', '{UPVOTES}'), (
('{SUBREDDIT}', '{UPVOTES}{POSTID}'), ("{SUBREDDIT}", "{POSTID}"),
)) ("{SUBREDDIT}", "{UPVOTES}"),
("{SUBREDDIT}", "{UPVOTES}{POSTID}"),
),
)
def test_format_full_conform( def test_format_full_conform(
format_string_directory: str, format_string_directory: str, format_string_file: str, reddit_submission: praw.models.Submission
format_string_file: str, ):
reddit_submission: praw.models.Submission): test_resource = Resource(reddit_submission, "i.reddit.com/blabla.png", lambda: None)
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None) test_formatter = FileNameFormatter(format_string_file, format_string_directory, "ISO")
test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO') test_formatter.format_path(test_resource, Path("test"))
test_formatter.format_path(test_resource, Path('test'))
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('format_string_directory', 'format_string_file', 'index', 'expected'), ( @pytest.mark.parametrize(
('{SUBREDDIT}', '{POSTID}', None, 'test/formula1/w22m5l.png'), ("format_string_directory", "format_string_file", "index", "expected"),
('{SUBREDDIT}', '{POSTID}', 1, 'test/formula1/w22m5l_1.png'), (
('{SUBREDDIT}', '{POSTID}', 2, 'test/formula1/w22m5l_2.png'), ("{SUBREDDIT}", "{POSTID}", None, "test/formula1/w22m5l.png"),
('{SUBREDDIT}', '{TITLE}_{POSTID}', 2, 'test/formula1/George Russel acknowledges the Twitter trend about him_w22m5l_2.png'), ("{SUBREDDIT}", "{POSTID}", 1, "test/formula1/w22m5l_1.png"),
)) ("{SUBREDDIT}", "{POSTID}", 2, "test/formula1/w22m5l_2.png"),
(
"{SUBREDDIT}",
"{TITLE}_{POSTID}",
2,
"test/formula1/George Russel acknowledges the Twitter trend about him_w22m5l_2.png",
),
),
)
def test_format_full_with_index_suffix( def test_format_full_with_index_suffix(
format_string_directory: str, format_string_directory: str,
format_string_file: str, format_string_file: str,
index: Optional[int], index: Optional[int],
expected: str, expected: str,
reddit_submission: praw.models.Submission, reddit_submission: praw.models.Submission,
): ):
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None) test_resource = Resource(reddit_submission, "i.reddit.com/blabla.png", lambda: None)
test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO') test_formatter = FileNameFormatter(format_string_file, format_string_directory, "ISO")
result = test_formatter.format_path(test_resource, Path('test'), index) result = test_formatter.format_path(test_resource, Path("test"), index)
assert do_test_path_equality(result, expected) assert do_test_path_equality(result, expected)
@ -170,99 +190,114 @@ def test_format_multiple_resources():
mocks = [] mocks = []
for i in range(1, 5): for i in range(1, 5):
new_mock = MagicMock() new_mock = MagicMock()
new_mock.url = 'https://example.com/test.png' new_mock.url = "https://example.com/test.png"
new_mock.extension = '.png' new_mock.extension = ".png"
new_mock.source_submission.title = 'test' new_mock.source_submission.title = "test"
new_mock.source_submission.__class__ = praw.models.Submission new_mock.source_submission.__class__ = praw.models.Submission
mocks.append(new_mock) mocks.append(new_mock)
test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') test_formatter = FileNameFormatter("{TITLE}", "", "ISO")
results = test_formatter.format_resource_paths(mocks, Path('.')) results = test_formatter.format_resource_paths(mocks, Path("."))
results = set([str(res[0].name) for res in results]) results = set([str(res[0].name) for res in results])
expected = {'test_1.png', 'test_2.png', 'test_3.png', 'test_4.png'} expected = {"test_1.png", "test_2.png", "test_3.png", "test_4.png"}
assert results == expected assert results == expected
@pytest.mark.parametrize(('test_filename', 'test_ending'), ( @pytest.mark.parametrize(
('A' * 300, '.png'), ("test_filename", "test_ending"),
('A' * 300, '_1.png'), (
('a' * 300, '_1000.jpeg'), ("A" * 300, ".png"),
('😍💕✨' * 100, '_1.png'), ("A" * 300, "_1.png"),
)) ("a" * 300, "_1000.jpeg"),
("😍💕✨" * 100, "_1.png"),
),
)
def test_limit_filename_length(test_filename: str, test_ending: str): def test_limit_filename_length(test_filename: str, test_ending: str):
result = FileNameFormatter.limit_file_name_length(test_filename, test_ending, Path('.')) result = FileNameFormatter.limit_file_name_length(test_filename, test_ending, Path("."))
assert len(result.name) <= 255 assert len(result.name) <= 255
assert len(result.name.encode('utf-8')) <= 255 assert len(result.name.encode("utf-8")) <= 255
assert len(str(result)) <= FileNameFormatter.find_max_path_length() assert len(str(result)) <= FileNameFormatter.find_max_path_length()
assert isinstance(result, Path) assert isinstance(result, Path)
@pytest.mark.parametrize(('test_filename', 'test_ending', 'expected_end'), ( @pytest.mark.parametrize(
('test_aaaaaa', '_1.png', 'test_aaaaaa_1.png'), ("test_filename", "test_ending", "expected_end"),
('test_aataaa', '_1.png', 'test_aataaa_1.png'), (
('test_abcdef', '_1.png', 'test_abcdef_1.png'), ("test_aaaaaa", "_1.png", "test_aaaaaa_1.png"),
('test_aaaaaa', '.png', 'test_aaaaaa.png'), ("test_aataaa", "_1.png", "test_aataaa_1.png"),
('test', '_1.png', 'test_1.png'), ("test_abcdef", "_1.png", "test_abcdef_1.png"),
('test_m1hqw6', '_1.png', 'test_m1hqw6_1.png'), ("test_aaaaaa", ".png", "test_aaaaaa.png"),
('A' * 300 + '_bbbccc', '.png', '_bbbccc.png'), ("test", "_1.png", "test_1.png"),
('A' * 300 + '_bbbccc', '_1000.jpeg', '_bbbccc_1000.jpeg'), ("test_m1hqw6", "_1.png", "test_m1hqw6_1.png"),
('😍💕✨' * 100 + '_aaa1aa', '_1.png', '_aaa1aa_1.png'), ("A" * 300 + "_bbbccc", ".png", "_bbbccc.png"),
)) ("A" * 300 + "_bbbccc", "_1000.jpeg", "_bbbccc_1000.jpeg"),
("😍💕✨" * 100 + "_aaa1aa", "_1.png", "_aaa1aa_1.png"),
),
)
def test_preserve_id_append_when_shortening(test_filename: str, test_ending: str, expected_end: str): def test_preserve_id_append_when_shortening(test_filename: str, test_ending: str, expected_end: str):
result = FileNameFormatter.limit_file_name_length(test_filename, test_ending, Path('.')) result = FileNameFormatter.limit_file_name_length(test_filename, test_ending, Path("."))
assert len(result.name) <= 255 assert len(result.name) <= 255
assert len(result.name.encode('utf-8')) <= 255 assert len(result.name.encode("utf-8")) <= 255
assert result.name.endswith(expected_end) assert result.name.endswith(expected_end)
assert len(str(result)) <= FileNameFormatter.find_max_path_length() assert len(str(result)) <= FileNameFormatter.find_max_path_length()
@pytest.mark.skipif(sys.platform == 'win32', reason='Test broken on windows github') @pytest.mark.skipif(sys.platform == "win32", reason="Test broken on windows github")
def test_shorten_filename_real(submission: MagicMock, tmp_path: Path): def test_shorten_filename_real(submission: MagicMock, tmp_path: Path):
submission.title = 'A' * 500 submission.title = "A" * 500
submission.author.name = 'test' submission.author.name = "test"
submission.subreddit.display_name = 'test' submission.subreddit.display_name = "test"
submission.id = 'BBBBBB' submission.id = "BBBBBB"
test_resource = Resource(submission, 'www.example.com/empty', lambda: None, '.jpeg') test_resource = Resource(submission, "www.example.com/empty", lambda: None, ".jpeg")
test_formatter = FileNameFormatter('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}', 'ISO') test_formatter = FileNameFormatter("{REDDITOR}_{TITLE}_{POSTID}", "{SUBREDDIT}", "ISO")
result = test_formatter.format_path(test_resource, tmp_path) result = test_formatter.format_path(test_resource, tmp_path)
result.parent.mkdir(parents=True) result.parent.mkdir(parents=True)
result.touch() result.touch()
@pytest.mark.parametrize(('test_name', 'test_ending'), ( @pytest.mark.parametrize(
('a', 'b'), ("test_name", "test_ending"),
('a', '_bbbbbb.jpg'), (
('a' * 20, '_bbbbbb.jpg'), ("a", "b"),
('a' * 50, '_bbbbbb.jpg'), ("a", "_bbbbbb.jpg"),
('a' * 500, '_bbbbbb.jpg'), ("a" * 20, "_bbbbbb.jpg"),
)) ("a" * 50, "_bbbbbb.jpg"),
("a" * 500, "_bbbbbb.jpg"),
),
)
def test_shorten_path(test_name: str, test_ending: str, tmp_path: Path): def test_shorten_path(test_name: str, test_ending: str, tmp_path: Path):
result = FileNameFormatter.limit_file_name_length(test_name, test_ending, tmp_path) result = FileNameFormatter.limit_file_name_length(test_name, test_ending, tmp_path)
assert len(str(result.name)) <= 255 assert len(str(result.name)) <= 255
assert len(str(result.name).encode('UTF-8')) <= 255 assert len(str(result.name).encode("UTF-8")) <= 255
assert len(str(result.name).encode('cp1252')) <= 255 assert len(str(result.name).encode("cp1252")) <= 255
assert len(str(result)) <= FileNameFormatter.find_max_path_length() assert len(str(result)) <= FileNameFormatter.find_max_path_length()
@pytest.mark.parametrize(('test_string', 'expected'), ( @pytest.mark.parametrize(
('test', 'test'), ("test_string", "expected"),
('test😍', 'test'), (
('test.png', 'test.png'), ("test", "test"),
('test*', 'test'), ("test😍", "test"),
('test**', 'test'), ("test.png", "test.png"),
('test?*', 'test'), ("test*", "test"),
('test_???.png', 'test_.png'), ("test**", "test"),
('test_???😍.png', 'test_.png'), ("test?*", "test"),
)) ("test_???.png", "test_.png"),
("test_???😍.png", "test_.png"),
),
)
def test_format_file_name_for_windows(test_string: str, expected: str): def test_format_file_name_for_windows(test_string: str, expected: str):
result = FileNameFormatter._format_for_windows(test_string) result = FileNameFormatter._format_for_windows(test_string)
assert result == expected assert result == expected
@pytest.mark.parametrize(('test_string', 'expected'), ( @pytest.mark.parametrize(
('test', 'test'), ("test_string", "expected"),
('test😍', 'test'), (
('😍', ''), ("test", "test"),
)) ("test😍", "test"),
("😍", ""),
),
)
def test_strip_emojies(test_string: str, expected: str): def test_strip_emojies(test_string: str, expected: str):
result = FileNameFormatter._strip_emojis(test_string) result = FileNameFormatter._strip_emojis(test_string)
assert result == expected assert result == expected
@ -270,121 +305,151 @@ def test_strip_emojies(test_string: str, expected: str):
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected'), ( @pytest.mark.parametrize(
('mfuteh', { ("test_submission_id", "expected"),
'title': 'Why Do Interviewers Ask Linked List Questions?', (
'redditor': 'mjgardner', (
}), "mfuteh",
)) {
"title": "Why Do Interviewers Ask Linked List Questions?",
"redditor": "mjgardner",
},
),
),
)
def test_generate_dict_for_submission(test_submission_id: str, expected: dict, reddit_instance: praw.Reddit): def test_generate_dict_for_submission(test_submission_id: str, expected: dict, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id) test_submission = reddit_instance.submission(id=test_submission_id)
test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') test_formatter = FileNameFormatter("{TITLE}", "", "ISO")
result = test_formatter._generate_name_dict_from_submission(test_submission) result = test_formatter._generate_name_dict_from_submission(test_submission)
assert all([result.get(key) == expected[key] for key in expected.keys()]) assert all([result.get(key) == expected[key] for key in expected.keys()])
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_comment_id', 'expected'), ( @pytest.mark.parametrize(
('gsq0yuw', { ("test_comment_id", "expected"),
'title': 'Why Do Interviewers Ask Linked List Questions?', (
'redditor': 'Doctor-Dapper', (
'postid': 'gsq0yuw', "gsq0yuw",
'flair': '', {
}), "title": "Why Do Interviewers Ask Linked List Questions?",
)) "redditor": "Doctor-Dapper",
"postid": "gsq0yuw",
"flair": "",
},
),
),
)
def test_generate_dict_for_comment(test_comment_id: str, expected: dict, reddit_instance: praw.Reddit): def test_generate_dict_for_comment(test_comment_id: str, expected: dict, reddit_instance: praw.Reddit):
test_comment = reddit_instance.comment(id=test_comment_id) test_comment = reddit_instance.comment(id=test_comment_id)
test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') test_formatter = FileNameFormatter("{TITLE}", "", "ISO")
result = test_formatter._generate_name_dict_from_comment(test_comment) result = test_formatter._generate_name_dict_from_comment(test_comment)
assert all([result.get(key) == expected[key] for key in expected.keys()]) assert all([result.get(key) == expected[key] for key in expected.keys()])
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme', 'test_comment_id', 'expected_name'), ( @pytest.mark.parametrize(
('{POSTID}', '', 'gsoubde', 'gsoubde.json'), ("test_file_scheme", "test_folder_scheme", "test_comment_id", "expected_name"),
('{REDDITOR}_{POSTID}', '', 'gsoubde', 'DELETED_gsoubde.json'), (
)) ("{POSTID}", "", "gsoubde", "gsoubde.json"),
("{REDDITOR}_{POSTID}", "", "gsoubde", "DELETED_gsoubde.json"),
),
)
def test_format_archive_entry_comment( def test_format_archive_entry_comment(
test_file_scheme: str, test_file_scheme: str,
test_folder_scheme: str, test_folder_scheme: str,
test_comment_id: str, test_comment_id: str,
expected_name: str, expected_name: str,
tmp_path: Path, tmp_path: Path,
reddit_instance: praw.Reddit, reddit_instance: praw.Reddit,
): ):
test_comment = reddit_instance.comment(id=test_comment_id) test_comment = reddit_instance.comment(id=test_comment_id)
test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme, 'ISO') test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme, "ISO")
test_entry = Resource(test_comment, '', lambda: None, '.json') test_entry = Resource(test_comment, "", lambda: None, ".json")
result = test_formatter.format_path(test_entry, tmp_path) result = test_formatter.format_path(test_entry, tmp_path)
assert do_test_string_equality(result, expected_name) assert do_test_string_equality(result, expected_name)
@pytest.mark.parametrize(('test_folder_scheme', 'expected'), ( @pytest.mark.parametrize(
('{REDDITOR}/{SUBREDDIT}', 'person/randomreddit'), ("test_folder_scheme", "expected"),
('{POSTID}/{SUBREDDIT}/{REDDITOR}', '12345/randomreddit/person'), (
)) ("{REDDITOR}/{SUBREDDIT}", "person/randomreddit"),
("{POSTID}/{SUBREDDIT}/{REDDITOR}", "12345/randomreddit/person"),
),
)
def test_multilevel_folder_scheme( def test_multilevel_folder_scheme(
test_folder_scheme: str, test_folder_scheme: str,
expected: str, expected: str,
tmp_path: Path, tmp_path: Path,
submission: MagicMock, submission: MagicMock,
): ):
test_formatter = FileNameFormatter('{POSTID}', test_folder_scheme, 'ISO') test_formatter = FileNameFormatter("{POSTID}", test_folder_scheme, "ISO")
test_resource = MagicMock() test_resource = MagicMock()
test_resource.source_submission = submission test_resource.source_submission = submission
test_resource.extension = '.png' test_resource.extension = ".png"
result = test_formatter.format_path(test_resource, tmp_path) result = test_formatter.format_path(test_resource, tmp_path)
result = result.relative_to(tmp_path) result = result.relative_to(tmp_path)
assert do_test_path_equality(result.parent, expected) assert do_test_path_equality(result.parent, expected)
assert len(result.parents) == (len(expected.split('/')) + 1) assert len(result.parents) == (len(expected.split("/")) + 1)
@pytest.mark.parametrize(('test_name_string', 'expected'), ( @pytest.mark.parametrize(
('test', 'test'), ("test_name_string", "expected"),
('😍', '😍'), (
('test😍', 'test😍'), ("test", "test"),
('test😍 ', 'test😍 '), ("😍", "😍"),
('test😍 \\u2019', 'test😍 '), ("test😍", "test😍"),
('Using that real good [1\\4]', 'Using that real good [1\\4]'), ("test😍 ", "test😍 "),
)) ("test😍 \\u2019", "test😍 "),
("Using that real good [1\\4]", "Using that real good [1\\4]"),
),
)
def test_preserve_emojis(test_name_string: str, expected: str, submission: MagicMock): def test_preserve_emojis(test_name_string: str, expected: str, submission: MagicMock):
submission.title = test_name_string submission.title = test_name_string
test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') test_formatter = FileNameFormatter("{TITLE}", "", "ISO")
result = test_formatter._format_name(submission, '{TITLE}') result = test_formatter._format_name(submission, "{TITLE}")
assert do_test_string_equality(result, expected) assert do_test_string_equality(result, expected)
@pytest.mark.parametrize(('test_string', 'expected'), ( @pytest.mark.parametrize(
('test \\u2019', 'test '), ("test_string", "expected"),
('My cat\\u2019s paws are so cute', 'My cats paws are so cute'), (
)) ("test \\u2019", "test "),
("My cat\\u2019s paws are so cute", "My cats paws are so cute"),
),
)
def test_convert_unicode_escapes(test_string: str, expected: str): def test_convert_unicode_escapes(test_string: str, expected: str):
result = FileNameFormatter._convert_unicode_escapes(test_string) result = FileNameFormatter._convert_unicode_escapes(test_string)
assert result == expected assert result == expected
@pytest.mark.parametrize(('test_datetime', 'expected'), ( @pytest.mark.parametrize(
(datetime(2020, 1, 1, 8, 0, 0), '2020-01-01T08:00:00'), ("test_datetime", "expected"),
(datetime(2020, 1, 1, 8, 0), '2020-01-01T08:00:00'), (
(datetime(2021, 4, 21, 8, 30, 21), '2021-04-21T08:30:21'), (datetime(2020, 1, 1, 8, 0, 0), "2020-01-01T08:00:00"),
)) (datetime(2020, 1, 1, 8, 0), "2020-01-01T08:00:00"),
(datetime(2021, 4, 21, 8, 30, 21), "2021-04-21T08:30:21"),
),
)
def test_convert_timestamp(test_datetime: datetime, expected: str): def test_convert_timestamp(test_datetime: datetime, expected: str):
test_timestamp = test_datetime.timestamp() test_timestamp = test_datetime.timestamp()
test_formatter = FileNameFormatter('{POSTID}', '', 'ISO') test_formatter = FileNameFormatter("{POSTID}", "", "ISO")
result = test_formatter._convert_timestamp(test_timestamp) result = test_formatter._convert_timestamp(test_timestamp)
assert result == expected assert result == expected
@pytest.mark.parametrize(('test_time_format', 'expected'), ( @pytest.mark.parametrize(
('ISO', '2021-05-02T13:33:00'), ("test_time_format", "expected"),
('%Y_%m', '2021_05'), (
('%Y-%m-%d', '2021-05-02'), ("ISO", "2021-05-02T13:33:00"),
)) ("%Y_%m", "2021_05"),
("%Y-%m-%d", "2021-05-02"),
),
)
def test_time_string_formats(test_time_format: str, expected: str): def test_time_string_formats(test_time_format: str, expected: str):
test_time = datetime(2021, 5, 2, 13, 33) test_time = datetime(2021, 5, 2, 13, 33)
test_formatter = FileNameFormatter('{TITLE}', '', test_time_format) test_formatter = FileNameFormatter("{TITLE}", "", test_time_format)
result = test_formatter._convert_timestamp(test_time.timestamp()) result = test_formatter._convert_timestamp(test_time.timestamp())
assert result == expected assert result == expected
@ -395,29 +460,32 @@ def test_get_max_path_length():
def test_windows_max_path(tmp_path: Path): def test_windows_max_path(tmp_path: Path):
with unittest.mock.patch('platform.system', return_value='Windows'): with unittest.mock.patch("platform.system", return_value="Windows"):
with unittest.mock.patch('bdfr.file_name_formatter.FileNameFormatter.find_max_path_length', return_value=260): with unittest.mock.patch("bdfr.file_name_formatter.FileNameFormatter.find_max_path_length", return_value=260):
result = FileNameFormatter.limit_file_name_length('test' * 100, '_1.png', tmp_path) result = FileNameFormatter.limit_file_name_length("test" * 100, "_1.png", tmp_path)
assert len(str(result)) <= 260 assert len(str(result)) <= 260
assert len(result.name) <= (260 - len(str(tmp_path))) assert len(result.name) <= (260 - len(str(tmp_path)))
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.parametrize(('test_reddit_id', 'test_downloader', 'expected_names'), ( @pytest.mark.parametrize(
('gphmnr', YtdlpFallback, {'He has a lot to say today.mp4'}), ("test_reddit_id", "test_downloader", "expected_names"),
('d0oir2', YtdlpFallback, {"Crunk's finest moment. Welcome to the new subreddit!.mp4"}), (
('jiecu', SelfPost, {'[deleted by user].txt'}), ("gphmnr", YtdlpFallback, {"He has a lot to say today.mp4"}),
)) ("d0oir2", YtdlpFallback, {"Crunk's finest moment. Welcome to the new subreddit!.mp4"}),
("jiecu", SelfPost, {"[deleted by user].txt"}),
),
)
def test_name_submission( def test_name_submission(
test_reddit_id: str, test_reddit_id: str,
test_downloader: Type[BaseDownloader], test_downloader: Type[BaseDownloader],
expected_names: set[str], expected_names: set[str],
reddit_instance: praw.reddit.Reddit, reddit_instance: praw.reddit.Reddit,
): ):
test_submission = reddit_instance.submission(id=test_reddit_id) test_submission = reddit_instance.submission(id=test_reddit_id)
test_resources = test_downloader(test_submission).find_resources() test_resources = test_downloader(test_submission).find_resources()
test_formatter = FileNameFormatter('{TITLE}', '', '') test_formatter = FileNameFormatter("{TITLE}", "", "")
results = test_formatter.format_resource_paths(test_resources, Path('.')) results = test_formatter.format_resource_paths(test_resources, Path("."))
results = set([r[0].name for r in results]) results = set([r[0].name for r in results])
assert results == expected_names assert results == expected_names

View file

@ -14,38 +14,58 @@ from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
@pytest.fixture() @pytest.fixture()
def example_config() -> configparser.ConfigParser: def example_config() -> configparser.ConfigParser:
out = configparser.ConfigParser() out = configparser.ConfigParser()
config_dict = {'DEFAULT': {'user_token': 'example'}} config_dict = {"DEFAULT": {"user_token": "example"}}
out.read_dict(config_dict) out.read_dict(config_dict)
return out return out
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize('test_scopes', ( @pytest.mark.parametrize(
{'history', }, "test_scopes",
{'history', 'creddits'}, (
{'account', 'flair'}, {
{'*', }, "history",
)) },
{"history", "creddits"},
{"account", "flair"},
{
"*",
},
),
)
def test_check_scopes(test_scopes: set[str]): def test_check_scopes(test_scopes: set[str]):
OAuth2Authenticator._check_scopes(test_scopes) OAuth2Authenticator._check_scopes(test_scopes)
@pytest.mark.parametrize(('test_scopes', 'expected'), ( @pytest.mark.parametrize(
('history', {'history', }), ("test_scopes", "expected"),
('history creddits', {'history', 'creddits'}), (
('history, creddits, account', {'history', 'creddits', 'account'}), (
('history,creddits,account,flair', {'history', 'creddits', 'account', 'flair'}), "history",
)) {
"history",
},
),
("history creddits", {"history", "creddits"}),
("history, creddits, account", {"history", "creddits", "account"}),
("history,creddits,account,flair", {"history", "creddits", "account", "flair"}),
),
)
def test_split_scopes(test_scopes: str, expected: set[str]): def test_split_scopes(test_scopes: str, expected: set[str]):
result = OAuth2Authenticator.split_scopes(test_scopes) result = OAuth2Authenticator.split_scopes(test_scopes)
assert result == expected assert result == expected
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize('test_scopes', ( @pytest.mark.parametrize(
{'random', }, "test_scopes",
{'scope', 'another_scope'}, (
)) {
"random",
},
{"scope", "another_scope"},
),
)
def test_check_scopes_bad(test_scopes: set[str]): def test_check_scopes_bad(test_scopes: set[str]):
with pytest.raises(BulkDownloaderException): with pytest.raises(BulkDownloaderException):
OAuth2Authenticator._check_scopes(test_scopes) OAuth2Authenticator._check_scopes(test_scopes)
@ -56,16 +76,16 @@ def test_token_manager_read(example_config: configparser.ConfigParser):
mock_authoriser.refresh_token = None mock_authoriser.refresh_token = None
test_manager = OAuth2TokenManager(example_config, MagicMock()) test_manager = OAuth2TokenManager(example_config, MagicMock())
test_manager.pre_refresh_callback(mock_authoriser) test_manager.pre_refresh_callback(mock_authoriser)
assert mock_authoriser.refresh_token == example_config.get('DEFAULT', 'user_token') assert mock_authoriser.refresh_token == example_config.get("DEFAULT", "user_token")
def test_token_manager_write(example_config: configparser.ConfigParser, tmp_path: Path): def test_token_manager_write(example_config: configparser.ConfigParser, tmp_path: Path):
test_path = tmp_path / 'test.cfg' test_path = tmp_path / "test.cfg"
mock_authoriser = MagicMock() mock_authoriser = MagicMock()
mock_authoriser.refresh_token = 'changed_token' mock_authoriser.refresh_token = "changed_token"
test_manager = OAuth2TokenManager(example_config, test_path) test_manager = OAuth2TokenManager(example_config, test_path)
test_manager.post_refresh_callback(mock_authoriser) test_manager.post_refresh_callback(mock_authoriser)
assert example_config.get('DEFAULT', 'user_token') == 'changed_token' assert example_config.get("DEFAULT", "user_token") == "changed_token"
with test_path.open('r') as file: with test_path.open("r") as file:
file_contents = file.read() file_contents = file.read()
assert 'user_token = changed_token' in file_contents assert "user_token = changed_token" in file_contents

View file

@ -8,18 +8,21 @@ import pytest
from bdfr.resource import Resource from bdfr.resource import Resource
@pytest.mark.parametrize(('test_url', 'expected'), ( @pytest.mark.parametrize(
('test.png', '.png'), ("test_url", "expected"),
('another.mp4', '.mp4'), (
('test.jpeg', '.jpeg'), ("test.png", ".png"),
('http://www.random.com/resource.png', '.png'), ("another.mp4", ".mp4"),
('https://www.resource.com/test/example.jpg', '.jpg'), ("test.jpeg", ".jpeg"),
('hard.png.mp4', '.mp4'), ("http://www.random.com/resource.png", ".png"),
('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'), ("https://www.resource.com/test/example.jpg", ".jpg"),
('test.jpg#test', '.jpg'), ("hard.png.mp4", ".mp4"),
('test.jpg?width=247#test', '.jpg'), ("https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99", ".png"),
('https://www.test.com/test/test2/example.png?random=test#thing', '.png'), ("test.jpg#test", ".jpg"),
)) ("test.jpg?width=247#test", ".jpg"),
("https://www.test.com/test/test2/example.png?random=test#thing", ".png"),
),
)
def test_resource_get_extension(test_url: str, expected: str): def test_resource_get_extension(test_url: str, expected: str):
test_resource = Resource(MagicMock(), test_url, lambda: None) test_resource = Resource(MagicMock(), test_url, lambda: None)
result = test_resource._determine_extension() result = test_resource._determine_extension()
@ -27,9 +30,10 @@ def test_resource_get_extension(test_url: str, expected: str):
@pytest.mark.online @pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), ( @pytest.mark.parametrize(
('https://www.iana.org/_img/2013.1/iana-logo-header.svg', '426b3ac01d3584c820f3b7f5985d6623'), ("test_url", "expected_hash"),
)) (("https://www.iana.org/_img/2013.1/iana-logo-header.svg", "426b3ac01d3584c820f3b7f5985d6623"),),
)
def test_download_online_resource(test_url: str, expected_hash: str): def test_download_online_resource(test_url: str, expected_hash: str):
test_resource = Resource(MagicMock(), test_url, Resource.retry_download(test_url)) test_resource = Resource(MagicMock(), test_url, Resource.retry_download(test_url))
test_resource.download() test_resource.download()