Update to work with new database settings
This commit is contained in:
parent
8e89847ed2
commit
5612f6e25a
|
@ -7,10 +7,22 @@ from . import config
|
|||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# The list of tables needed for the database
|
||||
table_list = ['battle_records', 'battling', 'boops', 'bot_data', 'command_usage', 'custom_permissions',
|
||||
'deviantart', 'motd', 'nsfw_channels', 'overwatch', 'picarto', 'prefixes', 'raffles',
|
||||
'rules', 'server_alerts', 'strawpolls', 'tags', 'tictactoe', 'twitch', 'user_notifications']
|
||||
# The tables needed for the database, as well as their primary keys
|
||||
required_tables = {
|
||||
'battle_records': 'member_id',
|
||||
'boops': 'member_id',
|
||||
'command_usage': 'command',
|
||||
'motd': 'date',
|
||||
'overwatch': 'member_id',
|
||||
'picarto': 'member_id',
|
||||
'server_settings': 'server_id',
|
||||
'raffles': 'id',
|
||||
'strawpolls': 'server_id',
|
||||
'osu': 'member_id',
|
||||
'tags': 'server_id',
|
||||
'tictactoe': 'member_id',
|
||||
'twitch': 'member_id'
|
||||
}
|
||||
|
||||
|
||||
async def db_check():
|
||||
|
@ -24,9 +36,10 @@ async def db_check():
|
|||
except r.errors.ReqlDriverError:
|
||||
print("Cannot connect to the RethinkDB instance with the following information: {}".format(db_opts))
|
||||
|
||||
print("The RethinkDB instance you have setup may be down, otherwise please ensure you setup a"\
|
||||
" RethinkDB instance, and you have provided the correct database information in config.yml")
|
||||
print("The RethinkDB instance you have setup may be down, otherwise please ensure you setup a"
|
||||
" RethinkDB instance, and you have provided the correct database information in config.yml")
|
||||
quit()
|
||||
return
|
||||
|
||||
# Get the current databases and check if the one we need is there
|
||||
dbs = await r.db_list().run(conn)
|
||||
|
@ -35,19 +48,20 @@ async def db_check():
|
|||
print('Couldn\'t find database {}...creating now'.format(db_opts['db']))
|
||||
await r.db_create(db_opts['db']).run(conn)
|
||||
# Then add all the tables
|
||||
for table in table_list:
|
||||
for table, key in required_tables.items():
|
||||
print("Creating table {}...".format(table))
|
||||
await r.table_create(table).run(conn)
|
||||
await r.table_create(table, primary_key=key).run(conn)
|
||||
print("Done!")
|
||||
else:
|
||||
# Otherwise, if the database is setup, make sure all the required tables are there
|
||||
tables = await r.table_list().run(conn)
|
||||
for table in table_list:
|
||||
for table, key in required_tables.items():
|
||||
if table not in tables:
|
||||
print("Creating table {}...".format(table))
|
||||
await r.table_create(table).run(conn)
|
||||
await r.table_create(table, primary_key=key).run(conn)
|
||||
print("Done checking tables!")
|
||||
|
||||
|
||||
def is_owner(ctx):
|
||||
return ctx.message.author.id in config.owner_ids
|
||||
|
||||
|
@ -66,24 +80,15 @@ def custom_perms(**perms):
|
|||
for perm, setting in perms.items():
|
||||
setattr(required_perm, perm, setting)
|
||||
|
||||
perm_values = config.cache.get('custom_permissions').values
|
||||
|
||||
# Loop through and find this server's entry for custom permissions
|
||||
# Find the command we're using, if it exists, then overwrite
|
||||
# The required permissions, based on the value saved
|
||||
for x in perm_values:
|
||||
if x['server_id'] == ctx.message.server.id and x.get(ctx.command.qualified_name):
|
||||
required_perm = discord.Permissions(x[ctx.command.qualified_name])
|
||||
try:
|
||||
server_settings = config.cache.get('server_settings').values
|
||||
required_perm_value = [x for x in server_settings if x['server_id'] == ctx.message.server.id][0]['permissions'][ctx.command.qualified_name]
|
||||
required_perm = discord.Permissions(required_perm_value)
|
||||
except (TypeError, IndexError, KeyError):
|
||||
pass
|
||||
|
||||
# Now just check if the person running the command has these permissions
|
||||
return member_perms >= required_perm
|
||||
|
||||
predicate.perms = perms
|
||||
return commands.check(predicate)
|
||||
|
||||
|
||||
def is_pm():
|
||||
def predicate(ctx):
|
||||
return ctx.message.channel.is_private
|
||||
|
||||
return commands.check(predicate)
|
||||
|
|
|
@ -9,7 +9,7 @@ global_config = {}
|
|||
# Ensure that the required config.yml file actually exists
|
||||
try:
|
||||
with open("config.yml", "r") as f:
|
||||
global_config = yaml.load(f)
|
||||
global_config = yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
print("You have no config file setup! Please use config.yml.sample to setup a valid config file")
|
||||
quit()
|
||||
|
@ -70,10 +70,6 @@ user_agent = global_config.get('user_agent', "")
|
|||
# The extensions to load
|
||||
extensions = global_config.get('extensions', [])
|
||||
|
||||
# The variables needed for sharding
|
||||
shard_count = global_config.get('shard_count', 1)
|
||||
shard_id = global_config.get('shard_id', 0)
|
||||
|
||||
# The default status the bot will use
|
||||
default_status = global_config.get("default_status", None)
|
||||
# The URL that will be used to link to for the help command
|
||||
|
@ -95,10 +91,6 @@ db_pass = global_config.get('db_pass', '')
|
|||
# {'ca_certs': db_cert}, 'user': db_user, 'password': db_pass}
|
||||
db_opts = {'host': db_host, 'db': db_name, 'port': db_port, 'user': db_user, 'password': db_pass}
|
||||
|
||||
possible_keys = ['prefixes', 'battle_records', 'boops', 'server_alerts', 'user_notifications', 'nsfw_channels',
|
||||
'custom_permissions', 'rules', 'overwatch', 'picarto', 'twitch', 'strawpolls', 'tags',
|
||||
'tictactoe', 'bot_data', 'command_manage']
|
||||
|
||||
# This will be a dictionary that holds the cache object, based on the key that is saved
|
||||
cache = {}
|
||||
|
||||
|
@ -110,7 +102,7 @@ cache = {}
|
|||
|
||||
# We still need 'cache' for prefixes and custom permissions however, so for now, just include that
|
||||
cache['prefixes'] = Cache('prefixes')
|
||||
cache['custom_permissions'] = Cache('custom_permissions')
|
||||
cache['server_settings'] = Cache('server_settings')
|
||||
|
||||
async def update_cache():
|
||||
for value in cache.values():
|
||||
|
@ -123,61 +115,46 @@ def command_prefix(bot, message):
|
|||
# If the prefix does exist in the database and isn't in our cache; too bad, something has messed up
|
||||
# But it is not worth a query for every single message the bot detects, to fix
|
||||
try:
|
||||
values = cache['prefixes'].values
|
||||
try:
|
||||
prefix = [data['prefix'] for data in values if message.server.id == data['server_id']][0]
|
||||
except IndexError:
|
||||
prefix = None
|
||||
except AttributeError:
|
||||
prefix = None
|
||||
prefixes = cache['server_settings'].values
|
||||
prefix = [x for x in prefixes if x['server_id'] == message.guild.id][0]['prefix']
|
||||
return prefix or default_prefix
|
||||
except KeyError:
|
||||
except (KeyError, TypeError, IndexError, AttributeError):
|
||||
return default_prefix
|
||||
|
||||
|
||||
async def add_content(table, content, r_filter=None):
|
||||
async def add_content(table, content):
|
||||
r.set_loop_type("asyncio")
|
||||
conn = await r.connect(**db_opts)
|
||||
# First we need to make sure that this entry doesn't exist
|
||||
# For all rethinkDB cares, multiple entries can exist with the same content
|
||||
# For our purposes however, we do not want this
|
||||
try:
|
||||
if r_filter is not None:
|
||||
cursor = await r.table(table).filter(r_filter).run(conn)
|
||||
cur_content = await _convert_to_list(cursor)
|
||||
if len(cur_content) > 0:
|
||||
await conn.close()
|
||||
return False
|
||||
await r.table(table).insert(content).run(conn)
|
||||
result = await r.table(table).insert(content).run(conn)
|
||||
await conn.close()
|
||||
return True
|
||||
except r.ReqlOpFailedError:
|
||||
# This means the table does not exist
|
||||
await r.table_create(table).run(conn)
|
||||
await r.table(table).insert(content).run(conn)
|
||||
await conn.close()
|
||||
return True
|
||||
result = {}
|
||||
return result.get('inserted', 0) > 0
|
||||
|
||||
|
||||
async def remove_content(table, r_filter=None):
|
||||
if r_filter is None:
|
||||
r_filter = {}
|
||||
async def remove_content(table, key):
|
||||
r.set_loop_type("asyncio")
|
||||
conn = await r.connect(**db_opts)
|
||||
try:
|
||||
result = await r.table(table).filter(r_filter).delete().run(conn)
|
||||
result = await r.table(table).get(key).delete().run(conn)
|
||||
except r.ReqlOpFailedError:
|
||||
result = {}
|
||||
pass
|
||||
await conn.close()
|
||||
if table == 'prefixes' or table == 'custom_permissions':
|
||||
if table == 'prefixes' or table == 'server_settings':
|
||||
loop.create_task(cache[table].update())
|
||||
return result.get('deleted', 0) > 0
|
||||
|
||||
|
||||
async def update_content(table, content, r_filter=None):
|
||||
if r_filter is None:
|
||||
r_filter = {}
|
||||
async def update_content(table, content, key):
|
||||
r.set_loop_type("asyncio")
|
||||
conn = await r.connect(**db_opts)
|
||||
# This method is only for updating content, so if we find that it doesn't exist, just return false
|
||||
|
@ -185,36 +162,56 @@ async def update_content(table, content, r_filter=None):
|
|||
# Update based on the content and filter passed to us
|
||||
# rethinkdb allows you to do many many things inside of update
|
||||
# This is why we're accepting a variable and using it, whatever it may be, as the query
|
||||
result = await r.table(table).filter(r_filter).update(content).run(conn)
|
||||
result = await r.table(table).get(key).update(content).run(conn)
|
||||
except r.ReqlOpFailedError:
|
||||
await conn.close()
|
||||
result = {}
|
||||
await conn.close()
|
||||
if table == 'prefixes' or table == 'custom_permissions':
|
||||
if table == 'prefixes' or table == 'server_settings':
|
||||
loop.create_task(cache[table].update())
|
||||
return result.get('replaced', 0) > 0 or result.get('unchanged', 0) > 0
|
||||
|
||||
|
||||
async def replace_content(table, content, r_filter=None):
|
||||
async def replace_content(table, content, key):
|
||||
# This method is here because .replace and .update can have some different functionalities
|
||||
if r_filter is None:
|
||||
r_filter = {}
|
||||
r.set_loop_type("asyncio")
|
||||
conn = await r.connect(**db_opts)
|
||||
try:
|
||||
result = await r.table(table).filter(r_filter).replace(content).run(conn)
|
||||
result = await r.table(table).get(key).replace(content).run(conn)
|
||||
except r.ReqlOpFailedError:
|
||||
await conn.close()
|
||||
result = {}
|
||||
await conn.close()
|
||||
if table == 'prefixes' or table == 'custom_permissions':
|
||||
if table == 'prefixes' or table == 'server_settings':
|
||||
loop.create_task(cache[table].update())
|
||||
return result.get('replaced', 0) > 0 or result.get('unchanged', 0) > 0
|
||||
|
||||
|
||||
async def get_content(table: str, r_filter=None):
|
||||
if r_filter is None:
|
||||
r_filter = {}
|
||||
async def get_content(table, key=None):
|
||||
r.set_loop_type("asyncio")
|
||||
conn = await r.connect(**db_opts)
|
||||
|
||||
try:
|
||||
if key:
|
||||
cursor = await r.table(table).get(key).run(conn)
|
||||
else:
|
||||
cursor = await r.table(table).run(conn)
|
||||
if cursor is None:
|
||||
content = None
|
||||
elif type(cursor) is not dict:
|
||||
content = await _convert_to_list(cursor)
|
||||
if len(content) == 0:
|
||||
content = None
|
||||
else:
|
||||
content = cursor
|
||||
except (IndexError, r.ReqlOpFailedError):
|
||||
content = None
|
||||
await conn.close()
|
||||
if table == 'prefixes' or table == 'server_settings':
|
||||
loop.create_task(cache[table].update())
|
||||
return content
|
||||
|
||||
async def filter_content(table: str, r_filter):
|
||||
r.set_loop_type("asyncio")
|
||||
conn = await r.connect(**db_opts)
|
||||
try:
|
||||
|
@ -225,7 +222,7 @@ async def get_content(table: str, r_filter=None):
|
|||
except (IndexError, r.ReqlOpFailedError):
|
||||
content = None
|
||||
await conn.close()
|
||||
if table == 'prefixes' or table == 'custom_permissions':
|
||||
if table == 'prefixes' or table == 'server_settings':
|
||||
loop.create_task(cache[table].update())
|
||||
return content
|
||||
|
||||
|
|
Loading…
Reference in a new issue