1
0
Fork 0
mirror of synced 2024-10-01 17:47:46 +13:00

Merge pull request #818 from OMEGARAZER/Annotations

This commit is contained in:
Serene 2023-05-29 13:04:10 +10:00 committed by GitHub
commit dd6f12345e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 77 additions and 71 deletions

View file

@ -69,8 +69,8 @@ _archiver_options = [
] ]
def _add_options(opts: list): def _add_options(opts: list): # noqa: ANN202
def wrap(func): def wrap(func): # noqa: ANN001,ANN202
for opt in opts: for opt in opts:
func = opt(func) func = opt(func)
return func return func
@ -78,7 +78,7 @@ def _add_options(opts: list):
return wrap return wrap
def _check_version(context, _param, value): def _check_version(context: click.core.Context, _param, value: bool) -> None:
if not value or context.resilient_parsing: if not value or context.resilient_parsing:
return return
current = __version__ current = __version__
@ -101,7 +101,7 @@ def _check_version(context, _param, value):
callback=_check_version, callback=_check_version,
help="Check version and exit.", help="Check version and exit.",
) )
def cli(): def cli() -> None:
"""BDFR is used to download and archive content from Reddit.""" """BDFR is used to download and archive content from Reddit."""
pass pass
@ -111,7 +111,7 @@ def cli():
@_add_options(_downloader_options) @_add_options(_downloader_options)
@click.help_option("-h", "--help") @click.help_option("-h", "--help")
@click.pass_context @click.pass_context
def cli_download(context: click.Context, **_): def cli_download(context: click.Context, **_) -> None:
"""Used to download content posted to Reddit.""" """Used to download content posted to Reddit."""
config = Configuration() config = Configuration()
config.process_click_arguments(context) config.process_click_arguments(context)
@ -132,7 +132,7 @@ def cli_download(context: click.Context, **_):
@_add_options(_archiver_options) @_add_options(_archiver_options)
@click.help_option("-h", "--help") @click.help_option("-h", "--help")
@click.pass_context @click.pass_context
def cli_archive(context: click.Context, **_): def cli_archive(context: click.Context, **_) -> None:
"""Used to archive post data from Reddit.""" """Used to archive post data from Reddit."""
config = Configuration() config = Configuration()
config.process_click_arguments(context) config.process_click_arguments(context)
@ -154,7 +154,7 @@ def cli_archive(context: click.Context, **_):
@_add_options(_downloader_options) @_add_options(_downloader_options)
@click.help_option("-h", "--help") @click.help_option("-h", "--help")
@click.pass_context @click.pass_context
def cli_clone(context: click.Context, **_): def cli_clone(context: click.Context, **_) -> None:
"""Combines archive and download commands.""" """Combines archive and download commands."""
config = Configuration() config = Configuration()
config.process_click_arguments(context) config.process_click_arguments(context)
@ -174,7 +174,7 @@ def cli_clone(context: click.Context, **_):
@click.argument("shell", type=click.Choice(("all", "bash", "fish", "zsh"), case_sensitive=False), default="all") @click.argument("shell", type=click.Choice(("all", "bash", "fish", "zsh"), case_sensitive=False), default="all")
@click.help_option("-h", "--help") @click.help_option("-h", "--help")
@click.option("-u", "--uninstall", is_flag=True, default=False, help="Uninstall completion") @click.option("-u", "--uninstall", is_flag=True, default=False, help="Uninstall completion")
def cli_completion(shell: str, uninstall: bool): def cli_completion(shell: str, uninstall: bool) -> None:
"""\b """\b
Installs shell completions for BDFR. Installs shell completions for BDFR.
Options: all, bash, fish, zsh Options: all, bash, fish, zsh
@ -216,7 +216,7 @@ def make_console_logging_handler(verbosity: int) -> logging.StreamHandler:
return stream return stream
def silence_module_loggers(): def silence_module_loggers() -> None:
logging.getLogger("praw").setLevel(logging.CRITICAL) logging.getLogger("praw").setLevel(logging.CRITICAL)
logging.getLogger("prawcore").setLevel(logging.CRITICAL) logging.getLogger("prawcore").setLevel(logging.CRITICAL)
logging.getLogger("urllib3").setLevel(logging.CRITICAL) logging.getLogger("urllib3").setLevel(logging.CRITICAL)

View file

