Add OAuth2 class
This commit is contained in:
parent
aeb9afdc66
commit
5a2e045c77
88
bulkredditdownloader/oauth2.py
Normal file
88
bulkredditdownloader/oauth2.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import configparser
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
|
||||
import praw
|
||||
import requests
|
||||
|
||||
from bulkredditdownloader.exceptions import BulkDownloaderException, RedditAuthenticationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuth2Authenticator:
|
||||
|
||||
def __init__(self, wanted_scopes: list[str]):
|
||||
self._check_scopes(wanted_scopes)
|
||||
self.scopes = wanted_scopes
|
||||
|
||||
@staticmethod
|
||||
def _check_scopes(wanted_scopes: list[str]):
|
||||
response = requests.get('https://www.reddit.com/api/v1/scopes.json',
|
||||
headers={'User-Agent': 'fetch-scopes test'})
|
||||
known_scopes = [scope for scope, data in response.json().items()]
|
||||
known_scopes.append('*')
|
||||
for scope in wanted_scopes:
|
||||
if scope not in known_scopes:
|
||||
raise BulkDownloaderException(f'Scope {scope} is not known to reddit')
|
||||
|
||||
def retrieve_new_token(self) -> str:
|
||||
reddit = praw.Reddit(redirect_uri='http://localhost:8080', user_agent='obtain_refresh_token for BDFR')
|
||||
state = str(random.randint(0, 65000))
|
||||
url = reddit.auth.url(self.scopes, state, 'permanent')
|
||||
logger.warning('Authentication action required before the program can proceed')
|
||||
logger.warning(f'Authenticate at {url}')
|
||||
|
||||
client = self.receive_connection()
|
||||
data = client.recv(1024).decode('utf-8')
|
||||
param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&')
|
||||
params = {key: value for (key, value) in [token.split('=') for token in param_tokens]}
|
||||
|
||||
if state != params['state']:
|
||||
self.send_message(client)
|
||||
raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
|
||||
elif 'error' in params:
|
||||
self.send_message(client)
|
||||
raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
|
||||
|
||||
refresh_token = reddit.auth.authorize(params["code"])
|
||||
return refresh_token
|
||||
|
||||
@staticmethod
|
||||
def receive_connection() -> socket.socket:
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(('localhost', 8080))
|
||||
logger.debug('Server listening on localhost:8080')
|
||||
|
||||
server.listen(1)
|
||||
client = server.accept()[0]
|
||||
server.close()
|
||||
logger.debug('Server closed')
|
||||
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def send_message(client: socket.socket):
|
||||
client.send('HTTP/1.1 200 OK'.encode('utf-8'))
|
||||
client.close()
|
||||
|
||||
|
||||
class OAuth2TokenManager(praw.reddit.BaseTokenManager):
|
||||
def __init__(self, config: configparser.ConfigParser):
|
||||
super(OAuth2TokenManager, self).__init__()
|
||||
self.config = config
|
||||
|
||||
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
|
||||
if authorizer.refresh_token is None:
|
||||
if self.config.has_option('DEFAULT', 'user_token'):
|
||||
authorizer.refresh_token = self.config.get('DEFAULT', 'user_token')
|
||||
else:
|
||||
raise RedditAuthenticationError('No auth token loaded in configuration')
|
||||
|
||||
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
|
||||
self.config.set('DEFAULT', 'user_token', authorizer.refresh_token)
|
56
bulkredditdownloader/tests/test_oauth2.py
Normal file
56
bulkredditdownloader/tests/test_oauth2.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import configparser
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bulkredditdownloader.exceptions import BulkDownloaderException
|
||||
from bulkredditdownloader.oauth2 import OAuth2Authenticator, OAuth2TokenManager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def example_config() -> configparser.ConfigParser:
|
||||
out = configparser.ConfigParser()
|
||||
config_dict = {'DEFAULT': {'user_token': 'example'}}
|
||||
out.read_dict(config_dict)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize('test_scopes', (
|
||||
('history',),
|
||||
('history', 'creddits'),
|
||||
('account', 'flair'),
|
||||
('*',),
|
||||
))
|
||||
def test_check_scopes(test_scopes: list[str]):
|
||||
OAuth2Authenticator._check_scopes(test_scopes)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize('test_scopes', (
|
||||
('random',),
|
||||
('scope', 'another_scope'),
|
||||
))
|
||||
def test_check_scopes_bad(test_scopes: list[str]):
|
||||
with pytest.raises(BulkDownloaderException):
|
||||
OAuth2Authenticator._check_scopes(test_scopes)
|
||||
|
||||
|
||||
def test_token_manager_read(example_config: configparser.ConfigParser):
|
||||
mock_authoriser = MagicMock()
|
||||
mock_authoriser.refresh_token = None
|
||||
test_manager = OAuth2TokenManager(example_config)
|
||||
test_manager.pre_refresh_callback(mock_authoriser)
|
||||
assert mock_authoriser.refresh_token == example_config.get('DEFAULT', 'user_token')
|
||||
|
||||
|
||||
def test_token_manager_write(example_config: configparser.ConfigParser):
|
||||
mock_authoriser = MagicMock()
|
||||
mock_authoriser.refresh_token = 'changed_token'
|
||||
test_manager = OAuth2TokenManager(example_config)
|
||||
test_manager.post_refresh_callback(mock_authoriser)
|
||||
assert example_config.get('DEFAULT', 'user_token') == 'changed_token'
|
Loading…
Reference in a new issue