1
0
Fork 0
mirror of synced 2024-05-19 19:52:41 +12:00
bulk-downloader-for-reddit/bdfr/oauth2.py

109 lines
4.1 KiB
Python
Raw Normal View History

2021-03-08 14:37:01 +13:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
2021-03-08 14:37:01 +13:00
import configparser
import logging
import random
2021-03-08 15:34:03 +13:00
import re
2021-03-08 14:37:01 +13:00
import socket
from pathlib import Path
2021-03-08 14:37:01 +13:00
import praw
import requests
2021-04-12 19:58:32 +12:00
from bdfr.exceptions import BulkDownloaderException, RedditAuthenticationError
2021-03-08 14:37:01 +13:00
logger = logging.getLogger(__name__)
class OAuth2Authenticator:
2021-03-08 15:46:32 +13:00
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
2021-03-08 14:37:01 +13:00
self._check_scopes(wanted_scopes)
self.scopes = wanted_scopes
2021-03-08 15:46:32 +13:00
self.client_id = client_id
self.client_secret = client_secret
2021-03-08 14:37:01 +13:00
@staticmethod
2021-03-08 15:34:03 +13:00
def _check_scopes(wanted_scopes: set[str]):
2022-12-03 18:11:17 +13:00
response = requests.get(
"https://www.reddit.com/api/v1/scopes.json", headers={"User-Agent": "fetch-scopes test"}
)
2021-03-08 14:37:01 +13:00
known_scopes = [scope for scope, data in response.json().items()]
2022-12-03 18:11:17 +13:00
known_scopes.append("*")
2021-03-08 14:37:01 +13:00
for scope in wanted_scopes:
if scope not in known_scopes:
2022-12-03 18:11:17 +13:00
raise BulkDownloaderException(f"Scope {scope} is not known to reddit")
2021-03-08 14:37:01 +13:00
2021-03-08 15:34:03 +13:00
@staticmethod
def split_scopes(scopes: str) -> set[str]:
2022-12-03 18:11:17 +13:00
scopes = re.split(r"[,: ]+", scopes)
2021-03-08 15:34:03 +13:00
return set(scopes)
2021-03-08 14:37:01 +13:00
def retrieve_new_token(self) -> str:
2021-03-08 15:46:32 +13:00
reddit = praw.Reddit(
2022-12-03 18:11:17 +13:00
redirect_uri="http://localhost:7634",
user_agent="obtain_refresh_token for BDFR",
2021-03-08 15:46:32 +13:00
client_id=self.client_id,
2022-12-03 18:11:17 +13:00
client_secret=self.client_secret,
)
2021-03-08 14:37:01 +13:00
state = str(random.randint(0, 65000))
2022-12-03 18:11:17 +13:00
url = reddit.auth.url(self.scopes, state, "permanent")
logger.warning("Authentication action required before the program can proceed")
logger.warning(f"Authenticate at {url}")
2021-03-08 14:37:01 +13:00
client = self.receive_connection()
2022-12-03 18:11:17 +13:00
data = client.recv(1024).decode("utf-8")
param_tokens = data.split(" ", 2)[1].split("?", 1)[1].split("&")
params = {key: value for (key, value) in [token.split("=") for token in param_tokens]}
2021-03-08 14:37:01 +13:00
2022-12-03 18:11:17 +13:00
if state != params["state"]:
2021-03-08 14:37:01 +13:00
self.send_message(client)
raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
2022-12-03 18:11:17 +13:00
elif "error" in params:
2021-03-08 14:37:01 +13:00
self.send_message(client)
raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
2021-03-28 09:17:37 +13:00
self.send_message(client, "<script>alert('You can go back to terminal window now.')</script>")
2021-03-08 14:37:01 +13:00
refresh_token = reddit.auth.authorize(params["code"])
return refresh_token
@staticmethod
def receive_connection() -> socket.socket:
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
2022-12-03 18:11:17 +13:00
server.bind(("0.0.0.0", 7634))
logger.log(9, "Server listening on 0.0.0.0:7634")
2021-03-08 14:37:01 +13:00
server.listen(1)
client = server.accept()[0]
server.close()
2022-12-03 18:11:17 +13:00
logger.log(9, "Server closed")
2021-03-08 14:37:01 +13:00
return client
@staticmethod
2022-12-03 18:11:17 +13:00
def send_message(client: socket.socket, message: str = ""):
client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode("utf-8"))
2021-03-08 14:37:01 +13:00
client.close()
class OAuth2TokenManager(praw.reddit.BaseTokenManager):
def __init__(self, config: configparser.ConfigParser, config_location: Path):
2021-03-08 14:37:01 +13:00
super(OAuth2TokenManager, self).__init__()
self.config = config
self.config_location = config_location
2021-03-08 14:37:01 +13:00
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
if authorizer.refresh_token is None:
2022-12-03 18:11:17 +13:00
if self.config.has_option("DEFAULT", "user_token"):
authorizer.refresh_token = self.config.get("DEFAULT", "user_token")
logger.log(9, "Loaded OAuth2 token for authoriser")
2021-03-08 14:37:01 +13:00
else:
2022-12-03 18:11:17 +13:00
raise RedditAuthenticationError("No auth token loaded in configuration")
2021-03-08 14:37:01 +13:00
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
2022-12-03 18:11:17 +13:00
self.config.set("DEFAULT", "user_token", authorizer.refresh_token)
2023-01-26 16:23:59 +13:00
with Path(self.config_location).open(mode="w") as file:
self.config.write(file, True)
2022-12-03 18:11:17 +13:00
logger.log(9, f"Written OAuth2 token from authoriser to {self.config_location}")