#!/usr/bin/env python3 # coding=utf-8 import configparser import logging import random import socket import praw import requests from bulkredditdownloader.exceptions import BulkDownloaderException, RedditAuthenticationError logger = logging.getLogger(__name__) class OAuth2Authenticator: def __init__(self, wanted_scopes: list[str]): self._check_scopes(wanted_scopes) self.scopes = wanted_scopes @staticmethod def _check_scopes(wanted_scopes: list[str]): response = requests.get('https://www.reddit.com/api/v1/scopes.json', headers={'User-Agent': 'fetch-scopes test'}) known_scopes = [scope for scope, data in response.json().items()] known_scopes.append('*') for scope in wanted_scopes: if scope not in known_scopes: raise BulkDownloaderException(f'Scope {scope} is not known to reddit') def retrieve_new_token(self) -> str: reddit = praw.Reddit(redirect_uri='http://localhost:8080', user_agent='obtain_refresh_token for BDFR') state = str(random.randint(0, 65000)) url = reddit.auth.url(self.scopes, state, 'permanent') logger.warning('Authentication action required before the program can proceed') logger.warning(f'Authenticate at {url}') client = self.receive_connection() data = client.recv(1024).decode('utf-8') param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&') params = {key: value for (key, value) in [token.split('=') for token in param_tokens]} if state != params['state']: self.send_message(client) raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}') elif 'error' in params: self.send_message(client) raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}') refresh_token = reddit.auth.authorize(params["code"]) return refresh_token @staticmethod def receive_connection() -> socket.socket: server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server.bind(('localhost', 8080)) logger.debug('Server listening on localhost:8080') server.listen(1) client = server.accept()[0] server.close() logger.debug('Server closed') return client @staticmethod def send_message(client: socket.socket): client.send('HTTP/1.1 200 OK'.encode('utf-8')) client.close() class OAuth2TokenManager(praw.reddit.BaseTokenManager): def __init__(self, config: configparser.ConfigParser): super(OAuth2TokenManager, self).__init__() self.config = config def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer): if authorizer.refresh_token is None: if self.config.has_option('DEFAULT', 'user_token'): authorizer.refresh_token = self.config.get('DEFAULT', 'user_token') else: raise RedditAuthenticationError('No auth token loaded in configuration') def post_refresh_callback(self, authorizer: praw.reddit.Authorizer): self.config.set('DEFAULT', 'user_token', authorizer.refresh_token)