1
0
Fork 0
mirror of synced 2024-05-18 19:22:38 +12:00
bulk-downloader-for-reddit/bulkredditdownloader/downloader.py

294 lines
13 KiB
Python

#!/usr/bin/env python3
# coding=utf-8
import argparse
import configparser
import logging
import re
import socket
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import Iterator
import appdirs
import praw
import praw.models
import prawcore
import bulkredditdownloader.exceptions as errors
from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.oauth2 import OAuth2Authenticator, OAuth2TokenManager
from bulkredditdownloader.site_authenticator import SiteAuthenticator
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
logger = logging.getLogger(__name__)
class RedditTypes:
class SortType(Enum):
HOT = auto()
RISING = auto()
CONTROVERSIAL = auto()
NEW = auto()
RELEVENCE = auto()
class TimeType(Enum):
HOUR = auto()
DAY = auto()
WEEK = auto()
MONTH = auto()
YEAR = auto()
ALL = auto()
class RedditDownloader:
def __init__(self, args: argparse.Namespace):
self.args = args
self.config_directories = appdirs.AppDirs('bulk_reddit_downloader', 'BDFR')
self.run_time = datetime.now().isoformat()
self._setup_internal_objects()
self.reddit_lists = self._retrieve_reddit_lists()
def _setup_internal_objects(self):
self._determine_directories()
self._create_file_logger()
self.download_filter = self._create_download_filter()
logger.debug('Created download filter')
self.time_filter = self._create_time_filter()
logger.debug('Created time filter')
self.sort_filter = self._create_sort_filter()
logger.debug('Created sort filter')
self.file_name_formatter = self._create_file_name_formatter()
logger.debug('Create file name formatter')
self._resolve_user_name()
self._load_config()
logger.debug(f'Configuration loaded from {self.config_location}')
self.master_hash_list = []
self.authenticator = self._create_authenticator()
logger.debug('Created site authenticator')
self._create_reddit_instance()
def _create_reddit_instance(self):
if self.args.authenticate:
logger.debug('Using authenticated Reddit instance')
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
logger.debug('Commencing OAuth2 authentication')
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
scopes = OAuth2Authenticator.split_scopes(scopes)
oauth2_authenticator = OAuth2Authenticator(
scopes,
self.cfg_parser.get('DEFAULT', 'client_id'),
self.cfg_parser.get('DEFAULT', 'client_secret'))
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token
with open(self.config_location, 'w') as file:
self.cfg_parser.write(file, True)
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
self.authenticated = True
self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
token_manager=token_manager)
else:
logger.debug('Using unauthenticated Reddit instance')
self.authenticated = False
self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname())
def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
master_list = []
master_list.extend(self._get_subreddits())
logger.debug('Retrieved subreddits')
master_list.extend(self._get_multireddits())
logger.debug('Retrieved multireddits')
master_list.extend(self._get_user_data())
logger.debug('Retrieved user data')
master_list.extend(self._get_submissions_from_link())
logger.debug('Retrieved submissions for given links')
return master_list
def _determine_directories(self):
self.download_directory = Path(self.args.directory)
self.logfile_directory = self.download_directory / 'LOG_FILES'
self.config_directory = self.config_directories.user_config_dir
self.download_directory.mkdir(exist_ok=True, parents=True)
self.logfile_directory.mkdir(exist_ok=True, parents=True)
def _load_config(self):
self.cfg_parser = configparser.ConfigParser()
possible_paths = [Path('./config.cfg'),
Path(self.config_directory, 'config.cfg'),
Path('./default_config.cfg'),
]
self.config_location = None
for path in possible_paths:
if path.resolve().expanduser().exists():
self.config_location = path
break
if not self.config_location:
raise errors.BulkDownloaderException('Could not find a configuration file to load')
self.cfg_parser.read(self.config_location)
def _create_file_logger(self):
main_logger = logging.getLogger()
file_handler = logging.FileHandler(self.logfile_directory / 'log_output.txt')
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
file_handler.setFormatter(formatter)
file_handler.setLevel(0)
main_logger.addHandler(file_handler)
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit:
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
if self.args.search:
return [
reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit) for reddit in subreddits]
else:
sort_function = self._determine_sort_function()
return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits]
else:
return []
def _resolve_user_name(self):
if self.args.user == 'me':
self.args.user = self.reddit_instance.user.me()
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
return [supplied_submissions]
def _determine_sort_function(self):
if self.sort_filter is RedditTypes.SortType.NEW:
sort_function = praw.models.Subreddit.new
elif self.sort_filter is RedditTypes.SortType.RISING:
sort_function = praw.models.Subreddit.rising
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
sort_function = praw.models.Subreddit.controversial
else:
sort_function = praw.models.Subreddit.hot
return sort_function
def _get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
if self.authenticated:
if self.args.user:
sort_function = self._determine_sort_function()
return [
sort_function(self.reddit_instance.multireddit(
self.args.user,
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
else:
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
else:
raise errors.RedditAuthenticationError('Accessing multireddits requires authentication')
else:
return []
def _get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if self.args.user:
if not self._check_user_existence(self.args.user):
raise errors.RedditUserError(f'User {self.args.user} does not exist')
generators = []
sort_function = self._determine_sort_function()
if self.args.submitted:
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
raise errors.RedditAuthenticationError('Accessing user lists requires authentication')
else:
if self.args.upvoted:
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
if self.args.saved:
generators.append(self.reddit_instance.redditor(self.args.user).saved)
return generators
else:
raise errors.BulkDownloaderException('A user must be supplied to download user data')
else:
return []
def _check_user_existence(self, name: str) -> bool:
user = self.reddit_instance.redditor(name=name)
try:
if not user.id:
return False
except prawcore.exceptions.NotFound:
return False
return True
def _create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
def _create_time_filter(self) -> RedditTypes.TimeType:
try:
return RedditTypes.TimeType[self.args.time.upper()]
except (KeyError, AttributeError):
return RedditTypes.TimeType.ALL
def _create_sort_filter(self) -> RedditTypes.SortType:
try:
return RedditTypes.SortType[self.args.sort.upper()]
except (KeyError, AttributeError):
return RedditTypes.SortType.HOT
def _create_download_filter(self) -> DownloadFilter:
return DownloadFilter(self.args.skip, self.args.skip_domain)
def _create_authenticator(self) -> SiteAuthenticator:
return SiteAuthenticator(self.cfg_parser)
def download(self):
for generator in self.reddit_lists:
for submission in generator:
self._download_submission(submission)
def _download_submission(self, submission: praw.models.Submission):
if self.download_filter.check_url(submission.url):
logger.debug('Attempting to download submission {}'.format(submission.id))
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
except errors.NotADownloadableLinkError as e:
logger.error('Could not download submission {}: {}'.format(submission.name, e))
return
if self.args.no_download:
logger.info('Skipping download for submission {}'.format(submission.id))
else:
content = downloader.find_resources(self.authenticator)
for res in content:
destination = self.file_name_formatter.format_path(res, self.download_directory)
if destination.exists():
logger.debug('File already exists: {}'.format(destination))
else:
if res.hash.hexdigest() not in self.master_hash_list and not self.args.no_dupes:
# TODO: consider making a hard link/symlink here
destination.parent.mkdir(parents=True, exist_ok=True)
with open(destination, 'wb') as file:
file.write(res.content)
logger.debug('Written file to {}'.format(destination))
self.master_hash_list.append(res.hash.hexdigest())
logger.debug('Hash added to master list: {}'.format(res.hash.hexdigest()))
else:
logger.debug(f'Resource from {res.url} downloaded elsewhere')
logger.info('Downloaded submission {}'.format(submission.name))