From c76e4eecfbd4e7bc9a0e126455b3ead123e2fa1e Mon Sep 17 00:00:00 2001 From: Dan Hess Date: Thu, 19 Nov 2020 14:58:41 -0600 Subject: [PATCH] Optimize checks --- utils/checks.py | 81 ++++++++--------- utils/paginator.py | 216 +++++++++++++++++++++++++++++---------------- 2 files changed, 176 insertions(+), 121 deletions(-) diff --git a/utils/checks.py b/utils/checks.py index 615a808..135f037 100644 --- a/utils/checks.py +++ b/utils/checks.py @@ -2,7 +2,6 @@ import asyncio from discord.ext import commands import discord -from . import utilities loop = asyncio.get_event_loop() @@ -13,7 +12,10 @@ def should_ignore(ctx): ignored = ctx.bot.cache.ignored[ctx.guild.id] if not ignored: return False - return ctx.message.author.id in ignored['members'] or ctx.message.channel.id in ignored['channels'] + return ( + ctx.message.author.id in ignored["members"] + or ctx.message.channel.id in ignored["channels"] + ) async def check_not_restricted(ctx): @@ -24,30 +26,27 @@ async def check_not_restricted(ctx): # First get all the restrictions restrictions = ctx.bot.cache.restrictions[ctx.guild.id] # Now lets check the "from" restrictions - for from_restriction in restrictions.get('from', []): + for from_restriction in restrictions.get("from", []): # Get the source and destination # Source should ALWAYS be a command in this case - source = from_restriction.get('source') - destination = from_restriction.get('destination') + source = from_restriction.get("source") + destination = int(from_restriction.get("destination")) # Special check for what the "disable" command produces if destination == "everyone" and ctx.command.qualified_name == source: return False - # Convert destination to the object we want - destination = await utilities.convert(ctx, destination) - # If we couldn't find the destination, just continue with other restrictions - # Also if this restriction we're checking isn't for this command - if destination is None or source != ctx.command.qualified_name: + # If this isn't the command we care about, continue + if source != ctx.command.qualified_name: continue # This means that the type of restriction we have is `command from channel` # Which means we do not want commands to be ran in this channel - if destination == ctx.message.channel: + if destination == ctx.channel.id: return False # This type is `command from Role` meaning anyone with this role can't run this command - elif destination in ctx.message.author.roles: + elif discord.utils.get(ctx.author.roles, id=destination): return False # This is `command from Member` meaning this user specifically cannot run this command - elif destination == ctx.message.author: + elif destination == ctx.author.id: return False # If we are here, then there are no blacklists stopping this from running @@ -55,44 +54,38 @@ async def check_not_restricted(ctx): # Now for the to restrictions this is a little different, we need to make a whitelist and # see if our current channel is in this whitelist, as well as any whitelisted roles are in the author's roles # Only if there is no whitelist, do we want to blanket return True - to_restrictions = restrictions.get('to', []) - if len(to_restrictions) == 0: + to_restrictions = restrictions.get("to", []) + if not to_restrictions: return True - # Otherwise there is a whitelist, and we need to start it - whitelisted_channels = [] - whitelisted_roles = [] + # If the author has a role that should whitelist them + whitelisted_role = False + # If this channel is one that is whitelisted + whitelisted_channel = False + # If a whitelist was found for this command + whitelist_found = False + # Otherwise check whitelists for to_restriction in to_restrictions: # Get the source and destination # Source should ALWAYS be a command in this case - source = to_restriction.get('source') - destination = to_restriction.get('destination') - # Convert destination to the object we want - destination = await utilities.convert(ctx, destination) - # If we couldn't find the destination, just continue with other restrictions - # Also if this restriction we're checking isn't for this command - if destination is None or source != ctx.command.qualified_name: + source = to_restriction.get("source") + destination = int(to_restriction.get("destination")) + # If this isn't the source we care about, continue + if source != ctx.command.qualified_name: continue - # Append to our two whitelists depending on what type this is - if isinstance(destination, discord.TextChannel): - whitelisted_channels.append(destination) - elif isinstance(destination, discord.Role): - whitelisted_roles.append(destination) + # If we've found a whitelist valid for this command, now we can set it + whitelist_found = True + # Now check against roles + if not whitelisted_role and discord.utils.get(ctx.author.roles, id=destination): + whitelisted_role = True + if ctx.channel.id == destination: + whitelisted_channel = True - if whitelisted_channels: - if ctx.channel not in whitelisted_channels: - return False - if whitelisted_roles: - if not any(x in ctx.message.author.roles for x in whitelisted_roles): - return False - - # If we have passed all of these, then we are allowed to run this command - # This looks like a whole lot, but all of these lists will be very tiny in almost all cases - # And only delving deep into the specific lists that may be large, will we finally see "large" lists - # Which means this still will not be slow in other cases - return True + # If we have reached here, then there is a whitelist... so we just need to return if they matched + # the whitelist + return whitelisted_role or whitelisted_channel or not whitelist_found def has_perms(ctx, **perms): @@ -112,7 +105,9 @@ def has_perms(ctx, **perms): for perm, setting in perms.items(): setattr(required_perm, perm, setting) - required_perm_value = ctx.bot.cache.custom_permissions[ctx.guild.id].get(ctx.command.qualified_name) + required_perm_value = ctx.bot.cache.custom_permissions[ctx.guild.id].get( + ctx.command.qualified_name + ) if required_perm_value: required_perm = discord.Permissions(required_perm_value) diff --git a/utils/paginator.py b/utils/paginator.py index 7a6e3c4..a7f572c 100644 --- a/utils/paginator.py +++ b/utils/paginator.py @@ -54,13 +54,19 @@ class Pages: self.paginating = len(entries) > per_page self.show_entry_count = show_entry_count self.reaction_emojis = [ - ('\N{BLACK LEFT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}', self.first_page), - ('\N{BLACK LEFT-POINTING TRIANGLE}', self.previous_page), - ('\N{BLACK RIGHT-POINTING TRIANGLE}', self.next_page), - ('\N{BLACK RIGHT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}', self.last_page), - ('\N{INPUT SYMBOL FOR NUMBERS}', self.numbered_page), - ('\N{BLACK SQUARE FOR STOP}', self.stop_pages), - ('\N{INFORMATION SOURCE}', self.show_help), + ( + "\N{BLACK LEFT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}", + self.first_page, + ), + ("\N{BLACK LEFT-POINTING TRIANGLE}", self.previous_page), + ("\N{BLACK RIGHT-POINTING TRIANGLE}", self.next_page), + ( + "\N{BLACK RIGHT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}", + self.last_page, + ), + ("\N{INPUT SYMBOL FOR NUMBERS}", self.numbered_page), + ("\N{BLACK SQUARE FOR STOP}", self.stop_pages), + ("\N{INFORMATION SOURCE}", self.show_help), ] if ctx.guild is not None: @@ -69,53 +75,55 @@ class Pages: self.permissions = self.channel.permissions_for(ctx.bot.user) if not self.permissions.embed_links: - raise CannotPaginate('Bot does not have embed links permission.') + raise CannotPaginate("Bot does not have embed links permission.") if not self.permissions.send_messages: - raise CannotPaginate('Bot cannot send messages.') + raise CannotPaginate("Bot cannot send messages.") if self.paginating: # verify we can actually use the pagination session if not self.permissions.add_reactions: - raise CannotPaginate('Bot does not have add reactions permission.') + raise CannotPaginate("Bot does not have add reactions permission.") if not self.permissions.read_message_history: - raise CannotPaginate('Bot does not have Read Message History permission.') + raise CannotPaginate( + "Bot does not have Read Message History permission." + ) def get_page(self, page): base = (page - 1) * self.per_page - return self.entries[base:base + self.per_page] + return self.entries[base : base + self.per_page] async def show_page(self, page, *, first=False): self.current_page = page entries = self.get_page(page) p = [] for index, entry in enumerate(entries, 1 + ((page - 1) * self.per_page)): - p.append(f'{index}. {entry}') + p.append(f"{index}. {entry}") if self.maximum_pages > 1: if self.show_entry_count: - text = f'Page {page}/{self.maximum_pages} ({len(self.entries)} entries)' + text = f"Page {page}/{self.maximum_pages} ({len(self.entries)} entries)" else: - text = f'Page {page}/{self.maximum_pages}' + text = f"Page {page}/{self.maximum_pages}" self.embed.set_footer(text=text) if not self.paginating: - self.embed.description = '\n'.join(p) + self.embed.description = "\n".join(p) return await self.channel.send(embed=self.embed) if not first: - self.embed.description = '\n'.join(p) + self.embed.description = "\n".join(p) await self.message.edit(embed=self.embed) return - p.append('') - p.append('Confused? React with \N{INFORMATION SOURCE} for more info.') - self.embed.description = '\n'.join(p) + p.append("") + p.append("Confused? React with \N{INFORMATION SOURCE} for more info.") + self.embed.description = "\n".join(p) self.message = await self.channel.send(embed=self.embed) for (reaction, _) in self.reaction_emojis: - if self.maximum_pages == 2 and reaction in ('\u23ed', '\u23ee'): + if self.maximum_pages == 2 and reaction in ("\u23ed", "\u23ee"): # no |<< or >>| buttons if we only have two pages # we can't forbid it if someone ends up using it but remove # it from the default set @@ -150,15 +158,19 @@ class Pages: async def numbered_page(self): """lets you type a page number to go to""" to_delete = [] - to_delete.append(await self.channel.send('What page do you want to go to?')) + to_delete.append(await self.channel.send("What page do you want to go to?")) def message_check(m): - return m.author == self.author and self.channel == m.channel and m.content.isdigit() + return ( + m.author == self.author + and self.channel == m.channel + and m.content.isdigit() + ) try: - msg = await self.bot.wait_for('message', check=message_check, timeout=30.0) + msg = await self.bot.wait_for("message", check=message_check, timeout=30.0) except asyncio.TimeoutError: - to_delete.append(await self.channel.send('Took too long.')) + to_delete.append(await self.channel.send("Took too long.")) await asyncio.sleep(5) else: page = int(msg.content) @@ -166,7 +178,11 @@ class Pages: if page != 0 and page <= self.maximum_pages: await self.show_page(page) else: - to_delete.append(await self.channel.send(f'Invalid page given. ({page}/{self.maximum_pages})')) + to_delete.append( + await self.channel.send( + f"Invalid page given. ({page}/{self.maximum_pages})" + ) + ) await asyncio.sleep(5) try: @@ -176,16 +192,20 @@ class Pages: async def show_help(self): """shows this message""" - messages = ['Welcome to the interactive paginator!\n'] - messages.append('This interactively allows you to see pages of text by navigating with ' - 'reactions. They are as follows:\n') + messages = ["Welcome to the interactive paginator!\n"] + messages.append( + "This interactively allows you to see pages of text by navigating with " + "reactions. They are as follows:\n" + ) for (emoji, func) in self.reaction_emojis: - messages.append(f'{emoji} {func.__doc__}') + messages.append(f"{emoji} {func.__doc__}") - self.embed.description = '\n'.join(messages) + self.embed.description = "\n".join(messages) self.embed.clear_fields() - self.embed.set_footer(text=f'We were on page {self.current_page} before this message.') + self.embed.set_footer( + text=f"We were on page {self.current_page} before this message." + ) await self.message.edit(embed=self.embed) async def go_back_to_current_page(): @@ -223,7 +243,9 @@ class Pages: while self.paginating: try: - reaction, user = await self.bot.wait_for('reaction_add', check=self.react_check, timeout=120.0) + reaction, user = await self.bot.wait_for( + "reaction_add", check=self.react_check, timeout=120.0 + ) except asyncio.TimeoutError: self.paginating = False try: @@ -245,6 +267,7 @@ class FieldPages(Pages): """Similar to Pages except entries should be a list of tuples having (key, value) to show as embed fields instead. """ + async def show_page(self, page, *, first=False): self.current_page = page entries = self.get_page(page) @@ -257,9 +280,9 @@ class FieldPages(Pages): if self.maximum_pages > 1: if self.show_entry_count: - text = f'Page {page}/{self.maximum_pages} ({len(self.entries)} entries)' + text = f"Page {page}/{self.maximum_pages} ({len(self.entries)} entries)" else: - text = f'Page {page}/{self.maximum_pages}' + text = f"Page {page}/{self.maximum_pages}" self.embed.set_footer(text=text) @@ -272,7 +295,7 @@ class FieldPages(Pages): self.message = await self.channel.send(embed=self.embed) for (reaction, _) in self.reaction_emojis: - if self.maximum_pages == 2 and reaction in ('\u23ed', '\u23ee'): + if self.maximum_pages == 2 and reaction in ("\u23ed", "\u23ee"): # no |<< or >>| buttons if we only have two pages # we can't forbid it if someone ends up using it but remove # it from the default set @@ -286,7 +309,7 @@ class FieldPages(Pages): # ?help command # -> could be a subcommand -_mention = re.compile(r'<@\!?([0-9]{1,19})>') +_mention = re.compile(r"<@\!?([0-9]{1,19})>") def cleanup_prefix(bot, prefix): @@ -294,7 +317,7 @@ def cleanup_prefix(bot, prefix): if m: user = bot.get_user(int(m.group(1))) if user: - return f'@{user.name} ' + return f"@{user.name} " return prefix @@ -311,33 +334,39 @@ def _command_signature(cmd): result = [cmd.qualified_name] if cmd.usage: result.append(cmd.usage) - return ' '.join(result) + return " ".join(result) params = cmd.clean_params if not params: - return ' '.join(result) + return " ".join(result) for name, param in params.items(): if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = param.default if isinstance(param.default, str) else param.default is not None + should_print = ( + param.default + if isinstance(param.default, str) + else param.default is not None + ) if should_print: - result.append(f'[{name}={param.default!r}]') + result.append(f"[{name}={param.default!r}]") else: - result.append(f'[{name}]') + result.append(f"[{name}]") elif param.kind == param.VAR_POSITIONAL: - result.append(f'[{name}...]') + result.append(f"[{name}...]") else: - result.append(f'<{name}>') + result.append(f"<{name}>") - return ' '.join(result) + return " ".join(result) class HelpPaginator(Pages): def __init__(self, ctx, entries, *, per_page=4): super().__init__(ctx, entries=entries, per_page=per_page) - self.reaction_emojis.append(('\N{WHITE QUESTION MARK ORNAMENT}', self.show_bot_help)) + self.reaction_emojis.append( + ("\N{WHITE QUESTION MARK ORNAMENT}", self.show_bot_help) + ) self.total = len(entries) @classmethod @@ -348,10 +377,12 @@ class HelpPaginator(Pages): entries = sorted(ctx.bot.get_cog_commands(cog_name), key=lambda c: c.name) # remove the ones we can't run - entries = [cmd for cmd in entries if (await _can_run(cmd, ctx)) and not cmd.hidden] + entries = [ + cmd for cmd in entries if (await _can_run(cmd, ctx)) and not cmd.hidden + ] self = cls(ctx, entries) - self.title = f'{cog_name} Commands' + self.title = f"{cog_name} Commands" self.description = inspect.getdoc(cog) self.prefix = cleanup_prefix(ctx.bot, ctx.prefix) @@ -364,15 +395,17 @@ class HelpPaginator(Pages): except AttributeError: entries = [] else: - entries = [cmd for cmd in entries if (await _can_run(cmd, ctx)) and not cmd.hidden] + entries = [ + cmd for cmd in entries if (await _can_run(cmd, ctx)) and not cmd.hidden + ] self = cls(ctx, entries) self.title = command.signature if command.description: - self.description = f'{command.description}\n\n{command.help}' + self.description = f"{command.description}\n\n{command.help}" else: - self.description = command.help or 'No help given.' + self.description = command.help or "No help given." self.prefix = cleanup_prefix(ctx.bot, ctx.prefix) return self @@ -380,7 +413,7 @@ class HelpPaginator(Pages): @classmethod async def from_bot(cls, ctx): def key(c): - return c.cog_name or '\u200bMisc' + return c.cog_name or "\u200bMisc" entries = sorted(ctx.bot.commands, key=key) nested_pages = [] @@ -391,7 +424,9 @@ class HelpPaginator(Pages): # ... for cog, commands in itertools.groupby(entries, key=key): - plausible = [cmd for cmd in commands if (await _can_run(cmd, ctx)) and not cmd.hidden] + plausible = [ + cmd for cmd in commands if (await _can_run(cmd, ctx)) and not cmd.hidden + ] if len(plausible) == 0: continue @@ -401,7 +436,10 @@ class HelpPaginator(Pages): else: description = inspect.getdoc(description) or discord.Embed.Empty - nested_pages.extend((cog, description, plausible[i:i + per_page]) for i in range(0, len(plausible), per_page)) + nested_pages.extend( + (cog, description, plausible[i : i + per_page]) + for i in range(0, len(plausible), per_page) + ) self = cls(ctx, nested_pages, per_page=1) # this forces the pagination session self.prefix = cleanup_prefix(ctx.bot, ctx.prefix) @@ -416,7 +454,7 @@ class HelpPaginator(Pages): def get_bot_page(self, page): cog, description, commands = self.entries[page - 1] - self.title = f'{cog} Commands' + self.title = f"{cog} Commands" self.description = description return commands @@ -428,19 +466,27 @@ class HelpPaginator(Pages): self.embed.description = self.description self.embed.title = self.title - if hasattr(self, '_is_bot'): - value = 'For more help, join the official bot support server: https://discord.gg/f6uzJEj' - self.embed.add_field(name='Support', value=value, inline=False) + if hasattr(self, "_is_bot"): + value = "For more help, join the official bot support server: https://discord.gg/f6uzJEj" + self.embed.add_field(name="Support", value=value, inline=False) - self.embed.set_footer(text=f'Use "{self.prefix}help command" for more info on a command.') + self.embed.set_footer( + text=f'Use "{self.prefix}help command" for more info on a command.' + ) signature = _command_signature for entry in entries: - self.embed.add_field(name=signature(entry), value=entry.short_doc or "No help given", inline=False) + self.embed.add_field( + name=signature(entry), + value=entry.short_doc or "No help given", + inline=False, + ) if self.maximum_pages: - self.embed.set_author(name=f'Page {page}/{self.maximum_pages} ({self.total} commands)') + self.embed.set_author( + name=f"Page {page}/{self.maximum_pages} ({self.total} commands)" + ) if not self.paginating: return await self.channel.send(embed=self.embed) @@ -451,7 +497,7 @@ class HelpPaginator(Pages): self.message = await self.channel.send(embed=self.embed) for (reaction, _) in self.reaction_emojis: - if self.maximum_pages == 2 and reaction in ('\u23ed', '\u23ee'): + if self.maximum_pages == 2 and reaction in ("\u23ed", "\u23ee"): # no |<< or >>| buttons if we only have two pages # we can't forbid it if someone ends up using it but remove # it from the default set @@ -462,14 +508,20 @@ class HelpPaginator(Pages): async def show_help(self): """shows this message""" - self.embed.title = 'Paginator help' - self.embed.description = 'Hello! Welcome to the help page.' + self.embed.title = "Paginator help" + self.embed.description = "Hello! Welcome to the help page." - messages = [f'{emoji} {func.__doc__}' for emoji, func in self.reaction_emojis] + messages = [f"{emoji} {func.__doc__}" for emoji, func in self.reaction_emojis] self.embed.clear_fields() - self.embed.add_field(name='What are these reactions for?', value='\n'.join(messages), inline=False) + self.embed.add_field( + name="What are these reactions for?", + value="\n".join(messages), + inline=False, + ) - self.embed.set_footer(text=f'We were on page {self.current_page} before this message.') + self.embed.set_footer( + text=f"We were on page {self.current_page} before this message." + ) await self.message.edit(embed=self.embed) async def go_back_to_current_page(): @@ -481,25 +533,33 @@ class HelpPaginator(Pages): async def show_bot_help(self): """shows how to use the bot""" - self.embed.title = 'Using the bot' - self.embed.description = 'Hello! Welcome to the help page.' + self.embed.title = "Using the bot" + self.embed.description = "Hello! Welcome to the help page." self.embed.clear_fields() entries = ( - ('', 'This means the argument is __**required**__.'), - ('[argument]', 'This means the argument is __**optional**__.'), - ('[A|B]', 'This means the it can be __**either A or B**__.'), - ('[argument...]', 'This means you can have multiple arguments.\n' - 'Now that you know the basics, it should be noted that...\n' - '__**You do not type in the brackets!**__') + ("", "This means the argument is __**required**__."), + ("[argument]", "This means the argument is __**optional**__."), + ("[A|B]", "This means the it can be __**either A or B**__."), + ( + "[argument...]", + "This means you can have multiple arguments.\n" + "Now that you know the basics, it should be noted that...\n" + "__**You do not type in the brackets!**__", + ), ) - self.embed.add_field(name='How do I use this bot?', value='Reading the bot signature is pretty simple.') + self.embed.add_field( + name="How do I use this bot?", + value="Reading the bot signature is pretty simple.", + ) for name, value in entries: self.embed.add_field(name=name, value=value, inline=False) - self.embed.set_footer(text=f'We were on page {self.current_page} before this message.') + self.embed.set_footer( + text=f"We were on page {self.current_page} before this message." + ) await self.message.edit(embed=self.embed) async def go_back_to_current_page():