#!/usr/bin/env python3 # coding=utf-8 import argparse from pathlib import Path from unittest.mock import MagicMock import praw import praw.models import pytest from bulkredditdownloader.download_filter import DownloadFilter from bulkredditdownloader.downloader import RedditDownloader, RedditTypes from bulkredditdownloader.errors import BulkDownloaderException from bulkredditdownloader.file_name_formatter import FileNameFormatter from bulkredditdownloader.site_authenticator import SiteAuthenticator @pytest.fixture() def args() -> argparse.Namespace: args = argparse.Namespace() args.directory = '.' args.verbose = 0 args.link = [] args.submitted = False args.upvoted = False args.subreddit = [] args.multireddit = [] args.user = None args.search = None args.sort = 'hot' args.limit = None args.time = 'all' args.skip = [] args.skip_domain = [] args.set_folder_scheme = '{SUBREDDIT}' args.set_file_scheme = '{REDDITOR}_{TITLE}_{POSTID}' args.no_dupes = False return args @pytest.fixture() def downloader_mock(args: argparse.Namespace): mock_downloader = MagicMock() mock_downloader.args = args return mock_downloader def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): downloader_mock.args.directory = tmp_path / 'test' RedditDownloader._determine_directories(downloader_mock) assert Path(tmp_path / 'test').exists() assert downloader_mock.logfile_directory == Path(tmp_path / 'test' / 'LOG_FILES') assert downloader_mock.logfile_directory.exists() @pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( ([], []), (['.test'], ['test.com']), )) def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): downloader_mock.args.skip = skip_extensions downloader_mock.args.skip_domain = skip_domains result = RedditDownloader._create_download_filter(downloader_mock) assert isinstance(result, DownloadFilter) assert result.excluded_domains == skip_domains assert result.excluded_extensions == skip_extensions @pytest.mark.parametrize(('test_time', 'expected'), ( ('all', 'all'), ('hour', 'hour'), ('day', 'day'), ('week', 'week'), ('random', 'all'), ('', 'all'), )) def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): downloader_mock.args.time = test_time result = RedditDownloader._create_time_filter(downloader_mock) assert isinstance(result, RedditTypes.TimeType) assert result.name.lower() == expected @pytest.mark.parametrize(('test_sort', 'expected'), ( ('', 'hot'), ('hot', 'hot'), ('controversial', 'controversial'), ('new', 'new'), )) def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): downloader_mock.args.sort = test_sort result = RedditDownloader._create_sort_filter(downloader_mock) assert isinstance(result, RedditTypes.SortType) assert result.name.lower() == expected @pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( ('{POSTID}', '{SUBREDDIT}'), ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), )) def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): downloader_mock.args.set_file_scheme = test_file_scheme downloader_mock.args.set_folder_scheme = test_folder_scheme result = RedditDownloader._create_file_name_formatter(downloader_mock) assert isinstance(result, FileNameFormatter) assert result.file_format_string == test_file_scheme assert result.directory_format_string == test_folder_scheme @pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( ('', ''), ('{POSTID}', ''), ('', '{SUBREDDIT}'), ('test', '{SUBREDDIT}'), ('{POSTID}', 'test'), )) def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): downloader_mock.args.set_file_scheme = test_file_scheme downloader_mock.args.set_folder_scheme = test_folder_scheme with pytest.raises(BulkDownloaderException): RedditDownloader._create_file_name_formatter(downloader_mock) @pytest.mark.skip def test_create_authenticator(downloader_mock: MagicMock): result = RedditDownloader._create_authenticator(downloader_mock) assert isinstance(result, SiteAuthenticator) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize('test_submission_ids', ( ('lvpf4l',), ('lvpf4l', 'lvqnsn'), ('lvpf4l', 'lvqnsn', 'lvl9kd'), )) def test_get_submissions_from_link( test_submission_ids: list[str], reddit_instance: praw.Reddit, downloader_mock: MagicMock): downloader_mock.args.link = test_submission_ids downloader_mock.reddit_instance = reddit_instance results = RedditDownloader._get_submissions_from_link(downloader_mock) assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res]) assert len(results[0]) == len(test_submission_ids) @pytest.mark.skip def test_load_config(downloader_mock: MagicMock): raise NotImplementedError @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_subreddits', 'limit'), ( (('Futurology',), 10), (('Futurology',), 20), (('Futurology', 'Python'), 10), (('Futurology',), 100), (('Futurology',), 0), )) def test_get_subreddit_normal( test_subreddits: list[str], limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock.reddit_instance = reddit_instance downloader_mock.args.subreddit = test_subreddits downloader_mock.args.limit = limit downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT results = RedditDownloader._get_subreddits(downloader_mock) results = [sub for res in results for sub in res] assert all([isinstance(res, praw.models.Submission) for res in results]) assert all([res.subreddit.display_name for res in results]) if limit is not None: assert len(results) == (limit * len(test_subreddits)) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), ( (('Python',), 'scraper', 10), (('Python',), '', 10), )) def test_get_subreddit_search( test_subreddits: list[str], search_term: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock.reddit_instance = reddit_instance downloader_mock.args.subreddit = test_subreddits downloader_mock.args.limit = limit downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.search = search_term results = RedditDownloader._get_subreddits(downloader_mock) results = [sub for res in results for sub in res] assert all([isinstance(res, praw.models.Submission) for res in results]) assert all([res.subreddit.display_name for res in results]) if limit is not None: assert len(results) == (limit * len(test_subreddits)) @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip def test_get_subreddits_search_bad(): raise NotImplementedError @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip def test_get_multireddits(): raise NotImplementedError @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip def test_get_user_submissions(): raise NotImplementedError @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip def test_get_user_upvoted(): raise NotImplementedError @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip def test_get_user_saved(): raise NotImplementedError @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip def test_download_submission(): raise NotImplementedError @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip def test_download_submission_file_exists(): raise NotImplementedError @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip def test_download_submission_hash_exists(): raise NotImplementedError