1
0
Fork 0
mirror of synced 2024-05-18 11:32:26 +12:00
Bonfire/cogs/utils/database.py
2017-06-07 03:30:19 -05:00

142 lines
5 KiB
Python

import asyncio
import rethinkdb as r
from datetime import datetime
from .checks import required_tables
from . import config
async def _convert_to_list(cursor):
# This method is here because atm, AsyncioCursor is not iterable
# For our purposes, we want a list, so we need to do this manually
cursor_list = []
while True:
try:
val = await cursor.next()
cursor_list.append(val)
except r.ReqlCursorEmpty:
break
return cursor_list
class Cache:
"""A class to hold the cached database entries"""
def __init__(self, table, key, db, loop):
self.table = table # The name of the database table
self.key = key # The name of primary key
self.db = db # The database class connections are made through
self.loop = loop
self.values = [] # The values returned from the database
self.refreshed_time = None
self.loop.create_task(self.check_refresh())
async def refresh(self):
self.values = await self.db.actual_load(self.table)
self.refreshed_time = datetime.now()
async def check_refresh(self):
if self.refreshed_time is None:
await self.refresh()
else:
difference = datetime.now() - self.refreshed_time
if difference.total_seconds() > 300:
await self.refresh()
self.loop.call_later(60, self.check_refresh())
def get(self, key=None, table_filter=None, pluck=None):
"""This simulates the database call, to make it easier to get the data"""
if key is None and table_filter is None:
return self.values
elif key:
for value in self.values:
if value[self.key] == key:
if pluck:
return value.get(pluck)
else:
return value
elif table_filter:
req_key = list(table_filter.keys())[0]
req_val = list(table_filter.values())[0]
for value in self.values:
if value[req_key] == req_val:
if pluck:
return value.get(pluck)
else:
return value
class DB:
def __init__(self):
self.loop = asyncio.get_event_loop()
self.opts = config.db_opts
self.cache = {}
for table, key in required_tables.items():
self.cache[table] = Cache(table, key, self, self.loop)
async def query(self, query):
"""Lets you run a manual query"""
r.set_loop_type("asyncio")
conn = await r.connect(**self.opts)
try:
cursor = await query.run(conn)
except (r.ReqlOpFailedError, r.ReqlNonExistenceError):
cursor = None
if isinstance(cursor, r.Cursor):
cursor = await _convert_to_list(cursor)
await conn.close()
return cursor
def save(self, table, content):
"""A synchronous task to throw saving content into a task"""
self.loop.create_task(self._save(table, content))
async def _save(self, table, content):
"""Saves data in the table"""
index = await self.query(r.table(table).info())
index = index.get("primary_key")
key = content.get(index)
if key:
cur_content = await self.query(r.table(table).get(key))
if cur_content:
# We have content...we either need to update it, or replace
# Update will typically be more common so lets try that first
result = await self.query(r.table(table).get(key).update(content))
print(result)
if result.get('replaced', 0) == 0 and result.get('unchanged', 0) == 0:
print("Replacing...")
await self.query(r.table(table).get(key).replace(content))
else:
await self.query(r.table(table).insert(content))
else:
await self.query(r.table(table).insert(content))
await self.cache.get(table).refresh()
def load(self, table, **kwargs):
if kwargs.get('key'):
kwargs['key'] = str(kwargs.get('key'))
return self.cache.get(table).get(**kwargs)
async def actual_load(self, table, key=None, table_filter=None, pluck=None):
"""Loads the specified content from the specific table"""
query = r.table(table)
# If a key has been provided, get content with that key
if key:
query = query.get(str(key))
# A key and a filter shouldn't be combined for any case we'll ever use, so seperate these
elif table_filter:
query = query.filter(table_filter)
# If we want to pluck something specific, do that
if pluck:
query = query.pluck(pluck).values()[0]
cursor = await self.query(query)
return cursor