@ -7,7 +7,7 @@ from praw.models import Comment, Submission
class BaseArchiveEntry(ABC): class BaseArchiveEntry(ABC):
def __init__(self, source: Union[Comment, Submission]): def __init__(self, source: Union[Comment, Submission]) -> None:
self.source = source self.source = source
self.post_details: dict = {} self.post_details: dict = {}

View file

@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
class CommentArchiveEntry(BaseArchiveEntry): class CommentArchiveEntry(BaseArchiveEntry):
def __init__(self, comment: praw.models.Comment): def __init__(self, comment: praw.models.Comment) -> None:
super().__init__(comment) super().__init__(comment)
def compile(self) -> dict: def compile(self) -> dict:

View file

@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
class SubmissionArchiveEntry(BaseArchiveEntry): class SubmissionArchiveEntry(BaseArchiveEntry):
def __init__(self, submission: praw.models.Submission): def __init__(self, submission: praw.models.Submission) -> None:
super().__init__(submission) super().__init__(submission)
def compile(self) -> dict: def compile(self) -> dict:
@ -20,7 +20,7 @@ class SubmissionArchiveEntry(BaseArchiveEntry):
out["comments"] = comments out["comments"] = comments
return out return out
def _get_post_details(self): def _get_post_details(self) -> None:
self.post_details = { self.post_details = {
"title": self.source.title, "title": self.source.title,
"name": self.source.name, "name": self.source.name,

View file

@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
class Archiver(RedditConnector): class Archiver(RedditConnector):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()): def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()) -> None:
super().__init__(args, logging_handlers) super().__init__(args, logging_handlers)
def download(self): def download(self) -> None:
for generator in self.reddit_lists: for generator in self.reddit_lists:
try: try:
for submission in generator: for submission in generator:
@ -82,7 +82,7 @@ class Archiver(RedditConnector):
else: else:
raise ArchiverError(f"Factory failed to classify item of type {type(praw_item).__name__}") raise ArchiverError(f"Factory failed to classify item of type {type(praw_item).__name__}")
def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]): def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]) -> None:
if self.args.comment_context and isinstance(praw_item, praw.models.Comment): if self.args.comment_context and isinstance(praw_item, praw.models.Comment):
logger.debug(f"Converting comment {praw_item.id} to submission {praw_item.submission.id}") logger.debug(f"Converting comment {praw_item.id} to submission {praw_item.submission.id}")
praw_item = praw_item.submission praw_item = praw_item.submission
@ -97,22 +97,22 @@ class Archiver(RedditConnector):
raise ArchiverError(f"Unknown format {self.args.format!r} given") raise ArchiverError(f"Unknown format {self.args.format!r} given")
logger.info(f"Record for entry item {praw_item.id} written to disk") logger.info(f"Record for entry item {praw_item.id} written to disk")
def _write_entry_json(self, entry: BaseArchiveEntry): def _write_entry_json(self, entry: BaseArchiveEntry) -> None:
resource = Resource(entry.source, "", lambda: None, ".json") resource = Resource(entry.source, "", lambda: None, ".json")
content = json.dumps(entry.compile()) content = json.dumps(entry.compile())
self._write_content_to_disk(resource, content) self._write_content_to_disk(resource, content)
def _write_entry_xml(self, entry: BaseArchiveEntry): def _write_entry_xml(self, entry: BaseArchiveEntry) -> None:
resource = Resource(entry.source, "", lambda: None, ".xml") resource = Resource(entry.source, "", lambda: None, ".xml")
content = dict2xml.dict2xml(entry.compile(), wrap="root") content = dict2xml.dict2xml(entry.compile(), wrap="root")
self._write_content_to_disk(resource, content) self._write_content_to_disk(resource, content)
def _write_entry_yaml(self, entry: BaseArchiveEntry): def _write_entry_yaml(self, entry: BaseArchiveEntry) -> None:
resource = Resource(entry.source, "", lambda: None, ".yaml") resource = Resource(entry.source, "", lambda: None, ".yaml")
content = yaml.safe_dump(entry.compile()) content = yaml.safe_dump(entry.compile())
self._write_content_to_disk(resource, content) self._write_content_to_disk(resource, content)
def _write_content_to_disk(self, resource: Resource, content: str): def _write_content_to_disk(self, resource: Resource, content: str) -> None:
file_path = self.file_name_formatter.format_path(resource, self.download_directory) file_path = self.file_name_formatter.format_path(resource, self.download_directory)
file_path.parent.mkdir(exist_ok=True, parents=True) file_path.parent.mkdir(exist_ok=True, parents=True)
with Path(file_path).open(mode="w", encoding="utf-8") as file: with Path(file_path).open(mode="w", encoding="utf-8") as file:

