From 75af7dcd6f23151b5afc0591e661cd552d47b844 Mon Sep 17 00:00:00 2001 From: phxntxm Date: Sun, 27 Jan 2019 20:58:39 -0600 Subject: [PATCH] Rewrite of database/configuration --- bot.py | 110 ++--- cogs/admin.py | 959 ++++++-------------------------------------- cogs/birthday.py | 214 +++++----- cogs/config.py | 556 +++++++++++++++++++++++++ cogs/events.py | 86 ++-- cogs/hangman.py | 10 +- cogs/images.py | 8 +- cogs/interaction.py | 329 ++++++++------- cogs/links.py | 12 +- cogs/misc.py | 58 +-- cogs/osu.py | 32 +- cogs/overwatch.py | 4 +- cogs/picarto.py | 346 ++++------------ cogs/raffle.py | 323 +++++---------- cogs/roles.py | 2 +- cogs/spades.py | 10 +- cogs/spotify.py | 2 +- cogs/stats.py | 133 +++--- cogs/tags.py | 199 ++++----- cogs/tutorial.py | 8 +- requirements.txt | 2 +- utils/__init__.py | 4 +- utils/checks.py | 64 +-- utils/config.py | 25 +- utils/database.py | 202 +++++----- utils/utilities.py | 101 +++-- 26 files changed, 1604 insertions(+), 2195 deletions(-) create mode 100644 cogs/config.py diff --git a/bot.py b/bot.py index b6c1f5e..018bb15 100644 --- a/bot.py +++ b/bot.py @@ -1,4 +1,3 @@ -#!/usr/local/bin/python3.5 import discord import traceback import logging @@ -24,35 +23,23 @@ bot = commands.AutoShardedBot(**opts) logging.basicConfig(level=logging.INFO, filename='bonfire.log') +@bot.before_invoke +async def start_typing(ctx): + await ctx.trigger_typing() + + @bot.event async def on_command_completion(ctx): - author = ctx.message.author - server = ctx.message.guild - command = ctx.command + author = ctx.author.id + guild = ctx.guild.id if ctx.guild else None + command = ctx.command.qualified_name - command_usage = await bot.db.actual_load( - 'command_usage', key=command.qualified_name - ) or {'command': command.qualified_name} - - # Add one to the total usage for this command, basing it off 0 to start with (obviously) - total_usage = command_usage.get('total_usage', 0) + 1 - command_usage['total_usage'] = total_usage - - # Add one to the author's usage for this command - total_member_usage = command_usage.get('member_usage', {}) - member_usage = total_member_usage.get(str(author.id), 0) + 1 - total_member_usage[str(author.id)] = member_usage - command_usage['member_usage'] = total_member_usage - - # Add one to the server's usage for this command - if ctx.message.guild is not None: - total_server_usage = command_usage.get('server_usage', {}) - server_usage = total_server_usage.get(str(server.id), 0) + 1 - total_server_usage[str(server.id)] = server_usage - command_usage['server_usage'] = total_server_usage - - # Save all the changes - await bot.db.save('command_usage', command_usage) + await bot.db.execute( + "INSERT INTO command_usage(command, guild, author) VALUES ($1, $2, $3)", + command, + guild, + author + ) # Now add credits to a users amount # user_credits = bot.db.load('credits', key=ctx.author.id, pluck='credits') or 1000 @@ -66,43 +53,35 @@ async def on_command_completion(ctx): @bot.event async def on_command_error(ctx, error): - if isinstance(error, commands.CommandNotFound): - return - if isinstance(error, commands.DisabledCommand): - return - try: - if isinstance(error.original, discord.Forbidden): - return - elif isinstance(error.original, discord.HTTPException) and ( - 'empty message' in str(error.original) or - 'INTERNAL SERVER ERROR' in str(error.original) or - 'REQUEST ENTITY TOO LARGE' in str(error.original) or - 'Unknown Message' in str(error.original) or - 'Origin Time-out' in str(error.original) or - 'Bad Gateway' in str(error.original) or - 'Gateway Time-out' in str(error.original) or - 'Explicit content' in str(error.original)): - return - elif isinstance(error.original, aiohttp.ClientOSError): - return - elif isinstance(error.original, discord.NotFound) and 'Unknown Channel' in str(error.original): - return + error = error.original if hasattr(error, "original") else error + ignored_errors = ( + commands.CommandNotFound, + commands.DisabledCommand, + discord.Forbidden, + aiohttp.ClientOSError, + commands.CheckFailure, + commands.CommandOnCooldown, + ) - except AttributeError: - pass + if isinstance(error, ignored_errors): + return + elif isinstance(error, discord.HTTPException) and ( + 'empty message' in str(error) or + 'INTERNAL SERVER ERROR' in str(error) or + 'REQUEST ENTITY TOO LARGE' in str(error) or + 'Unknown Message' in str(error) or + 'Origin Time-out' in str(error) or + 'Bad Gateway' in str(error) or + 'Gateway Time-out' in str(error) or + 'Explicit content' in str(error)): + return + elif isinstance(error, discord.NotFound) and 'Unknown Channel' in str(error): + return try: if isinstance(error, commands.BadArgument): fmt = "Please provide a valid argument to pass to the command: {}".format(error) await ctx.message.channel.send(fmt) - elif isinstance(error, commands.CheckFailure): - fmt = "You can't tell me what to do!" - # await ctx.message.channel.send(fmt) - elif isinstance(error, commands.CommandOnCooldown): - m, s = divmod(error.retry_after, 60) - fmt = "This command is on cooldown! Hold your horses! >:c\nTry again in {} minutes and {} seconds" \ - .format(round(m), round(s)) - # await ctx.message.channel.send(fmt) elif isinstance(error, commands.NoPrivateMessage): fmt = "This command cannot be used in a private message" await ctx.message.channel.send(fmt) @@ -113,21 +92,20 @@ async def on_command_error(ctx, error): with open("error_log", 'a') as f: print("In server '{0.message.guild}' at {1}\nFull command: `{0.message.content}`".format(ctx, str(now)), file=f) - try: - traceback.print_tb(error.original.__traceback__, file=f) - print('{0.__class__.__name__}: {0}'.format(error.original), file=f) - except Exception: - traceback.print_tb(error.__traceback__, file=f) - print('{0.__class__.__name__}: {0}'.format(error), file=f) + traceback.print_tb(error.__traceback__, file=f) + print('{0.__class__.__name__}: {0}'.format(error), file=f) except discord.HTTPException: pass if __name__ == '__main__': - bot.loop.create_task(utils.db_check()) bot.remove_command('help') + # Setup our bot vars, db and cache bot.db = utils.DB() - + bot.cache = utils.Cache(bot.db) + # Start our startup tasks + bot.loop.create_task(bot.db.setup()) + bot.loop.create_task(bot.cache.setup()) for e in utils.extensions: bot.load_extension(e) diff --git a/cogs/admin.py b/cogs/admin.py index f806e28..9bc0fbb 100644 --- a/cogs/admin.py +++ b/cogs/admin.py @@ -1,18 +1,14 @@ -from discord.ext import commands - +import discord import utils -import discord -import asyncio -import re +from asyncpg import UniqueViolationError +from discord.ext import commands valid_perms = [p for p in dir(discord.Permissions) if isinstance(getattr(discord.Permissions, p), property)] -class Administration: - """Handles the administration of the bot for a server; this is mainly different settings for the bot""" - def __init__(self, bot): - self.bot = bot +class Admin: + """These are commands that allow more intuitive configuration, that don't fit into the config command""" @commands.command() @commands.guild_only() @@ -23,61 +19,41 @@ class Administration: await ctx.send("You cannot disable `{}`".format(command)) return - cmd = self.bot.get_command(command) + cmd = ctx.bot.get_command(command) if cmd is None: await ctx.send("No command called `{}`".format(command)) return - from_entry = { - 'source': cmd.qualified_name, - 'destination': "everyone", - } - - restrictions = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='restrictions') or {} - _from = restrictions.get('from', []) - if from_entry not in _from: - _from.append(from_entry) - update = { - 'server_id': str(ctx.message.guild.id), - 'restrictions': { - 'from': _from - } - } - await self.bot.db.save('server_settings', update) - await ctx.send("I have disabled `{}`".format(cmd.qualified_name)) + try: + await ctx.bot.db.execute( + "INSERT INTO restrictions (source, destination, from_to, guild) VALUES ($1, 'everyone', 'from', $2)", + cmd.qualified_name, + ctx.guild.id + ) + except UniqueViolationError: + await ctx.send(f"{cmd.qualified_name} is already disabled") else: - await ctx.send("That command is already disabled") + await ctx.send(f"{cmd.qualified_name} is now disabled") @commands.command() @commands.guild_only() @utils.can_run(manage_guild=True) async def enable(self, ctx, *, command): """Enables the use of a command on this server""" - cmd = self.bot.get_command(command) + cmd = ctx.bot.get_command(command) if cmd is None: await ctx.send("No command called `{}`".format(command)) return - from_entry = { - 'source': cmd.qualified_name, - 'destination': "everyone", - } - - restrictions = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='restrictions') or {} - _from = restrictions.get('from', []) - try: - _from.remove(from_entry) - except ValueError: - await ctx.send("That command is not disabled") - else: - update = { - 'server_id': str(ctx.message.guild.id), - 'restrictions': { - 'from': _from - } - } - await self.bot.db.save('server_settings', update) - await ctx.send("I have enabled `{}`".format(cmd.qualified_name)) + query = f""" +DELETE FROM restrictions WHERE +source=$1 AND +from_to='from' AND +destination='everyone' AND +guild=$2 +""" + await ctx.bot.db.execute(query, cmd.qualified_name, ctx.guild.id) + await ctx.send(f"{cmd.qualified_name} is no longer disabled") @commands.command() @commands.guild_only() @@ -88,7 +64,7 @@ class Administration: This sets the role to mentionable, mentions the role, then sets it back """ if not ctx.me.guild_permissions.manage_roles: - await ctx.send("I do not have permissions to edit roles") + await ctx.send("I do not have permissions to edit roles (this is required to complete this command)") return try: await role.edit(mentionable=True) @@ -96,301 +72,11 @@ class Administration: await ctx.send("I do not have permissions to edit that role. " "(I either don't have manage roles permissions, or it is higher on the hierarchy)") else: - fmt = "{}\n{}".format(role.mention, message) + fmt = f"{role.mention}\n{message}" await ctx.send(fmt) await role.edit(mentionable=False) await ctx.message.delete() - @commands.group(invoke_without_command=True) - @commands.guild_only() - @utils.can_run(send_messages=True) - async def battles(self, ctx): - """Used to list the server specific battles messages on this server - - EXAMPLE: !battles - RESULT: A list of the battle messages that can be used on this server""" - msgs = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='battles') - if msgs: - try: - pages = utils.Pages(ctx, entries=msgs) - await pages.paginate() - except utils.CannotPaginate as e: - await ctx.send(str(e)) - else: - await ctx.send("There are no server specific battles on this server!") - - @battles.command(name='add') - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def add_battles(self, ctx, *, message): - """Used to add a battle message to the server specific battle messages - Use {winner} or {loser} in order to display the winner/loser's display name - - EXAMPLE: !battles add {winner} has beaten {loser} - RESULT: Player1 has beaten Player2""" - # Try to simulate the message, to ensure they haven't provided an invalid phrase - try: - message.format(loser="player1", winner="player2") - except Exception: - await ctx.send("That is an invalid format! The winner needs to be " - "labeled with {winner} and the loser with {loser}") - return - - # Now simply load the current messages - msgs = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='battles') or [] - # Append this one - msgs.append("*{}*".format(message)) - # And save it - update = { - 'server_id': str(ctx.message.guild.id), - 'battles': msgs - } - await self.bot.db.save('server_settings', update) - fmt = "I have just saved your new battle message, it will appear like this: \n\n*{}*".format(message) - await ctx.send(fmt.format(loser=ctx.message.author.display_name, winner=ctx.message.guild.me.display_name)) - - @battles.command(name='remove', aliases=['delete']) - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def remove_battles(self, ctx): - """Used to remove one of the custom hugs from the server's list of hug messages - - EXAMPLE: !hugs remove - RESULT: I'll ask which hug you want to remove""" - # First just send the hugs - await ctx.invoke(self.battles) - # Then let them know to respond with the number needed - await ctx.send("Please respond with the number matching the battle message you want to remove") - # The check to ensure it's in this channel...and what's provided is an int - - def check(m): - if m.author == ctx.message.author and m.channel == ctx.message.channel: - try: - return bool(int(m.content)) - except Exception: - return False - else: - return False - - # Get the message - try: - msg = await self.bot.wait_for('message', check=check, timeout=60.0) - except asyncio.TimeoutError: - await ctx.send("You took too long. I'm impatient, don't make me wait") - return - - # Get the number needed - num = int(msg.content) - 1 - msgs = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='battles') - # Try to remove it, if it fails then it doesn't match - try: - msgs.pop(num) - except (IndexError, AttributeError): - await ctx.send("That is not a valid match!") - return - - entry = { - 'server_id': str(ctx.message.guild.id), - 'battles': msgs - } - await self.bot.db.save('server_settings', entry) - await ctx.send("I have just removed that battle message") - - @battles.command(name='default') - @commands.guild_only() - @utils.can_run(send_messages=True) - async def default_battles(self, ctx): - """Used to toggle if battles should include default messages as well as server-custom messages - - EXAMPLE: !hugs default - RESULT: No longer uses both defaults!""" - # Get the setting - setting = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='default_battles') - if setting is None: - setting = True - # Now reverse it - setting = not setting - entry = { - 'server_id': str(ctx.message.guild.id), - 'default_battles': setting - } - await self.bot.db.save('server_settings', entry) - fmt = "" if setting else "not " - await ctx.send("Default messages will {}be used as well as custom messages".format(fmt)) - - @commands.group(invoke_without_command=True) - @commands.guild_only() - @utils.can_run(send_messages=True) - async def hugs(self, ctx): - """Used to list the server specific hug messages on this server - - EXAMPLE: !hugs - RESULT: A list of the hug messages that can be used on this server""" - msgs = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='hugs') - if msgs: - try: - pages = utils.Pages(ctx, entries=msgs) - await pages.paginate() - except utils.CannotPaginate as e: - await ctx.send(str(e)) - else: - await ctx.send("There are no server specific hugs on this server!") - - @hugs.command(name='add') - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def add_hugs(self, ctx, *, message): - """Used to add a hug to the server specific hug messages - Use {user} in order to display the user's display name - - EXAMPLE: !hugs add I hugged {user} - RESULT: *new hug message that says I hugged UserName*""" - # Try to simulate the message, to ensure they haven't provided an invalid phrase - try: - message.format(user="user") - except Exception: - await ctx.send("That is an invalid format! The user being hugged needs to be labeled with {user}") - return - - msgs = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='hugs') or [] - msgs.append("*{}*".format(message)) - update = { - 'server_id': str(ctx.message.guild.id), - 'hugs': msgs - } - await self.bot.db.save('server_settings', update) - fmt = "I have just saved your new hug message, it will appear like this: \n\n*{}*".format(message) - await ctx.send(fmt.format(user=ctx.message.author.display_name)) - - @hugs.command(name='remove', aliases=['delete']) - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def remove_hugs(self, ctx): - """Used to remove one of the custom hugs from the server's list of hug messages - - EXAMPLE: !hugs remove - RESULT: I'll ask which hug you want to remove""" - # First just send the hugs - await ctx.invoke(self.hugs) - # Then let them know to respond with the number needed - await ctx.send("Please respond with the number matching the hug message you want to remove") - # The check to ensure it's in this channel...and what's provided is an int - - def check(m): - if m.author == ctx.message.author and m.channel == ctx.message.channel: - try: - return bool(int(m.content)) - except Exception: - return False - else: - return False - - # Get the message - try: - msg = await self.bot.wait_for('message', check=check, timeout=60.0) - except asyncio.TimeoutError: - await ctx.send("You took too long. I'm impatient, don't make me wait") - return - - # Get the number needed - num = int(msg.content) - 1 - msgs = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='hugs') - # Try to remove it, if it fails then it doesn't match - try: - msgs.pop(num) - except (IndexError, AttributeError): - await ctx.send("That is not a valid match!") - return - - entry = { - 'server_id': str(ctx.message.guild.id), - 'hugs': msgs - } - await self.bot.db.save('server_settings', entry) - await ctx.send("I have just removed that hug message") - - @hugs.command(name='default') - @commands.guild_only() - @utils.can_run(send_messages=True) - async def default_hugs(self, ctx): - """Used to toggle if hugs should include default messages as well as server-custom messages - - EXAMPLE: !hugs default - RESULT: No longer uses both defaults!""" - # Get the setting - setting = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='default_hugs') - if setting is None: - setting = True - # Now reverse it - setting = not setting - entry = { - 'server_id': str(ctx.message.guild.id), - 'default_hugs': setting - } - await self.bot.db.save('server_settings', entry) - fmt = "" if setting else "not " - await ctx.send("Default messages will {}be used as well as custom messages".format(fmt)) - - @commands.command() - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def allowbirthdays(self, ctx, setting): - """Turns on/off the birthday announcements in this server - - EXAMPLE: !allowbirthdays on - RESULT: Birthdays will now be announced""" - if setting.lower() in ['on', 'yes', 'true']: - allowed = True - else: - allowed = False - entry = { - 'server_id': str(ctx.message.guild.id), - 'birthdays_allowed': allowed - } - await self.bot.db.save('server_settings', entry) - fmt = "The birthday announcements have just been turned {}".format("on" if allowed else "off") - await ctx.send(fmt) - - @commands.command() - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def allowcolours(self, ctx, setting): - """Turns on/off the ability to use colour roles in this server - - EXAMPLE: !allowcolours on - RESULT: Colour roles can now be used in this server""" - if setting.lower() in ['on', 'yes', 'true']: - allowed = True - else: - allowed = False - entry = { - 'server_id': str(ctx.message.guild.id), - 'colour_roles_allowed': allowed - } - await self.bot.db.save('server_settings', entry) - fmt = "The ability to use colour roles have just been turned {}".format("on" if allowed else "off") - await ctx.send(fmt) - - @commands.command() - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def allowplaylists(self, ctx, setting): - """Turns on/off the ability to playlists - - EXAMPLE: !allowplaylists on - RESULT: Playlists can now be used""" - if setting.lower() in ['on', 'yes', 'true']: - allowed = True - else: - allowed = False - entry = { - 'server_id': str(ctx.message.guild.id), - 'playlists_allowed': allowed - } - await self.bot.db.save('server_settings', entry) - fmt = "The ability to use playlists has just been turned {}".format("on" if allowed else "off") - await ctx.send(fmt) - @commands.command() @commands.guild_only() @utils.can_run(kick_members=True) @@ -399,24 +85,20 @@ class Administration: EXAMPLE: !restrictions RESULT: All the current restrictions""" - # Get the restrictions - restrictions = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='restrictions') or {} + restrictions = await ctx.bot.db.fetch( + "SELECT source, destination, from_to FROM restrictions WHERE guild=$1", + ctx.guild.id + ) + entries = [] - # Loop through all the from restrictions - for _from in restrictions.get('from', []): - source = _from.get('source') - # Resolve our destination based on the ID - dest = await utils.convert(ctx, _from.get('destination')) - # Don't add it if it doesn't exist + for restriction in restrictions: + # Check whether it's from or to to change what the format looks like + dest = restriction["destination"] + if dest != "everyone": + dest = await utils.convert(ctx, restriction["destination"]) + # If it doesn't exist, don't add it if dest: - entries.append("{} from {}".format(source, dest)) - for _to in restrictions.get('to', []): - source = _to.get('source') - # Resolve our destination based on the ID - dest = await utils.convert(ctx, _to.get('destination')) - # Don't add it if it doesn't exist - if dest: - entries.append("{} to {}".format(source, dest)) + entries.append(f"{restriction['source']} {'from' if restriction['from_to'] == 'from' else 'to'} {dest}") if entries: # Then paginate @@ -448,8 +130,7 @@ class Administration: await ctx.send("You need to provide 3 options! Such as `command from @User`") return elif ctx.message.mention_everyone: - arg1, arg2, arg3 = options - await ctx.send("Please do not restrict something {} everyone".format(arg2)) + await ctx.send("Please do not use this command to 'disable from everyone'. Use the `disable` command") return else: # Get the three arguments from this list, then make sure the 2nd is either from or to @@ -467,8 +148,9 @@ class Administration: await ctx.send("Sorry, but I don't know how to restrict {} {} {}".format(arg1, arg2, arg3)) return - from_entry = None - to_entry = None + from_to = arg2 + source = None + destination = None overwrites = None # The possible options: @@ -484,19 +166,15 @@ class Administration: # Roles - Command can't be ran by anyone in this role (least likely, but still possible uses) if arg2 == "from": if isinstance(option2, (discord.Member, discord.Role, discord.TextChannel)): - from_entry = { - 'source': option1.qualified_name, - 'destination': str(option2.id) - } + source = option1.qualified_name + destination = str(option2.id) # To: # Channels - Command can only be run in this channel # Roles - This role is required in order to run this command else: if isinstance(option2, (discord.Role, discord.TextChannel)): - to_entry = { - 'source': option1.qualified_name, - 'destination': str(option2.id) - } + source = option1.qualified_name + destination = str(option2.id) elif isinstance(option1, discord.Member): # From: # Channels - Setup an overwrite for this channel so that they cannot read it @@ -514,10 +192,8 @@ class Administration: option1: ov } elif isinstance(option2, (commands.core.Command, commands.core.Group)): - from_entry = { - 'source': option2.qualified_name, - 'destination': str(option1.id) - } + source = option2.qualified_name + destination = str(option1.id) elif isinstance(option1, (discord.TextChannel, discord.VoiceChannel)): # From: # Command - Command cannot be used in this channel @@ -537,20 +213,16 @@ class Administration: } elif isinstance(option2, (commands.core.Command, commands.core.Group)) \ and isinstance(option1, discord.TextChannel): - from_entry = { - 'source': option2.qualified_name, - 'destination': str(option1.id) - } + source = option2.qualified_name + destination = str(option1.id) # To: # Command - Command can only be used in this channel # Role - Setup an overwrite so only this role can read this channel else: if isinstance(option2, (commands.core.Command, commands.core.Group)) \ and isinstance(option1, discord.TextChannel): - to_entry = { - 'source': option2.qualified_name, - 'destination': str(option1.id) - } + source = option2.qualified_name + destination = str(option1.id) elif isinstance(option2, (discord.Member, discord.Role)): ov = discord.utils.find(lambda t: t[0] == option2, option1.overwrites) if ov: @@ -576,10 +248,8 @@ class Administration: # Channel - Setup an overwrite for this channel so that this Role cannot read it if arg2 == "from": if isinstance(option2, (commands.core.Command, commands.core.Group)): - from_entry = { - 'source': option2.qualified_name, - 'destination': str(option1.id) - } + source = option2.qualified_name + destination = option1.id elif isinstance(option2, (discord.TextChannel, discord.VoiceChannel)): ov = discord.utils.find(lambda t: t[0] == option1, option2.overwrites) if ov: @@ -615,35 +285,22 @@ class Administration: ctx.message.guild.default_role: ov2 } elif isinstance(option2, (commands.core.Command, commands.core.Group)): - to_entry = { - 'source': option2.qualified_name, - 'destination': str(option1.id) - } + source = option2.qualified_name + destination = str(option1.id) - if to_entry: - restrictions = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='restrictions') or {} - to = restrictions.get('to', []) - if to_entry not in to: - to.append(to_entry) - update = { - 'server_id': str(ctx.message.guild.id), - 'restrictions': { - 'to': to - } - } - await self.bot.db.save('server_settings', update) - elif from_entry: - restrictions = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='restrictions') or {} - _from = restrictions.get('from', []) - if from_entry not in _from: - _from.append(from_entry) - update = { - 'server_id': str(ctx.message.guild.id), - 'restrictions': { - 'from': _from - } - } - await self.bot.db.save('server_settings', update) + if source is not None and destination is not None: + try: + await ctx.bot.db.execute( + "INSERT INTO restrictions (guild, source, destination, from_to) VALUES ($1, $2, $3, $4)", + ctx.guild.id, + source, + destination, + from_to + ) + except UniqueViolationError: + # If it's already inserted, then nothing needs to be updated + # It just meansthis particular restriction is already set + pass elif overwrites: channel = overwrites.pop('channel') for target, setting in overwrites.items(): @@ -692,32 +349,22 @@ class Administration: # The source should always be the command, so just set this based on which order is given (either is # allowed) if isinstance(option1, (commands.core.Command, commands.core.Group)): - restriction = { - 'source': option1.qualified_name, - 'destination': str(option2.id) - } + source = option1.qualified_name + destination = str(option2.id) else: - restriction = { - 'source': option2.qualified_name, - 'destination': str(option1.id) - } + source = option2.qualified_name + destination = str(option1.id) + + # Now just try to remove it + await ctx.bot.db.execute(""" +DELETE FROM + restrictions +WHERE + source=$1 AND + destination=$2 AND + from_to=$3 AND + guild=$4""", source, destination, arg2, ctx.guild.id) - # Load restrictions - restrictions = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='restrictions') or {} - # Attempt to remove the restriction provided - try: - restrictions.get(arg2, []).remove(restriction) - # If it doesn't exist, nothing is needed to be done - except ValueError: - await ctx.send("The restriction {} {} {} does not exist!".format(arg1, arg2, arg3)) - return - # If it was removed succesfully, save the change and let the author know this has been done - else: - entry = { - 'server_id': str(ctx.message.guild.id), - 'restrictions': restrictions - } - await self.bot.db.save('server_settings', entry) # If this isn't a blacklist/whitelist, then we are attempting to remove an overwrite else: # Get the source and destination based on whatever order is provided @@ -751,288 +398,17 @@ class Administration: await ctx.send("I have just unrestricted {} {} {}".format(arg1, arg2, arg3)) - @commands.command(aliases=['nick']) - @commands.guild_only() - @utils.can_run(kick_members=True) - async def nickname(self, ctx, *, name=None): - """Used to set the nickname for Bonfire (provide no nickname and it will reset) - - EXAMPLE: !nick Music Bot - RESULT: My nickname is now Music Bot""" - try: - await ctx.message.guild.me.edit(nick=name) - except discord.HTTPException: - await ctx.send("Sorry but I can't change my nickname to {}".format(name)) - else: - await ctx.send("\N{OK HAND SIGN}") - - @commands.command() - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def ignore(self, ctx, member_or_channel): - """This command can be used to have Bonfire ignore certain members/channels - - EXAMPLE: !ignore #general - RESULT: Bonfire will ignore commands sent in the general channel""" - key = ctx.message.guild.id - - converter = commands.converter.MemberConverter() - member = None - channel = None - try: - member = await converter.convert(ctx, member_or_channel) - except commands.converter.BadArgument: - converter = commands.converter.TextChannelConverter() - try: - channel = await converter.convert(ctx, member_or_channel) - except commands.converter.BadArgument: - await ctx.send("{} does not appear to be a member or channel!".format(member_or_channel)) - return - - settings = self.bot.db.load('server_settings', key=key, pluck='ignored') or {} - ignored = settings.get('ignored', {'members': [], 'channels': []}) - if member: - if str(member.id) in ignored['members']: - await ctx.send("I am already ignoring {}!".format(member.display_name)) - return - elif member.guild_permissions >= ctx.message.author.guild_permissions: - await ctx.send("You cannot make me ignore someone at equal or higher rank than you!") - return - else: - ignored['members'].append(str(member.id)) - fmt = "Ignoring {}".format(member.display_name) - else: - if str(channel.id) in ignored['channels']: - await ctx.send("I am already ignoring {}!".format(channel.mention)) - return - else: - ignored['channels'].append(str(channel.id)) - fmt = "Ignoring {}".format(channel.mention) - - entry = { - 'ignored': ignored, - 'server_id': str(key) - } - - await self.bot.db.save('server_settings', entry) - await ctx.send(fmt) - - @commands.command() - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def unignore(self, ctx, member_or_channel): - """This command can be used to have Bonfire stop ignoring certain members/channels - - EXAMPLE: !unignore #general - RESULT: Bonfire will no longer ignore commands sent in the general channel""" - key = str(ctx.message.guild.id) - - converter = commands.converter.MemberConverter() - member = None - channel = None - try: - member = await converter.convert(ctx, member_or_channel) - except commands.converter.BadArgument: - converter = commands.converter.TextChannelConverter() - try: - channel = await converter.convert(ctx, member_or_channel) - except commands.converter.BadArgument: - await ctx.send("{} does not appear to be a member or channel!".format(member_or_channel)) - return - - settings = self.bot.db.load('server_settings', key=key) or {} - ignored = settings.get('ignored', {'members': [], 'channels': []}) - if member: - if str(member.id) not in ignored['members']: - await ctx.send("I'm not even ignoring {}!".format(member.display_name)) - return - - ignored['members'].remove(str(member.id)) - fmt = "I am no longer ignoring {}".format(member.display_name) - else: - if str(channel.id) not in ignored['channels']: - await ctx.send("I'm not even ignoring {}!".format(channel.mention)) - return - - ignored['channels'].remove(str(channel.id)) - fmt = "I am no longer ignoring {}".format(channel.mention) - - entry = { - 'ignored': ignored, - 'server_id': str(key) - } - - await self.bot.db.save('server_settings', entry) - await ctx.send(fmt) - - @commands.command(aliases=['notifications']) - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def alerts(self, ctx, channel: discord.TextChannel): - """This command is used to set a channel as the server's default 'notifications' channel - Any notifications (like someone going live on Twitch, or Picarto) will go to that channel by default - This can be overridden with specific alerts command, such as `!picarto alerts #channel` - This command is just the default; the one used if there is no other one set. - - EXAMPLE: !alerts #alerts - RESULT: No more alerts spammed in #general!""" - entry = { - 'server_id': str(ctx.message.guild.id), - 'notifications': { - 'default': str(channel.id) - } - } - - await self.bot.db.save('server_settings', entry) - await ctx.send("I have just changed this server's default 'notifications' channel" - "\nAll notifications will now default to `{}`".format(channel)) - - @commands.group(invoke_without_command=True, aliases=['goodbye']) - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def welcome(self, ctx, on_off: str): - """This command can be used to set whether or not you want user notificaitons to show - Provide on, yes, or true to set it on; otherwise it will be turned off - - EXAMPLE: !welcome on - RESULT: Annoying join/leave notifications! Yay!""" - # Join/Leave notifications can be kept separate from normal alerts - # So we base this channel on it's own and not from alerts - # When mod logging becomes available, that will be kept to it's own channel if wanted as well - on_off = True if re.search("(on|yes|true)", on_off.lower()) else False - - entry = { - 'server_id': str(ctx.message.guild.id), - 'join_leave': on_off - } - - await self.bot.db.save('server_settings', entry) - fmt = "notify" if on_off else "not notify" - await ctx.send("This server will now {} if someone has joined or left".format(fmt)) - - @welcome.command(name='alerts', aliases=['notifications']) - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def _welcome_alerts(self, ctx, *, channel: discord.TextChannel): - """A command used to set the override for notifications about users joining/leaving - - EXAMPLE: !welcome alerts #notifications - RESULT: All user joins/leaves will be sent to the #notificatoins channel""" - entry = { - 'server_id': str(ctx.message.guild.id), - 'notifications': { - 'welcome': str(channel.id) - } - } - - await self.bot.db.save('server_settings', entry) - await ctx.send( - "I have just changed this server's welcome/goodbye notifications channel to {}".format(channel.name)) - - @welcome.command(name='message') - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def _welcome_message(self, ctx, *, msg): - """A command to customize the welcome/goodbye message - There are a couple things that can be set to customize the message - {member} - Will mention the user joining - {server} - Will display the server's name - - Give no message and it will be set to the default - EXAMPLE: !welcome message {member} to {server} - RESULT: Welcome Member#1234 to ServerName""" - parent = ctx.message.content.split()[0] - parent = parent[len(ctx.prefix):] - - if re.search("{.*token.*}", msg): - await ctx.send("Illegal content in {} message".format(parent)) - else: - try: - msg.format(member='test', server='test') - except KeyError: - await ctx.send("Illegal keyword in {0} message. Please use `{1.prefix}help {0} message` " - "for what keywords can be used".format(parent, ctx)) - return - entry = { - 'server_id': str(ctx.message.guild.id), - parent + '_message': msg - } - await self.bot.db.save('server_settings', entry) - await ctx.send("I have just updated your {} message".format(parent)) - - @commands.group() - async def nsfw(self, ctx): - """Handles adding or removing a channel as a nsfw channel""" - # This command isn't meant to do anything, so just send an error if an invalid subcommand is passed - pass - - @nsfw.command(name="add") - @utils.can_run(kick_members=True) - async def nsfw_add(self, ctx): - """Registers this channel as a 'nsfw' channel - - EXAMPLE: !nsfw add - RESULT: ;)""" - - if type(ctx.message.channel) is discord.DMChannel: - key = 'DMs' - else: - key = str(ctx.message.guild.id) - - channels = self.bot.db.load('server_settings', key=key, pluck='nsfw_channels') or [] - channels.append(str(ctx.message.channel.id)) - - entry = { - 'server_id': key, - 'nsfw_channels': channels - } - - await self.bot.db.save('server_settings', entry) - - await ctx.send("This channel has just been registered as 'nsfw'! Have fun you naughties ;)") - - @nsfw.command(name="remove", aliases=["delete"]) - @utils.can_run(kick_members=True) - async def nsfw_remove(self, ctx): - """Removes this channel as a 'nsfw' channel - - EXAMPLE: !nsfw remove - RESULT: ;(""" - channel = str(ctx.message.channel.id) - if type(ctx.message.channel) is discord.DMChannel: - key = 'DMs' - else: - key = str(ctx.message.guild.id) - - channels = self.bot.db.load('server_settings', key=key, pluck='nsfw_channels') or [] - if channel in channels: - channels.remove(channel) - - entry = { - 'server_id': key, - 'nsfw_channels': channels - } - await self.bot.db.save('server_settings', entry) - await ctx.send("This channel has just been unregistered as a nsfw channel") - else: - await ctx.send("This channel is not registerred as a nsfw channel!") - @commands.group(invoke_without_command=True) @commands.guild_only() @utils.can_run(send_messages=True) - async def perms(self, ctx, *, command: str = None): + async def perms(self, ctx, *, command: str): """This command can be used to print the current allowed permissions on a specific command This supports groups as well as subcommands; pass no argument to print a list of available permissions EXAMPLE: !perms help RESULT: Hopefully a result saying you just need send_messages permissions; otherwise lol this server's admin doesn't like me """ - if command is None: - await ctx.send( - "Valid permissions are: ```\n{}```".format("\n".join("{}".format(i) for i in valid_perms))) - return - - cmd = self.bot.get_command(command) + cmd = ctx.bot.get_command(command) if cmd is None: # If a command wasn't provided, see if a user was @@ -1054,10 +430,13 @@ class Administration: embed.add_field(name="Allowed permissions", value="\n".join(perms)) await ctx.send(embed=embed) return + result = await ctx.bot.db.fetchrow( + "SELECT permission FROM custom_permissions WHERE guild = $1 AND command = $2", + ctx.guild.id, + command + ) + perms_value = result["permission"] if result else None - server_perms = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='permissions') or {} - - perms_value = server_perms.get(cmd.qualified_name) if perms_value is None: # If we don't find custom permissions, get the required permission for a command # based on what we set in utils.can_run, if can_run isn't found, we'll get an IndexError @@ -1092,7 +471,7 @@ class Administration: @perms.command(name="add", aliases=["setup,create"]) @commands.guild_only() @commands.has_permissions(manage_guild=True) - async def add_perms(self, ctx, *msg: str): + async def add_perms(self, ctx, *, msg: str): """Sets up custom permissions on the provided command Format must be 'perms add ' If you want to open the command to everyone, provide 'none' as the permission @@ -1101,19 +480,13 @@ class Administration: RESULT: No more random people voting to skip a song""" # Since subcommands exist, base the last word in the list as the permission, and the rest of it as the command - command = " ".join(msg[0:len(msg) - 1]) + command, _, permission = msg.rpartition(" ") if command == "": await ctx.send("Please provide the permissions you want to setup, the format for this must be in:\n" "`perms add `") return - try: - permissions = msg[len(msg) - 1] - except IndexError: - await ctx.send("Please provide the permissions you want to setup, the format for this must be in:\n" - "`perms add `") - return - cmd = self.bot.get_command(command) + cmd = ctx.bot.get_command(command) if cmd is None: await ctx.send( @@ -1121,17 +494,17 @@ class Administration: return # If a user can run a command, they have to have send_messages permissions; so use this as the base - if permissions.lower() == "none": - permissions = "send_messages" + if permission.lower() == "none": + permission = "send_messages" # Convert the string to an int value of the permissions object, based on the required permission # If we hit an attribute error, that means the permission given was not correct perm_obj = discord.Permissions.none() try: - setattr(perm_obj, permissions, True) + setattr(perm_obj, permission, True) except AttributeError: await ctx.send("{} does not appear to be a valid permission! Valid permissions are: ```\n{}```" - .format(permissions, "\n".join(valid_perms))) + .format(permission, "\n".join(valid_perms))) return perm_value = perm_obj.value @@ -1144,15 +517,15 @@ class Administration: await ctx.send("This command cannot have custom permissions setup!") return - entry = { - 'server_id': str(ctx.message.guild.id), - 'permissions': {cmd.qualified_name: perm_value} - } - - await self.bot.db.save('server_settings', entry) + await ctx.bot.db.execute( + "INSERT INTO custom_permissions (guild, command, permission) VALUES ($1, $2, $3)", + ctx.guild.id, + cmd.qualified_name, + perm_value + ) await ctx.send("I have just added your custom permissions; " - "you now need to have `{}` permissions to use the command `{}`".format(permissions, command)) + "you now need to have `{}` permissions to use the command `{}`".format(permission, command)) @perms.command(name="remove", aliases=["delete"]) @commands.guild_only() @@ -1163,122 +536,34 @@ class Administration: EXAMPLE: !perms remove play RESULT: Freedom!""" - cmd = self.bot.get_command(command) + cmd = ctx.bot.get_command(command) if cmd is None: await ctx.send( "That command does not exist! You can't have custom permissions on a non-existant command....") return - entry = { - 'server_id': str(ctx.message.guild.id), - 'permissions': {cmd.qualified_name: None} - } + await ctx.bot.db.execute( + "DELETE FROM custom_permissions WHERE guild=$1 AND command=$2", ctx.guild.id, cmd.qualified_name + ) - await self.bot.db.save('server_settings', entry) await ctx.send("I have just removed the custom permissions for {}!".format(cmd)) - @commands.command() + @commands.command(aliases=['nick']) @commands.guild_only() - @utils.can_run(manage_guild=True) - async def prefix(self, ctx, *, prefix: str): - """This command can be used to set a custom prefix per server + @utils.can_run(kick_members=True) + async def nickname(self, ctx, *, name=None): + """Used to set the nickname for Bonfire (provide no nickname and it will reset) - EXAMPLE: !prefix $ - RESULT: You now need to call commands like: $help""" - key = str(ctx.message.guild.id) - if len(prefix.strip()) > 20: - await ctx.send("Please keep prefixes under 20 characters") - return - if prefix.lower().strip() == "none": - prefix = None - - entry = { - 'server_id': key, - 'prefix': prefix - } - - await self.bot.db.save('server_settings', entry) - - if prefix is None: - fmt = "I have just cleared your custom prefix, the default prefix will have to be used now" - else: - fmt = "I have just updated the prefix for this server; you now need to call commands with `{0}`. " \ - "For example, you can call this command again with {0}prefix".format(prefix) - await ctx.send(fmt) - - @commands.group(aliases=['rule'], invoke_without_command=True) - @commands.guild_only() - @utils.can_run(send_messages=True) - async def rules(self, ctx, rule: int = None): - """This command can be used to view the current rules on the server - - EXAMPLE: !rules 5 - RESULT: Rule 5 is printed""" - rules = self.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='rules') - - if rules is None: - await ctx.send("This server currently has no rules on it! I see you like to live dangerously...") - return - - if rule is None: - try: - pages = utils.Pages(ctx, entries=rules, per_page=5) - pages.title = "Rules for {}".format(ctx.message.guild.name) - await pages.paginate() - except utils.CannotPaginate as e: - await ctx.send(str(e)) - else: - try: - fmt = rules[rule - 1] - except IndexError: - await ctx.send("That rules does not exist.") - return - await ctx.send("Rule {}: \"{}\"".format(rule, fmt)) - - @rules.command(name='add', aliases=['create']) - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def rules_add(self, ctx, *, rule: str): - """Adds a rule to this server's rules - - EXAMPLE: !rules add No fun allowed in this server >:c - RESULT: No more fun...unless they break the rules!""" - key = str(ctx.message.guild.id) - rules = self.bot.db.load('server_settings', key=key, pluck='rules') or [] - rules.append(rule) - - entry = { - 'server_id': key, - 'rules': rules - } - - await self.bot.db.save('server_settings', entry) - - await ctx.send("I have just saved your new rule, use the rules command to view this server's current rules") - - @rules.command(name='remove', aliases=['delete']) - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def rules_delete(self, ctx, rule: int): - """Removes one of the rules from the list of this server's rules - Provide a number to delete that rule - - EXAMPLE: !rules delete 5 - RESULT: Freedom from opression!""" - key = str(ctx.message.guild.id) - rules = self.bot.db.load('server_settings', key=key, pluck='rules') or [] + EXAMPLE: !nick Music Bot + RESULT: My nickname is now Music Bot""" try: - rules.pop(rule - 1) - entry = { - 'server_id': key, - 'rules': rules - } - await self.bot.db.save('server_settings', entry) - await ctx.send("I have just removed that rule from your list of rules!") - except IndexError: - await ctx.send("That is not a valid rule number, try running the command again.") + await ctx.message.guild.me.edit(nick=name) + except discord.HTTPException: + await ctx.send("Sorry but I can't change my nickname to {}".format(name)) + else: + await ctx.send("\N{OK HAND SIGN}") def setup(bot): - bot.add_cog(Administration(bot)) + bot.add_cog(Admin()) diff --git a/cogs/birthday.py b/cogs/birthday.py index 3c59eb1..c1f90b4 100644 --- a/cogs/birthday.py +++ b/cogs/birthday.py @@ -1,42 +1,17 @@ import discord -import pendulum +import datetime import asyncio import traceback import re +import calendar from discord.ext import commands +from asyncpg import UniqueViolationError import utils -tzmap = { - 'us-central': pendulum.timezone('US/Central'), - 'eu-central': pendulum.timezone('Europe/Paris'), - 'hongkong': pendulum.timezone('Hongkong'), - -} - - -def sort_birthdays(bds): - # First sort the birthdays based on the comparison of the actual date - bds = sorted(bds, key=lambda x: x['birthday']) - # We want to split this into birthdays after and before todays date - # We can then use this to sort based on "whose is closest" - later_bds = [] - previous_bds = [] - # Loop through each birthday - for bd in bds: - # If it is after or equal to today, insert into our later list - if bd['birthday'] >= pendulum.today().date(): - later_bds.append(bd) - # Otherwise, insert into our previous list - else: - previous_bds.append(bd) - # At this point we have 2 lists, in order, one from all of dates before today, and one after - # So all we need to do is put them in order all of "laters" then all of "befores" - return later_bds + previous_bds - def parse_string(date): - year = pendulum.now().year + today = datetime.date.today() month = None day = None month_map = { @@ -74,84 +49,104 @@ def parse_string(date): elif part in month_map: month = month_map.get(part) if month and day: - return pendulum.date(year, month, day) + year = today.year + if month < today.month: + year += 1 + elif month == today.month and day <= today.day: + year += 1 + return datetime.date(year, month, day) class Birthday: - """Track and announcebirthdays""" + """Track and announce birthdays""" def __init__(self, bot): self.bot = bot self.task = self.bot.loop.create_task(self.birthday_task()) - def get_birthdays_for_server(self, server, today=False): - bds = self.bot.db.load('birthdays') - # Get a list of the ID's to compare against - member_ids = [str(m.id) for m in server.members] + async def get_birthdays_for_server(self, server, today=False): + members = ", ".join(f"{m.id}" for m in server.members) + query = f""" +SELECT + id, birthday +FROM + users +WHERE + id IN ({members}) +""" + if today: + query += """ +AND + birthday = CURRENT_DATE +""" + query += """ +ORDER BY + birthday +""" - # Now create a list comparing to the server's list of member IDs - bds = [ - bd - for member_id, bd in bds.items() - if str(member_id) in member_ids - ] - - _entries = [] - - for bd in bds: - if not bd['birthday']: - continue - - day = parse_string(bd['birthday']) - # tz = tzmap.get(server.region) - # Check if it's today, and we want to only get todays birthdays - if (today and day == pendulum.today().date()) or not today: - # If so, get the member and add them to the entry - member = server.get_member(int(bd['member_id'])) - _entries.append({ - 'birthday': day, - 'member': member - }) - - return sort_birthdays(_entries) + return await self.bot.db.fetch(query) async def birthday_task(self): - while True: + await self.bot.wait_until_ready() + + while not self.bot.is_closed(): try: await self.notify_birthdays() except Exception as error: with open("error_log", 'a') as f: traceback.print_tb(error.__traceback__, file=f) - print('{0.__class__.__name__}: {0}'.format(error), file=f) + print(f"{error.__class__.__name__}: {error}", file=f) finally: - # Every 12 hours, this is not something that needs to happen often - await asyncio.sleep(60 * 60 * 12) + # Every day + await asyncio.sleep(60 * 60 * 24) async def notify_birthdays(self): - tfilter = {'birthdays_allowed': True} - servers = await self.bot.db.actual_load('server_settings', table_filter=tfilter) + query = """ +SELECT + id, COALESCE(birthday_alerts, default_alerts) AS channel +FROM + guilds +WHERE + birthday_notifications=True +AND + COALESCE(birthday_alerts, default_alerts) IS NOT NULL +""" + servers = await self.bot.db.fetch(query) + update_bds = [] + if not servers: + return + for s in servers: - server = self.bot.get_guild(int(s['server_id'])) - if not server: + # Get the channel based on the birthday alerts, or default alerts channel + channel = self.bot.get_channel(s['channel']) + if not channel: continue - # Set our default to either the one set - default_channel_id = s.get('notifications', {}).get('default') - # If it is has been overriden by picarto notifications setting, use this - channel_id = s.get('notifications', {}).get('birthday') or default_channel_id - if not channel_id: - continue + bds = await self.get_birthdays_for_server(channel.guild, today=True) - # Now get the channel based on that ID - channel = server.get_channel(int(channel_id)) - - bds = self.get_birthdays_for_server(server, today=True) + # A list of the id's that will get updated for bd in bds: try: - await channel.send("It is {}'s birthday today! " - "Wish them a happy birthday! \N{SHORTCAKE}".format(bd['member'].mention)) - except (discord.Forbidden, discord.HTTPException, AttributeError): + await channel.send(f"It is {bd['member'].mention}'s birthday today! " + "Wish them a happy birthday! \N{SHORTCAKE}") + except (discord.Forbidden, discord.HTTPException): pass + finally: + update_bds.append(bd['id']) + + if not update_bds: + return + + query = f""" +UPDATE + users +SET + birthday = birthday + interval '1 year' +WHERE + id IN ({", ".join(f"'{bd}'" for bd in update_bds)}) +""" + print(query) + await self.bot.db.execute(query) @commands.group(aliases=['birthdays'], invoke_without_command=True) @commands.guild_only() @@ -162,20 +157,29 @@ class Birthday: EXAMPLE: !birthdays RESULT: A printout of the birthdays from everyone on this server""" if member: - date = self.bot.db.load('birthdays', key=member.id, pluck='birthday') + date = await self.bot.db.fetchrow("SELECT birthday FROM users WHERE id=$1", member.id) + date = date['birthday'] if date: - await ctx.send("{}'s birthday is {}".format(member.display_name, date)) + await ctx.send(f"{member.display_name}'s birthday is {calendar.month_name[date.month]} {date.day}") else: - await ctx.send("I do not have {}'s birthday saved!".format(member.display_name)) + await ctx.send(f"I do not have {member.display_name}'s birthday saved!") else: # Get this server's birthdays - bds = self.get_birthdays_for_server(ctx.message.guild) + bds = await self.get_birthdays_for_server(ctx.guild) # Create entries based on the user's display name and their birthday - entries = ["{} ({})".format(bd['member'].display_name, bd['birthday'].strftime("%B %-d")) for bd in bds] + entries = [ + f"{ctx.guild.get_member(bd['id']).display_name} ({bd['birthday'].strftime('%B %-d')})" + for bd in bds + if bd['birthday'] + ] + if not entries: + await ctx.send("I don't know anyone's birthday in this server!") + return + # Create our pages object try: pages = utils.Pages(ctx, entries=entries, per_page=5) - pages.title = "Birthdays for {}".format(ctx.message.guild.name) + pages.title = f"Birthdays for {ctx.guild.name}" await pages.paginate() except utils.CannotPaginate as e: await ctx.send(str(e)) @@ -196,13 +200,11 @@ class Birthday: await ctx.send("Please provide date in a valid format, such as December 1st!") return - date = date.strftime("%B %-d") - entry = { - 'member_id': str(ctx.message.author.id), - 'birthday': date - } - await self.bot.db.save('birthdays', entry) - await ctx.send("I have just saved your birthday as {}".format(date)) + await ctx.send(f"I have just saved your birthday as {date}") + try: + await self.bot.db.execute("INSERT INTO users (id, birthday) VALUES ($1, $2)", ctx.author.id, date) + except UniqueViolationError: + await self.bot.db.execute("UPDATE users SET birthday = $1 WHERE id = $2", date, ctx.author.id) @birthday.command(name='remove') @utils.can_run(send_messages=True) @@ -211,30 +213,8 @@ class Birthday: EXAMPLE: !birthday remove RESULT: I have magically forgotten your birthday""" - entry = { - 'member_id': str(ctx.message.author.id), - 'birthday': None - } - await self.bot.db.save('birthdays', entry) await ctx.send("I don't know your birthday anymore :(") - - @birthday.command(name='alerts', aliases=['notifications']) - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def birthday_alerts_channel(self, ctx, channel: discord.TextChannel): - """Sets the notifications channel for birthday notifications - - EXAMPLE: !birthday alerts #birthday - RESULT: birthday notifications will go to this channel - """ - entry = { - 'server_id': str(ctx.message.guild.id), - 'notifications': { - 'birthday': str(channel.id) - } - } - await self.bot.db.save('server_settings', entry) - await ctx.send("All birthday notifications will now go to {}".format(channel.mention)) + await self.bot.db.execute("UPDATE users SET birthday=NULL WHERE id=$1", ctx.author.id) def setup(bot): diff --git a/cogs/config.py b/cogs/config.py new file mode 100644 index 0000000..e48af42 --- /dev/null +++ b/cogs/config.py @@ -0,0 +1,556 @@ +from discord.ext import commands +from asyncpg import UniqueViolationError + +import utils + +import discord + +valid_perms = [p for p in dir(discord.Permissions) if isinstance(getattr(discord.Permissions, p), property)] + + +class ConfigException(Exception): + pass + + +class WrongSettingType(ConfigException): + + def __init__(self, message): + self.message = message + + +class MessageFormatError(ConfigException): + + def __init__(self, original, keys): + self.original = original + self.keys = keys + + +class GuildConfiguration: + """Handles configuring the different settings that can be used on the bot""" + + def _str_to_bool(self, opt, setting): + setting = setting.title() + if setting.title() not in ["True", "False"]: + raise WrongSettingType( + f"The {opt} setting requires either 'True' or 'False', not {setting}" + ) + + return setting.title() == "True" + + async def _get_channel(self, ctx, setting): + converter = commands.converter.TextChannelConverter() + return await converter.convert(ctx, setting) + + async def _set_db_guild_opt(self, opt, setting, ctx): + try: + return await ctx.bot.db.execute(f"INSERT INTO guilds (id, {opt}) VALUES ($1, $2)", ctx.guild.id, setting) + except UniqueViolationError: + return await ctx.bot.db.execute(f"UPDATE guilds SET {opt} = $1 WHERE id = $2", setting, ctx.guild.id) + + # These are handles for each setting type + async def _handle_set_birthday_notifications(self, ctx, setting): + opt = "birthday_notifications" + setting = self._str_to_bool(opt, setting) + return await self._set_db_guild_opt(opt, setting, ctx) + + async def _handle_set_welcome_notifications(self, ctx, setting): + opt = "welcome_notifications" + setting = self._str_to_bool(opt, setting) + return await self._set_db_guild_opt(opt, setting, ctx) + + async def _handle_set_goodbye_notifications(self, ctx, setting): + opt = "goodbye_notifications" + setting = self._str_to_bool(opt, setting) + return await self._set_db_guild_opt(opt, setting, ctx) + + async def _handle_set_colour_roles(self, ctx, setting): + opt = "colour_roles" + setting = self._str_to_bool(opt, setting) + return await self._set_db_guild_opt(opt, setting, ctx) + + async def _handle_set_include_default_battles(self, ctx, setting): + opt = "include_default_battles" + setting = self._str_to_bool(opt, setting) + return await self._set_db_guild_opt(opt, setting, ctx) + + async def _handle_set_include_default_hugs(self, ctx, setting): + opt = "include_default_hugs" + setting = self._str_to_bool(opt, setting) + return await self._set_db_guild_opt(opt, setting, ctx) + + async def _handle_set_welcome_msg(self, ctx, setting): + try: + setting.format(member='test', server='test') + except KeyError as e: + raise MessageFormatError(e, ["member", "server"]) + else: + return await self._set_db_guild_opt("welcome_msg", setting, ctx) + + async def _handle_set_goodbye_msg(self, ctx, setting): + try: + setting.format(member='test', server='test') + except KeyError as e: + raise MessageFormatError(e, ["member", "server"]) + else: + return await self._set_db_guild_opt("goodbye_msg", setting, ctx) + + async def _handle_set_prefix(self, ctx, setting): + if len(setting) > 20: + raise WrongSettingType("Please keep the prefix under 20 characters") + if setting.lower().strip() == "none": + setting = None + + result = await self._set_db_guild_opt("prefix", setting, ctx) + # We want to update our cache for prefixes + ctx.bot.cache.update_prefix(ctx.guild, setting) + return result + + async def _handle_set_default_alerts(self, ctx, setting): + channel = await self._get_channel(ctx, setting) + return await self._set_db_guild_opt("default_alerts", channel.id, ctx) + + async def _handle_set_welcome_alerts(self, ctx, setting): + channel = await self._get_channel(ctx, setting) + return await self._set_db_guild_opt("welcome_alerts", channel.id, ctx) + + async def _handle_set_goodbye_alerts(self, ctx, setting): + channel = await self._get_channel(ctx, setting) + return await self._set_db_guild_opt("goodbye_alerts", channel.id, ctx) + + async def _handle_set_picarto_alerts(self, ctx, setting): + channel = await self._get_channel(ctx, setting) + return await self._set_db_guild_opt("picarto_alerts", channel.id, ctx) + + async def _handle_set_birthday_alerts(self, ctx, setting): + channel = await self._get_channel(ctx, setting) + return await self._set_db_guild_opt("birthday_alerts", channel.id, ctx) + + async def _handle_set_raffle_alerts(self, ctx, setting): + channel = await self._get_channel(ctx, setting) + return await self._set_db_guild_opt("raffle_alerts", channel.id, ctx) + + async def _handle_set_followed_picarto_channels(self, ctx, setting): + user = await utils.request(f"http://api.picarto.tv/v1/channel/name/{setting}") + if user is None: + raise WrongSettingType(f"Could not find a picarto user with the username {setting}") + + query = """ +UPDATE + guilds +SET + followed_picarto_channels = array_append(followed_picarto_channels, $1) +WHERE + id=$2 AND + NOT $1 = ANY(followed_picarto_channels); +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_set_ignored_channels(self, ctx, setting): + channel = await self._get_channel(ctx, setting) + + query = """ +UPDATE + guilds +SET + ignored_channels = array_append(ignored_channels, $1) +WHERE + id=$2 AND + NOT $1 = ANY(ignored_channels); +""" + return await ctx.bot.db.execute(query, channel.id, ctx.guild.id) + + async def _handle_set_ignored_members(self, ctx, setting): + # We want to make it possible to have members that aren't in the server ignored + # So first check if it's a digit (the id) + if not setting.isdigit(): + converter = commands.converter.MemberConverter() + member = await converter.convert(ctx, setting) + setting = member.id + + query = """ +UPDATE + guilds +SET + ignored_members = array_append(ignored_members, $1) +WHERE + id=$2 AND + NOT $1 = ANY(ignored_members); +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_set_rules(self, ctx, setting): + query = """ +UPDATE + guilds +SET + rules = array_append(rules, $1) +WHERE + id=$2 AND + NOT $1 = ANY(rules); +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_set_assignable_roles(self, ctx, setting): + converter = commands.converter.RoleConverter() + role = await converter.convert(ctx, setting) + + query = """ +UPDATE + guilds +SET + assignable_roles = array_append(assignable_roles, $1) +WHERE + id=$2 AND + NOT $1 = ANY(assignable_roles); +""" + return await ctx.bot.db.execute(query, role.id, ctx.guild.id) + + async def _handle_set_custom_battles(self, ctx, setting): + try: + setting.format(loser="player1", winner="player2") + except KeyError as e: + raise MessageFormatError(e, ["loser", "winner"]) + else: + query = """ +UPDATE + guilds +SET + custom_battles = array_append(custom_battles, $1) +WHERE + id=$2 AND + NOT $1 = ANY(custom_battles); +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_set_custom_hugs(self, ctx, setting): + try: + setting.format(user="user") + except KeyError as e: + raise MessageFormatError(e, ["user"]) + else: + query = """ +UPDATE + guilds +SET + custom_hugs = array_append(custom_hugs, $1) +WHERE + id=$2 AND + NOT $1 = ANY(custom_hugs); +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_remove_birthday_notifications(self, ctx, setting=None): + return await self._set_db_guild_opt("birthday_notifications", False, ctx) + + async def _handle_remove_welcome_notifications(self, ctx, setting=None): + return await self._set_db_guild_opt("welcome_notifications", False, ctx) + + async def _handle_remove_goodbye_notifications(self, ctx, setting=None): + return await self._set_db_guild_opt("goodbye_notifications", False, ctx) + + async def _handle_remove_colour_roles(self, ctx, setting=None): + return await self._set_db_guild_opt("colour_roles", False, ctx) + + async def _handle_remove_include_default_battles(self, ctx, setting=None): + return await self._set_db_guild_opt("include_default_battles", False, ctx) + + async def _handle_remove_include_default_hugs(self, ctx, setting=None): + return await self._set_db_guild_opt("include_default_hugs", False, ctx) + + async def _handle_remove_welcome_msg(self, ctx, setting=None): + return await self._set_db_guild_opt("welcome_msg", None, ctx) + + async def _handle_remove_goodbye_msg(self, ctx, setting=None): + return await self._set_db_guild_opt("goodbye_msg", None, ctx) + + async def _handle_remove_prefix(self, ctx, setting=None): + return await self._set_db_guild_opt("prefix", None, ctx) + + async def _handle_remove_default_alerts(self, ctx, setting=None): + return await self._set_db_guild_opt("default_alerts", None, ctx) + + async def _handle_remove_welcome_alerts(self, ctx, setting=None): + return await self._set_db_guild_opt("welcome_alerts", None, ctx) + + async def _handle_remove_goodbye_alerts(self, ctx, setting=None): + return await self._set_db_guild_opt("goodbye_alerts", None, ctx) + + async def _handle_remove_picarto_alerts(self, ctx, setting=None): + return await self._set_db_guild_opt("picarto_alerts", None, ctx) + + async def _handle_remove_birthday_alerts(self, ctx, setting=None): + return await self._set_db_guild_opt("birthday_alerts", None, ctx) + + async def _handle_remove_raffle_alerts(self, ctx, setting=None): + return await self._set_db_guild_opt("raffle_alerts", None, ctx) + + async def _handle_remove_followed_picarto_channels(self, ctx, setting=None): + if setting is None: + raise WrongSettingType("Specifying which channel you want to remove is required") + + query = """ +UPDATE + guilds +SET + followed_picarto_channels = array_remove(followed_picarto_channels, $1) +WHERE + id=$2 +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_remove_ignored_channels(self, ctx, setting=None): + if setting is None: + raise WrongSettingType("Specifying which channel you want to remove is required") + + channel = await self._get_channel(ctx, setting) + + query = """ +UPDATE + guilds +SET + ignored_channels = array_remove(ignored_channels, $1) +WHERE + id=$2 +""" + return await ctx.bot.db.execute(query, channel.id, ctx.guild.id) + + async def _handle_remove_ignored_members(self, ctx, setting=None): + if setting is None: + raise WrongSettingType("Specifying which channel you want to remove is required") + # We want to make it possible to have members that aren't in the server ignored + # So first check if it's a digit (the id) + if not setting.isdigit(): + converter = commands.converter.MemberConverter() + member = await converter.convert(ctx, setting) + setting = member.id + + query = """ +UPDATE + guilds +SET + ignored_members = array_remove(ignored_members, $1) +WHERE + id=$2 +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_remove_rules(self, ctx, setting=None): + if setting is None or not setting.isdigit(): + raise WrongSettingType("Please provide the number of the rule you want to remove") + + query = """ +UPDATE + guilds +SET + rules = array_remove(rules, rules[$1]) +WHERE + id=$2 +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_remove_assignable_roles(self, ctx, setting=None): + if setting is None: + raise WrongSettingType("Specifying which channel you want to remove is required") + if not setting.isdigit(): + converter = commands.converter.RoleConverter() + role = await converter.convert(ctx, setting) + setting = role.id + + query = """ +UPDATE + guilds +SET + assignable_roles = array_remove(assignable_roles, $1) +WHERE + id=$2 +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_remove_custom_battles(self, ctx, setting=None): + if setting is None or not setting.isdigit(): + raise WrongSettingType("Please provide the number of the custom battle you want to remove") + + query = """ +UPDATE + guilds +SET + custom_battles = array_remove(custom_battles, rules[$1]) +WHERE + id=$2 +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def _handle_remove_custom_hugs(self, ctx, setting=None): + if setting is None or not setting.isdigit(): + raise WrongSettingType("Please provide the number of the custom hug you want to remove") + + query = """ +UPDATE + guilds +SET + custom_hugs = array_remove(custom_hugs, rules[$1]) +WHERE + id=$2 +""" + return await ctx.bot.db.execute(query, setting, ctx.guild.id) + + async def __after_invoke(self, ctx): + """Here we will facilitate cleaning up settings, will remove channels/roles that no longer exist, etc.""" + pass + + @commands.group(invoke_without_command=True) + @commands.guild_only() + @utils.can_run(manage_guild=True) + async def config(self, ctx, *, opt=None): + """Handles the configuration of the bot for this server""" + if opt: + setting = await ctx.bot.db.fetchrow("SELECT * FROM guilds WHERE id=$1", ctx.guild.id) + if setting and opt in setting: + setting = await utils.convert(ctx, str(setting[opt])) or setting[opt] + + await ctx.send(f"{opt} is set to:\n{setting}") + return + + settings = await ctx.bot.db.fetchrow("SELECT * FROM guilds WHERE id=$1", ctx.guild.id) + + # For convenience, if it's None, just create it and return the default values + if settings is None: + await ctx.bot.db.execute("INSERT INTO guilds (id) VALUES ($1)", ctx.guild.id) + settings = await ctx.bot.db.fetchrow("SELECT * FROM guilds WHERE id=$1", ctx.guild.id) + + alerts = {} + # This is dirty I know, but oh well... + for alert_type in ["default", "welcome", "goodbye", "picarto", "birthday", "raffle"]: + channel = ctx.guild.get_channel(settings.get(f"{alert_type}_alerts")) + name = channel.name if channel else None + alerts[alert_type] = name + + fmt = f""" +**Notification Settings** + birthday_notifications + *Notify on the birthday that users in this guild have saved* + **{settings.get("birthday_notifications")}** + + welcome_notifications + *Notify when someone has joined this guild* + **{settings.get("welcome_notifications")}** + + goodbye_notifications + *Notify when someone has left this guild + **{settings.get("goodbye_notifications")}** + + welcome_msg + *A message that can be customized and used when someone joins the server* + **{"Set" if settings.get("welcome_msg") is not None else "Not set"}** + + goodbye_msg + *A message that can be customized and used when someone leaves the server* + **{"Set" if settings.get("goodbye_msg") is not None else "Not set"}** + +**Alert Channels** + default_alerts + *The channel to default alert messages to* + **{alerts.get("default_alerts")}** + + welcome_alerts + *The channel to send welcome alerts to (when someone joins the server)* + **{alerts.get("welcome_alerts")}** + + goodbye_alerts + *The channel to send goodbye alerts to (when someone leaves the server)* + **{alerts.get("goodbye_alerts")}** + + picarto_alerts + *The channel to send Picarto alerts to (when a channel the server follows goes on/offline)* + **{alerts.get("picarto_alerts")}** + + birthday_alerts + *The channel to send birthday alerts to (on the day of someone's birthday)* + **{alerts.get("birthday_alerts")}** + + raffle_alerts + *The channel to send alerts for server raffles to* + **{alerts.get("raffle_alerts")}** + + +**Misc Settings** + followed_picarto_channels + *Channels for the bot to "follow" and notify this server when they go live* + **{len(settings.get("followed_picarto_channels"))}** + + ignored_channels + *Channels that the bot ignores* + **{len(settings.get("ignored_channels"))}** + + ignored_members + *Members that the bot ignores* + **{len(settings.get("ignored_members"))}** + + rules + *Rules for this server* + **{len(settings.get("rules"))}** + + assignable_roles + *Roles that can be self-assigned by users* + **{len(settings.get("assignable_roles"))}** + + custom_battles + *Possible outcomes to battles that can be received on this server* + **{len(settings.get("custom_battles"))}** + + custom_hugs + *Possible outcomes to hugs that can be received on this server* + **{len(settings.get("custom_hugs"))}** +""".strip() + + embed = discord.Embed(title=f"Configuration for {ctx.guild.name}", description=fmt) + embed.set_image(url=ctx.guild.icon_url) + await ctx.send(embed=embed) + + @config.command(name="set", aliases=["add"]) + @commands.guild_only() + @utils.can_run(manage_guild=True) + async def _set_setting(self, ctx, option, *, setting): + """Sets one of the configuration settings for this server""" + try: + coro = getattr(self, f"_handle_set_{option}") + except AttributeError: + await ctx.send(f"{option} is not a valid config option. Use {ctx.prefix}config to list all config options") + else: + try: + await coro(ctx, setting=setting) + except WrongSettingType as exc: + await ctx.send(exc.message) + except MessageFormatError as exc: + fmt = f""" +Failed to parse the format string provided, possible keys are: {', '.join(k for k in exc.keys)} +Extraneous args provided: {', '.join(k for k in exc.original.args)} +""" + await ctx.send(fmt) + except commands.BadArgument: + pass + else: + await ctx.send(f"{option} has succesfully been set to {setting}") + + @config.command(name="unset", aliases=["remove"]) + @commands.guild_only() + @utils.can_run(manage_guild=True) + async def _remove_setting(self, ctx, option, *, setting=None): + """Unsets/removes an option from one of the settings.""" + try: + coro = getattr(self, f"_handle_remove_{option}") + except AttributeError: + await ctx.send(f"{option} is not a valid config option. Use {ctx.prefix}config to list all config options") + else: + try: + await coro(ctx, setting=setting) + except WrongSettingType as exc: + await ctx.send(exc.message) + except commands.BadArgument: + pass + else: + await ctx.send(f"{option} has succesfully been unset") + + +def setup(bot): + bot.add_cog(GuildConfiguration()) diff --git a/cogs/events.py b/cogs/events.py index c1c9c2c..6f7c959 100644 --- a/cogs/events.py +++ b/cogs/events.py @@ -69,64 +69,46 @@ class StatsUpdate: await self.update() async def on_member_join(self, member): - guild = member.guild - server_settings = self.bot.db.load('server_settings', key=str(guild.id)) - + query = """ +SELECT + COALESCE(welcome_alerts, default_alerts) AS channel, + welcome_msg AS msg +FROM + guilds +WHERE + welcome_notifications = True +AND + id = $1 +AND + COALESCE(welcome_alerts, default_alerts) IS NOT NULL +""" + settings = await self.bot.db.fetchrow(query, member.guild.id) + message = settings['msg'] or "Welcome to the '{server}' server {member}!" + channel = member.guild.get_channel(settings['channel']) try: - join_leave_on = server_settings['join_leave'] - if join_leave_on: - # Get the notifications settings, get the welcome setting - notifications = self.bot.db.load('server_settings', key=guild.id, pluck='notifications') or {} - # Set our default to either the one set, or the default channel of the server - default_channel_id = notifications.get('default') - # If it is has been overriden by picarto notifications setting, use this - channel_id = notifications.get('welcome') or default_channel_id - # Get the message if it exists - join_message = self.bot.db.load('server_settings', key=guild.id, pluck='welcome_message') - if not join_message: - join_message = "Welcome to the '{server}' server {member}!" - else: - return - except (IndexError, TypeError, KeyError): - return - - if channel_id: - channel = guild.get_channel(int(channel_id)) - else: - return - try: - await channel.send(join_message.format(server=guild.name, member=member.mention)) + await channel.send(message.format(server=member.guild.name, member=member.mention)) except (discord.Forbidden, discord.HTTPException, AttributeError): pass async def on_member_remove(self, member): - guild = member.guild - server_settings = self.bot.db.load('server_settings', key=str(guild.id)) - + query = """ +SELECT + COALESCE(goodbye_alerts, default_alerts) AS channel, + goodbye_msg AS msg +FROM + guilds +WHERE + welcome_notifications = True +AND + id = $1 +AND + COALESCE(goodbye_alerts, default_alerts) IS NOT NULL +""" + settings = await self.bot.db.fetchrow(query, member.guild.id) + message = settings['msg'] or "{member} has left the server, I hope it wasn't because of something I said :c" + channel = member.guild.get_channel(settings['channel']) try: - join_leave_on = server_settings['join_leave'] - if join_leave_on: - # Get the notifications settings, get the welcome setting - notifications = self.bot.db.load('server_settings', key=guild.id, pluck='notifications') or {} - # Set our default to either the one set, or the default channel of the server - default_channel_id = notifications.get('default') - # If it is has been overriden by picarto notifications setting, use this - channel_id = notifications.get('welcome') or default_channel_id - # Get the message if it exists - leave_message = self.bot.db.load('server_settings', key=guild.id, pluck='goodbye_message') - if not leave_message: - leave_message = "{member} has left the server, I hope it wasn't because of something I said :c" - else: - return - except (IndexError, TypeError, KeyError): - return - - if channel_id: - channel = guild.get_channel(int(channel_id)) - else: - return - try: - await channel.send(leave_message.format(server=guild.name, member=member.name)) + await channel.send(message.format(server=member.guild.name, member=member.mention)) except (discord.Forbidden, discord.HTTPException, AttributeError): pass diff --git a/cogs/hangman.py b/cogs/hangman.py index 781ffe9..74a1b32 100644 --- a/cogs/hangman.py +++ b/cogs/hangman.py @@ -70,6 +70,7 @@ class Hangman: def __init__(self, bot): self.bot = bot self.games = {} + self.pending_games = [] def create(self, word, ctx): # Create a new game, then save it as the server's game @@ -101,7 +102,7 @@ class Hangman: # We're creating a fmt variable, so that we can add a message for if a guess was correct or not # And also add a message for a loss/win if len(guess) == 1: - if guess in game.guessed_letters: + if guess.lower() in game.guessed_letters: ctx.command.reset_cooldown(ctx) await ctx.send("That letter has already been guessed!") # Return here as we don't want to count this as a failure @@ -142,6 +143,9 @@ class Hangman: if self.games.get(ctx.message.guild.id) is not None: await ctx.send("Sorry but only one Hangman game can be running per server!") return + if ctx.guild.id in self.pending_games: + await ctx.send("Someone has already started one, and I'm now waiting for them...") + return try: msg = await ctx.message.author.send( @@ -160,12 +164,16 @@ class Hangman: def check(m): return m.channel == msg.channel and len(m.content) <= 30 + self.pending_games.append(ctx.guild.id) try: msg = await self.bot.wait_for('message', check=check, timeout=60) except asyncio.TimeoutError: + self.pending_games.remove(ctx.guild.id) await ctx.send( "You took too long! Please look at your DM's as that's where I'm asking for the phrase you want to use") return + else: + self.pending_games.remove(ctx.guild.id) forbidden_phrases = ['stop', 'delete', 'remove', 'end', 'create', 'start'] if msg.content in forbidden_phrases: diff --git a/cogs/images.py b/cogs/images.py index 3e6d17e..c9bd3c5 100644 --- a/cogs/images.py +++ b/cogs/images.py @@ -41,7 +41,7 @@ class Images: result = await utils.request('https://random.dog/woof.json') try: url = result.get("url") - filename = re.match("https:\/\/random.dog\/(.*)", url).group(1) + filename = re.match("https://random.dog/(.*)", url).group(1) except AttributeError: await ctx.send("I couldn't connect! Sorry no dogs right now ;w;") return @@ -130,7 +130,6 @@ class Images: EXAMPLE: !derpi Rainbow Dash RESULT: A picture of Rainbow Dash!""" - await ctx.message.channel.trigger_typing() if len(search) > 0: url = 'https://derpibooru.org/search.json' @@ -139,7 +138,7 @@ class Images: query = ' '.join(value for value in search if not re.search('&?filter_id=[0-9]+', value)) params = {'q': query} - nsfw = await utils.channel_is_nsfw(ctx.message.channel, self.bot.db) + nsfw = utils.channel_is_nsfw(ctx.message.channel) # If this is a nsfw channel, we just need to tack on 'explicit' to the terms # Also use the custom filter that I have setup, that blocks some certain tags # If the channel is not nsfw, we don't need to do anything, as the default filter blocks explicit @@ -200,7 +199,6 @@ class Images: EXAMPLE: !e621 dragon RESULT: A picture of a dragon (hopefully, screw your tagging system e621)""" - await ctx.message.channel.trigger_typing() # This changes the formatting for queries, so we don't # Have to use e621's stupid formatting when using the command @@ -214,7 +212,7 @@ class Images: 'tags': tags } - nsfw = await utils.channel_is_nsfw(ctx.message.channel, self.bot.db) + nsfw = utils.channel_is_nsfw(ctx.message.channel) # e621 by default does not filter explicit content, so tack on # safe/explicit based on if this channel is nsfw or not diff --git a/cogs/interaction.py b/cogs/interaction.py index dd47eef..4cbe227 100644 --- a/cogs/interaction.py +++ b/cogs/interaction.py @@ -1,13 +1,12 @@ -import rethinkdb as r from discord.ext import commands from discord.ext.commands.cooldowns import BucketType +from collections import defaultdict import utils import discord import random import functools -import asyncio battle_outcomes = \ ["A meteor fell on {loser}, {winner} is left standing and has been declared the victor!", @@ -91,68 +90,36 @@ class Interaction: def __init__(self, bot): self.bot = bot - self.battles = {} - self.bot.br = BattleRankings(self.bot) - self.bot.br.update_start() + self.battles = defaultdict(list) - def get_battle(self, player): - battles = self.battles.get(player.guild.id) - - if battles is None: - return None - - for battle in battles: - if battle['p2'] == player.id: + def get_receivers_battle(self, receiver): + for battle in self.battles.get(receiver.guild.id, []): + if battle.is_receiver(receiver): return battle - def can_battle(self, player): - battles = self.battles.get(player.guild.id) - - if battles is None: - return True - - for x in battles: - if x['p1'] == player.id: + def can_initiate_battle(self, player): + for battle in self.battles.get(player.guild.id, []): + if battle.is_initiator(player): return False return True - def can_be_battled(self, player): - battles = self.battles.get(player.guild.id) - - if battles is None: - return True - - for x in battles: - if x['p2'] == player.id: + def can_receive_battle(self, player): + for battle in self.battles.get(player.guild.id, []): + if battle.is_receiver(player): return False return True - def start_battle(self, player1, player2): - battles = self.battles.get(player1.guild.id, []) - entry = { - 'p1': player1.id, - 'p2': player2.id - } - battles.append(entry) - self.battles[player1.guild.id] = battles + def start_battle(self, initiator, receiver): + battle = Battle(initiator, receiver) + self.battles[initiator.guild.id].append(battle) + return battle # Handles removing the author from the dictionary of battles - def battling_off(self, player1=None, player2=None): - if player1: - guild = player1.guild.id - else: - guild = player2.guild.id - battles = self.battles.get(guild, []) - # Create a new list, exactly the way the last one was setup - # But don't include the one start with player's ID - new_battles = [] - for b in battles: - if player1 and b['p1'] == player1.id: - continue - if player2 and b['p2'] == player2.id: - continue - new_battles.append(b) - self.battles[guild] = new_battles + def battling_off(self, battle): + for guild, battles in self.battles.items(): + if battle in battles: + battles.remove(battle) + return @commands.command() @commands.guild_only() @@ -166,7 +133,7 @@ class Interaction: await ctx.send("Your arms aren't big enough") return if user is None: - user = ctx.message.author + user = ctx.author else: converter = commands.converter.MemberConverter() try: @@ -175,12 +142,12 @@ class Interaction: await ctx.send("Error: Could not find user: {}".format(user)) return - # Lets get the settings - settings = self.bot.db.load('server_settings', key=ctx.message.guild.id) or {} - # Get the custom messages we can use - custom_msgs = settings.get('hugs') - default_on = settings.get('default_hugs') - # if they exist, then we want to see if we want to use default as well + settings = await self.bot.db.fetchrow( + "SELECT custom_hugs, include_default_hugs FROM guilds WHERE id = $1", + ctx.guild.id + ) + custom_msgs = settings["custom_hugs"] + default_on = settings["include_default_hugs"] if custom_msgs: if default_on or default_on is None: msgs = hugs + custom_msgs @@ -205,7 +172,7 @@ class Interaction: # First check if everyone was mentioned if ctx.message.mention_everyone: await ctx.send("You want to battle {} people? Good luck with that...".format( - len(ctx.message.channel.members) - 1) + len(ctx.channel.members) - 1) ) return # Then check if nothing was provided @@ -221,7 +188,7 @@ class Interaction: await ctx.send("Error: Could not find user: {}".format(player2)) return # Then check if the person used is the author - if ctx.message.author.id == player2.id: + if ctx.author.id == player2.id: ctx.command.reset_cooldown(ctx) await ctx.send("Why would you want to battle yourself? Suicide is not the answer") return @@ -231,24 +198,24 @@ class Interaction: await ctx.send("I always win, don't even try it.") return # Next two checks are to see if the author or person battled can be battled - if not self.can_battle(ctx.message.author): + if not self.can_initiate_battle(ctx.author): ctx.command.reset_cooldown(ctx) await ctx.send("You are already battling someone!") return - if not self.can_be_battled(player2): + if not self.can_receive_battle(player2): ctx.command.reset_cooldown(ctx) await ctx.send("{} is already being challenged to a battle!".format(player2)) return # Add the author and player provided in a new battle - self.start_battle(ctx.message.author, player2) + battle = self.start_battle(ctx.author, player2) - fmt = "{0.message.author.mention} has challenged you to a battle {1.mention}\n" \ - "{0.prefix}accept or {0.prefix}decline" + fmt = f"{ctx.author.mention} has challenged you to a battle {player2.mention}\n" \ + f"{ctx.prefix}accept or {ctx.prefix}decline" # Add a call to turn off battling, if the battle is not accepted/declined in 3 minutes - part = functools.partial(self.battling_off, player1=ctx.message.author) + part = functools.partial(self.battling_off, battle) self.bot.loop.call_later(180, part) - await ctx.send(fmt.format(ctx, player2)) + await ctx.send(fmt) @commands.command() @commands.guild_only() @@ -260,23 +227,23 @@ class Interaction: RESULT: Hopefully the other person's death""" # This is a check to make sure that the author is the one being BATTLED # And not the one that started the battle - battle = self.get_battle(ctx.message.author) + battle = self.get_receivers_battle(ctx.author) if battle is None: await ctx.send("You are not currently being challenged to a battle!") return - battleP1 = discord.utils.find(lambda m: m.id == battle['p1'], ctx.message.guild.members) - if battleP1 is None: + if ctx.guild.get_member(battle.initiator.id) is None: await ctx.send("The person who challenged you to a battle has apparently left the server....why?") + self.battling_off(battle) return - battleP2 = ctx.message.author - # Lets get the settings - settings = self.bot.db.load('server_settings', key=ctx.message.guild.id) or {} - # Get the custom messages we can use - custom_msgs = settings.get('battles') - default_on = settings.get('default_battles') + settings = await self.bot.db.fetchrow( + "SELECT custom_battles, include_default_battles FROM guilds WHERE id = $1", + ctx.guild.id + ) + custom_msgs = settings["custom_battles"] + default_on = settings["include_default_battles"] # if they exist, then we want to see if we want to use default as well if custom_msgs: if default_on or default_on is None: @@ -289,46 +256,106 @@ class Interaction: fmt = random.SystemRandom().choice(msgs) # Due to our previous checks, the ID should only be in the dictionary once, in the current battle we're checking - self.battling_off(player2=ctx.message.author) - await self.bot.br.update() + self.battling_off(battle) # Randomize the order of who is printed/sent to the update system - if random.SystemRandom().randint(0, 1): - winner = battleP1 - loser = battleP2 + winner, loser = battle.choose() + + member_list = [m.id for m in ctx.guild.members] + query = """ +SELECT id, rank, battle_rating, battle_wins, battle_losses +FROM + (SELECT + id, + ROW_NUMBER () OVER (ORDER BY battle_rating DESC) as "rank", + battle_rating, + battle_wins, + battle_losses + FROM + users + WHERE + id = any($1::bigint[]) AND + battle_rating IS NOT NULL + ) AS sub +WHERE id = any($2) + """ + results = await self.bot.db.fetch(query, member_list, [winner.id, loser.id]) + + old_winner = old_loser = None + for result in results: + if result['id'] == loser.id: + old_loser = result + else: + old_winner = result + + winner_rating, loser_rating, = utils.update_rating( + old_winner["battle_rating"] if old_winner else 1000, + old_loser["battle_rating"] if old_loser else 1000, + ) + print(old_winner, old_loser) + + update_query = """ +UPDATE + users +SET + battle_rating = $1, + battle_wins = $2, + battle_losses = $3 +WHERE + id=$4 +""" + insert_query = """ +INSERT INTO + users (id, battle_rating, battle_wins, battle_losses) +VALUES + ($1, $2, $3, $4) +""" + if old_loser: + await self.bot.db.execute( + update_query, + loser_rating, + old_loser['battle_wins'], + old_loser['battle_losses'] + 1, + loser.id + ) else: - winner = battleP2 - loser = battleP1 + await self.bot.db.execute(insert_query, loser.id, loser_rating, 0, 1) + if old_winner: + await self.bot.db.execute( + update_query, + winner_rating, + old_winner['battle_wins'] + 1, + old_winner['battle_losses'] , + winner.id + ) + else: + await self.bot.db.execute(insert_query, winner.id, winner_rating, 1, 0) - msg = await ctx.send(fmt.format(winner=winner.display_name, loser=loser.display_name)) - old_winner_rank, _ = self.bot.br.get_server_rank(winner) - old_loser_rank, _ = self.bot.br.get_server_rank(loser) + results = await self.bot.db.fetch(query, member_list, [winner.id, loser.id]) + print(results) - # Update our records; this will update our cache - await utils.update_records('battle_records', self.bot.db, winner, loser) - # Now wait a couple seconds to ensure cache is updated - await asyncio.sleep(2) - await self.bot.br.update() + new_winner_rank = new_loser_rank = None + for result in results: + if result['id'] == loser.id: + new_loser_rank = result['rank'] + else: + new_winner_rank = result['rank'] - # Now get the new ranks after this stuff has been updated - new_winner_rank, _ = self.bot.br.get_server_rank(winner) - new_loser_rank, _ = self.bot.br.get_server_rank(loser) - fmt = msg.content - if old_winner_rank: + fmt = fmt.format(winner=winner.display_name, loser=loser.display_name) + if old_winner: fmt += "\n{} - Rank: {} ( +{} )".format( - winner.display_name, new_winner_rank, old_winner_rank - new_winner_rank + winner.display_name, new_winner_rank, old_winner["rank"] - new_winner_rank ) else: fmt += "\n{} - Rank: {}".format(winner.display_name, new_winner_rank) - if old_loser_rank: - fmt += "\n{} - Rank: {} ( -{} )".format(loser.display_name, new_loser_rank, new_loser_rank - old_loser_rank) + if old_winner: + fmt += "\n{} - Rank: {} ( -{} )".format( + loser.display_name, new_loser_rank, new_loser_rank - old_winner["rank"] + ) else: fmt += "\n{} - Rank: {}".format(loser.display_name, new_loser_rank) - try: - await msg.edit(content=fmt) - except Exception: - pass + await ctx.send(fmt) @commands.command() @commands.guild_only() @@ -340,21 +367,13 @@ class Interaction: RESULT: You chicken out""" # This is a check to make sure that the author is the one being BATTLED # And not the one that started the battle - battle = self.get_battle(ctx.message.author) + battle = self.get_receivers_battle(ctx.author) if battle is None: await ctx.send("You are not currently being challenged to a battle!") return - battleP1 = discord.utils.find(lambda m: m.id == battle['p1'], ctx.message.guild.members) - if battleP1 is None: - await ctx.send("The person who challenged you to a battle has apparently left the server....why?") - return - - battleP2 = ctx.message.author - - # There's no need to update the stats for the members if they declined the battle - self.battling_off(player2=battleP2) - await ctx.send("{} has chickened out! What a loser~".format(battleP2.mention)) + self.battling_off(battle) + await ctx.send("{} has chickened out! What a loser~".format(ctx.author.mention)) @commands.command() @commands.guild_only() @@ -365,7 +384,7 @@ class Interaction: EXAMPLE: !boop @OtherPerson RESULT: You do a boop o3o""" - booper = ctx.message.author + booper = ctx.author if boopee is None: ctx.command.reset_cooldown(ctx) await ctx.send("You try to boop the air, the air boops back. Be afraid....") @@ -382,64 +401,40 @@ class Interaction: await ctx.send("Why the heck are you booping me? Get away from me >:c") return - key = str(booper.id) - boops = self.bot.db.load('boops', key=key, pluck='boops') or {} - amount = boops.get(str(boopee.id), 0) + 1 - entry = { - 'member_id': str(booper.id), - 'boops': { - str(boopee.id): amount - } - } - await self.bot.db.save('boops', entry) + query = "SELECT amount FROM boops WHERE booper = $1 AND boopee = $2" + amount = await self.bot.db.fetchrow(query, booper.id, boopee.id) + if amount is None: + amount = 1 + replacement_query = "INSERT INTO boops (booper, boopee, amount) VALUES($1, $2, $3)" + else: + replacement_query = "UPDATE boops SET amount=$3 WHERE booper=$1 AND boopee=$2" + amount = amount['amount'] + 1 - fmt = "{0.mention} has just booped {1.mention}{3}! That's {2} times now!" - await ctx.send(fmt.format(booper, boopee, amount, message)) + await ctx.send(f"{booper.mention} has just booped {boopee.mention}{message}! That's {amount} times now!") + await self.bot.db.execute(replacement_query, booper.id, boopee.id, amount) -# noinspection PyMethodMayBeStatic -class BattleRankings: - def __init__(self, bot): - self.db = bot.db - self.loop = bot.loop - self.ratings = None +class Battle: - def build_dict(self, seq, key): - return dict((d[key], dict(d, rank=index + 1)) for (index, d) in enumerate(seq[::-1])) + def __init__(self, initiator, receiver): + self.initiator = initiator + self.receiver = receiver + self.rand = random.SystemRandom() - def update_start(self): - self.loop.create_task(self.update()) + def is_initiator(self, player): + return player.id == self.initiator.id and player.guild.id == self.initiator.guild.id - async def update(self): - ratings = await self.db.query(r.table('battle_records').order_by('rating')) + def is_receiver(self, player): + return player.id == self.receiver.id and player.guild.id == self.receiver.guild.id - # Create a dictionary so that we have something to "get" from easily - self.ratings = self.build_dict(ratings, 'member_id') + def is_battling(self, player): + return self.is_initiator(player) or self.is_receiver(player) - def get_record(self, member): - data = self.ratings.get(str(member.id), {}) - fmt = "{} - {}".format(data.get('wins'), data.get('losses')) - return fmt - - def get_rating(self, member): - data = self.ratings.get(str(member.id), {}) - return data.get('rating') - - def get_rank(self, member): - data = self.ratings.get(str(member.id), {}) - return data.get('rank'), len(self.ratings) - - def get_server_rank(self, member): - # Get the id's of all the members to compare to - server_ids = [str(m.id) for m in member.guild.members] - # Get all the ratings for members in this server - ratings = [x for x in self.ratings.values() if x['member_id'] in server_ids] - # Since we went from a dictionary to a list, we're no longer sorted, sort this - ratings = sorted(ratings, key=lambda x: x['rating']) - # Build our dictionary to get correct rankings - server_ratings = self.build_dict(ratings, 'member_id') - # Return the rank - return server_ratings.get(str(member.id), {}).get('rank'), len(server_ratings) + def choose(self): + """Returns the two users in the order winner, loser""" + choices = [self.initiator, self.receiver] + self.rand.shuffle(choices) + return choices def setup(bot): diff --git a/cogs/links.py b/cogs/links.py index 2f18eba..fc8b8be 100644 --- a/cogs/links.py +++ b/cogs/links.py @@ -22,12 +22,10 @@ class Links: EXAMPLE: !g Random cat pictures! RESULT: Links to sites with random cat pictures!""" - await ctx.message.channel.trigger_typing() - url = "https://www.google.com/search" # Turn safe filter on or off, based on whether or not this is a nsfw channel - nsfw = await utils.channel_is_nsfw(ctx.message.channel, self.bot.db) + nsfw = utils.channel_is_nsfw(ctx.message.channel) safe = 'off' if nsfw else 'on' params = {'q': query, @@ -76,8 +74,6 @@ class Links: EXAMPLE: !youtube Cat videos! RESULT: Cat videos!""" - await ctx.message.channel.trigger_typing() - key = utils.youtube_key url = "https://www.googleapis.com/youtube/v3/search" params = {'key': key, @@ -111,8 +107,6 @@ class Links: EXAMPLE: !wiki Test RESULT: A link to the wikipedia article for the word test""" - await ctx.message.channel.trigger_typing() - # All we need to do is search for the term provided, so the action, list, and format never need to change base_url = "https://en.wikipedia.org/w/api.php" params = {"action": "query", @@ -150,9 +144,7 @@ class Links: EXAMPLE: !urban a normal phrase RESULT: Probably something lewd; this is urban dictionary we're talking about""" - if await utils.channel_is_nsfw(ctx.message.channel, self.bot.db): - await ctx.message.channel.trigger_typing() - + if utils.channel_is_nsfw(ctx.message.channel): url = "http://api.urbandictionary.com/v0/define" params = {"term": msg} try: diff --git a/cogs/misc.py b/cogs/misc.py index 3fc1d0d..432481e 100644 --- a/cogs/misc.py +++ b/cogs/misc.py @@ -11,6 +11,33 @@ import datetime import psutil +def _command_signature(cmd): + result = [cmd.qualified_name] + if cmd.usage: + result.append(cmd.usage) + return ' '.join(result) + + params = cmd.clean_params + if not params: + return ' '.join(result) + + for name, param in params.items(): + if param.default is not param.empty: + # We don't want None or '' to trigger the [name=value] case and instead it should + # do [name] since [name=None] or [name=] are not exactly useful for the user. + should_print = param.default if isinstance(param.default, str) else param.default is not None + if should_print: + result.append(f'[{name}={param.default!r}]') + else: + result.append(f'[{name}]') + elif param.kind == param.VAR_POSITIONAL: + result.append(f'[{name}...]') + else: + result.append(f'<{name}>') + + return ' '.join(result) + + class Miscallaneous: """Core commands, these are the miscallaneous commands that don't fit into other categories'""" @@ -19,32 +46,6 @@ class Miscallaneous: self.process = psutil.Process() self.process.cpu_percent() - def _command_signature(self, cmd): - result = [cmd.qualified_name] - if cmd.usage: - result.append(cmd.usage) - return ' '.join(result) - - params = cmd.clean_params - if not params: - return ' '.join(result) - - for name, param in params.items(): - if param.default is not param.empty: - # We don't want None or '' to trigger the [name=value] case and instead it should - # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = param.default if isinstance(param.default, str) else param.default is not None - if should_print: - result.append(f'[{name}={param.default!r}]') - else: - result.append(f'[{name}]') - elif param.kind == param.VAR_POSITIONAL: - result.append(f'[{name}...]') - else: - result.append(f'<{name}>') - - return ' '.join(result) - @commands.command() @commands.cooldown(1, 3, commands.cooldowns.BucketType.user) @utils.can_run(send_messages=True) @@ -75,7 +76,8 @@ class Miscallaneous: if entity: entity = self.bot.get_cog(entity) or self.bot.get_command(entity) if entity is None: - fmt = "Hello! Here is a list of the sections of commands that I have (there are a lot of commands so just start with the sections...I know, I'm pretty great)\n" + fmt = "Hello! Here is a list of the sections of commands that I have " \ + "(there are a lot of commands so just start with the sections...I know, I'm pretty great)\n" fmt += "To use a command's paramaters, you need to know the notation for them:\n" fmt += "\t This means the argument is __**required**__.\n" fmt += "\t[argument] This means the argument is __**optional**__.\n" @@ -96,7 +98,7 @@ class Miscallaneous: else: chunks[len(chunks) - 1] += tmp elif isinstance(entity, (commands.core.Command, commands.core.Group)): - tmp = "**{}**".format(self._command_signature(entity)) + tmp = "**{}**".format(_command_signature(entity)) tmp += "\n{}".format(entity.help) chunks.append(tmp) else: diff --git a/cogs/osu.py b/cogs/osu.py index 6f9a331..775d7fb 100644 --- a/cogs/osu.py +++ b/cogs/osu.py @@ -46,15 +46,13 @@ class Osu: async def get_users(self): """A task used to 'cache' all member's and their Osu profile's""" - data = await self.bot.db.actual_load('osu') - if data is None: - return + query = "SELECT id, osu FROM users WHERE osu IS NOT NULL;" + rows = await self.bot.db.fetch(query) - for result in data: - member = int(result['member_id']) - user = await self.get_user_from_api(result['osu_username']) + for row in rows: + user = await self.get_user_from_api(row['osu']) if user: - self.osu_users[member] = user + self.osu_users[row['id']] = user @commands.group(invoke_without_command=True) @utils.can_run(send_messages=True) @@ -63,7 +61,7 @@ class Osu: EXAMPLE: !osu @Person RESULT: Informationa bout that person's osu account""" - await ctx.message.channel.trigger_typing() + if member is None: member = ctx.message.author @@ -95,21 +93,19 @@ class Osu: EXAMPLE: !osu add username RESULT: Links your username to your account, and allows stats to be pulled from it""" - await ctx.message.channel.trigger_typing() + author = ctx.message.author user = await self.get_user(author, username) if user is None: await ctx.send("I couldn't find an osu user that matches {}".format(username)) return - entry = { - 'member_id': str(author.id), - 'osu_username': user.username - } - - await self.bot.db.save('osu', entry) - await ctx.send("I have just saved your Osu user {}".format(author.display_name)) + update = { + "id": author.id, + "osu": user.username + } + await self.bot.db.upsert("users", update) @osu.command(name='score', aliases=['scores']) @utils.can_run(send_messages=True) @@ -119,7 +115,7 @@ class Osu: EXAMPLE: !osu scores @Person 5 RESULT: The top 5 maps for the user @Person""" - await ctx.message.channel.trigger_typing() + # Set the defaults before we go through our passed data to figure out what we want limit = 5 member = ctx.message.author @@ -135,7 +131,7 @@ class Osu: limit = 50 elif limit < 1: limit = 5 - except: + except Exception: converter = commands.converter.MemberConverter() try: member = await converter.convert(ctx, piece) diff --git a/cogs/overwatch.py b/cogs/overwatch.py index 5ad6e9c..788bd1b 100644 --- a/cogs/overwatch.py +++ b/cogs/overwatch.py @@ -38,8 +38,6 @@ class Overwatch: EXAMPLE: !ow stats @OtherPerson Junkrat RESULT: Whether or not you should unfriend this person because they're a dirty rat""" - await ctx.message.channel.trigger_typing() - user = user or ctx.message.author bt = self.bot.db.load('overwatch', key=str(user.id), pluck='battletag') @@ -99,7 +97,7 @@ class Overwatch: EXAMPLE: !ow add Username#1234 RESULT: Your battletag is now saved""" - await ctx.message.channel.trigger_typing() + # Battletags are normally provided like name#id # However the API needs this to be a -, so repliace # with - if it exists diff --git a/cogs/picarto.py b/cogs/picarto.py index 8152c32..f742d88 100644 --- a/cogs/picarto.py +++ b/cogs/picarto.py @@ -1,21 +1,33 @@ import asyncio import discord -import re import traceback -from discord.ext import commands - import utils BASE_URL = 'https://api.picarto.tv/v1' +def produce_embed(*channels): + description = "" + # Loop through each channel and produce the information that will go in the description + for channel in channels: + url = f"https://picarto.tv/{channel.get('name')}" + description = f"""{description}\n\n**Title:** [{channel.get("title")}]({url}) +**Channel:** [{channel.get("name")}]({url}) +**Adult:** {"Yes" if channel.get("adult") else "No"} +**Gaming:** {"Yes" if channel.get("gaming") else "No"} +**Commissions:** {"Yes" if channel.get("commissions") else "No"}""" + + return discord.Embed(title="Channels that have gone online!", description=description.strip()) + + class Picarto: """Pretty self-explanatory""" def __init__(self, bot): self.bot = bot self.task = self.bot.loop.create_task(self.picarto_task()) + self.channel_info = {} # noinspection PyAttributeOutsideInit async def get_online_users(self): @@ -25,49 +37,38 @@ class Picarto: 'adult': 'true', 'gaming': 'true' } - self.online_channels = await utils.request(url, payload=payload) + channel_info = {} + channels = await utils.request(url, payload=payload) + if channels: + for channel in channels: + name = channel["name"] + previous = self.channel_info.get("name") + # There are three statuses, on, remained, and off + # On means they were off previously, but are now online + # Remained means they were on previous, and are still on + # Off means they were on preivous, but are now offline + # If they weren't included in the online channels...well they're off + if previous is None: + channel_info[name] = channel + channel_info[name]["status"] = "on" + elif previous["status"] in ["on", "remaining"]: + channel_info[name] = channel + channel_info[name]["status"] = "remaining" + # After loop has finished successfully, we want to override the statuses of the channels + self.channel_info = channel_info - async def channel_embed(self, channel): - # First make sure the picarto URL is actually given - if not channel: - return None - # Use regex to get the actual username so that we can make a request to the API - stream = re.search("(?<=picarto.tv/)(.*)", channel).group(1) - url = BASE_URL + '/channel/name/{}'.format(stream) + def produce_embed(self, *channels): + description = "" + # Loop through each channel and produce the information that will go in the description + for channel in channels: + url = f"https://picarto.tv/{channel.get('name')}" + description = f"""{description}\n\n**Title:** [{channel.get("title")}]({url}) +**Channel:** [{channel.get("name")}]({url}) +**Adult:** {"Yes" if channel.get("adult") else "No"} +**Gaming:** {"Yes" if channel.get("gaming") else "No"} +**Commissions:** {"Yes" if channel.get("commissions") else "No"}""" - data = await utils.request(url) - if data is None: - return None - - # Not everyone has all these settings, so use this as a way to print information if it does, otherwise ignore it - things_to_print = ['comissions', 'adult', 'followers', 'category', 'online'] - - embed = discord.Embed(title='{}\'s Picarto'.format(data['name']), url=channel) - avatar_url = 'https://picarto.tv/user_data/usrimg/{}/dsdefault.jpg'.format(data['name'].lower()) - embed.set_thumbnail(url=avatar_url) - - for i, result in data.items(): - if i in things_to_print and str(result): - i = i.title().replace('_', ' ') - embed.add_field(name=i, value=str(result)) - - # Social URL's can be given if a user wants them to show - # Print them if they exist, otherwise don't try to include them - social_links = data.get('social_urls', {}) - - for i, result in social_links.items(): - embed.add_field(name=i.title(), value=result) - - return embed - - def channel_online(self, channel): - # Channel is the name we are checking against that - # This creates a list of all users that match this channel name (should only ever be 1) - # And returns True as long as it is more than 0 - if not self.online_channels or channel is None: - return False - channel = re.search("(?<=picarto.tv/)(.*)", channel).group(1) - return channel.lower() in [stream['name'].lower() for stream in self.online_channels] + return discord.Embed(title="Channels that have gone online!", description=description.strip()) async def picarto_task(self): try: @@ -82,237 +83,34 @@ class Picarto: await asyncio.sleep(30) async def check_channels(self): + query = """ +SELECT + id, followed_picarto_channels, COALESCE(picarto_alerts, default_alerts) AS channel +FROM + guilds +WHERE + COALESCE(picarto_alerts, default_alerts) IS NOT NULL +""" + # Recheck who is currently online await self.get_online_users() - picarto = await self.bot.db.actual_load('picarto', table_filter={'notifications_on': 1}) - for data in picarto: - m_id = int(data['member_id']) - url = data['picarto_url'] - # Check if they are online - online = self.channel_online(url) - # If they're currently online, but saved as not then we'll let servers know they are now online - if online and data['live'] == 0: - msg = "{member.display_name} has just gone live!" - await self.bot.db.save('picarto', {'live': 1, 'member_id': str(m_id)}) - # Otherwise our notification will say they've gone offline - elif not online and data['live'] == 1: - msg = "{member.display_name} has just gone offline!" - await self.bot.db.save('picarto', {'live': 0, 'member_id': str(m_id)}) - else: - continue - - embed = await self.channel_embed(url) - # Loop through each server that they are set to notify - for s_id in data['servers']: - server = self.bot.get_guild(int(s_id)) - # If we can't find it, ignore this one - if server is None: - continue - member = server.get_member(m_id) - # If we can't find them in this server, also ignore - if member is None: - continue - - # Get the notifications settings, get the picarto setting - notifications = self.bot.db.load('server_settings', key=s_id, pluck='notifications') or {} - # Set our default to either the one set, or the default channel of the server - default_channel_id = notifications.get('default') - # If it is has been overriden by picarto notifications setting, use this - channel_id = notifications.get('picarto') or default_channel_id - # Now get the channel - if channel_id: - channel = server.get_channel(int(channel_id)) - else: - continue - - # Then just send our message - try: - await channel.send(msg.format(member=member), embed=embed) - except (discord.Forbidden, discord.HTTPException, AttributeError): - pass - - @commands.group(invoke_without_command=True) - @utils.can_run(send_messages=True) - async def picarto(self, ctx, member: discord.Member = None): - """This command can be used to view Picarto stats about a certain member - - EXAMPLE: !picarto @otherPerson - RESULT: Info about their picarto stream""" - await ctx.message.channel.trigger_typing() - - # If member is not given, base information on the author - member = member or ctx.message.author - member_url = self.bot.db.load('picarto', key=member.id, pluck='picarto_url') - if member_url is None: - await ctx.send("That user does not have a picarto url setup!") - return - - embed = await self.channel_embed(member_url) - - await ctx.send(embed=embed) - - @picarto.command(name='add') - @commands.guild_only() - @utils.can_run(send_messages=True) - async def add_picarto_url(self, ctx, url: str): - """Saves your user's picarto URL - - EXAMPLE: !picarto add MyUsername - RESULT: Your picarto stream is saved, and notifications should go to this guild""" - await ctx.message.channel.trigger_typing() - - # This uses a lookbehind to check if picarto.tv exists in the url given - # If it does, it matches picarto.tv/user and sets the url as that - # Then (in the else) add https://www. to that - # Otherwise if it doesn't match, we'll hit an AttributeError due to .group(0) - # This means that the url was just given as a user (or something complete invalid) - # So set URL as https://www.picarto.tv/[url] - # Even if this was invalid such as https://www.picarto.tv/picarto.tv/user - # For example, our next check handles that - try: - url = re.search("((?<=://)?picarto.tv/)+(.*)", url).group(0) - except AttributeError: - url = "https://www.picarto.tv/{}".format(url) - else: - url = "https://www.{}".format(url) - channel = re.search("https://www.picarto.tv/(.*)", url).group(1) - api_url = BASE_URL + '/channel/name/{}'.format(channel) - - data = await utils.request(api_url) - if not data: - await ctx.send("That Picarto user does not exist! What would be the point of adding a nonexistant Picarto " - "user? Silly") - return - - key = str(ctx.message.author.id) - - # Check if it exists first, if it does we don't want to override some of the settings - result = self.bot.db.load('picarto', key=key) - if result: - entry = { - 'picarto_url': url, - 'member_id': key - } - else: - entry = { - 'picarto_url': url, - 'servers': [str(ctx.message.guild.id)], - 'notifications_on': 1, - 'live': 0, - 'member_id': key - } - await self.bot.db.save('picarto', entry) - await ctx.send( - "I have just saved your Picarto URL {}, this guild will now be notified when you go live".format( - ctx.message.author.mention)) - - @picarto.command(name='remove', aliases=['delete']) - @utils.can_run(send_messages=True) - async def remove_picarto_url(self, ctx): - """Removes your picarto URL""" - key = str(ctx.message.author.id) - - result = self.bot.db.load('picarto', key=key) - if result: - entry = { - 'picarto_url': None, - 'member_id': str(ctx.message.author.id) - } - - await self.bot.db.save('picarto', entry) - await ctx.send("I am no longer saving your picarto URL {}".format(ctx.message.author.mention)) - else: - await ctx.send("I cannot remove something that I don't have (you've never saved your Picarto URL)") - - @picarto.command(name='alerts') - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def picarto_alerts_channel(self, ctx, channel: discord.TextChannel): - """Sets the notifications channel for picarto notifications - - EXAMPLE: !picarto alerts #picarto - RESULT: Picarto notifications will go to this channel - """ - entry = { - 'server_id': str(ctx.message.guild.id), - 'notifications': { - 'picarto': str(channel.id) - } - } - await self.bot.db.save('server_settings', entry) - await ctx.send("All Picarto notifications will now go to {}".format(channel.mention)) - - @picarto.group(invoke_without_command=True) - @commands.guild_only() - @utils.can_run(send_messages=True) - async def notify(self, ctx): - """This can be used to turn picarto notifications on or off - Call this command by itself, to add this guild to the list of guilds to be notified - - EXAMPLE: !picarto notify - RESULT: This guild will now be notified of you going live""" - key = str(ctx.message.author.id) - servers = self.bot.db.load('picarto', key=key, pluck='servers') - # Check if this user is saved at all - if servers is None: - await ctx.send( - "I do not have your Picarto URL added {}. You can save your Picarto url with !picarto add".format( - ctx.message.author.mention)) - # Then check if this guild is already added as one to notify in - elif str(ctx.message.guild.id) in servers: - await ctx.send("I am already set to notify in this guild...") - else: - servers.append(str(ctx.message.guild.id)) - entry = { - 'member_id': key, - 'servers': servers - } - await self.bot.db.save('picarto', entry) - await ctx.send("This server will now be notified if you go live") - - @notify.command(name='on', aliases=['start,yes']) - @commands.guild_only() - @utils.can_run(send_messages=True) - async def notify_on(self, ctx): - """Turns picarto notifications on - - EXAMPLE: !picarto notify on - RESULT: Notifications are sent when you go live""" - key = str(ctx.message.author.id) - result = self.bot.db.load('picarto', key=key) - if result: - entry = { - 'member_id': key, - 'notifications_on': 1 - } - await self.bot.db.save('picarto', entry) - await ctx.send("I will notify if you go live {}, you'll get a bajillion followers I promise c:".format( - ctx.message.author.mention)) - else: - await ctx.send("I can't notify if you go live if I don't know your picarto URL yet!") - - @notify.command(name='off', aliases=['stop,no']) - @commands.guild_only() - @utils.can_run(send_messages=True) - async def notify_off(self, ctx): - """Turns picarto notifications off - - EXAMPLE: !picarto notify off - RESULT: No more notifications sent when you go live""" - key = str(ctx.message.author.id) - result = self.bot.db.load('picarto', key=key) - if result: - entry = { - 'member_id': key, - 'notifications_on': 0 - } - await self.bot.db.save('picarto', entry) - await ctx.send( - "I will not notify if you go live anymore {}, " - "are you going to stream some lewd stuff you don't want people to see?~".format( - ctx.message.author.mention)) - else: - await ctx.send( - "I'm already not going to notify anyone, because I don't have your picarto URL saved...") + # Now get all guilds and their picarto channels they follow and loop through them + results = await self.bot.db.fetch(query) or [] + for result in results: + # Get all the channels that have gone online + gone_online = [ + self.channel_info.get(name) + for name in result["followed_picarto_channels"] + if self.channel_info.get(name) == "on" + ] + # If they've gone online, produce the embed for them and send it + if gone_online: + embed = produce_embed(*gone_online) + channel = self.bot.get_channel(result["channel"]) + if channel is not None: + try: + await channel.send(embed=embed) + except (discord.Forbidden, discord.HTTPException, AttributeError): + pass def setup(bot): diff --git a/cogs/raffle.py b/cogs/raffle.py index 5ab5d06..b9ebc06 100644 --- a/cogs/raffle.py +++ b/cogs/raffle.py @@ -1,13 +1,12 @@ from discord.ext import commands -import discord +from collections import defaultdict import utils -import random -import pendulum +import discord import re import asyncio -import traceback +import random class Raffle: @@ -15,174 +14,57 @@ class Raffle: def __init__(self, bot): self.bot = bot - self.bot.loop.create_task(self.raffle_task()) + self.raffles = defaultdict(list) - async def raffle_task(self): - while True: - try: - await self.check_raffles() - except Exception as error: - with open("error_log", 'a') as f: - traceback.print_tb(error.__traceback__, file=f) - print('{0.__class__.__name__}: {0}'.format(error), file=f) - finally: - await asyncio.sleep(60) + def create_raffle(self, ctx, title, num): + raffle = GuildRaffle(ctx, title, num) + self.raffles[ctx.guild.id].append(raffle) + raffle.start() - async def check_raffles(self): - # This is used to periodically check the current raffles, and see if they have ended yet - # If the raffle has ended, we'll pick a winner from the entrants - raffles = self.bot.db.load('raffles') - - if raffles is None: - return - - for server_id, raffle in raffles.items(): - server = self.bot.get_guild(int(server_id)) - - # Check to see if this cog can find the server in question - if server is None: - continue - for r in raffle['raffles']: - title = r['title'] - entrants = r['entrants'] - - now = pendulum.now(tz="UTC") - expires = pendulum.parse(r['expires']) - - # Now lets compare and see if this raffle has ended, if not just continue - if expires > now: - continue - - # Make sure there are actually entrants - if len(entrants) == 0: - fmt = 'Sorry, but there were no entrants for the raffle `{}`!'.format(title) - else: - winner = None - count = 0 - while winner is None: - winner = server.get_member(int(random.SystemRandom().choice(entrants))) - - # Lets make sure we don't get caught in an infinite loop - # Realistically having more than 50 random entrants found that aren't in the server anymore - # Isn't something that should be an issue, but better safe than sorry - count += 1 - if count >= 50: - break - - if winner is None: - fmt = 'I couldn\'t find an entrant that is still in this server, for the raffle `{}`!'.format( - title) - else: - fmt = 'The raffle `{}` has just ended! The winner is {}!'.format(title, winner.display_name) - - # Get the notifications settings, get the raffle setting - notifications = self.bot.db.load('server_settings', key=server.id, pluck='notifications') or {} - # Set our default to either the one set - default_channel_id = notifications.get('default') - # If it is has been overriden by picarto notifications setting, use this - channel_id = notifications.get('raffle') or default_channel_id - if channel_id: - channel = self.bot.get_channel(int(channel_id)) - else: - continue - try: - await channel.send(fmt) - except (discord.Forbidden, AttributeError): - pass - - # No matter which one of these matches were met, the raffle has ended and we want to remove it - raffle['raffles'].remove(r) - entry = { - 'server_id': raffle['server_id'], - 'raffles': raffle['raffles'] - } - await self.bot.db.save('raffles', entry) - - @commands.command() + @commands.command(name="raffles") @commands.guild_only() @utils.can_run(send_messages=True) - async def raffles(self, ctx): + async def _raffles(self, ctx): """Used to print the current running raffles on the server EXAMPLE: !raffles RESULT: A list of the raffles setup on this server""" - raffles = self.bot.db.load('raffles', key=ctx.message.guild.id, pluck='raffles') - if not raffles: + raffles = self.raffles[ctx.guild.id] + if len(raffles) == 0: await ctx.send("There are currently no raffles setup on this server!") return - # For EVERY OTHER COG, when we get one result, it is nice to have it return that exact object - # This is the only cog where that is different, so just to make this easier lets throw it - # back in a one-indexed list, for easier parsing - if isinstance(raffles, dict): - raffles = [raffles] - fmt = "\n\n".join("**Raffle:** {}\n**Title:** {}\n**Total Entrants:** {}\n**Ends:** {} UTC".format( - num + 1, - raffle['title'], - len(raffle['entrants']), - raffle['expires']) for num, raffle in enumerate(raffles)) - await ctx.send(fmt) + embed = discord.Embed(title=f"Raffles in {ctx.guild.name}") + + for num, raffle in enumerate(raffles): + embed.add_field( + name=f"Raffle {num + 1}", + value=f"Title: {raffle.title}\n" + f"Total Entrants: {len(raffle.entrants)}\n" + f"Ends in {raffle.remaining}", + inline=False + ) + await ctx.send(embed=embed) @commands.group(invoke_without_command=True) @commands.guild_only() @utils.can_run(send_messages=True) - async def raffle(self, ctx, raffle_id: int = 0): + async def raffle(self, ctx, raffle_id: int): """Used to enter a raffle running on this server If there is more than one raffle running, provide an ID of the raffle you want to enter EXAMPLE: !raffle 1 RESULT: You've entered the first raffle!""" - # Lets let people use 1 - (length of raffles) and handle 0 base ourselves - raffle_id -= 1 - author = ctx.message.author - key = str(ctx.message.guild.id) - - raffles = self.bot.db.load('raffles', key=key, pluck='raffles') - if raffles is None: - await ctx.send("There are currently no raffles setup on this server!") - return - - raffle_count = len(raffles) - - # There is only one raffle, so use the first's info - if raffle_count == 1: - entrants = raffles[0]['entrants'] - # Lets make sure that the user hasn't already entered the raffle - if str(author.id) in entrants: - await ctx.send("You have already entered this raffle!") - return - entrants.append(str(author.id)) - - update = { - 'raffles': raffles, - 'server_id': key - } - await self.bot.db.save('raffles', update) - await ctx.send("{} you have just entered the raffle!".format(author.mention)) - # Otherwise, make sure the author gave a valid raffle_id - elif raffle_id in range(raffle_count): - entrants = raffles[raffle_id]['entrants'] - - # Lets make sure that the user hasn't already entered the raffle - if str(author.id) in entrants: - await ctx.send("You have already entered this raffle!") - return - entrants.append(str(author.id)) - - # Since we have no good thing to filter things off of, lets use the internal rethinkdb id - - update = { - 'raffles': raffles, - 'server_id': key - } - await self.bot.db.save('raffles', update) - await ctx.send("{} you have just entered the raffle!".format(author.mention)) + try: + raffle = self.raffles[ctx.guild.id][raffle_id - 1] + except IndexError: + await ctx.send(f"I could not find a raffle for ID {raffle_id}") + await self._raffles.invoke(ctx) else: - fmt = "Please provide a valid raffle ID, as there are more than one setup on the server! " \ - "There are currently `{}` raffles running, use {}raffles to view the current running raffles".format( - raffle_count, ctx.prefix - ) - await ctx.send(fmt) + if raffle.enter(ctx.author): + await ctx.send(f"You have just joined the raffle {raffle['title']}") + else: + await ctx.send("You have already entered this raffle!") @raffle.command(name='create', aliases=['start', 'begin', 'add']) @commands.guild_only() @@ -193,10 +75,8 @@ class Raffle: EXAMPLE: !raffle create RESULT: A follow-along for setting up a new raffle""" - author = ctx.message.author - server = ctx.message.guild - channel = ctx.message.channel - now = pendulum.now(tz="UTC") + author = ctx.author + channel = ctx.channel await ctx.send( "Ready to start a new raffle! Please respond with the title you would like to use for this raffle!") @@ -212,13 +92,13 @@ class Raffle: fmt = "Alright, your new raffle will be titled:\n\n{}\n\nHow long would you like this raffle to run for? " \ "The format should be [number] [length] for example, `2 days` or `1 hour` or `30 minutes` etc. " \ - "The minimum for this is 10 minutes, and the maximum is 3 months" + "The minimum for this is 10 minutes, and the maximum is 3 days" await ctx.send(fmt.format(title)) # Our check to ensure that a proper length of time was passed def check(m): if m.author == author and m.channel == channel: - return re.search("\d+ (minutes?|hours?|days?|weeks?|months?)", m.content.lower()) is not None + return re.search("\d+ (minutes?|hours?|days?)", m.content.lower()) is not None else: return False @@ -229,73 +109,86 @@ class Raffle: return # Lets get the length provided, based on the number and type passed - num, term = re.search("\d+ (minutes?|hours?|days?|weeks?|months?)", msg.content.lower()).group(0).split(' ') + num, term = re.search("(\d+) (minutes?|hours?|days?)", msg.content.lower()).groups() # This should be safe to convert, we already made sure with our check earlier this would match num = int(num) # Now lets ensure this meets our min/max - if "minute" in term and (num < 10 or num > 129600): + if "minute" in term: + num = num * 60 + elif "hour" in term: + num = num * 60 * 60 + elif "day" in term: + num = num * 24 * 60 * 60 + + if 60 < num < 259200: await ctx.send( - "Length provided out of range! The minimum for this is 10 minutes, and the maximum is 3 months") - return - elif "hour" in term and num > 2160: - await ctx.send( - "Length provided out of range! The minimum for this is 10 minutes, and the maximum is 3 months") - return - elif "day" in term and num > 90: - await ctx.send( - "Length provided out of range! The minimum for this is 10 minutes, and the maximum is 3 months") - return - elif "week" in term and num > 12: - await ctx.send( - "Length provided out of range! The minimum for this is 10 minutes, and the maximum is 3 months") - return - elif "month" in term and num > 3: - await ctx.send( - "Length provided out of range! The minimum for this is 10 minutes, and the maximum is 3 months") + "Length provided out of range! The minimum for this is 10 minutes, and the maximum is 3 days") return - # Pendulum only accepts the plural version of terms, lets make sure this is added - term = term if term.endswith('s') else term + 's' - # If we're in the range, lets just pack this in a dictionary we can pass to set the time we want, then set that - payload = {term: num} - expires = now.add(**payload) - - # Now we're ready to add this as a new raffle - entry = { - 'title': title, - 'expires': expires.to_datetime_string(), - 'entrants': [], - 'author': str(author.id), - } - - raffles = self.bot.db.load('raffles', key=server.id, pluck='raffles') or [] - raffles.append(entry) - update = { - 'server_id': str(server.id), - 'raffles': raffles - } - await self.bot.db.save('raffles', update) + self.create_raffle(ctx, title, num) await ctx.send("I have just saved your new raffle!") - @raffle.command(name='alerts') - @commands.guild_only() - @utils.can_run(manage_guild=True) - async def raffle_alerts_channel(self, ctx, channel: discord.TextChannel): - """Sets the notifications channel for raffle notifications - - EXAMPLE: !raffle alerts #raffle - RESULT: raffle notifications will go to this channel - """ - entry = { - 'server_id': str(ctx.message.guild.id), - 'notifications': { - 'raffle': str(channel.id) - } - } - await self.bot.db.save('server_settings', entry) - await ctx.send("All raffle notifications will now go to {}".format(channel.mention)) - def setup(bot): bot.add_cog(Raffle(bot)) + + +class GuildRaffle: + + def __init__(self, ctx, title, expires): + self._ctx = ctx + self.title = title + self.expires = expires + self.entrants = set() + self.task = None + + @property + def guild(self): + return self._ctx.guild + + @property + def db(self): + return self._ctx.bot.db + + def start(self): + self.task = self._ctx.bot.loop.call_later(self.expires, self.end_raffle()) + + @property + def remaining(self): + minutes, seconds = divmod(self.task.when(), 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + return f"{days} days, {hours} hours, {minutes} minutes, {seconds} seconds" + + def enter(self, entrant): + self.entrants.add(entrant) + + async def end_raffle(self): + entrants = {e for e in self.entrants if self.guild.get_member(e.id)} + + query = """ +SELECT + COALESCE(raffle_alerts, default_alerts) AS channel, +FROM + guilds +WHERE + id = $1 +AND + COALESCE(raffle_alerts, default_alerts) IS NOT NULL + """ + channel = None + result = await self.db.fetch(query, self.guild.id) + + if result: + channel = self.guild.get_channel(result['channel']) + if channel is None: + return + + if entrants: + winner = random.SystemRandom().choice(self.entrants) + await channel.send(f"The winner of the raffle `{self.title}` is {winner.mention}! Congratulations!") + else: + await channel.send( + f"There were no entrants to the raffle `{self.title}`, who are in this server currently!" + ) diff --git a/cogs/roles.py b/cogs/roles.py index a7cf1e2..f940a6d 100644 --- a/cogs/roles.py +++ b/cogs/roles.py @@ -85,7 +85,7 @@ class Roles: total_members = len(role.members) embed.add_field(name="Total members", value=str(total_members)) # If there are only a few members in this role, display them - if total_members <= 5 and total_members > 0: + if 5 >= total_members > 0: embed.add_field(name="Members", value="\n".join(m.display_name for m in role.members)) await ctx.send(embed=embed) else: diff --git a/cogs/spades.py b/cogs/spades.py index 69f8d4f..cf6272e 100644 --- a/cogs/spades.py +++ b/cogs/spades.py @@ -299,7 +299,7 @@ class Game: await p.show_table() await p.get_bid() - self.order_turns(self.get_highest_bidder()) + self.order_turns(self.get_highest_bidder()) # Bids are complete, time to start the game await self.clean_messages() @@ -319,13 +319,14 @@ class Game: await self.update_table() # Get the winner after the round, increase their tricks winner = self.get_round_winner() + winning_card = winner.played_card winner.tricks += 1 # Order players based off the winner self.order_turns(winner) # Reset the round await self.reset_round() - fmt = "{} won with a {}".format(winner.discord_member.display_name, winner.played_card) + fmt = "{} won with a {}".format(winner.discord_member.display_name, winning_card) for p in self.players: await p.send_message(content=fmt) @@ -352,9 +353,12 @@ class Game: highest_bid = -1 highest_player = None for player in self.players: + print(player.bid_num, player.discord_member.display_name) if player.bid_num > highest_bid: highest_player = player + print(highest_player.discord_member.display_name) + return highest_player def order_turns(self, player): @@ -432,7 +436,7 @@ class Spades: # If so add the player to it self.pending_game.join(author) # If we've hit 4 players, we want to start the game, add it to our list of games, and wipe our pending game - if len(self.pending_game.players) == 4: + if len(self.pending_game.players) == 2: task = self.bot.loop.create_task(self.pending_game.start()) self.games.append((self.pending_game, task)) self.pending_game = None diff --git a/cogs/spotify.py b/cogs/spotify.py index d35d869..9be4c52 100644 --- a/cogs/spotify.py +++ b/cogs/spotify.py @@ -27,10 +27,10 @@ class Spotify: async def api_token_task(self): while True: + delay = 2400 try: delay = await self.get_api_token() except Exception as error: - delay = 2400 with open("error_log", 'a') as f: traceback.print_tb(error.__traceback__, file=f) print('{0.__class__.__name__}: {0}'.format(error), file=f) diff --git a/cogs/stats.py b/cogs/stats.py index d5d0bda..022a579 100644 --- a/cogs/stats.py +++ b/cogs/stats.py @@ -92,8 +92,6 @@ class Stats: EXAMPLE: !command stats play RESULT: The realization that this is the only reason people use me ;-;""" - await ctx.message.channel.trigger_typing() - cmd = self.bot.get_command(command) if cmd is None: await ctx.send("`{}` is not a valid command".format(command)) @@ -124,8 +122,6 @@ class Stats: EXAMPLE: !command leaderboard me RESULT: The realization of how little of a life you have""" - await ctx.message.channel.trigger_typing() - if re.search('(author|me)', option): mid = str(ctx.message.author.id) # First lets get all the command usage @@ -176,31 +172,35 @@ class Stats: EXAMPLE: !mostboops RESULT: You've booped @OtherPerson 351253897120935712093572193057310298 times!""" + query = """ +SELECT + boopee, amount +FROM + boops +WHERE + booper=$1 +AND + boopee IN ($2) +ORDER BY + amount DESC +LIMIT 1 +""" + members = ", ".join(f"{m.id}" for m in ctx.guild.members) + most = await self.bot.db.fetchrow(query, ctx.author.id, members) + boops = self.bot.db.load('boops', key=ctx.message.author.id) if boops is None or "boops" not in boops: await ctx.send("You have not booped anyone {} Why the heck not...?".format(ctx.message.author.mention)) return - # Just to make this easier, just pay attention to the boops data, now that we have the right entry - boops = boops['boops'] - - sorted_boops = sorted( - ((ctx.guild.get_member(int(member_id)), amount) - for member_id, amount in boops.items() - if ctx.guild.get_member(int(member_id))), - reverse=True, - key=lambda k: k[1] - ) - - # Since this is sorted, we just need to get the following information on the first user in the list - try: - member, most_boops = sorted_boops[0] - except IndexError: - await ctx.send("You have not booped anyone in this server {}".format(ctx.message.author.mention)) - return + if len(most) == 0: + await ctx.send(f"You have not booped anyone in this server {ctx.author.mention}") else: - await ctx.send("{0} you have booped {1} the most amount of times, coming in at {2} times".format( - ctx.message.author.mention, member.display_name, most_boops)) + member = ctx.guild.get_member(most['boopee']) + await ctx.send( + f"{ctx.author.mention} you have booped {member.display_name} the most amount of times, " + f"coming in at {most['amount']} times" + ) @commands.command() @commands.guild_only() @@ -210,28 +210,30 @@ class Stats: EXAMPLE: !listboops RESULT: The list of your booped members!""" - await ctx.message.channel.trigger_typing() - boops = self.bot.db.load('boops', key=ctx.message.author.id) - if not boops: - await ctx.send("You have not booped anyone {} Why the heck not...?".format(ctx.message.author.mention)) - return + query = """ +SELECT + boopee, amount +FROM + boops +WHERE + booper=$1 +AND + boopee IN ($2) +ORDER BY + amount DESC +LIMIT 10 + """ - # Just to make this easier, just pay attention to the boops data, now that we have the right entry - boops = boops['boops'] + members = ", ".join(f"{m.id}" for m in ctx.guild.members) + most = await self.bot.db.fetch(query, ctx.author.id, members) - sorted_boops = sorted( - ((ctx.guild.get_member(int(member_id)), amount) - for member_id, amount in boops.items() - if ctx.guild.get_member(int(member_id))), - reverse=True, - key=lambda k: k[1] - ) - if sorted_boops: + if len(most) != 0: embed = discord.Embed(title="Your booped victims", colour=ctx.author.colour) embed.set_author(name=str(ctx.author), icon_url=ctx.author.avatar_url) - for member, amount in sorted_boops: - embed.add_field(name=member.display_name, value=amount) + for row in most: + member = ctx.guild.get_member(row['boopee']) + embed.add_field(name=member.display_name, value=row['amount']) await ctx.send(embed=embed) else: await ctx.send("You haven't booped anyone in this server!") @@ -244,35 +246,34 @@ class Stats: EXAMPLE: !leaderboard RESULT: A leaderboard of this server's battle records""" - await ctx.message.channel.trigger_typing() - # Create a list of the ID's of all members in this server, for comparison to the records saved - server_member_ids = [member.id for member in ctx.message.guild.members] - battles = self.bot.db.load('battle_records') - if battles is None or len(battles) == 0: + query = """ +SELECT + id, battle_rating +FROM + users +WHERE + id = any($1::bigint[]) +ORDER BY + battle_rating DESC +""" + + results = await self.bot.db.fetch(query, [m.id for m in ctx.guild.members]) + + if len(results) == 0: await ctx.send("No one has battled on this server!") + else: - battles = [ - battle - for member_id, battle in battles.items() - if int(member_id) in server_member_ids - ] + output = [] + for row in results: + member = ctx.guild.get_member(row['id']) + output.append(f"{member.display_name} (Rating: {row['battle_rating']})") - # Sort the members based on their rating - sorted_members = sorted(battles, key=lambda k: k['rating'], reverse=True) - - output = [] - for x in sorted_members: - member_id = int(x['member_id']) - rating = x['rating'] - member = ctx.message.guild.get_member(member_id) - output.append("{} (Rating: {})".format(member.display_name, rating)) - - try: - pages = utils.Pages(ctx, entries=output) - await pages.paginate() - except utils.CannotPaginate as e: - await ctx.send(str(e)) + try: + pages = utils.Pages(ctx, entries=output) + await pages.paginate() + except utils.CannotPaginate as e: + await ctx.send(str(e)) @commands.command() @commands.guild_only() @@ -282,8 +283,6 @@ class Stats: EXAMPLE: !stats @OtherPerson RESULT: How good they are at winning a completely luck based game""" - await ctx.message.channel.trigger_typing() - member = member or ctx.message.author # Get the different data that we'll display server_rank = "{}/{}".format(*self.bot.br.get_server_rank(member)) diff --git a/cogs/tags.py b/cogs/tags.py index a1969fc..e6c62b3 100644 --- a/cogs/tags.py +++ b/cogs/tags.py @@ -20,8 +20,9 @@ class Tags: EXAMPLE: !tags RESULT: All tags setup on this server""" - tags = self.bot.db.load('tags', key=ctx.message.guild.id, pluck='tags') - if tags: + tags = await self.bot.db.fetch("SELECT trigger FROM tags WHERE guild=$1", ctx.guild.id) + + if len(tags) > 0: entries = [t['trigger'] for t in tags] pages = utils.Pages(ctx, entries=entries) await pages.paginate() @@ -36,16 +37,18 @@ class Tags: EXAMPLE: !mytags RESULT: All your tags setup on this server""" - tags = self.bot.db.load('tags', key=ctx.message.guild.id, pluck='tags') - if tags: - entries = [t['trigger'] for t in tags if t['author'] == str(ctx.message.author.id)] - if len(entries) == 0: - await ctx.send("You have no tags setup on this server!") - else: - pages = utils.Pages(ctx, entries=entries) - await pages.paginate() + tags = await self.bot.db.fetch( + "SELECT trigger FROM tags WHERE guild=$1 AND creator=$2", + ctx.guild.id, + ctx.author.id + ) + + if len(tags) > 0: + entries = [t['trigger'] for t in tags] + pages = utils.Pages(ctx, entries=entries) + await pages.paginate() else: - await ctx.send("There are no tags setup on this server!") + await ctx.send("You have no tags on this server!") @commands.group(invoke_without_command=True) @commands.guild_only() @@ -56,16 +59,17 @@ class Tags: EXAMPLE: !tag butts RESULT: Whatever you setup for the butts tag!!""" - tag = tag.lower().strip() - tags = self.bot.db.load('tags', key=ctx.message.guild.id, pluck='tags') - if tags: - for t in tags: - if t['trigger'].lower().strip() == tag: - await ctx.send("\u200B{}".format(t['result'])) - return - await ctx.send("There is no tag called {}".format(tag)) + tag = await self.bot.db.fetchrow( + "SELECT id, result FROM tags WHERE guild=$1 AND trigger=$2", + ctx.guild.id, + tag.lower().strip() + ) + + if tag: + await ctx.send("\u200B{}".format(tag['result'])) + await self.bot.db.execute("UPDATE tags SET uses = uses + 1 WHERE id = $1", tag['id']) else: - await ctx.send("There are no tags setup on this server!") + await ctx.send("There is no tag called {}".format(tag)) @tag.command(name='add', aliases=['create', 'setup']) @commands.guild_only() @@ -88,22 +92,24 @@ class Tags: return trigger = msg.content.lower().strip() - forbidden_tags = ['add', 'create', 'setup', 'edit', ''] + forbidden_tags = ['add', 'create', 'setup', 'edit', 'info', 'delete', 'remove', 'stop'] if len(trigger) > 100: await ctx.send("Please keep tag triggers under 100 characters") return - elif trigger in forbidden_tags: + elif trigger.lower() in forbidden_tags: await ctx.send( "Sorry, but your tag trigger was detected to be forbidden. " "Current forbidden tag triggers are: \n{}".format("\n".join(forbidden_tags))) return - tags = self.bot.db.load('tags', key=ctx.message.guild.id, pluck='tags') or [] - if tags: - for t in tags: - if t['trigger'].lower().strip() == trigger: - await ctx.send("There is already a tag setup called {}!".format(trigger)) - return + tag = await self.bot.db.fetchrow( + "SELECT result FROM tags WHERE guild=$1 AND trigger=$2", + ctx.guild.id, + trigger.lower().strip() + ) + if tag: + await ctx.send("There is already a tag setup called {}!".format(trigger)) + return try: await my_msg.delete() @@ -111,10 +117,6 @@ class Tags: except (discord.Forbidden, discord.HTTPException): pass - if trigger.lower() in ['edit', 'delete', 'remove', 'stop']: - await ctx.send("You can't create a tag with {}!".format(trigger)) - return - my_msg = await ctx.send( "Alright, your new tag can be called with {}!\n\nWhat do you want to be displayed with this tag?".format( trigger)) @@ -132,92 +134,97 @@ class Tags: except (discord.Forbidden, discord.HTTPException): pass - # The different DB settings - tag = { - 'author': str(ctx.message.author.id), - 'trigger': trigger, - 'result': result - } - tags.append(tag) - entry = { - 'server_id': str(ctx.message.guild.id), - 'tags': tags - } - await self.bot.db.save('tags', entry) await ctx.send("I have just setup a new tag for this server! You can call your tag with {}".format(trigger)) + await self.bot.db.execute( + "INSERT INTO tags(guild, creator, trigger, result) VALUES ($1, $2, $3, $4)", + ctx.guild.id, + ctx.author.id, + trigger, + result + ) @tag.command(name='edit') @commands.guild_only() @utils.can_run(send_messages=True) - async def edit_tag(self, ctx, *, tag: str): + async def edit_tag(self, ctx, *, trigger: str): """This will allow you to edit a tag that you have created EXAMPLE: !tag edit this tag RESULT: I'll ask what you want the new result to be""" - tags = self.bot.db.load('tags', key=ctx.message.guild.id, pluck='tags') - def check(m): return m.channel == ctx.message.channel and m.author == ctx.message.author and len(m.content) > 0 - if tags: - for i, t in enumerate(tags): - if t['trigger'] == tag: - if t['author'] == str(ctx.message.author.id): - my_msg = await ctx.send( - "Alright, what do you want the new result for the tag {} to be".format(tag)) - try: - msg = await self.bot.wait_for("message", check=check, timeout=60) - except asyncio.TimeoutError: - await ctx.send("You took too long!") - return - new_tag = t.copy() - new_tag['result'] = msg.content - tags[i] = new_tag - try: - await my_msg.delete() - await msg.delete() - except discord.Forbidden: - pass - entry = { - 'server_id': str(ctx.message.guild.id), - 'tags': tags - } - await self.bot.db.save('tags', entry) - await ctx.send("Alright, the tag {} has been updated".format(tag)) - return - else: - await ctx.send("You can't edit someone else's tag!") - return - await ctx.send("There isn't a tag called {}!".format(tag)) + tag = await self.bot.db.fetchrow( + "SELECT id, trigger FROM tags WHERE guild=$1 AND creator=$2 AND trigger=$3", + ctx.guild.id, + ctx.author.id, + trigger + ) + + if tag: + my_msg = await ctx.send(f"Alright, what do you want the new result for the tag {tag} to be") + try: + msg = await self.bot.wait_for("message", check=check, timeout=60) + except asyncio.TimeoutError: + await ctx.send("You took too long!") + return + + new_result = msg.content + + try: + await my_msg.delete() + await msg.delete() + except (discord.Forbidden, discord.HTTPException): + pass + + await ctx.send(f"Alright, the tag {trigger} has been updated") + await self.bot.db.execute("UPDATE tags SET result=$1 WHERE id=$2", new_result, tag['id']) else: - await ctx.send("There are no tags setup on this server!") + await ctx.send(f"You do not have a tag called {trigger} on this server!") @tag.command(name='delete', aliases=['remove', 'stop']) @commands.guild_only() @utils.can_run(send_messages=True) - async def del_tag(self, ctx, *, tag: str): + async def del_tag(self, ctx, *, trigger: str): """Use this to remove a tag from use for this server Format to delete a tag is !tag delete EXAMPLE: !tag delete stupid_tag RESULT: Deletes that stupid tag""" - tags = self.bot.db.load('tags', key=ctx.message.guild.id, pluck='tags') - if tags: - for t in tags: - if t['trigger'].lower().strip() == tag: - if ctx.message.author.permissions_in(ctx.message.channel).manage_guild or str( - ctx.message.author.id) == t['author']: - tags.remove(t) - entry = { - 'server_id': str(ctx.message.guild.id), - 'tags': tags - } - await self.bot.db.save('tags', entry) - await ctx.send("I have just removed the tag {}".format(tag)) - else: - await ctx.send("You don't own that tag! You can't remove it!") - return + + tag = await self.bot.db.fetchrow( + "SELECT id FROM tags WHERE guild=$1 AND creator=$2 AND trigger=$3", + ctx.guild.id, + ctx.author.id, + trigger + ) + + if tag: + await ctx.send(f"I have just deleted the tag {trigger}") + await self.bot.db.execute("DELETE FROM tags WHERE id=$1", tag['id']) else: - await ctx.send("There are no tags setup on this server!") + await ctx.send(f"You do not own a tag called {trigger} on this server!") + + @tag.command(name="info") + @commands.guild_only() + @utils.can_run(send_messages=True) + async def info_tag(self, ctx, *, trigger: str): + """Shows some information a bout the tag given""" + + tag = await self.bot.db.fetchrow( + "SELECT creator, uses, trigger FROM tags WHERE guild=$1 AND trigger=$3", + ctx.guild.id, + trigger + ) + + embed = discord.Embed(title=tag['trigger']) + creator = ctx.guild.get_member(tag['creator']) + if creator: + embed.set_author(name=creator.display_name, url=creator.avatar_url) + embed.add_field(name="Uses", value=tag['uses']) + embed.add_field(name="Owner", value=creator.mention) + + await ctx.send(embed=embed) + def setup(bot): diff --git a/cogs/tutorial.py b/cogs/tutorial.py index efd18a5..db6a62b 100644 --- a/cogs/tutorial.py +++ b/cogs/tutorial.py @@ -26,13 +26,17 @@ class Tutorial: await ctx.send("Could not find a command or a cog for {}".format(cmd_or_cog)) return - commands = [c for c in utils.get_all_commands(self.bot) if c.cog_name == cmd_or_cog.title()] + commands = set([ + c + for c in self.bot.walk_commands() + if c.cog_name == cmd_or_cog.title() + ]) # Specific command else: commands = [cmd] # Use all commands else: - commands = list(utils.get_all_commands(self.bot)) + commands = set(self.bot.walk_commands()) # Loop through all the commands that we want to use for command in commands: diff --git a/requirements.txt b/requirements.txt index dff3063..8f451af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -Pillow==4.2.0 rethinkdb pyyaml psutil pendulum beautifulsoup4 osuapi +asyncpg -e git+https://github.com/Rapptz/discord.py@rewrite#egg=discord.py \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py index a9ea23b..3055ffc 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,6 +1,6 @@ from .cards import Deck, Face, Suit -from .checks import can_run, db_check +from .checks import can_run from .config import * from .utilities import * from .paginator import Pages, CannotPaginate, HelpPaginator -from .database import DB +from .database import DB, Cache diff --git a/utils/checks.py b/utils/checks.py index ab10b68..518dae4 100644 --- a/utils/checks.py +++ b/utils/checks.py @@ -8,69 +8,14 @@ from . import utilities loop = asyncio.get_event_loop() -# The tables needed for the database, as well as their primary keys -required_tables = { - 'battle_records': 'member_id', - 'boops': 'member_id', - 'command_usage': 'command', - 'overwatch': 'member_id', - 'picarto': 'member_id', - 'server_settings': 'server_id', - 'raffles': 'server_id', - 'strawpolls': 'server_id', - 'osu': 'member_id', - 'tags': 'server_id', - 'tictactoe': 'member_id', - 'twitch': 'member_id', - 'user_playlists': 'member_id', - 'birthdays': 'member_id' -} - - -async def db_check(): - """Used to check if the required database/tables are setup""" - db_opts = config.db_opts - - r.set_loop_type('asyncio') - # First try to connect, and see if the correct information was provided - try: - conn = await r.connect(**db_opts) - except r.errors.ReqlDriverError: - print("Cannot connect to the RethinkDB instance with the following information: {}".format(db_opts)) - - print("The RethinkDB instance you have setup may be down, otherwise please ensure you setup a" - " RethinkDB instance, and you have provided the correct database information in config.yml") - quit() - return - - # Get the current databases and check if the one we need is there - dbs = await r.db_list().run(conn) - if db_opts['db'] not in dbs: - # If not, we want to create it - print('Couldn\'t find database {}...creating now'.format(db_opts['db'])) - await r.db_create(db_opts['db']).run(conn) - # Then add all the tables - for table, key in required_tables.items(): - print("Creating table {}...".format(table)) - await r.table_create(table, primary_key=key).run(conn) - print("Done!") - else: - # Otherwise, if the database is setup, make sure all the required tables are there - tables = await r.table_list().run(conn) - for table, key in required_tables.items(): - if table not in tables: - print("Creating table {}...".format(table)) - await r.table_create(table, primary_key=key).run(conn) - print("Done checking tables!") - def should_ignore(ctx): if ctx.message.guild is None: return False - ignored = ctx.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='ignored') + ignored = ctx.bot.cache.ignored[ctx.guild.id] if not ignored: return False - return str(ctx.message.author.id) in ignored['members'] or str(ctx.message.channel.id) in ignored['channels'] + return ctx.message.author.id in ignored['members'] or ctx.message.channel.id in ignored['channels'] async def check_not_restricted(ctx): @@ -79,7 +24,7 @@ async def check_not_restricted(ctx): return True # First get all the restrictions - restrictions = ctx.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='restrictions') or {} + restrictions = ctx.bot.cache.restrictions[ctx.guild.id] # Now lets check the "from" restrictions for from_restriction in restrictions.get('from', []): # Get the source and destination @@ -169,8 +114,7 @@ def has_perms(ctx, **perms): for perm, setting in perms.items(): setattr(required_perm, perm, setting) - required_perm_value = ctx.bot.db.load('server_settings', key=ctx.message.guild.id, pluck='permissions') or {} - required_perm_value = required_perm_value.get(ctx.command.qualified_name) + required_perm_value = ctx.bot.cache.custom_permissions[ctx.guild.id].get(ctx.command.qualified_name) if required_perm_value: required_perm = discord.Permissions(required_perm_value) diff --git a/utils/config.py b/utils/config.py index 0eec099..2c7d202 100644 --- a/utils/config.py +++ b/utils/config.py @@ -57,6 +57,7 @@ extensions = [ 'cogs.misc', 'cogs.mod', 'cogs.admin', + 'cogs.config', 'cogs.images', 'cogs.birthday', 'cogs.owner', @@ -80,25 +81,21 @@ extensions = [ # The default status the bot will use default_status = global_config.get("default_status", None) -# The rethinkdb hostname -db_host = global_config.get('db_host', 'localhost') -# The rethinkdb database name -db_name = global_config.get('db_name', 'Discord_Bot') -# The rethinkdb certification -db_cert = global_config.get('db_cert', '') -# The rethinkdb port -db_port = global_config.get('db_port', 28015) +# The database hostname +db_host = global_config.get('db_host', None) +# The database name +db_name = global_config.get('db_name', 'bonfire') +# The database port +db_port = global_config.get('db_port', None) # The user and password assigned -db_user = global_config.get('db_user', 'admin') -db_pass = global_config.get('db_pass', '') +db_user = global_config.get('db_user', None) +db_pass = global_config.get('db_pass', None) # We've set all the options we need to be able to connect # so create a dictionary that we can use to unload to connect -# db_opts = {'host': db_host, 'db': db_name, 'port': db_port, 'ssl': -# {'ca_certs': db_cert}, 'user': db_user, 'password': db_pass} -db_opts = {'host': db_host, 'db': db_name, 'port': db_port, 'user': db_user, 'password': db_pass} +db_opts = {'host': db_host, 'database': db_name, 'port': db_port, 'user': db_user, 'password': db_pass} def command_prefix(bot, message): if not message.guild: return default_prefix - return bot.db.load('server_settings', key=message.guild.id, pluck='prefix') or default_prefix + return bot.cache.prefixes.get(message.guild.id, default_prefix) diff --git a/utils/database.py b/utils/database.py index 78638e3..0cfd86b 100644 --- a/utils/database.py +++ b/utils/database.py @@ -1,61 +1,87 @@ import asyncio -import rethinkdb as r -from datetime import datetime -from .checks import required_tables +import asyncpg + +from collections import defaultdict from . import config -async def _convert_to_list(cursor): - # This method is here because atm, AsyncioCursor is not iterable - # For our purposes, we want a list, so we need to do this manually - cursor_list = [] - while True: - try: - val = await cursor.next() - cursor_list.append(val) - except r.ReqlCursorEmpty: - break - return cursor_list - - class Cache: - """A class to hold the cached database entries""" + """A class to hold the entires that are called on every message/command""" - def __init__(self, table, key, db, loop): - self.table = table # The name of the database table - self.key = key # The name of primary key - self.db = db # The database class connections are made through - self.loop = loop - self.values = {} # The values returned from the database - self.refreshed_time = None - self.loop.create_task(self.refresh_task()) + def __init__(self, db): + self.db = db + self.prefixes = {} + self.ignored = defaultdict(dict) + self.custom_permissions = defaultdict(dict) + self.restrictions = defaultdict(dict) - async def refresh(self): - self.values = await self.db.query(r.table(self.table).group(self.key)[0]) - self.refreshed_time = datetime.now() + async def setup(self): + await self.load_prefixes() + await self.load_custom_permissions() + await self.load_restrictions() + await self.load_ignored() - async def refresh_task(self): - await self.check_refresh() - await asyncio.sleep(60) + async def load_ignored(self): + query = """ +SELECT + id, ignored_channels, ignored_members +FROM + guilds +WHERE + array_length(ignored_channels, 1) > 0 OR + array_length(ignored_members, 1) > 0 +""" + rows = await self.db.fetch(query) + for row in rows: + self.ignored[row['guild']]['members'] = row['ignored_members'] + self.ignored[row['guild']]['channels'] = row['ignored_channels'] - async def check_refresh(self): - if self.refreshed_time is None: - await self.refresh() - else: - difference = datetime.now() - self.refreshed_time - if difference.total_seconds() > 300: - await self.refresh() + async def load_prefixes(self): + query = """ +SELECT + id, prefix +FROM + guilds +WHERE + prefix IS NOT NULL +""" + rows = await self.db.fetch(query) + for row in rows: + self.prefixes[row['id']] = row['prefix'] - def get(self, key=None, pluck=None): - """This simulates the database call, to make it easier to get the data""" - value = self.values - if key: - value = value.get(str(key), {}) - if pluck: - value = value.get(pluck) + def update_prefix(self, guild, prefix): + self.prefixes[guild.id] = prefix - return value + async def load_custom_permissions(self): + query = """ +SELECT + guild, command, permission +FROM + custom_permissions +WHERE + permission IS NOT NULL +""" + rows = await self.db.fetch(query) + for row in rows: + self.custom_permissions[row['guild']][row['command']] = row['permission'] + + def update_custom_permission(self, guild, command, permission): + self.custom_permissions[guild.id][command.qualified_name] = permission + + async def load_restrictions(self): + query = """ +SELECT + guild, source, from_to, destination +FROM + restrictions +""" + rows = await self.db.fetch(query) + for row in rows: + opt = {"source": row['source'], "destination": row['destination']} + from_restrictions = self.restrictions[row['guild']].get(row['from_to'], []) + from_restrictions.append(opt) + self.restrictions[row['guild']][row['from_to']] = from_restrictions class DB: @@ -63,66 +89,40 @@ class DB: self.loop = asyncio.get_event_loop() self.opts = config.db_opts self.cache = {} + self._pool = None - for table, key in required_tables.items(): - self.cache[table] = Cache(table, key, self, self.loop) + async def connect(self): + self._pool = await asyncpg.create_pool(**self.opts) - async def query(self, query): - """Lets you run a manual query""" - r.set_loop_type("asyncio") - conn = await r.connect(**self.opts) - try: - cursor = await query.run(conn) - except (r.ReqlOpFailedError, r.ReqlNonExistenceError): - cursor = None - if isinstance(cursor, r.Cursor): - cursor = await _convert_to_list(cursor) - await conn.close() - return cursor + async def setup(self): + await self.connect() -# def save(self, table, content): -# """A synchronous task to throw saving content into a task""" -# self.loop.create_task(self._save(table, content)) + async def _query(self, call, query, *args, **kwargs): + """this will acquire a connection and make the call, then return the result""" + async with self._pool.acquire() as connection: + async with connection.transaction(): + return await getattr(connection, call)(query, *args, **kwargs) - async def save(self, table, content): - """Saves data in the table""" + async def execute(self, *args, **kwargs): + return await self._query("execute", *args, **kwargs) - index = await self.query(r.table(table).info()) - index = index.get("primary_key") - key = content.get(index) - if key: - cur_content = await self.query(r.table(table).get(key)) - if cur_content: - # We have content...we either need to update it, or replace - # Update will typically be more common so lets try that first - result = await self.query(r.table(table).get(key).update(content)) - if result.get('replaced', 0) == 0 and result.get('unchanged', 0) == 0: - await self.query(r.table(table).get(key).replace(content)) - else: - await self.query(r.table(table).insert(content)) - else: - await self.query(r.table(table).insert(content)) + async def fetch(self, *args, **kwargs): + return await self._query("fetch", *args, **kwargs) - await self.cache.get(table).refresh() + async def fetchrow(self, *args, **kwargs): + return await self._query("fetchrow", *args, **kwargs) - def load(self, table, **kwargs): - return self.cache.get(table).get(**kwargs) + async def fetchval(self, *args, **kwargs): + return await self._query("fetchval", *args, **kwargs) - async def actual_load(self, table, key=None, table_filter=None, pluck=None): - """Loads the specified content from the specific table""" - query = r.table(table) - - # If a key has been provided, get content with that key - if key: - query = query.get(str(key)) - # A key and a filter shouldn't be combined for any case we'll ever use, so seperate these - elif table_filter: - query = query.filter(table_filter) - - # If we want to pluck something specific, do that - if pluck: - query = query.pluck(pluck).values()[0] - - cursor = await self.query(query) - - return cursor + async def upsert(self, table, data): + keys = values = "" + for num, k in enumerate(data.keys()): + if num > 0: + keys += ", " + values += ", " + keys += k + values += f"${num}" + query = f"INSERT INTO {table} ({keys}) VALUES ({values}) ON CONFLICT DO UPDATE" + print(query) + return await self.execute(query, *data.values()) diff --git a/utils/utilities.py b/utils/utilities.py index e8e1cc5..473a469 100644 --- a/utils/utilities.py +++ b/utils/utilities.py @@ -5,47 +5,10 @@ import discord from discord.ext import commands from . import config -from PIL import Image -def convert_to_jpeg(pfile): - # Open the file given - img = Image.open(pfile) - # Create the BytesIO object we'll use as our new "file" - new_file = BytesIO() - # Save to this file as jpeg - img.save(new_file, format='JPEG') - # In order to use the file, we need to seek back to the 0th position - new_file.seek(0) - return new_file - - -def get_all_commands(bot): - """Returns a list of all command names for the bot""" - # First lets create a set of all the parent names - for cmd in bot.commands: - yield from get_all_subcommands(cmd) - - -def get_all_subcommands(command): - yield command - if type(command) is discord.ext.commands.core.Group: - for subcmd in command.commands: - yield from get_all_subcommands(subcmd) - - -async def channel_is_nsfw(channel, db): - if type(channel) is discord.DMChannel: - server = 'DMs' - elif channel.is_nsfw(): - return True - else: - server = str(channel.guild.id) - - channel = str(channel.id) - - channels = db.load('server_settings', key=server, pluck='nsfw_channels') or [] - return channel in channels +def channel_is_nsfw(channel): + return isinstance(channel, discord.DMChannel) or channel.is_nsfw() async def download_image(url): @@ -103,8 +66,11 @@ async def request(url, *, headers=None, payload=None, method='GET', attr='json', except: continue + async def convert(ctx, option): """Tries to convert a string to an object of useful representiation""" + # Due to id's being ints, it's very possible that an int is passed + option = str(option) cmd = ctx.bot.get_command(option) if cmd: return cmd @@ -132,25 +98,52 @@ async def convert(ctx, option): return role +def update_rating(winner_rating, loser_rating): + # The scale is based off of increments of 25, increasing the change by 1 for each increment + # That is all this loop does, increment the "change" for every increment of 25 + # The change caps off at 300 however, so break once we are over that limit + difference = abs(winner_rating - loser_rating) + rating_change = 0 + count = 25 + while count <= difference: + if count > 300: + break + rating_change += 1 + count += 25 + + # 16 is the base change, increased or decreased based on whoever has the higher current rating + if winner_rating > loser_rating: + winner_rating += 16 - rating_change + loser_rating -= 16 - rating_change + else: + winner_rating += 16 + rating_change + loser_rating -= 16 + rating_change + + return winner_rating, loser_rating + + async def update_records(key, db, winner, loser): # We're using the Harkness scale to rate # http://opnetchessclub.wikidot.com/harkness-rating-system - r_filter = lambda row: (row['member_id'] == str(winner.id)) | (row['member_id'] == str(loser.id)) - matches = await db.actual_load(key, table_filter=r_filter) + wins = f"{key}_wins" + losses = f"{key}_losses" + key = f"{key}_rating" + query = """ +SELECT + id, $1, $2, $3 +FROM + users +WHERE + id = any($4::bigint[]) +""" + results = await db.fetch(key, wins, losses, [winner.id, loser.id]) - winner_stats = {} - loser_stats = {} - try: - for stat in matches: - if stat.get('member_id') == str(winner.id): - winner_stats = stat - elif stat.get('member_id') == str(loser.id): - loser_stats = stat - except TypeError: - pass - - winner_rating = winner_stats.get('rating') or 1000 - loser_rating = loser_stats.get('rating') or 1000 + winner_rating = loser_rating = 1000 + for result in results: + if result['id'] == winner.id: + winner_rating = result[key] + else: + loser_rating = result[key] # The scale is based off of increments of 25, increasing the change by 1 for each increment # That is all this loop does, increment the "change" for every increment of 25