Add --subscribed option
This commit is contained in:
parent
274407537e
commit
a620ae91a1
|
@ -23,6 +23,7 @@ _common_options = [
|
||||||
click.option('--saved', is_flag=True, default=None),
|
click.option('--saved', is_flag=True, default=None),
|
||||||
click.option('--search', default=None, type=str),
|
click.option('--search', default=None, type=str),
|
||||||
click.option('--submitted', is_flag=True, default=None),
|
click.option('--submitted', is_flag=True, default=None),
|
||||||
|
click.option('--subscribed', is_flag=True, default=None),
|
||||||
click.option('--time-format', type=str, default=None),
|
click.option('--time-format', type=str, default=None),
|
||||||
click.option('--upvoted', is_flag=True, default=None),
|
click.option('--upvoted', is_flag=True, default=None),
|
||||||
click.option('-L', '--limit', default=None, type=int),
|
click.option('-L', '--limit', default=None, type=int),
|
||||||
|
|
|
@ -35,6 +35,7 @@ class Configuration(Namespace):
|
||||||
self.skip_subreddit: list[str] = []
|
self.skip_subreddit: list[str] = []
|
||||||
self.sort: str = 'hot'
|
self.sort: str = 'hot'
|
||||||
self.submitted: bool = False
|
self.submitted: bool = False
|
||||||
|
self.subscribed: bool = True
|
||||||
self.subreddit: list[str] = []
|
self.subreddit: list[str] = []
|
||||||
self.time: str = 'all'
|
self.time: str = 'all'
|
||||||
self.time_format = None
|
self.time_format = None
|
||||||
|
|
|
@ -243,9 +243,19 @@ class RedditConnector(metaclass=ABCMeta):
|
||||||
return set(all_entries)
|
return set(all_entries)
|
||||||
|
|
||||||
def get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
def get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
||||||
if self.args.subreddit:
|
out = []
|
||||||
out = []
|
subscribed_subreddits = set()
|
||||||
for reddit in self.split_args_input(self.args.subreddit):
|
if self.args.subscribed:
|
||||||
|
if self.args.authenticate:
|
||||||
|
try:
|
||||||
|
subscribed_subreddits = list(self.reddit_instance.user.subreddits(limit=None))
|
||||||
|
subscribed_subreddits = set([s.display_name for s in subscribed_subreddits])
|
||||||
|
except prawcore.InsufficientScope:
|
||||||
|
logger.error('BDFR has insufficient scope to access subreddit lists')
|
||||||
|
else:
|
||||||
|
logger.error('Cannot find subscribed subreddits without an authenticated instance')
|
||||||
|
if self.args.subreddit or subscribed_subreddits:
|
||||||
|
for reddit in self.split_args_input(self.args.subreddit) | subscribed_subreddits:
|
||||||
if reddit == 'friends' and self.authenticated is False:
|
if reddit == 'friends' and self.authenticated is False:
|
||||||
logger.error('Cannot read friends subreddit without an authenticated instance')
|
logger.error('Cannot read friends subreddit without an authenticated instance')
|
||||||
continue
|
continue
|
||||||
|
@ -270,9 +280,7 @@ class RedditConnector(metaclass=ABCMeta):
|
||||||
logger.debug(f'Added submissions from subreddit {reddit}')
|
logger.debug(f'Added submissions from subreddit {reddit}')
|
||||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
||||||
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
||||||
return out
|
return out
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def resolve_user_name(self, in_name: str) -> str:
|
def resolve_user_name(self, in_name: str) -> str:
|
||||||
if in_name == 'me':
|
if in_name == 'me':
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
[DEFAULT]
|
[DEFAULT]
|
||||||
client_id = U-6gk4ZCh3IeNQ
|
client_id = U-6gk4ZCh3IeNQ
|
||||||
client_secret = 7CZHY6AmKweZME5s50SfDGylaPg
|
client_secret = 7CZHY6AmKweZME5s50SfDGylaPg
|
||||||
scopes = identity, history, read, save
|
scopes = identity, history, read, save, mysubreddits
|
||||||
backup_log_count = 3
|
backup_log_count = 3
|
||||||
max_wait_time = 120
|
max_wait_time = 120
|
||||||
time_format = ISO
|
time_format = ISO
|
|
@ -31,23 +31,23 @@ def create_basic_args_for_download_runner(test_args: list[str], run_path: Path):
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
|
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
|
||||||
@pytest.mark.parametrize('test_args', (
|
@pytest.mark.parametrize('test_args', (
|
||||||
['-s', 'Mindustry', '-L', 1],
|
['-s', 'Mindustry', '-L', 3],
|
||||||
['-s', 'r/Mindustry', '-L', 1],
|
['-s', 'r/Mindustry', '-L', 3],
|
||||||
['-s', 'r/mindustry', '-L', 1],
|
['-s', 'r/mindustry', '-L', 3],
|
||||||
['-s', 'mindustry', '-L', 1],
|
['-s', 'mindustry', '-L', 3],
|
||||||
['-s', 'https://www.reddit.com/r/TrollXChromosomes/', '-L', 1],
|
['-s', 'https://www.reddit.com/r/TrollXChromosomes/', '-L', 3],
|
||||||
['-s', 'r/TrollXChromosomes/', '-L', 1],
|
['-s', 'r/TrollXChromosomes/', '-L', 3],
|
||||||
['-s', 'TrollXChromosomes/', '-L', 1],
|
['-s', 'TrollXChromosomes/', '-L', 3],
|
||||||
['-s', 'trollxchromosomes', '-L', 1],
|
['-s', 'trollxchromosomes', '-L', 3],
|
||||||
['-s', 'trollxchromosomes,mindustry,python', '-L', 1],
|
['-s', 'trollxchromosomes,mindustry,python', '-L', 3],
|
||||||
['-s', 'trollxchromosomes, mindustry, python', '-L', 1],
|
['-s', 'trollxchromosomes, mindustry, python', '-L', 3],
|
||||||
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day'],
|
['-s', 'trollxchromosomes', '-L', 3, '--time', 'day'],
|
||||||
['-s', 'trollxchromosomes', '-L', 1, '--sort', 'new'],
|
['-s', 'trollxchromosomes', '-L', 3, '--sort', 'new'],
|
||||||
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--sort', 'new'],
|
['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--sort', 'new'],
|
||||||
['-s', 'trollxchromosomes', '-L', 1, '--search', 'women'],
|
['-s', 'trollxchromosomes', '-L', 3, '--search', 'women'],
|
||||||
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--search', 'women'],
|
['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--search', 'women'],
|
||||||
['-s', 'trollxchromosomes', '-L', 1, '--sort', 'new', '--search', 'women'],
|
['-s', 'trollxchromosomes', '-L', 3, '--sort', 'new', '--search', 'women'],
|
||||||
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--sort', 'new', '--search', 'women'],
|
['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--sort', 'new', '--search', 'women'],
|
||||||
))
|
))
|
||||||
def test_cli_download_subreddits(test_args: list[str], tmp_path: Path):
|
def test_cli_download_subreddits(test_args: list[str], tmp_path: Path):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
@ -64,6 +64,7 @@ def test_cli_download_subreddits(test_args: list[str], tmp_path: Path):
|
||||||
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
|
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
|
||||||
@pytest.mark.parametrize('test_args', (
|
@pytest.mark.parametrize('test_args', (
|
||||||
['-s', 'hentai', '-L', 10, '--search', 'red', '--authenticate'],
|
['-s', 'hentai', '-L', 10, '--search', 'red', '--authenticate'],
|
||||||
|
['--authenticate', '--subscribed', '-L', 10],
|
||||||
))
|
))
|
||||||
def test_cli_download_search_subreddits_authenticated(test_args: list[str], tmp_path: Path):
|
def test_cli_download_search_subreddits_authenticated(test_args: list[str], tmp_path: Path):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
|
@ -336,13 +336,29 @@ def test_get_user_authenticated_lists(
|
||||||
downloader_mock.args.__dict__[test_flag] = True
|
downloader_mock.args.__dict__[test_flag] = True
|
||||||
downloader_mock.reddit_instance = authenticated_reddit_instance
|
downloader_mock.reddit_instance = authenticated_reddit_instance
|
||||||
downloader_mock.args.limit = 10
|
downloader_mock.args.limit = 10
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')]
|
downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')]
|
||||||
results = RedditConnector.get_user_data(downloader_mock)
|
results = RedditConnector.get_user_data(downloader_mock)
|
||||||
assert_all_results_are_submissions_or_comments(10, results)
|
assert_all_results_are_submissions_or_comments(10, results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.authenticated
|
||||||
|
def test_get_subscribed_subreddits(downloader_mock: MagicMock, authenticated_reddit_instance: praw.Reddit):
|
||||||
|
downloader_mock.reddit_instance = authenticated_reddit_instance
|
||||||
|
downloader_mock.args.limit = 10
|
||||||
|
downloader_mock.args.authenticate = True
|
||||||
|
downloader_mock.args.subscribed = True
|
||||||
|
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
results = RedditConnector.get_subreddits(downloader_mock)
|
||||||
|
assert all([isinstance(s, praw.models.ListingGenerator) for s in results])
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('test_name', 'expected'), (
|
@pytest.mark.parametrize(('test_name', 'expected'), (
|
||||||
('Mindustry', 'Mindustry'),
|
('Mindustry', 'Mindustry'),
|
||||||
('Futurology', 'Futurology'),
|
('Futurology', 'Futurology'),
|
||||||
|
|
Loading…
Reference in a new issue