1
0
Fork 0
mirror of synced 2024-10-01 09:41:03 +13:00

Merge pull request #818 from OMEGARAZER/Annotations

This commit is contained in:
Serene 2023-05-29 13:04:10 +10:00 committed by GitHub
commit dd6f12345e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 77 additions and 71 deletions

View file

@ -69,8 +69,8 @@ _archiver_options = [
]
def _add_options(opts: list):
def wrap(func):
def _add_options(opts: list): # noqa: ANN202
def wrap(func): # noqa: ANN001,ANN202
for opt in opts:
func = opt(func)
return func
@ -78,7 +78,7 @@ def _add_options(opts: list):
return wrap
def _check_version(context, _param, value):
def _check_version(context: click.core.Context, _param, value: bool) -> None:
if not value or context.resilient_parsing:
return
current = __version__
@ -101,7 +101,7 @@ def _check_version(context, _param, value):
callback=_check_version,
help="Check version and exit.",
)
def cli():
def cli() -> None:
"""BDFR is used to download and archive content from Reddit."""
pass
@ -111,7 +111,7 @@ def cli():
@_add_options(_downloader_options)
@click.help_option("-h", "--help")
@click.pass_context
def cli_download(context: click.Context, **_):
def cli_download(context: click.Context, **_) -> None:
"""Used to download content posted to Reddit."""
config = Configuration()
config.process_click_arguments(context)
@ -132,7 +132,7 @@ def cli_download(context: click.Context, **_):
@_add_options(_archiver_options)
@click.help_option("-h", "--help")
@click.pass_context
def cli_archive(context: click.Context, **_):
def cli_archive(context: click.Context, **_) -> None:
"""Used to archive post data from Reddit."""
config = Configuration()
config.process_click_arguments(context)
@ -154,7 +154,7 @@ def cli_archive(context: click.Context, **_):
@_add_options(_downloader_options)
@click.help_option("-h", "--help")
@click.pass_context
def cli_clone(context: click.Context, **_):
def cli_clone(context: click.Context, **_) -> None:
"""Combines archive and download commands."""
config = Configuration()
config.process_click_arguments(context)
@ -174,7 +174,7 @@ def cli_clone(context: click.Context, **_):
@click.argument("shell", type=click.Choice(("all", "bash", "fish", "zsh"), case_sensitive=False), default="all")
@click.help_option("-h", "--help")
@click.option("-u", "--uninstall", is_flag=True, default=False, help="Uninstall completion")
def cli_completion(shell: str, uninstall: bool):
def cli_completion(shell: str, uninstall: bool) -> None:
"""\b
Installs shell completions for BDFR.
Options: all, bash, fish, zsh
@ -216,7 +216,7 @@ def make_console_logging_handler(verbosity: int) -> logging.StreamHandler:
return stream
def silence_module_loggers():
def silence_module_loggers() -> None:
logging.getLogger("praw").setLevel(logging.CRITICAL)
logging.getLogger("prawcore").setLevel(logging.CRITICAL)
logging.getLogger("urllib3").setLevel(logging.CRITICAL)

View file

@ -7,7 +7,7 @@ from praw.models import Comment, Submission
class BaseArchiveEntry(ABC):
def __init__(self, source: Union[Comment, Submission]):
def __init__(self, source: Union[Comment, Submission]) -> None:
self.source = source
self.post_details: dict = {}

View file

@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
class CommentArchiveEntry(BaseArchiveEntry):
def __init__(self, comment: praw.models.Comment):
def __init__(self, comment: praw.models.Comment) -> None:
super().__init__(comment)
def compile(self) -> dict:

View file

@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
class SubmissionArchiveEntry(BaseArchiveEntry):
def __init__(self, submission: praw.models.Submission):
def __init__(self, submission: praw.models.Submission) -> None:
super().__init__(submission)
def compile(self) -> dict:
@ -20,7 +20,7 @@ class SubmissionArchiveEntry(BaseArchiveEntry):
out["comments"] = comments
return out
def _get_post_details(self):
def _get_post_details(self) -> None:
self.post_details = {
"title": self.source.title,
"name": self.source.name,

View file

@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
class Archiver(RedditConnector):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()) -> None:
super().__init__(args, logging_handlers)
def download(self):
def download(self) -> None:
for generator in self.reddit_lists:
try:
for submission in generator:
@ -82,7 +82,7 @@ class Archiver(RedditConnector):
else:
raise ArchiverError(f"Factory failed to classify item of type {type(praw_item).__name__}")
def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]):
def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]) -> None:
if self.args.comment_context and isinstance(praw_item, praw.models.Comment):
logger.debug(f"Converting comment {praw_item.id} to submission {praw_item.submission.id}")
praw_item = praw_item.submission
@ -97,22 +97,22 @@ class Archiver(RedditConnector):
raise ArchiverError(f"Unknown format {self.args.format!r} given")
logger.info(f"Record for entry item {praw_item.id} written to disk")
def _write_entry_json(self, entry: BaseArchiveEntry):
def _write_entry_json(self, entry: BaseArchiveEntry) -> None:
resource = Resource(entry.source, "", lambda: None, ".json")
content = json.dumps(entry.compile())
self._write_content_to_disk(resource, content)
def _write_entry_xml(self, entry: BaseArchiveEntry):
def _write_entry_xml(self, entry: BaseArchiveEntry) -> None:
resource = Resource(entry.source, "", lambda: None, ".xml")
content = dict2xml.dict2xml(entry.compile(), wrap="root")
self._write_content_to_disk(resource, content)
def _write_entry_yaml(self, entry: BaseArchiveEntry):
def _write_entry_yaml(self, entry: BaseArchiveEntry) -> None:
resource = Resource(entry.source, "", lambda: None, ".yaml")
content = yaml.safe_dump(entry.compile())
self._write_content_to_disk(resource, content)
def _write_content_to_disk(self, resource: Resource, content: str):
def _write_content_to_disk(self, resource: Resource, content: str) -> None:
file_path = self.file_name_formatter.format_path(resource, self.download_directory)
file_path.parent.mkdir(exist_ok=True, parents=True)
with Path(file_path).open(mode="w", encoding="utf-8") as file:

View file

@ -14,10 +14,10 @@ logger = logging.getLogger(__name__)
class RedditCloner(RedditDownloader, Archiver):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()) -> None:
super().__init__(args, logging_handlers)
def download(self):
def download(self) -> None:
for generator in self.reddit_lists:
try:
for submission in generator:

View file

@ -8,13 +8,13 @@ import appdirs
class Completion:
def __init__(self, shell: str):
def __init__(self, shell: str) -> None:
self.shell = shell
self.env = environ.copy()
self.share_dir = appdirs.user_data_dir()
self.entry_points = ["bdfr", "bdfr-archive", "bdfr-clone", "bdfr-download"]
def install(self):
def install(self) -> None:
if self.shell in ("all", "bash"):
comp_dir = self.share_dir + "/bash-completion/completions/"
if not Path(comp_dir).exists():
@ -46,7 +46,7 @@ class Completion:
file.write(subprocess.run([point], env=self.env, capture_output=True, text=True).stdout)
print(f"Zsh completion for {point} written to {comp_dir}_{point}")
def uninstall(self):
def uninstall(self) -> None:
if self.shell in ("all", "bash"):
comp_dir = self.share_dir + "/bash-completion/completions/"
for point in self.entry_points:

View file

@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class Configuration(Namespace):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.authenticate = False
self.config = None
@ -59,7 +59,7 @@ class Configuration(Namespace):
self.format = "json"
self.comment_context: bool = False
def process_click_arguments(self, context: click.Context):
def process_click_arguments(self, context: click.Context) -> None:
if context.params.get("opts") is not None:
self.parse_yaml_options(context.params["opts"])
for arg_key in context.params.keys():
@ -72,7 +72,7 @@ class Configuration(Namespace):
continue
setattr(self, arg_key, val)
def parse_yaml_options(self, file_path: str):
def parse_yaml_options(self, file_path: str) -> None:
yaml_file_loc = Path(file_path)
if not yaml_file_loc.exists():
logger.error(f"No YAML file found at {yaml_file_loc}")

View file

@ -14,6 +14,7 @@ from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from time import sleep
from typing import Union
import appdirs
import praw
@ -51,7 +52,7 @@ class RedditTypes:
class RedditConnector(metaclass=ABCMeta):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()) -> None:
self.args = args
self.config_directories = appdirs.AppDirs("bdfr", "BDFR")
self.determine_directories()
@ -64,7 +65,7 @@ class RedditConnector(metaclass=ABCMeta):
self.reddit_lists = self.retrieve_reddit_lists()
def _setup_internal_objects(self):
def _setup_internal_objects(self) -> None:
self.parse_disabled_modules()
self.download_filter = self.create_download_filter()
@ -95,12 +96,12 @@ class RedditConnector(metaclass=ABCMeta):
self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit}
@staticmethod
def _apply_logging_handlers(handlers: Iterable[logging.Handler]):
def _apply_logging_handlers(handlers: Iterable[logging.Handler]) -> None:
main_logger = logging.getLogger()
for handler in handlers:
main_logger.addHandler(handler)
def read_config(self):
def read_config(self) -> None:
"""Read any cfg values that need to be processed"""
if self.args.max_wait_time is None:
self.args.max_wait_time = self.cfg_parser.getint("DEFAULT", "max_wait_time", fallback=120)
@ -122,14 +123,14 @@ class RedditConnector(metaclass=ABCMeta):
with Path(self.config_location).open(mode="w") as file:
self.cfg_parser.write(file)
def parse_disabled_modules(self):
def parse_disabled_modules(self) -> None:
disabled_modules = self.args.disable_module
disabled_modules = self.split_args_input(disabled_modules)
disabled_modules = {name.strip().lower() for name in disabled_modules}
self.args.disable_module = disabled_modules
logger.debug(f"Disabling the following modules: {', '.join(self.args.disable_module)}")
def create_reddit_instance(self):
def create_reddit_instance(self) -> None:
if self.args.authenticate:
logger.debug("Using authenticated Reddit instance")
if not self.cfg_parser.has_option("DEFAULT", "user_token"):
@ -176,14 +177,14 @@ class RedditConnector(metaclass=ABCMeta):
logger.log(9, "Retrieved submissions for given links")
return master_list
def determine_directories(self):
def determine_directories(self) -> None:
self.download_directory = Path(self.args.directory).resolve().expanduser()
self.config_directory = Path(self.config_directories.user_config_dir)
self.download_directory.mkdir(exist_ok=True, parents=True)
self.config_directory.mkdir(exist_ok=True, parents=True)
def load_config(self):
def load_config(self) -> None:
self.cfg_parser = configparser.ConfigParser()
if self.args.config:
if (cfg_path := Path(self.args.config)).exists():
@ -349,7 +350,9 @@ class RedditConnector(metaclass=ABCMeta):
else:
return []
def create_filtered_listing_generator(self, reddit_source) -> Iterator:
def create_filtered_listing_generator(
self, reddit_source: Union[praw.models.Subreddit, praw.models.Multireddit, praw.models.Redditor.submissions]
) -> Iterator:
sort_function = self.determine_sort_function()
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
@ -396,7 +399,7 @@ class RedditConnector(metaclass=ABCMeta):
else:
return []
def check_user_existence(self, name: str):
def check_user_existence(self, name: str) -> None:
user = self.reddit_instance.redditor(name=name)
try:
if user.id:
@ -431,15 +434,16 @@ class RedditConnector(metaclass=ABCMeta):
return SiteAuthenticator(self.cfg_parser)
@abstractmethod
def download(self):
def download(self) -> None:
pass
@staticmethod
def check_subreddit_status(subreddit: praw.models.Subreddit):
def check_subreddit_status(subreddit: praw.models.Subreddit) -> None:
if subreddit.display_name in ("all", "friends"):
return
try:
assert subreddit.id
if subreddit.id:
return
except prawcore.NotFound:
raise errors.BulkDownloaderException(f"Source {subreddit.display_name} cannot be found")
except prawcore.Redirect:

View file

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
class DownloadFilter:
def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None):
def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None) -> None:
self.excluded_extensions = excluded_extensions
self.excluded_domains = excluded_domains

View file

@ -23,7 +23,7 @@ from bdfr.site_downloaders.download_factory import DownloadFactory
logger = logging.getLogger(__name__)
def _calc_hash(existing_file: Path):
def _calc_hash(existing_file: Path) -> tuple[Path, str]:
chunk_size = 1024 * 1024
md5_hash = hashlib.md5(usedforsecurity=False)
with existing_file.open("rb") as file:
@ -36,12 +36,12 @@ def _calc_hash(existing_file: Path):
class RedditDownloader(RedditConnector):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()) -> None:
super().__init__(args, logging_handlers)
if self.args.search_existing:
self.master_hash_list = self.scan_existing_files(self.download_directory)
def download(self):
def download(self) -> None:
for generator in self.reddit_lists:
try:
for submission in generator:
@ -54,7 +54,7 @@ class RedditDownloader(RedditConnector):
logger.debug("Waiting 60 seconds to continue")
sleep(60)
def _download_submission(self, submission: praw.models.Submission):
def _download_submission(self, submission: praw.models.Submission) -> None:
if submission.id in self.excluded_submission_ids:
logger.debug(f"Object {submission.id} in exclusion list, skipping")
return

View file

@ -35,7 +35,7 @@ class FileNameFormatter:
directory_format_string: str,
time_format_string: str,
restriction_scheme: Optional[str] = None,
):
) -> None:
if not self.validate_string(file_format_string):
raise BulkDownloaderException(f"{file_format_string!r} is not a valid format string")
self.file_format_string = file_format_string

View file

@ -16,14 +16,14 @@ logger = logging.getLogger(__name__)
class OAuth2Authenticator:
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str, user_agent: str):
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str, user_agent: str) -> None:
self._check_scopes(wanted_scopes, user_agent)
self.scopes = wanted_scopes
self.client_id = client_id
self.client_secret = client_secret
@staticmethod
def _check_scopes(wanted_scopes: set[str], user_agent: str):
def _check_scopes(wanted_scopes: set[str], user_agent: str) -> None:
try:
response = requests.get(
"https://www.reddit.com/api/v1/scopes.json",
@ -86,18 +86,18 @@ class OAuth2Authenticator:
return client
@staticmethod
def send_message(client: socket.socket, message: str = ""):
def send_message(client: socket.socket, message: str = "") -> None:
client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode())
client.close()
class OAuth2TokenManager(praw.reddit.BaseTokenManager):
def __init__(self, config: configparser.ConfigParser, config_location: Path):
def __init__(self, config: configparser.ConfigParser, config_location: Path) -> None:
super().__init__()
self.config = config
self.config_location = config_location
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer) -> None:
if authorizer.refresh_token is None:
if self.config.has_option("DEFAULT", "user_token"):
authorizer.refresh_token = self.config.get("DEFAULT", "user_token")
@ -105,7 +105,7 @@ class OAuth2TokenManager(praw.reddit.BaseTokenManager):
else:
raise RedditAuthenticationError("No auth token loaded in configuration")
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer) -> None:
self.config.set("DEFAULT", "user_token", authorizer.refresh_token)
with Path(self.config_location).open(mode="w") as file:
self.config.write(file, True)

View file

@ -18,7 +18,9 @@ logger = logging.getLogger(__name__)
class Resource:
def __init__(self, source_submission: Submission, url: str, download_function: Callable, extension: str = None):
def __init__(
self, source_submission: Submission, url: str, download_function: Callable, extension: str = None
) -> None:
self.source_submission = source_submission
self.content: Optional[bytes] = None
self.url = url
@ -32,7 +34,7 @@ class Resource:
def retry_download(url: str) -> Callable:
return lambda global_params: Resource.http_download(url, global_params)
def download(self, download_parameters: Optional[dict] = None):
def download(self, download_parameters: Optional[dict] = None) -> None:
if download_parameters is None:
download_parameters = {}
if not self.content:
@ -47,7 +49,7 @@ class Resource:
if not self.hash and self.content:
self.create_hash()
def create_hash(self):
def create_hash(self) -> None:
self.hash = hashlib.md5(self.content, usedforsecurity=False)
def _determine_extension(self) -> Optional[str]:

View file

@ -4,5 +4,5 @@ import configparser
class SiteAuthenticator:
def __init__(self, cfg: configparser.ConfigParser):
def __init__(self, cfg: configparser.ConfigParser) -> None:
self.imgur_authentication = None

View file

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class BaseDownloader(ABC):
def __init__(self, post: Submission, typical_extension: Optional[str] = None):
def __init__(self, post: Submission, typical_extension: Optional[str] = None) -> None:
self.post = post
self.typical_extension = typical_extension

View file

@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class DelayForReddit(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -10,7 +10,7 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class Direct(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class Erome(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class YtdlpFallback(BaseFallbackDownloader, Youtube):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class Gallery(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -14,7 +14,7 @@ from bdfr.site_downloaders.redgifs import Redgifs
class Gfycat(Redgifs):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -13,7 +13,7 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class Imgur(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
self.raw_data = {}

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class PornHub(Youtube):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -14,7 +14,7 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class Redgifs(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class SelfPost(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class Vidble(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class VReddit(Youtube):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class Youtube(BaseDownloader):
def __init__(self, post: Submission):
def __init__(self, post: Submission) -> None:
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: