1
0
Fork 0
mirror of synced 2024-06-03 02:44:39 +12:00

Use proper user agent string

This commit is contained in:
OMEGARAZER 2023-03-04 11:45:20 -05:00
parent a16622e11e
commit 1705884dce
No known key found for this signature in database
GPG key ID: D89925310D306E35
3 changed files with 12 additions and 9 deletions

View file

@ -5,9 +5,9 @@ import importlib.resources
import itertools
import logging
import logging.handlers
import platform
import re
import shutil
import socket
from abc import ABCMeta, abstractmethod
from collections.abc import Callable, Iterable, Iterator
from datetime import datetime
@ -21,6 +21,7 @@ import praw.exceptions
import praw.models
import prawcore
from bdfr import __version__
from bdfr import exceptions as errors
from bdfr.configuration import Configuration
from bdfr.download_filter import DownloadFilter
@ -75,6 +76,7 @@ class RedditConnector(metaclass=ABCMeta):
self.file_name_formatter = self.create_file_name_formatter()
logger.log(9, "Create file name formatter")
self.user_agent = praw.const.USER_AGENT_FORMAT.format(":".join([platform.uname()[0], __package__, __version__]))
self.create_reddit_instance()
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
@ -138,6 +140,7 @@ class RedditConnector(metaclass=ABCMeta):
scopes,
self.cfg_parser.get("DEFAULT", "client_id"),
self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=self.user_agent,
)
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser["DEFAULT"]["user_token"] = token
@ -149,7 +152,7 @@ class RedditConnector(metaclass=ABCMeta):
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get("DEFAULT", "client_id"),
client_secret=self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(),
user_agent=self.user_agent,
token_manager=token_manager,
)
else:
@ -158,7 +161,7 @@ class RedditConnector(metaclass=ABCMeta):
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get("DEFAULT", "client_id"),
client_secret=self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(),
user_agent=self.user_agent,
)
def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:

View file

@ -16,18 +16,18 @@ logger = logging.getLogger(__name__)
class OAuth2Authenticator:
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
self._check_scopes(wanted_scopes)
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str, user_agent: str):
self._check_scopes(wanted_scopes, user_agent)
self.scopes = wanted_scopes
self.client_id = client_id
self.client_secret = client_secret
@staticmethod
def _check_scopes(wanted_scopes: set[str]):
def _check_scopes(wanted_scopes: set[str], user_agent: str):
try:
response = requests.get(
"https://www.reddit.com/api/v1/scopes.json",
headers={"User-Agent": "fetch-scopes test"},
headers={"User-Agent": user_agent},
timeout=10,
)
except TimeoutError:

View file

@ -33,7 +33,7 @@ def example_config() -> configparser.ConfigParser:
),
)
def test_check_scopes(test_scopes: set[str]):
OAuth2Authenticator._check_scopes(test_scopes)
OAuth2Authenticator._check_scopes(test_scopes, "fetch-scopes test")
@pytest.mark.parametrize(
@ -67,7 +67,7 @@ def test_split_scopes(test_scopes: str, expected: set[str]):
)
def test_check_scopes_bad(test_scopes: set[str]):
with pytest.raises(BulkDownloaderException):
OAuth2Authenticator._check_scopes(test_scopes)
OAuth2Authenticator._check_scopes(test_scopes, "fetch-scopes test")
def test_token_manager_read(example_config: configparser.ConfigParser):