#!/usr/bin/env python3 # coding=utf-8 from datetime import datetime, timedelta from pathlib import Path from typing import Iterator from unittest.mock import MagicMock import praw import praw.models import pytest from bdfr.configuration import Configuration from bdfr.connector import RedditConnector, RedditTypes from bdfr.download_filter import DownloadFilter from bdfr.exceptions import BulkDownloaderException from bdfr.file_name_formatter import FileNameFormatter from bdfr.site_authenticator import SiteAuthenticator @pytest.fixture() def args() -> Configuration: args = Configuration() args.time_format = 'ISO' return args @pytest.fixture() def downloader_mock(args: Configuration): downloader_mock = MagicMock() downloader_mock.args = args downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name downloader_mock.create_filtered_listing_generator = lambda x: RedditConnector.create_filtered_listing_generator( downloader_mock, x) downloader_mock.split_args_input = RedditConnector.split_args_input downloader_mock.master_hash_list = {} return downloader_mock def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list: results = [sub for res in results for sub in res] assert all([isinstance(res, praw.models.Submission) for res in results]) assert not any([isinstance(m, MagicMock) for m in results]) if result_limit is not None: assert len(results) == result_limit return results def assert_all_results_are_submissions_or_comments(result_limit: int, results: list[Iterator]) -> list: results = [sub for res in results for sub in res] assert all([isinstance(res, praw.models.Submission) or isinstance(res, praw.models.Comment) for res in results]) assert not any([isinstance(m, MagicMock) for m in results]) if result_limit is not None: assert len(results) == result_limit return results def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): downloader_mock.args.directory = tmp_path / 'test' downloader_mock.config_directories.user_config_dir = tmp_path RedditConnector.determine_directories(downloader_mock) assert Path(tmp_path / 'test').exists() @pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( ([], []), (['.test'], ['test.com'],), )) def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): downloader_mock.args.skip = skip_extensions downloader_mock.args.skip_domain = skip_domains result = RedditConnector.create_download_filter(downloader_mock) assert isinstance(result, DownloadFilter) assert result.excluded_domains == skip_domains assert result.excluded_extensions == skip_extensions @pytest.mark.parametrize(('test_time', 'expected'), ( ('all', 'all'), ('hour', 'hour'), ('day', 'day'), ('week', 'week'), ('random', 'all'), ('', 'all'), )) def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): downloader_mock.args.time = test_time result = RedditConnector.create_time_filter(downloader_mock) assert isinstance(result, RedditTypes.TimeType) assert result.name.lower() == expected @pytest.mark.parametrize(('test_sort', 'expected'), ( ('', 'hot'), ('hot', 'hot'), ('controversial', 'controversial'), ('new', 'new'), )) def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): downloader_mock.args.sort = test_sort result = RedditConnector.create_sort_filter(downloader_mock) assert isinstance(result, RedditTypes.SortType) assert result.name.lower() == expected @pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( ('{POSTID}', '{SUBREDDIT}'), ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), ('{POSTID}', 'test'), ('{POSTID}', ''), ('{POSTID}', '{SUBREDDIT}/{REDDITOR}'), )) def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): downloader_mock.args.file_scheme = test_file_scheme downloader_mock.args.folder_scheme = test_folder_scheme result = RedditConnector.create_file_name_formatter(downloader_mock) assert isinstance(result, FileNameFormatter) assert result.file_format_string == test_file_scheme assert result.directory_format_string == test_folder_scheme.split('/') @pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( ('', ''), ('', '{SUBREDDIT}'), ('test', '{SUBREDDIT}'), )) def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): downloader_mock.args.file_scheme = test_file_scheme downloader_mock.args.folder_scheme = test_folder_scheme with pytest.raises(BulkDownloaderException): RedditConnector.create_file_name_formatter(downloader_mock) def test_create_authenticator(downloader_mock: MagicMock): result = RedditConnector.create_authenticator(downloader_mock) assert isinstance(result, SiteAuthenticator) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize('test_submission_ids', ( ('lvpf4l',), ('lvpf4l', 'lvqnsn'), ('lvpf4l', 'lvqnsn', 'lvl9kd'), )) def test_get_submissions_from_link( test_submission_ids: list[str], reddit_instance: praw.Reddit, downloader_mock: MagicMock): downloader_mock.args.link = test_submission_ids downloader_mock.reddit_instance = reddit_instance results = RedditConnector.get_submissions_from_link(downloader_mock) assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res]) assert len(results[0]) == len(test_submission_ids) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), ( (('Futurology',), 10, 'hot', 'all', 10), (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30), (('Futurology',), 20, 'hot', 'all', 20), (('Futurology', 'Python'), 10, 'hot', 'all', 20), (('Futurology',), 100, 'hot', 'all', 100), (('Futurology',), 0, 'hot', 'all', 0), (('Futurology',), 10, 'top', 'all', 10), (('Futurology',), 10, 'top', 'week', 10), (('Futurology',), 10, 'hot', 'week', 10), )) def test_get_subreddit_normal( test_subreddits: list[str], limit: int, sort_type: str, time_filter: str, max_expected_len: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit, ): downloader_mock.args.limit = limit downloader_mock.args.sort = sort_type downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock) downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock) downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock) downloader_mock.args.subreddit = test_subreddits downloader_mock.reddit_instance = reddit_instance results = RedditConnector.get_subreddits(downloader_mock) test_subreddits = downloader_mock.split_args_input(test_subreddits) results = [sub for res1 in results for sub in res1] assert all([isinstance(res1, praw.models.Submission) for res1 in results]) assert all([res.subreddit.display_name in test_subreddits for res in results]) assert len(results) <= max_expected_len assert not any([isinstance(m, MagicMock) for m in results]) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_time', 'test_delta'), ( ('hour', timedelta(hours=1)), ('day', timedelta(days=1)), ('week', timedelta(days=7)), ('month', timedelta(days=31)), ('year', timedelta(days=365)), )) def test_get_subreddit_time_verification( test_time: str, test_delta: timedelta, downloader_mock: MagicMock, reddit_instance: praw.Reddit, ): downloader_mock.args.limit = 10 downloader_mock.args.sort = 'top' downloader_mock.args.time = test_time downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock) downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock) downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock) downloader_mock.args.subreddit = ['all'] downloader_mock.reddit_instance = reddit_instance results = RedditConnector.get_subreddits(downloader_mock) results = [sub for res1 in results for sub in res1] assert all([isinstance(res1, praw.models.Submission) for res1 in results]) nowtime = datetime.now() for r in results: result_time = datetime.fromtimestamp(r.created_utc) time_diff = nowtime - result_time assert time_diff < test_delta @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), ( (('Python',), 'scraper', 10, 'all', 10), (('Python',), '', 10, 'all', 0), (('Python',), 'djsdsgewef', 10, 'all', 0), (('Python',), 'scraper', 10, 'year', 10), )) def test_get_subreddit_search( test_subreddits: list[str], search_term: str, time_filter: str, limit: int, max_expected_len: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit, ): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.args.limit = limit downloader_mock.args.search = search_term downloader_mock.args.subreddit = test_subreddits downloader_mock.reddit_instance = reddit_instance downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.time = time_filter downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock) results = RedditConnector.get_subreddits(downloader_mock) results = [sub for res in results for sub in res] assert all([isinstance(res, praw.models.Submission) for res in results]) assert all([res.subreddit.display_name in test_subreddits for res in results]) assert len(results) <= max_expected_len if max_expected_len != 0: assert len(results) > 0 assert not any([isinstance(m, MagicMock) for m in results]) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), ( ('helen_darten', ('cuteanimalpics',), 10), ('korfor', ('chess',), 100), )) # Good sources at https://www.reddit.com/r/multihub/ def test_get_multireddits_public( test_user: str, test_multireddits: list[str], limit: int, reddit_instance: praw.Reddit, downloader_mock: MagicMock, ): downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.limit = limit downloader_mock.args.multireddit = test_multireddits downloader_mock.args.user = [test_user] downloader_mock.reddit_instance = reddit_instance downloader_mock.create_filtered_listing_generator.return_value = \ RedditConnector.create_filtered_listing_generator( downloader_mock, reddit_instance.multireddit(test_user, test_multireddits[0]), ) results = RedditConnector.get_multireddits(downloader_mock) results = [sub for res in results for sub in res] assert all([isinstance(res, praw.models.Submission) for res in results]) assert len(results) == limit assert not any([isinstance(m, MagicMock) for m in results]) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_user', 'limit'), ( ('danigirl3694', 10), ('danigirl3694', 50), ('CapitanHam', None), )) def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock.args.limit = limit downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.submitted = True downloader_mock.args.user = [test_user] downloader_mock.authenticated = False downloader_mock.reddit_instance = reddit_instance downloader_mock.create_filtered_listing_generator.return_value = \ RedditConnector.create_filtered_listing_generator( downloader_mock, reddit_instance.redditor(test_user).submissions, ) results = RedditConnector.get_user_data(downloader_mock) results = assert_all_results_are_submissions(limit, results) assert all([res.author.name == test_user for res in results]) assert not any([isinstance(m, MagicMock) for m in results]) @pytest.mark.online @pytest.mark.reddit @pytest.mark.authenticated @pytest.mark.parametrize('test_flag', ( 'upvoted', 'saved', )) def test_get_user_authenticated_lists( test_flag: str, downloader_mock: MagicMock, authenticated_reddit_instance: praw.Reddit, ): downloader_mock.args.__dict__[test_flag] = True downloader_mock.reddit_instance = authenticated_reddit_instance downloader_mock.args.limit = 10 downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')] results = RedditConnector.get_user_data(downloader_mock) assert_all_results_are_submissions_or_comments(10, results) @pytest.mark.online @pytest.mark.reddit @pytest.mark.authenticated def test_get_subscribed_subreddits(downloader_mock: MagicMock, authenticated_reddit_instance: praw.Reddit): downloader_mock.reddit_instance = authenticated_reddit_instance downloader_mock.args.limit = 10 downloader_mock.args.authenticate = True downloader_mock.args.subscribed = True downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT results = RedditConnector.get_subreddits(downloader_mock) assert all([isinstance(s, praw.models.ListingGenerator) for s in results]) assert len(results) > 0 @pytest.mark.parametrize(('test_name', 'expected'), ( ('Mindustry', 'Mindustry'), ('Futurology', 'Futurology'), ('r/Mindustry', 'Mindustry'), ('TrollXChromosomes', 'TrollXChromosomes'), ('r/TrollXChromosomes', 'TrollXChromosomes'), ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'), ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'), ('https://www.reddit.com/r/Futurology/', 'Futurology'), ('https://www.reddit.com/r/Futurology', 'Futurology'), )) def test_sanitise_subreddit_name(test_name: str, expected: str): result = RedditConnector.sanitise_subreddit_name(test_name) assert result == expected @pytest.mark.parametrize(('test_subreddit_entries', 'expected'), ( (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}), (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}), (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}), (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}), (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}), ([''], {''}), (['test'], {'test'}), )) def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]): results = RedditConnector.split_args_input(test_subreddit_entries) assert results == expected def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): test_file = tmp_path / 'test.txt' test_file.write_text('aaaaaa\nbbbbbb') results = RedditConnector.read_id_files([str(test_file)]) assert results == {'aaaaaa', 'bbbbbb'} @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize('test_redditor_name', ( 'nasa', 'crowdstrike', 'HannibalGoddamnit', )) def test_check_user_existence_good( test_redditor_name: str, reddit_instance: praw.Reddit, downloader_mock: MagicMock, ): downloader_mock.reddit_instance = reddit_instance RedditConnector.check_user_existence(downloader_mock, test_redditor_name) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize('test_redditor_name', ( 'lhnhfkuhwreolo', 'adlkfmnhglojh', )) def test_check_user_existence_nonexistent( test_redditor_name: str, reddit_instance: praw.Reddit, downloader_mock: MagicMock, ): downloader_mock.reddit_instance = reddit_instance with pytest.raises(BulkDownloaderException, match='Could not find'): RedditConnector.check_user_existence(downloader_mock, test_redditor_name) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize('test_redditor_name', ( 'Bree-Boo', )) def test_check_user_existence_banned( test_redditor_name: str, reddit_instance: praw.Reddit, downloader_mock: MagicMock, ): downloader_mock.reddit_instance = reddit_instance with pytest.raises(BulkDownloaderException, match='is banned'): RedditConnector.check_user_existence(downloader_mock, test_redditor_name) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), ( ('donaldtrump', 'cannot be found'), ('submitters', 'private and cannot be scraped'), ('lhnhfkuhwreolo', 'does not exist') )) def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit): test_subreddit = reddit_instance.subreddit(test_subreddit_name) with pytest.raises(BulkDownloaderException, match=expected_message): RedditConnector.check_subreddit_status(test_subreddit) @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize('test_subreddit_name', ( 'Python', 'Mindustry', 'TrollXChromosomes', 'all', )) def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit): test_subreddit = reddit_instance.subreddit(test_subreddit_name) RedditConnector.check_subreddit_status(test_subreddit)