diff --git a/bdfr/file_name_formatter.py b/bdfr/file_name_formatter.py index 9ee481d..0330336 100644 --- a/bdfr/file_name_formatter.py +++ b/bdfr/file_name_formatter.py @@ -28,12 +28,19 @@ class FileNameFormatter: "upvotes", ) - def __init__(self, file_format_string: str, directory_format_string: str, time_format_string: str): + def __init__( + self, + file_format_string: str, + directory_format_string: str, + time_format_string: str, + restriction_scheme: Optional[str] = None, + ): if not self.validate_string(file_format_string): raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string') self.file_format_string = file_format_string self.directory_format_string: list[str] = directory_format_string.split("/") self.time_format_string = time_format_string + self.restiction_scheme = restriction_scheme.lower().strip() if restriction_scheme else None def _format_name(self, submission: Union[Comment, Submission], format_string: str) -> str: if isinstance(submission, Submission): @@ -52,9 +59,11 @@ class FileNameFormatter: result = result.replace("/", "") - if platform.system() == "Windows": + if self.restiction_scheme is None: + if platform.system() == "Windows": + result = FileNameFormatter._format_for_windows(result) + elif self.restiction_scheme == "windows": result = FileNameFormatter._format_for_windows(result) - return result @staticmethod diff --git a/tests/test_file_name_formatter.py b/tests/test_file_name_formatter.py index 4964b3b..5f94e5f 100644 --- a/tests/test_file_name_formatter.py +++ b/tests/test_file_name_formatter.py @@ -33,6 +33,10 @@ def submission() -> MagicMock: return test +def check_valid_windows_path(test_string: str): + return test_string == FileNameFormatter._format_for_windows(test_string) + + def do_test_string_equality(result: Union[Path, str], expected: str) -> bool: if platform.system() == "Windows": expected = FileNameFormatter._format_for_windows(expected) @@ -91,6 +95,15 @@ def test_check_format_string_validity(test_string: str, expected: bool): @pytest.mark.online @pytest.mark.reddit +@pytest.mark.parametrize( + "restriction_scheme", + ( + "windows", + "linux", + "bla", + None, + ), +) @pytest.mark.parametrize( ("test_format_string", "expected"), ( @@ -102,10 +115,17 @@ def test_check_format_string_validity(test_string: str, expected: bool): ("{REDDITOR}_{TITLE}_{POSTID}", "Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l"), ), ) -def test_format_name_real(test_format_string: str, expected: str, reddit_submission: praw.models.Submission): - test_formatter = FileNameFormatter(test_format_string, "", "") +def test_format_name_real( + test_format_string: str, + expected: str, + reddit_submission: praw.models.Submission, + restriction_scheme: Optional[str], +): + test_formatter = FileNameFormatter(test_format_string, "", "", restriction_scheme) result = test_formatter._format_name(reddit_submission, test_format_string) assert do_test_string_equality(result, expected) + if restriction_scheme == "windows": + assert check_valid_windows_path(result) @pytest.mark.online