1
0
Fork 0
mirror of synced 2024-05-20 20:22:43 +12:00

Implement callbacks for downloading

This commit is contained in:
Serene-Arc 2021-07-27 13:39:49 +10:00
parent 44453b1707
commit 3cdae99490
23 changed files with 112 additions and 92 deletions

View file

@ -76,17 +76,17 @@ class Archiver(RedditConnector):
logger.info(f'Record for entry item {praw_item.id} written to disk')
def _write_entry_json(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', '.json')
resource = Resource(entry.source, '', lambda: None, '.json')
content = json.dumps(entry.compile())
self._write_content_to_disk(resource, content)
def _write_entry_xml(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', '.xml')
resource = Resource(entry.source, '', lambda: None, '.xml')
content = dict2xml.dict2xml(entry.compile(), wrap='root')
self._write_content_to_disk(resource, content)
def _write_entry_yaml(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', '.yaml')
resource = Resource(entry.source, '', lambda: None, '.yaml')
content = yaml.dump(entry.compile())
self._write_content_to_disk(resource, content)

View file

@ -82,7 +82,7 @@ class RedditDownloader(RedditConnector):
logger.debug(f'Download filter removed {submission.id} file with URL {submission.url}')
continue
try:
res.download(self.args.max_wait_time)
res.download()
except errors.BulkDownloaderException as e:
logger.error(f'Failed to download resource {res.url} in submission {submission.id} '
f'with downloader {downloader_class.__name__}: {e}')

View file

@ -6,7 +6,7 @@ import logging
import re
import time
import urllib.parse
from typing import Optional
from typing import Callable, Optional
import _hashlib
import requests
@ -18,40 +18,44 @@ logger = logging.getLogger(__name__)
class Resource:
def __init__(self, source_submission: Submission, url: str, extension: str = None):
def __init__(self, source_submission: Submission, url: str, download_function: Callable, extension: str = None):
self.source_submission = source_submission
self.content: Optional[bytes] = None
self.url = url
self.hash: Optional[_hashlib.HASH] = None
self.extension = extension
self.download_function = download_function
if not self.extension:
self.extension = self._determine_extension()
@staticmethod
def retry_download(url: str, max_wait_time: int, current_wait_time: int = 60) -> Optional[bytes]:
try:
response = requests.get(url)
if re.match(r'^2\d{2}', str(response.status_code)) and response.content:
return response.content
elif response.status_code in (408, 429):
raise requests.exceptions.ConnectionError(f'Response code {response.status_code}')
else:
raise BulkDownloaderException(
f'Unrecoverable error requesting resource: HTTP Code {response.status_code}')
except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e:
logger.warning(f'Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}')
time.sleep(current_wait_time)
if current_wait_time < max_wait_time:
current_wait_time += 60
return Resource.retry_download(url, max_wait_time, current_wait_time)
else:
logger.error(f'Max wait time exceeded for resource at url {url}')
raise
def retry_download(url: str, max_wait_time: int) -> Callable:
def http_download() -> Optional[bytes]:
current_wait_time = 60
while True:
try:
response = requests.get(url)
if re.match(r'^2\d{2}', str(response.status_code)) and response.content:
return response.content
elif response.status_code in (408, 429):
raise requests.exceptions.ConnectionError(f'Response code {response.status_code}')
else:
raise BulkDownloaderException(
f'Unrecoverable error requesting resource: HTTP Code {response.status_code}')
except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e:
logger.warning(f'Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}')
time.sleep(current_wait_time)
if current_wait_time < max_wait_time:
current_wait_time += 60
else:
logger.error(f'Max wait time exceeded for resource at url {url}')
raise
return http_download
def download(self, max_wait_time: int):
def download(self):
if not self.content:
try:
content = self.retry_download(self.url, max_wait_time)
content = self.download_function()
except requests.exceptions.ConnectionError as e:
raise BulkDownloaderException(f'Could not download resource: {e}')
except BulkDownloaderException:

View file

@ -14,4 +14,4 @@ class Direct(BaseDownloader):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
return [Resource(self.post, self.post.url)]
return [Resource(self.post, self.post.url, Resource.retry_download(self.post.url, 300))]

View file

@ -29,7 +29,7 @@ class Erome(BaseDownloader):
for link in links:
if not re.match(r'https?://.*', link):
link = 'https://' + link
out.append(Resource(self.post, link))
out.append(Resource(self.post, link, Resource.retry_download(link, 300)))
return out
@staticmethod

View file

@ -4,7 +4,6 @@
import logging
from typing import Optional
import youtube_dl
from praw.models import Submission
from bdfr.resource import Resource
@ -20,21 +19,18 @@ class YoutubeDlFallback(BaseFallbackDownloader, Youtube):
super(YoutubeDlFallback, self).__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
out = super()._download_video({})
out = Resource(
self.post,
self.post.url,
super()._download_video({}),
super().get_video_attributes(self.post.url)['ext'],
)
return [out]
@staticmethod
def can_handle_link(url: str) -> bool:
yt_logger = logging.getLogger('youtube-dl')
yt_logger.setLevel(logging.CRITICAL)
with youtube_dl.YoutubeDL({
'logger': yt_logger,
}) as ydl:
try:
result = ydl.extract_info(url, download=False)
if result:
return True
except Exception as e:
logger.exception(e)
return False
return False
attributes = YoutubeDlFallback.get_video_attributes(url)
if attributes:
return True
else:
return False

View file

@ -31,7 +31,7 @@ class Gallery(BaseDownloader):
if not image_urls:
raise SiteDownloaderError('No images found in Reddit gallery')
return [Resource(self.post, url) for url in image_urls]
return [Resource(self.post, url, Resource.retry_download(url, 300)) for url in image_urls]
@ staticmethod
def _get_links(id_dict: list[dict]) -> list[str]:

View file

@ -33,7 +33,7 @@ class Imgur(BaseDownloader):
def _compute_image_url(self, image: dict) -> Resource:
image_url = 'https://i.imgur.com/' + image['hash'] + self._validate_extension(image['ext'])
return Resource(self.post, image_url)
return Resource(self.post, image_url, Resource.retry_download(image_url, 300))
@staticmethod
def _get_data(link: str) -> dict:

View file

@ -22,5 +22,10 @@ class PornHub(Youtube):
'format': 'best',
'nooverwrites': True,
}
out = self._download_video(ytdl_options)
out = Resource(
self.post,
self.post.url,
super()._download_video(ytdl_options),
super().get_video_attributes(self.post.url)['ext'],
)
return [out]

View file

@ -18,7 +18,7 @@ class Redgifs(BaseDownloader):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
media_url = self._get_link(self.post.url)
return [Resource(self.post, media_url, '.mp4')]
return [Resource(self.post, media_url, Resource.retry_download(media_url, 300), '.mp4')]
@staticmethod
def _get_link(url: str) -> str:

View file

@ -17,7 +17,7 @@ class SelfPost(BaseDownloader):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
out = Resource(self.post, self.post.url, '.txt')
out = Resource(self.post, self.post.url, lambda: None, '.txt')
out.content = self.export_to_string().encode('utf-8')
out.create_hash()
return [out]

View file

@ -3,12 +3,12 @@
import logging
import tempfile
from pathlib import Path
from typing import Optional
from typing import Callable, Optional
import youtube_dl
from praw.models import Submission
from bdfr.exceptions import (NotADownloadableLinkError, SiteDownloaderError)
from bdfr.exceptions import NotADownloadableLinkError, SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader
@ -26,32 +26,47 @@ class Youtube(BaseDownloader):
'playlistend': 1,
'nooverwrites': True,
}
out = self._download_video(ytdl_options)
return [out]
download_function = self._download_video(ytdl_options)
try:
extension = self.get_video_attributes(self.post.url)['ext']
except KeyError:
raise NotADownloadableLinkError(f'Youtube-DL cannot download URL {self.post.url}')
res = Resource(self.post, self.post.url, download_function, extension)
return [res]
def _download_video(self, ytdl_options: dict) -> Resource:
def _download_video(self, ytdl_options: dict) -> Callable:
yt_logger = logging.getLogger('youtube-dl')
yt_logger.setLevel(logging.CRITICAL)
ytdl_options['quiet'] = True
ytdl_options['logger'] = yt_logger
with tempfile.TemporaryDirectory() as temp_dir:
download_path = Path(temp_dir).resolve()
ytdl_options['outtmpl'] = str(download_path) + '/' + 'test.%(ext)s'
try:
with youtube_dl.YoutubeDL(ytdl_options) as ydl:
ydl.download([self.post.url])
except youtube_dl.DownloadError as e:
raise SiteDownloaderError(f'Youtube download failed: {e}')
downloaded_files = list(download_path.iterdir())
if len(downloaded_files) > 0:
downloaded_file = downloaded_files[0]
else:
raise NotADownloadableLinkError(f"No media exists in the URL {self.post.url}")
extension = downloaded_file.suffix
with open(downloaded_file, 'rb') as file:
content = file.read()
out = Resource(self.post, self.post.url, extension)
out.content = content
out.create_hash()
return out
def download() -> bytes:
with tempfile.TemporaryDirectory() as temp_dir:
download_path = Path(temp_dir).resolve()
ytdl_options['outtmpl'] = str(download_path) + '/' + 'test.%(ext)s'
try:
with youtube_dl.YoutubeDL(ytdl_options) as ydl:
ydl.download([self.post.url])
except youtube_dl.DownloadError as e:
raise SiteDownloaderError(f'Youtube download failed: {e}')
downloaded_files = list(download_path.iterdir())
if len(downloaded_files) > 0:
downloaded_file = downloaded_files[0]
else:
raise NotADownloadableLinkError(f"No media exists in the URL {self.post.url}")
with open(downloaded_file, 'rb') as file:
content = file.read()
return content
return download
@staticmethod
def get_video_attributes(url: str) -> dict:
yt_logger = logging.getLogger('youtube-dl')
yt_logger.setLevel(logging.CRITICAL)
with youtube_dl.YoutubeDL({'logger': yt_logger, }) as ydl:
try:
result = ydl.extract_info(url, download=False)
return result
except Exception as e:
logger.exception(e)

View file

@ -21,5 +21,5 @@ def test_download_resource(test_url: str, expected_hash: str):
resources = test_site.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
resources[0].download()
assert resources[0].hash.hexdigest() == expected_hash

View file

@ -49,6 +49,6 @@ def test_download_resource(test_url: str, expected_hashes: tuple[str]):
mock_submission.url = test_url
test_site = Erome(mock_submission)
resources = test_site.find_resources()
[res.download(120) for res in resources]
[res.download() for res in resources]
resource_hashes = [res.hash.hexdigest() for res in resources]
assert len(resource_hashes) == len(expected_hashes)

View file

@ -62,7 +62,7 @@ def test_gallery_download(test_submission_id: str, expected_hashes: set[str], re
test_submission = reddit_instance.submission(id=test_submission_id)
gallery = Gallery(test_submission)
results = gallery.find_resources()
[res.download(120) for res in results]
[res.download() for res in results]
hashes = [res.hash.hexdigest() for res in results]
assert set(hashes) == expected_hashes

View file

@ -31,5 +31,5 @@ def test_download_resource(test_url: str, expected_hash: str):
resources = test_site.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
resources[0].download()
assert resources[0].hash.hexdigest() == expected_hash

View file

@ -149,6 +149,6 @@ def test_find_resources(test_url: str, expected_hashes: list[str]):
downloader = Imgur(mock_download)
results = downloader.find_resources()
assert all([isinstance(res, Resource) for res in results])
[res.download(120) for res in results]
[res.download() for res in results]
hashes = set([res.hash.hexdigest() for res in results])
assert hashes == set(expected_hashes)

View file

@ -21,5 +21,5 @@ def test_find_resources_good(test_url: str, expected_hash: str):
resources = downloader.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
resources[0].download()
assert resources[0].hash.hexdigest() == expected_hash

View file

@ -37,5 +37,5 @@ def test_download_resource(test_url: str, expected_hash: str):
resources = test_site.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
resources[0].download()
assert resources[0].hash.hexdigest() == expected_hash

View file

@ -23,7 +23,7 @@ def test_find_resources_good(test_url: str, expected_hash: str):
resources = downloader.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
resources[0].download()
assert resources[0].hash.hexdigest() == expected_hash

View file

@ -46,7 +46,7 @@ def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadF
('http://reddit.com/test.gif', False),
))
def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter):
test_resource = Resource(MagicMock(), test_url)
test_resource = Resource(MagicMock(), test_url, lambda: None)
result = download_filter.check_resource(test_resource)
assert result == expected
@ -59,6 +59,6 @@ def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilt
))
def test_filter_empty_filter(test_url: str):
download_filter = DownloadFilter()
test_resource = Resource(MagicMock(), test_url)
test_resource = Resource(MagicMock(), test_url, lambda: None)
result = download_filter.check_resource(test_resource)
assert result is True

