1
0
Fork 0
mirror of synced 2024-06-02 10:24:39 +12:00
bulk-downloader-for-reddit/bulkredditdownloader/downloader.py

294 lines
13 KiB
Python
Raw Normal View History

2021-02-11 12:10:40 +13:00
#!/usr/bin/env python3
# coding=utf-8
import argparse
import configparser
import logging
2021-03-08 15:35:34 +13:00
import re
2021-02-11 12:10:40 +13:00
import socket
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import Iterator
2021-02-11 12:10:40 +13:00
import appdirs
import praw
import praw.models
import prawcore
2021-02-11 12:10:40 +13:00
2021-03-05 16:32:24 +13:00
import bulkredditdownloader.exceptions as errors
2021-02-11 12:10:40 +13:00
from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.file_name_formatter import FileNameFormatter
2021-03-08 15:35:34 +13:00
from bulkredditdownloader.oauth2 import OAuth2Authenticator, OAuth2TokenManager
2021-03-03 15:53:53 +13:00
from bulkredditdownloader.site_authenticator import SiteAuthenticator
2021-02-11 12:10:40 +13:00
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
logger = logging.getLogger(__name__)
class RedditTypes:
class SortType(Enum):
HOT = auto()
RISING = auto()
CONTROVERSIAL = auto()
NEW = auto()
RELEVENCE = auto()
class TimeType(Enum):
HOUR = auto()
DAY = auto()
WEEK = auto()
MONTH = auto()
YEAR = auto()
ALL = auto()
class RedditDownloader:
def __init__(self, args: argparse.Namespace):
2021-02-14 22:04:20 +13:00
self.args = args
2021-03-08 15:35:34 +13:00
self.config_directories = appdirs.AppDirs('bulk_reddit_downloader', 'BDFR')
2021-02-11 12:10:40 +13:00
self.run_time = datetime.now().isoformat()
2021-02-14 22:04:20 +13:00
self._setup_internal_objects()
2021-02-11 12:10:40 +13:00
2021-02-14 22:04:20 +13:00
self.reddit_lists = self._retrieve_reddit_lists()
2021-02-11 12:10:40 +13:00
2021-02-14 22:04:20 +13:00
def _setup_internal_objects(self):
2021-03-09 22:45:26 +13:00
self._determine_directories()
self._create_file_logger()
2021-02-14 22:04:20 +13:00
self.download_filter = self._create_download_filter()
2021-03-09 22:45:26 +13:00
logger.debug('Created download filter')
2021-02-14 22:04:20 +13:00
self.time_filter = self._create_time_filter()
2021-03-09 22:45:26 +13:00
logger.debug('Created time filter')
2021-02-14 22:04:20 +13:00
self.sort_filter = self._create_sort_filter()
2021-03-09 22:45:26 +13:00
logger.debug('Created sort filter')
2021-02-14 22:04:20 +13:00
self.file_name_formatter = self._create_file_name_formatter()
2021-03-09 22:45:26 +13:00
logger.debug('Create file name formatter')
2021-03-08 15:35:34 +13:00
self._resolve_user_name()
2021-02-14 22:04:20 +13:00
self._load_config()
2021-03-09 22:45:26 +13:00
logger.debug(f'Configuration loaded from {self.config_location}')
2021-03-08 15:35:34 +13:00
2021-03-09 18:51:06 +13:00
self.master_hash_list = []
self.authenticator = self._create_authenticator()
2021-03-09 22:45:26 +13:00
logger.debug('Created site authenticator')
2021-03-08 15:35:34 +13:00
self._create_reddit_instance()
def _create_reddit_instance(self):
if self.args.authenticate:
2021-03-09 22:45:26 +13:00
logger.debug('Using authenticated Reddit instance')
2021-03-08 15:35:34 +13:00
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
2021-03-09 22:45:26 +13:00
logger.debug('Commencing OAuth2 authentication')
2021-03-08 15:35:34 +13:00
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
scopes = OAuth2Authenticator.split_scopes(scopes)
2021-03-08 15:46:32 +13:00
oauth2_authenticator = OAuth2Authenticator(
scopes,
self.cfg_parser.get('DEFAULT', 'client_id'),
self.cfg_parser.get('DEFAULT', 'client_secret'))
2021-03-08 15:35:34 +13:00
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token
with open(self.config_location, 'w') as file:
self.cfg_parser.write(file, True)
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
2021-03-08 15:35:34 +13:00
2021-02-11 12:10:40 +13:00
self.authenticated = True
self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
2021-03-08 15:35:34 +13:00
token_manager=token_manager)
2021-02-11 12:10:40 +13:00
else:
2021-03-09 22:45:26 +13:00
logger.debug('Using unauthenticated Reddit instance')
2021-02-11 12:10:40 +13:00
self.authenticated = False
self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname())
2021-02-14 22:04:20 +13:00
def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
2021-02-11 12:10:40 +13:00
master_list = []
2021-02-14 22:04:20 +13:00
master_list.extend(self._get_subreddits())
2021-03-09 22:45:26 +13:00
logger.debug('Retrieved subreddits')
2021-02-14 22:04:20 +13:00
master_list.extend(self._get_multireddits())
2021-03-09 22:45:26 +13:00
logger.debug('Retrieved multireddits')
2021-02-14 22:04:20 +13:00
master_list.extend(self._get_user_data())
2021-03-09 22:45:26 +13:00
logger.debug('Retrieved user data')
2021-02-15 21:05:04 +13:00
master_list.extend(self._get_submissions_from_link())
2021-03-09 22:45:26 +13:00
logger.debug('Retrieved submissions for given links')
2021-02-11 12:10:40 +13:00
return master_list
2021-02-14 22:04:20 +13:00
def _determine_directories(self):
self.download_directory = Path(self.args.directory)
2021-02-11 12:10:40 +13:00
self.logfile_directory = self.download_directory / 'LOG_FILES'
self.config_directory = self.config_directories.user_config_dir
2021-02-15 16:56:02 +13:00
self.download_directory.mkdir(exist_ok=True, parents=True)
self.logfile_directory.mkdir(exist_ok=True, parents=True)
2021-02-14 22:04:20 +13:00
def _load_config(self):
2021-02-11 12:10:40 +13:00
self.cfg_parser = configparser.ConfigParser()
2021-03-08 15:35:34 +13:00
possible_paths = [Path('./config.cfg'),
Path(self.config_directory, 'config.cfg'),
Path('./default_config.cfg'),
]
2021-03-09 22:45:26 +13:00
self.config_location = None
2021-03-08 15:35:34 +13:00
for path in possible_paths:
if path.resolve().expanduser().exists():
self.config_location = path
break
2021-03-09 22:45:26 +13:00
if not self.config_location:
raise errors.BulkDownloaderException('Could not find a configuration file to load')
2021-03-08 15:35:34 +13:00
self.cfg_parser.read(self.config_location)
2021-02-11 12:10:40 +13:00
2021-02-14 18:52:11 +13:00
def _create_file_logger(self):
main_logger = logging.getLogger()
2021-02-15 16:56:02 +13:00
file_handler = logging.FileHandler(self.logfile_directory / 'log_output.txt')
2021-02-14 18:52:11 +13:00
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
file_handler.setFormatter(formatter)
file_handler.setLevel(0)
main_logger.addHandler(file_handler)
2021-02-14 22:04:20 +13:00
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit:
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
if self.args.search:
2021-03-03 15:53:53 +13:00
return [
reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit) for reddit in subreddits]
2021-02-11 12:10:40 +13:00
else:
2021-02-15 16:55:33 +13:00
sort_function = self._determine_sort_function()
return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits]
2021-02-11 12:10:40 +13:00
else:
return []
def _resolve_user_name(self):
if self.args.user == 'me':
self.args.user = self.reddit_instance.user.me()
2021-02-15 21:05:04 +13:00
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
2021-03-03 15:53:53 +13:00
for sub_id in self.args.link:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
2021-02-15 21:05:04 +13:00
return [supplied_submissions]
2021-02-15 16:55:33 +13:00
def _determine_sort_function(self):
if self.sort_filter is RedditTypes.SortType.NEW:
sort_function = praw.models.Subreddit.new
elif self.sort_filter is RedditTypes.SortType.RISING:
sort_function = praw.models.Subreddit.rising
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
sort_function = praw.models.Subreddit.controversial
else:
sort_function = praw.models.Subreddit.hot
return sort_function
def _get_multireddits(self) -> list[Iterator]:
2021-02-14 22:04:20 +13:00
if self.args.multireddit:
2021-02-11 12:10:40 +13:00
if self.authenticated:
if self.args.user:
sort_function = self._determine_sort_function()
return [
sort_function(self.reddit_instance.multireddit(
self.args.user,
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
else:
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
2021-02-11 12:10:40 +13:00
else:
raise errors.RedditAuthenticationError('Accessing multireddits requires authentication')
2021-02-11 12:10:40 +13:00
else:
return []
def _get_user_data(self) -> list[Iterator]:
2021-03-08 15:52:53 +13:00
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if self.args.user:
if not self._check_user_existence(self.args.user):
raise errors.RedditUserError(f'User {self.args.user} does not exist')
generators = []
sort_function = self._determine_sort_function()
if self.args.submitted:
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
raise errors.RedditAuthenticationError('Accessing user lists requires authentication')
else:
if self.args.upvoted:
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
if self.args.saved:
generators.append(self.reddit_instance.redditor(self.args.user).saved)
return generators
else:
2021-03-08 15:52:53 +13:00
raise errors.BulkDownloaderException('A user must be supplied to download user data')
2021-03-09 22:36:32 +13:00
else:
return []
def _check_user_existence(self, name: str) -> bool:
user = self.reddit_instance.redditor(name=name)
try:
if not user.id:
return False
except prawcore.exceptions.NotFound:
return False
return True
2021-02-11 12:10:40 +13:00
2021-02-14 22:04:20 +13:00
def _create_file_name_formatter(self) -> FileNameFormatter:
2021-03-03 15:53:53 +13:00
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
2021-02-11 12:10:40 +13:00
2021-02-14 22:04:20 +13:00
def _create_time_filter(self) -> RedditTypes.TimeType:
2021-02-11 12:10:40 +13:00
try:
2021-03-03 15:53:53 +13:00
return RedditTypes.TimeType[self.args.time.upper()]
2021-02-11 12:10:40 +13:00
except (KeyError, AttributeError):
return RedditTypes.TimeType.ALL
2021-02-14 22:04:20 +13:00
def _create_sort_filter(self) -> RedditTypes.SortType:
2021-02-11 12:10:40 +13:00
try:
2021-03-03 15:53:53 +13:00
return RedditTypes.SortType[self.args.sort.upper()]
2021-02-11 12:10:40 +13:00
except (KeyError, AttributeError):
return RedditTypes.SortType.HOT
2021-02-14 22:04:20 +13:00
def _create_download_filter(self) -> DownloadFilter:
2021-03-03 15:53:53 +13:00
return DownloadFilter(self.args.skip, self.args.skip_domain)
2021-02-11 12:10:40 +13:00
2021-02-26 21:57:05 +13:00
def _create_authenticator(self) -> SiteAuthenticator:
2021-03-09 18:51:06 +13:00
return SiteAuthenticator(self.cfg_parser)
2021-02-26 21:56:21 +13:00
2021-02-11 12:10:40 +13:00
def download(self):
for generator in self.reddit_lists:
for submission in generator:
self._download_submission(submission)
def _download_submission(self, submission: praw.models.Submission):
if self.download_filter.check_url(submission.url):
2021-02-15 20:45:10 +13:00
logger.debug('Attempting to download submission {}'.format(submission.id))
2021-03-04 12:14:43 +13:00
2021-02-11 12:10:40 +13:00
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
2021-02-15 18:12:27 +13:00
downloader = downloader_class(submission)
except errors.NotADownloadableLinkError as e:
2021-02-11 12:10:40 +13:00
logger.error('Could not download submission {}: {}'.format(submission.name, e))
2021-03-04 12:14:43 +13:00
return
2021-03-04 12:14:43 +13:00
if self.args.no_download:
logger.info('Skipping download for submission {}'.format(submission.id))
else:
content = downloader.find_resources(self.authenticator)
for res in content:
destination = self.file_name_formatter.format_path(res, self.download_directory)
if destination.exists():
logger.debug('File already exists: {}'.format(destination))
else:
if res.hash.hexdigest() not in self.master_hash_list and not self.args.no_dupes:
# TODO: consider making a hard link/symlink here
destination.parent.mkdir(parents=True, exist_ok=True)
with open(destination, 'wb') as file:
file.write(res.content)
logger.debug('Written file to {}'.format(destination))
self.master_hash_list.append(res.hash.hexdigest())
logger.debug('Hash added to master list: {}'.format(res.hash.hexdigest()))
else:
logger.debug(f'Resource from {res.url} downloaded elsewhere')
logger.info('Downloaded submission {}'.format(submission.name))