View file

@ -14,10 +14,10 @@ logger = logging.getLogger(__name__)
class RedditCloner(RedditDownloader, Archiver): class RedditCloner(RedditDownloader, Archiver):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()): def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()) -> None:
super().__init__(args, logging_handlers) super().__init__(args, logging_handlers)
def download(self): def download(self) -> None:
for generator in self.reddit_lists: for generator in self.reddit_lists:
try: try:
for submission in generator: for submission in generator:

View file

@ -8,13 +8,13 @@ import appdirs
class Completion: class Completion:
def __init__(self, shell: str): def __init__(self, shell: str) -> None:
self.shell = shell self.shell = shell
self.env = environ.copy() self.env = environ.copy()
self.share_dir = appdirs.user_data_dir() self.share_dir = appdirs.user_data_dir()
self.entry_points = ["bdfr", "bdfr-archive", "bdfr-clone", "bdfr-download"] self.entry_points = ["bdfr", "bdfr-archive", "bdfr-clone", "bdfr-download"]
def install(self): def install(self) -> None:
if self.shell in ("all", "bash"): if self.shell in ("all", "bash"):
comp_dir = self.share_dir + "/bash-completion/completions/" comp_dir = self.share_dir + "/bash-completion/completions/"
if not Path(comp_dir).exists(): if not Path(comp_dir).exists():
@ -46,7 +46,7 @@ class Completion:
file.write(subprocess.run([point], env=self.env, capture_output=True, text=True).stdout) file.write(subprocess.run([point], env=self.env, capture_output=True, text=True).stdout)
print(f"Zsh completion for {point} written to {comp_dir}_{point}") print(f"Zsh completion for {point} written to {comp_dir}_{point}")
def uninstall(self): def uninstall(self) -> None:
if self.shell in ("all", "bash"): if self.shell in ("all", "bash"):
comp_dir = self.share_dir + "/bash-completion/completions/" comp_dir = self.share_dir + "/bash-completion/completions/"
for point in self.entry_points: for point in self.entry_points:

View file

@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class Configuration(Namespace): class Configuration(Namespace):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.authenticate = False self.authenticate = False
self.config = None self.config = None
@ -59,7 +59,7 @@ class Configuration(Namespace):
self.format = "json" self.format = "json"
self.comment_context: bool = False self.comment_context: bool = False
def process_click_arguments(self, context: click.Context): def process_click_arguments(self, context: click.Context) -> None:
if context.params.get("opts") is not None: if context.params.get("opts") is not None:
self.parse_yaml_options(context.params["opts"]) self.parse_yaml_options(context.params["opts"])
for arg_key in context.params.keys(): for arg_key in context.params.keys():
@ -72,7 +72,7 @@ class Configuration(Namespace):
continue continue
setattr(self, arg_key, val) setattr(self, arg_key, val)
def parse_yaml_options(self, file_path: str): def parse_yaml_options(self, file_path: str) -> None:
yaml_file_loc = Path(file_path) yaml_file_loc = Path(file_path)
if not yaml_file_loc.exists(): if not yaml_file_loc.exists():
logger.error(f"No YAML file found at {yaml_file_loc}") logger.error(f"No YAML file found at {yaml_file_loc}")

View file