View file

@ -119,7 +119,7 @@ def test_format_full(
format_string_file: str,
expected: str,
reddit_submission: praw.models.Submission):
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None)
test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO')
result = test_formatter.format_path(test_resource, Path('test'))
assert do_test_path_equality(result, expected)
@ -136,7 +136,7 @@ def test_format_full_conform(
format_string_directory: str,
format_string_file: str,
reddit_submission: praw.models.Submission):
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None)
test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO')
test_formatter.format_path(test_resource, Path('test'))
@ -156,7 +156,7 @@ def test_format_full_with_index_suffix(
expected: str,
reddit_submission: praw.models.Submission,
):
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None)
test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO')
result = test_formatter.format_path(test_resource, Path('test'), index)
assert do_test_path_equality(result, expected)
@ -216,7 +216,7 @@ def test_shorten_filenames(submission: MagicMock, tmp_path: Path):
submission.author.name = 'test'
submission.subreddit.display_name = 'test'
submission.id = 'BBBBBB'
test_resource = Resource(submission, 'www.example.com/empty', '.jpeg')
test_resource = Resource(submission, 'www.example.com/empty', lambda: None, '.jpeg')
test_formatter = FileNameFormatter('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}', 'ISO')
result = test_formatter.format_path(test_resource, tmp_path)
result.parent.mkdir(parents=True)
@ -296,7 +296,7 @@ def test_format_archive_entry_comment(
):
test_comment = reddit_instance.comment(id=test_comment_id)
test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme, 'ISO')
test_entry = Resource(test_comment, '', '.json')
test_entry = Resource(test_comment, '', lambda: None, '.json')
result = test_formatter.format_path(test_entry, tmp_path)
assert do_test_string_equality(result, expected_name)

View file

@ -21,7 +21,7 @@ from bdfr.resource import Resource
('https://www.test.com/test/test2/example.png?random=test#thing', '.png'),
))
def test_resource_get_extension(test_url: str, expected: str):
test_resource = Resource(MagicMock(), test_url)
test_resource = Resource(MagicMock(), test_url, lambda: None)
result = test_resource._determine_extension()
assert result == expected
@ -31,6 +31,6 @@ def test_resource_get_extension(test_url: str, expected: str):
('https://www.iana.org/_img/2013.1/iana-logo-header.svg', '426b3ac01d3584c820f3b7f5985d6623'),
))
def test_download_online_resource(test_url: str, expected_hash: str):
test_resource = Resource(MagicMock(), test_url)
test_resource.download(120)
test_resource = Resource(MagicMock(), test_url, Resource.retry_download(test_url, 60))
test_resource.download()
assert test_resource.hash.hexdigest() == expected_hash