diff --git a/bulkredditdownloader/oauth2.py b/bulkredditdownloader/oauth2.py new file mode 100644 index 0000000..67444d8 --- /dev/null +++ b/bulkredditdownloader/oauth2.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +import configparser +import logging +import random +import socket + +import praw +import requests + +from bulkredditdownloader.exceptions import BulkDownloaderException, RedditAuthenticationError + +logger = logging.getLogger(__name__) + + +class OAuth2Authenticator: + + def __init__(self, wanted_scopes: list[str]): + self._check_scopes(wanted_scopes) + self.scopes = wanted_scopes + + @staticmethod + def _check_scopes(wanted_scopes: list[str]): + response = requests.get('https://www.reddit.com/api/v1/scopes.json', + headers={'User-Agent': 'fetch-scopes test'}) + known_scopes = [scope for scope, data in response.json().items()] + known_scopes.append('*') + for scope in wanted_scopes: + if scope not in known_scopes: + raise BulkDownloaderException(f'Scope {scope} is not known to reddit') + + def retrieve_new_token(self) -> str: + reddit = praw.Reddit(redirect_uri='http://localhost:8080', user_agent='obtain_refresh_token for BDFR') + state = str(random.randint(0, 65000)) + url = reddit.auth.url(self.scopes, state, 'permanent') + logger.warning('Authentication action required before the program can proceed') + logger.warning(f'Authenticate at {url}') + + client = self.receive_connection() + data = client.recv(1024).decode('utf-8') + param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&') + params = {key: value for (key, value) in [token.split('=') for token in param_tokens]} + + if state != params['state']: + self.send_message(client) + raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}') + elif 'error' in params: + self.send_message(client) + raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}') + + refresh_token = reddit.auth.authorize(params["code"]) + return refresh_token + + @staticmethod + def receive_connection() -> socket.socket: + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(('localhost', 8080)) + logger.debug('Server listening on localhost:8080') + + server.listen(1) + client = server.accept()[0] + server.close() + logger.debug('Server closed') + + return client + + @staticmethod + def send_message(client: socket.socket): + client.send('HTTP/1.1 200 OK'.encode('utf-8')) + client.close() + + +class OAuth2TokenManager(praw.reddit.BaseTokenManager): + def __init__(self, config: configparser.ConfigParser): + super(OAuth2TokenManager, self).__init__() + self.config = config + + def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer): + if authorizer.refresh_token is None: + if self.config.has_option('DEFAULT', 'user_token'): + authorizer.refresh_token = self.config.get('DEFAULT', 'user_token') + else: + raise RedditAuthenticationError('No auth token loaded in configuration') + + def post_refresh_callback(self, authorizer: praw.reddit.Authorizer): + self.config.set('DEFAULT', 'user_token', authorizer.refresh_token) diff --git a/bulkredditdownloader/tests/test_oauth2.py b/bulkredditdownloader/tests/test_oauth2.py new file mode 100644 index 0000000..a80d7a7 --- /dev/null +++ b/bulkredditdownloader/tests/test_oauth2.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +import configparser +from unittest.mock import MagicMock + +import praw +import pytest + +from bulkredditdownloader.exceptions import BulkDownloaderException +from bulkredditdownloader.oauth2 import OAuth2Authenticator, OAuth2TokenManager + + +@pytest.fixture() +def example_config() -> configparser.ConfigParser: + out = configparser.ConfigParser() + config_dict = {'DEFAULT': {'user_token': 'example'}} + out.read_dict(config_dict) + return out + + +@pytest.mark.online +@pytest.mark.parametrize('test_scopes', ( + ('history',), + ('history', 'creddits'), + ('account', 'flair'), + ('*',), +)) +def test_check_scopes(test_scopes: list[str]): + OAuth2Authenticator._check_scopes(test_scopes) + + +@pytest.mark.online +@pytest.mark.parametrize('test_scopes', ( + ('random',), + ('scope', 'another_scope'), +)) +def test_check_scopes_bad(test_scopes: list[str]): + with pytest.raises(BulkDownloaderException): + OAuth2Authenticator._check_scopes(test_scopes) + + +def test_token_manager_read(example_config: configparser.ConfigParser): + mock_authoriser = MagicMock() + mock_authoriser.refresh_token = None + test_manager = OAuth2TokenManager(example_config) + test_manager.pre_refresh_callback(mock_authoriser) + assert mock_authoriser.refresh_token == example_config.get('DEFAULT', 'user_token') + + +def test_token_manager_write(example_config: configparser.ConfigParser): + mock_authoriser = MagicMock() + mock_authoriser.refresh_token = 'changed_token' + test_manager = OAuth2TokenManager(example_config) + test_manager.post_refresh_callback(mock_authoriser) + assert example_config.get('DEFAULT', 'user_token') == 'changed_token'