1
0
Fork 0
mirror of synced 2024-05-29 16:40:06 +12:00

Add option to search for files pre-emptively

This commit is contained in:
Serene-Arc 2021-03-12 13:24:25 +10:00 committed by Ali Parlakci
parent 285d422c0e
commit ae0269e13b
3 changed files with 52 additions and 32 deletions

View file

@ -74,6 +74,8 @@ class RedditDownloader:
self._resolve_user_name()
self.master_hash_list = []
if self.args.search_existing:
self.master_hash_list.extend(self.scan_existing_files(self.download_directory))
self.authenticator = self._create_authenticator()
logger.log(9, 'Created site authenticator')
@ -302,37 +304,39 @@ class RedditDownloader:
self._download_submission(submission)
def _download_submission(self, submission: praw.models.Submission):
if self.download_filter.check_url(submission.url):
if not self.download_filter.check_url(submission.url):
logger.debug(f'Download filter remove submission {submission.id} with URL {submission.url}')
return
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
except errors.NotADownloadableLinkError as e:
logger.error(f'Could not download submission {submission.name}: {e}')
return
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
except errors.NotADownloadableLinkError as e:
logger.error(f'Could not download submission {submission.name}: {e}')
return
content = downloader.find_resources(self.authenticator)
for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory):
if destination.exists():
logger.warning(f'File already exists: {destination}')
content = downloader.find_resources(self.authenticator)
for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory):
if destination.exists():
logger.warning(f'File already exists: {destination}')
else:
res.download()
if res.hash.hexdigest() in self.master_hash_list and self.args.no_dupes:
logger.warning(
f'Resource from "{res.url}" and hash "{res.hash.hexdigest()}" downloaded elsewhere')
else:
res.download()
if res.hash.hexdigest() in self.master_hash_list and self.args.no_dupes:
logger.warning(
f'Resource from "{res.url}" and hash "{res.hash.hexdigest()}" downloaded elsewhere')
else:
# TODO: consider making a hard link/symlink here
destination.parent.mkdir(parents=True, exist_ok=True)
with open(destination, 'wb') as file:
file.write(res.content)
logger.debug(f'Written file to {destination}')
self.master_hash_list.append(res.hash.hexdigest())
logger.debug(f'Hash added to master list: {res.hash.hexdigest()}')
logger.info(f'Downloaded submission {submission.name}')
# TODO: consider making a hard link/symlink here
destination.parent.mkdir(parents=True, exist_ok=True)
with open(destination, 'wb') as file:
file.write(res.content)
logger.debug(f'Written file to {destination}')
self.master_hash_list.append(res.hash.hexdigest())
logger.debug(f'Hash added to master list: {res.hash.hexdigest()}')
logger.info(f'Downloaded submission {submission.name}')
def scan_existing_files(self) -> list[str]:
@staticmethod
def scan_existing_files(directory: Path) -> list[str]:
files = []
for (dirpath, dirnames, filenames) in os.walk(self.download_directory):
for (dirpath, dirnames, filenames) in os.walk(directory):
files.extend([Path(dirpath, file) for file in filenames])
logger.info(f'Calculating hashes for {len(files)} files')
hash_list = []

View file

@ -388,8 +388,7 @@ def test_sanitise_subreddit_name(test_name: str, expected: str):
assert result == expected
def test_search_existing_files(downloader_mock: MagicMock):
downloader_mock.download_directory = Path('.').resolve().expanduser()
results = RedditDownloader.scan_existing_files(downloader_mock)
def test_search_existing_files():
results = RedditDownloader.scan_existing_files(Path('.'))
assert all([isinstance(result, str) for result in results])
assert len(results) >= 40

View file

@ -73,14 +73,15 @@ def test_cli_download_multireddit(test_args: list[str], tmp_path: Path):
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'helen_darten', '-m', 'xxyyzzqwertty', '-L', 10],
['--user', 'helen_darten', '-m', 'xxyyzzqwerty', '-L', 10],
))
def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Failed to get submissions for multireddit xxyyzzqwerty' in result.output
assert 'Failed to get submissions for multireddit' in result.output
assert 'received 404 HTTP response' in result.output
@pytest.mark.online
@ -117,3 +118,19 @@ def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'To use "me" as a user, an authenticated Reddit instance must be used' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--subreddit', 'python', '-L', 10, '--search-existing'],
))
def test_cli_download_search_existing(test_args: list[str], tmp_path: Path):
Path(tmp_path, 'test.txt').touch()
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Calculating hashes for' in result.output