From 0007912ad57df399a7120fa34d37deccbd883c0e Mon Sep 17 00:00:00 2001 From: Serene-Arc Date: Mon, 22 Mar 2021 14:21:56 +1000 Subject: [PATCH] Scrub windows paths for invalid characters --- bulkredditdownloader/file_name_formatter.py | 15 ++++++++++++++- .../tests/test_file_name_formatter.py | 13 +++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/bulkredditdownloader/file_name_formatter.py b/bulkredditdownloader/file_name_formatter.py index f36284e..dbb8d08 100644 --- a/bulkredditdownloader/file_name_formatter.py +++ b/bulkredditdownloader/file_name_formatter.py @@ -2,6 +2,7 @@ # coding=utf-8 import logging +import platform import re from pathlib import Path from typing import Optional @@ -41,6 +42,10 @@ class FileNameFormatter: logger.log(9, f'Found key string {key} in name') result = result.replace('/', '') + + if platform.system() == 'Windows': + result = FileNameFormatter._format_for_windows(result) + return result def format_path(self, resource: Resource, destination_directory: Path, index: Optional[int] = None) -> Path: @@ -51,6 +56,7 @@ class FileNameFormatter: ending = index + resource.extension file_name = str(self._format_name(resource.source_submission, self.file_format_string)) file_name = self._limit_file_name_length(file_name, ending) + try: file_path = Path(subfolder, file_name) except TypeError: @@ -76,8 +82,15 @@ class FileNameFormatter: out.append((self.format_path(res, destination_directory, i), res)) return out - @ staticmethod + @staticmethod def validate_string(test_string: str) -> bool: if not test_string: return False return any([f'{{{key}}}' in test_string.lower() for key in FileNameFormatter.key_terms]) + + @staticmethod + def _format_for_windows(input_string: str) -> str: + invalid_characters = r'<>:"\/|?*' + for char in invalid_characters: + input_string = input_string.replace(char, '') + return input_string diff --git a/bulkredditdownloader/tests/test_file_name_formatter.py b/bulkredditdownloader/tests/test_file_name_formatter.py index b376e9d..55909df 100644 --- a/bulkredditdownloader/tests/test_file_name_formatter.py +++ b/bulkredditdownloader/tests/test_file_name_formatter.py @@ -169,3 +169,16 @@ def test_shorten_filenames(reddit_instance: praw.Reddit, tmp_path: Path): result = test_formatter.format_path(test_resource, tmp_path) result.parent.mkdir(parents=True) result.touch() + + +@pytest.mark.parametrize(('test_string', 'expected'), ( + ('test', 'test'), + ('test.png', 'test.png'), + ('test*', 'test'), + ('test**', 'test'), + ('test?*', 'test'), + ('test_???.png', 'test_.png'), +)) +def test_format_file_name_for_windows(test_string: str, expected: str): + result = FileNameFormatter._format_for_windows(test_string) + assert result == expected