@ -14,6 +14,7 @@ from datetime import datetime
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path from pathlib import Path
from time import sleep from time import sleep
from typing import Union
import appdirs import appdirs
import praw import praw
@ -51,7 +52,7 @@ class RedditTypes:
class RedditConnector(metaclass=ABCMeta): class RedditConnector(metaclass=ABCMeta):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()): def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()) -> None:
self.args = args self.args = args
self.config_directories = appdirs.AppDirs("bdfr", "BDFR") self.config_directories = appdirs.AppDirs("bdfr", "BDFR")
self.determine_directories() self.determine_directories()
@ -64,7 +65,7 @@ class RedditConnector(metaclass=ABCMeta):
self.reddit_lists = self.retrieve_reddit_lists() self.reddit_lists = self.retrieve_reddit_lists()
def _setup_internal_objects(self): def _setup_internal_objects(self) -> None:
self.parse_disabled_modules() self.parse_disabled_modules()
self.download_filter = self.create_download_filter() self.download_filter = self.create_download_filter()
@ -95,12 +96,12 @@ class RedditConnector(metaclass=ABCMeta):
self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit} self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit}
@staticmethod @staticmethod
def _apply_logging_handlers(handlers: Iterable[logging.Handler]): def _apply_logging_handlers(handlers: Iterable[logging.Handler]) -> None:
main_logger = logging.getLogger() main_logger = logging.getLogger()
for handler in handlers: for handler in handlers:
main_logger.addHandler(handler) main_logger.addHandler(handler)
def read_config(self): def read_config(self) -> None:
"""Read any cfg values that need to be processed""" """Read any cfg values that need to be processed"""
if self.args.max_wait_time is None: if self.args.max_wait_time is None:
self.args.max_wait_time = self.cfg_parser.getint("DEFAULT", "max_wait_time", fallback=120) self.args.max_wait_time = self.cfg_parser.getint("DEFAULT", "max_wait_time", fallback=120)
@ -122,14 +123,14 @@ class RedditConnector(metaclass=ABCMeta):
with Path(self.config_location).open(mode="w") as file: with Path(self.config_location).open(mode="w") as file:
self.cfg_parser.write(file) self.cfg_parser.write(file)
def parse_disabled_modules(self): def parse_disabled_modules(self) -> None:
disabled_modules = self.args.disable_module disabled_modules = self.args.disable_module
disabled_modules = self.split_args_input(disabled_modules) disabled_modules = self.split_args_input(disabled_modules)
disabled_modules = {name.strip().lower() for name in disabled_modules} disabled_modules = {name.strip().lower() for name in disabled_modules}
self.args.disable_module = disabled_modules self.args.disable_module = disabled_modules
logger.debug(f"Disabling the following modules: {', '.join(self.args.disable_module)}") logger.debug(f"Disabling the following modules: {', '.join(self.args.disable_module)}")
def create_reddit_instance(self): def create_reddit_instance(self) -> None:
if self.args.authenticate: if self.args.authenticate:
logger.debug("Using authenticated Reddit instance") logger.debug("Using authenticated Reddit instance")
if not self.cfg_parser.has_option("DEFAULT", "user_token"): if not self.cfg_parser.has_option("DEFAULT", "user_token"):
@ -176,14 +177,14 @@ class RedditConnector(metaclass=ABCMeta):
logger.log(9, "Retrieved submissions for given links") logger.log(9, "Retrieved submissions for given links")
return master_list return master_list
def determine_directories(self): def determine_directories(self) -> None:
self.download_directory = Path(self.args.directory).resolve().expanduser() self.download_directory = Path(self.args.directory).resolve().expanduser()
self.config_directory = Path(self.config_directories.user_config_dir) self.config_directory = Path(self.config_directories.user_config_dir)
self.download_directory.mkdir(exist_ok=True, parents=True) self.download_directory.mkdir(exist_ok=True, parents=True)
self.config_directory.mkdir(exist_ok=True, parents=True) self.config_directory.mkdir(exist_ok=True, parents=True)
def load_config(self): def load_config(self) -> None:
self.cfg_parser = configparser.ConfigParser() self.cfg_parser = configparser.ConfigParser()
if self.args.config: if self.args.config:
if (cfg_path := Path(self.args.config)).exists(): if (cfg_path := Path(self.args.config)).exists():
@ -349,7 +350,9 @@ class RedditConnector(metaclass=ABCMeta):
else: else:
return [] return []
def create_filtered_listing_generator(self, reddit_source) -> Iterator: def create_filtered_listing_generator(
self, reddit_source: Union[praw.models.Subreddit, praw.models.Multireddit, praw.models.Redditor.submissions]
) -> Iterator:
sort_function = self.determine_sort_function() sort_function = self.determine_sort_function()
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL): if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value) return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
@ -396,7 +399,7 @@ class RedditConnector(metaclass=ABCMeta):
else: else:
return [] return []
def check_user_existence(self, name: str): def check_user_existence(self, name: str) -> None:
user = self.reddit_instance.redditor(name=name) user = self.reddit_instance.redditor(name=name)
try: try:
if user.id: if user.id:
@ -431,15 +434,16 @@ class RedditConnector(metaclass=ABCMeta):
return SiteAuthenticator(self.cfg_parser) return SiteAuthenticator(self.cfg_parser)
@abstractmethod @abstractmethod
def download(self): def download(self) -> None:
pass pass
@staticmethod @staticmethod
def check_subreddit_status(subreddit: praw.models.Subreddit): def check_subreddit_status(subreddit: praw.models.Subreddit) -> None:
if subreddit.display_name in ("all", "friends"): if subreddit.display_name in ("all", "friends"):
return return
try: try:
assert subreddit.id if subreddit.id:
return
except prawcore.NotFound: except prawcore.NotFound:
raise errors.BulkDownloaderException(f"Source {subreddit.display_name} cannot be found") raise errors.BulkDownloaderException(f"Source {subreddit.display_name} cannot be found")
except prawcore.Redirect: except prawcore.Redirect:

