diff --git a/bdfr/archive_entry/base_archive_entry.py b/bdfr/archive_entry/base_archive_entry.py index a33381e..57e36f8 100644 --- a/bdfr/archive_entry/base_archive_entry.py +++ b/bdfr/archive_entry/base_archive_entry.py @@ -2,12 +2,13 @@ # coding=utf-8 from abc import ABC, abstractmethod +from typing import Union from praw.models import Comment, Submission class BaseArchiveEntry(ABC): - def __init__(self, source: (Comment, Submission)): + def __init__(self, source: Union[Comment, Submission]): self.source = source self.post_details: dict = {} diff --git a/bdfr/archiver.py b/bdfr/archiver.py index 4bd24f5..809af96 100644 --- a/bdfr/archiver.py +++ b/bdfr/archiver.py @@ -4,7 +4,7 @@ import json import logging import re -from typing import Iterator +from typing import Iterator, Union import dict2xml import praw.models @@ -65,7 +65,7 @@ class Archiver(RedditConnector): return results @staticmethod - def _pull_lever_entry_factory(praw_item: (praw.models.Submission, praw.models.Comment)) -> BaseArchiveEntry: + def _pull_lever_entry_factory(praw_item: Union[praw.models.Submission, praw.models.Comment]) -> BaseArchiveEntry: if isinstance(praw_item, praw.models.Submission): return SubmissionArchiveEntry(praw_item) elif isinstance(praw_item, praw.models.Comment): @@ -73,7 +73,7 @@ class Archiver(RedditConnector): else: raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}') - def write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)): + def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]): if self.args.comment_context and isinstance(praw_item, praw.models.Comment): logger.debug(f'Converting comment {praw_item.id} to submission {praw_item.submission.id}') praw_item = praw_item.submission diff --git a/bdfr/file_name_formatter.py b/bdfr/file_name_formatter.py index 70bd527..4a039c9 100644 --- a/bdfr/file_name_formatter.py +++ b/bdfr/file_name_formatter.py @@ -6,7 +6,7 @@ import platform import re import subprocess from pathlib import Path -from typing import Optional +from typing import Optional, Union from praw.models import Comment, Submission @@ -34,7 +34,7 @@ class FileNameFormatter: self.directory_format_string: list[str] = directory_format_string.split('/') self.time_format_string = time_format_string - def _format_name(self, submission: (Comment, Submission), format_string: str) -> str: + def _format_name(self, submission: Union[Comment, Submission], format_string: str) -> str: if isinstance(submission, Submission): attributes = self._generate_name_dict_from_submission(submission) elif isinstance(submission, Comment): diff --git a/tests/test_file_name_formatter.py b/tests/test_file_name_formatter.py index e7f1ebe..0492536 100644 --- a/tests/test_file_name_formatter.py +++ b/tests/test_file_name_formatter.py @@ -6,7 +6,7 @@ import sys import unittest.mock from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Optional, Type, Union from unittest.mock import MagicMock import praw.models @@ -33,7 +33,7 @@ def submission() -> MagicMock: return test -def do_test_string_equality(result: [Path, str], expected: str) -> bool: +def do_test_string_equality(result: Union[Path, str], expected: str) -> bool: if platform.system() == 'Windows': expected = FileNameFormatter._format_for_windows(expected) return str(result).endswith(expected) @@ -411,7 +411,7 @@ def test_windows_max_path(tmp_path: Path): )) def test_name_submission( test_reddit_id: str, - test_downloader: type(BaseDownloader), + test_downloader: Type[BaseDownloader], expected_names: set[str], reddit_instance: praw.reddit.Reddit, ):