1
0
Fork 0
mirror of synced 2024-06-27 02:20:45 +12:00
bulk-downloader-for-reddit/bulkredditdownloader/downloader.py
2021-04-18 16:43:55 +03:00

226 lines
9.8 KiB
Python

#!/usr/bin/env python3
# coding=utf-8
import argparse
import configparser
import logging
import socket
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
import appdirs
import praw
import praw.models
from bulkredditdownloader.authenticator import Authenticator
from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
logger = logging.getLogger(__name__)
class RedditTypes:
class SortType(Enum):
HOT = auto()
RISING = auto()
CONTROVERSIAL = auto()
NEW = auto()
RELEVENCE = auto()
class TimeType(Enum):
HOUR = auto()
DAY = auto()
WEEK = auto()
MONTH = auto()
YEAR = auto()
ALL = auto()
class RedditDownloader:
def __init__(self, args: argparse.Namespace):
self.args = args
self.config_directories = appdirs.AppDirs('bulk_reddit_downloader')
self.run_time = datetime.now().isoformat()
self._setup_internal_objects()
self.reddit_lists = self._retrieve_reddit_lists()
def _setup_internal_objects(self):
self.download_filter = self._create_download_filter()
self.time_filter = self._create_time_filter()
self.sort_filter = self._create_sort_filter()
self.file_name_formatter = self._create_file_name_formatter()
self.authenticator = self._create_authenticator()
self._determine_directories()
self._create_file_logger()
self.master_hash_list = []
self._load_config()
if self.cfg_parser.has_option('DEFAULT', 'reddit_token'):
# TODO: implement OAuth2 authentication
self.authenticated = True
self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
)
else:
self.authenticated = False
self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname())
def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
master_list = []
master_list.extend(self._get_subreddits())
master_list.extend(self._get_multireddits())
master_list.extend(self._get_user_data())
master_list.extend(self._get_submissions_from_link())
return master_list
def _determine_directories(self):
self.download_directory = Path(self.args.directory)
self.logfile_directory = self.download_directory / 'LOG_FILES'
self.config_directory = self.config_directories.user_config_dir
self.download_directory.mkdir(exist_ok=True, parents=True)
self.logfile_directory.mkdir(exist_ok=True, parents=True)
def _load_config(self):
self.cfg_parser = configparser.ConfigParser()
if self.args.use_local_config and Path('./config.cfg').exists():
self.cfg_parser.read(Path('./config.cfg'))
else:
self.cfg_parser.read(Path('./default_config.cfg').resolve())
def _create_file_logger(self):
main_logger = logging.getLogger()
file_handler = logging.FileHandler(self.logfile_directory / 'log_output.txt')
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
file_handler.setFormatter(formatter)
file_handler.setLevel(0)
main_logger.addHandler(file_handler)
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit:
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
if self.args.search:
return [reddit.search(self.args.search, sort=self.sort_filter.name.lower()) for reddit in subreddits]
else:
sort_function = self._determine_sort_function()
return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits]
else:
return []
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for url in self.args.link:
supplied_submissions.append(self.reddit_instance.submission(url=url))
return [supplied_submissions]
def _determine_sort_function(self):
if self.sort_filter is RedditTypes.SortType.NEW:
sort_function = praw.models.Subreddit.new
elif self.sort_filter is RedditTypes.SortType.RISING:
sort_function = praw.models.Subreddit.rising
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
sort_function = praw.models.Subreddit.controversial
else:
sort_function = praw.models.Subreddit.hot
return sort_function
def _get_multireddits(self) -> list[praw.models.ListingGenerator]:
if self.args.multireddit:
if self.authenticated:
return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in self.args.multireddit]
else:
raise RedditAuthenticationError('Accessing multireddits requires authentication')
else:
return []
def _get_user_data(self) -> list[praw.models.ListingGenerator]:
if any((self.args.upvoted, self.args.submitted, self.args.saved)):
if self.authenticated:
generators = []
sort_function = self._determine_sort_function()
if self.args.upvoted:
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
if self.args.submitted:
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit))
if self.args.saved:
generators.append(self.reddit_instance.redditor(self.args.user).saved)
return generators
else:
raise RedditAuthenticationError('Accessing user lists requires authentication')
else:
return []
def _create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.set_filename, self.args.set_folderpath)
def _create_time_filter(self) -> RedditTypes.TimeType:
try:
return RedditTypes.TimeType[self.args.sort.upper()]
except (KeyError, AttributeError):
return RedditTypes.TimeType.ALL
def _create_sort_filter(self) -> RedditTypes.SortType:
try:
return RedditTypes.SortType[self.args.time.upper()]
except (KeyError, AttributeError):
return RedditTypes.SortType.HOT
def _create_download_filter(self) -> DownloadFilter:
formats = {
"videos": [".mp4", ".webm"],
"images": [".jpg", ".jpeg", ".png", ".bmp"],
"gifs": [".gif"],
"self": []
}
excluded_extensions = [extension for ext_type in self.args.skip for extension in formats.get(ext_type, ())]
return DownloadFilter(excluded_extensions, self.args.skip_domain)
def _create_authenticator(self) -> Authenticator:
raise NotImplementedError
def download(self):
for generator in self.reddit_lists:
for submission in generator:
self._download_submission(submission)
def _download_submission(self, submission: praw.models.Submission):
if self.download_filter.check_url(submission.url):
logger.debug('Attempting to download submission {}'.format(submission.id))
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
if self.args.no_download:
logger.info('Skipping download for submission {}'.format(submission.id))
else:
content = downloader.find_resources(self.authenticator)
for res in content:
destination = self.file_name_formatter.format_path(res, self.download_directory)
if destination.exists():
logger.debug('File already exists: {}'.format(destination))
else:
if res.hash.hexdigest() not in self.master_hash_list and not self.args.no_dupes:
# TODO: consider making a hard link/symlink here
destination.parent.mkdir(parents=True, exist_ok=True)
with open(destination, 'wb') as file:
file.write(res.content)
logger.debug('Written file to {}'.format(destination))
self.master_hash_list.append(res.hash.hexdigest())
logger.debug('Hash added to master list: {}'.format(res.hash.hexdigest()))
else:
logger.debug(f'Resource from {res.url} downloaded elsewhere')
logger.info('Downloaded submission {}'.format(submission.name))
except NotADownloadableLinkError as e:
logger.error('Could not download submission {}: {}'.format(submission.name, e))