View file

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
class DownloadFilter: class DownloadFilter:
def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None): def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None) -> None:
self.excluded_extensions = excluded_extensions self.excluded_extensions = excluded_extensions
self.excluded_domains = excluded_domains self.excluded_domains = excluded_domains

View file

@ -23,7 +23,7 @@ from bdfr.site_downloaders.download_factory import DownloadFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _calc_hash(existing_file: Path): def _calc_hash(existing_file: Path) -> tuple[Path, str]:
chunk_size = 1024 * 1024 chunk_size = 1024 * 1024
md5_hash = hashlib.md5(usedforsecurity=False) md5_hash = hashlib.md5(usedforsecurity=False)
with existing_file.open("rb") as file: with existing_file.open("rb") as file:
@ -36,12 +36,12 @@ def _calc_hash(existing_file: Path):
class RedditDownloader(RedditConnector): class RedditDownloader(RedditConnector):
def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()): def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()) -> None:
super().__init__(args, logging_handlers) super().__init__(args, logging_handlers)
if self.args.search_existing: if self.args.search_existing:
self.master_hash_list = self.scan_existing_files(self.download_directory) self.master_hash_list = self.scan_existing_files(self.download_directory)
def download(self): def download(self) -> None:
for generator in self.reddit_lists: for generator in self.reddit_lists:
try: try:
for submission in generator: for submission in generator:
@ -54,7 +54,7 @@ class RedditDownloader(RedditConnector):
logger.debug("Waiting 60 seconds to continue") logger.debug("Waiting 60 seconds to continue")
sleep(60) sleep(60)
def _download_submission(self, submission: praw.models.Submission): def _download_submission(self, submission: praw.models.Submission) -> None:
if submission.id in self.excluded_submission_ids: if submission.id in self.excluded_submission_ids:
logger.debug(f"Object {submission.id} in exclusion list, skipping") logger.debug(f"Object {submission.id} in exclusion list, skipping")
return return

View file

@ -35,7 +35,7 @@ class FileNameFormatter:
directory_format_string: str, directory_format_string: str,
time_format_string: str, time_format_string: str,
restriction_scheme: Optional[str] = None, restriction_scheme: Optional[str] = None,
): ) -> None:
if not self.validate_string(file_format_string): if not self.validate_string(file_format_string):
raise BulkDownloaderException(f"{file_format_string!r} is not a valid format string") raise BulkDownloaderException(f"{file_format_string!r} is not a valid format string")
self.file_format_string = file_format_string self.file_format_string = file_format_string

View file

@ -16,14 +16,14 @@ logger = logging.getLogger(__name__)
class OAuth2Authenticator: class OAuth2Authenticator:
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str, user_agent: str): def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str, user_agent: str) -> None:
self._check_scopes(wanted_scopes, user_agent) self._check_scopes(wanted_scopes, user_agent)
self.scopes = wanted_scopes self.scopes = wanted_scopes
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
@staticmethod @staticmethod
def _check_scopes(wanted_scopes: set[str], user_agent: str): def _check_scopes(wanted_scopes: set[str], user_agent: str) -> None:
try: try:
response = requests.get( response = requests.get(
"https://www.reddit.com/api/v1/scopes.json", "https://www.reddit.com/api/v1/scopes.json",
@ -86,18 +86,18 @@ class OAuth2Authenticator:
return client return client
@staticmethod @staticmethod
def send_message(client: socket.socket, message: str = ""): def send_message(client: socket.socket, message: str = "") -> None:
client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode()) client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode())
client.close() client.close()
class OAuth2TokenManager(praw.reddit.BaseTokenManager): class OAuth2TokenManager(praw.reddit.BaseTokenManager):
def __init__(self, config: configparser.ConfigParser, config_location: Path): def __init__(self, config: configparser.ConfigParser, config_location: Path) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.config_location = config_location self.config_location = config_location
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer): def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer) -> None:
if authorizer.refresh_token is None: if authorizer.refresh_token is None:
if self.config.has_option("DEFAULT", "user_token"): if self.config.has_option("DEFAULT", "user_token"):
authorizer.refresh_token = self.config.get("DEFAULT", "user_token") authorizer.refresh_token = self.config.get("DEFAULT", "user_token")
@ -105,7 +105,7 @@ class OAuth2TokenManager(praw.reddit.BaseTokenManager):
else: else:
raise RedditAuthenticationError("No auth token loaded in configuration") raise RedditAuthenticationError("No auth token loaded in configuration")
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer): def post_refresh_callback(self, authorizer: praw.reddit.Authorizer) -> None:
self.config.set("DEFAULT", "user_token", authorizer.refresh_token) self.config.set("DEFAULT", "user_token", authorizer.refresh_token)
with Path(self.config_location).open(mode="w") as file: with Path(self.config_location).open(mode="w") as file:
self.config.write(file, True) self.config.write(file, True)

View file

@ -18,7 +18,9 @@ logger = logging.getLogger(__name__)
class Resource: class Resource:
def __init__(self, source_submission: Submission, url: str, download_function: Callable, extension: str = None): def __init__(
self, source_submission: Submission, url: str, download_function: Callable, extension: str = None
) -> None:
self.source_submission = source_submission self.source_submission = source_submission
self.content: Optional[bytes] = None self.content: Optional[bytes] = None
self.url = url self.url = url
@ -32,7 +34,7 @@ class Resource:
def retry_download(url: str) -> Callable: def retry_download(url: str) -> Callable:
return lambda global_params: Resource.http_download(url, global_params) return lambda global_params: Resource.http_download(url, global_params)
def download(self, download_parameters: Optional[dict] = None): def download(self, download_parameters: Optional[dict] = None) -> None:
if download_parameters is None: if download_parameters is None:
download_parameters = {} download_parameters = {}
if not self.content: if not self.content:
@ -47,7 +49,7 @@ class Resource:
if not self.hash and self.content: if not self.hash and self.content:
self.create_hash() self.create_hash()
def create_hash(self): def create_hash(self) -> None:
self.hash = hashlib.md5(self.content, usedforsecurity=False) self.hash = hashlib.md5(self.content, usedforsecurity=False)
def _determine_extension(self) -> Optional[str]: def _determine_extension(self) -> Optional[str]:

View file

@ -4,5 +4,5 @@ import configparser
class SiteAuthenticator: class SiteAuthenticator:
def __init__(self, cfg: configparser.ConfigParser): def __init__(self, cfg: configparser.ConfigParser) -> None:
self.imgur_authentication = None self.imgur_authentication = None

View file

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class BaseDownloader(ABC): class BaseDownloader(ABC):
def __init__(self, post: Submission, typical_extension: Optional[str] = None): def __init__(self, post: Submission, typical_extension: Optional[str] = None) -> None:
self.post = post self.post = post
self.typical_extension = typical_extension self.typical_extension = typical_extension

View file

@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class DelayForReddit(BaseDownloader): class DelayForReddit(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -10,7 +10,7 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class Direct(BaseDownloader): class Direct(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class Erome(BaseDownloader): class Erome(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class YtdlpFallback(BaseFallbackDownloader, Youtube): class YtdlpFallback(BaseFallbackDownloader, Youtube):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class Gallery(BaseDownloader): class Gallery(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -14,7 +14,7 @@ from bdfr.site_downloaders.redgifs import Redgifs
class Gfycat(Redgifs): class Gfycat(Redgifs):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -13,7 +13,7 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class Imgur(BaseDownloader): class Imgur(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
self.raw_data = {} self.raw_data = {}

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class PornHub(Youtube): class PornHub(Youtube):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -14,7 +14,7 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class Redgifs(BaseDownloader): class Redgifs(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class SelfPost(BaseDownloader): class SelfPost(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class Vidble(BaseDownloader): class Vidble(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class VReddit(Youtube): class VReddit(Youtube):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:

View file

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class Youtube(BaseDownloader): class Youtube(BaseDownloader):
def __init__(self, post: Submission): def __init__(self, post: Submission) -> None:
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: