final updates to chatbot

This commit is contained in:
Brandon 2022-04-29 18:03:21 -04:00
parent f2bb5c99d4
commit bc56b1b8be
2 changed files with 96 additions and 29 deletions

View file

@ -13,6 +13,7 @@ import os
import asyncio import asyncio
import glob import glob
import io import io
import functools
from typing import Literal from typing import Literal
@ -111,6 +112,7 @@ class ActivityLogger(commands.Cog):
} }
self.bot.remove_command("userinfo") self.bot.remove_command("userinfo")
self.load_task = asyncio.create_task(self.initialize()) self.load_task = asyncio.create_task(self.initialize())
self.loop = asyncio.get_event_loop()
def cog_unload(self): def cog_unload(self):
self.lock = True self.lock = True
@ -412,7 +414,15 @@ class ActivityLogger(commands.Cog):
end_time = date end_time = date
# get messages split by channel # get messages split by channel
messages = self.log_handler(log_files, end_time, split_channels=True) messages = await self.loop.run_in_executor(
None,
functools.partial(
self.log_handler,
log_files,
end_time,
split_channels=True,
),
)
### set up data dictionary ### set up data dictionary
num_messages = {} num_messages = {}
@ -600,7 +610,15 @@ class ActivityLogger(commands.Cog):
await ctx.send(warning("**__Generating logs, please wait...__**")) await ctx.send(warning("**__Generating logs, please wait...__**"))
# runs in descending order, with most recent log file first # runs in descending order, with most recent log file first
messages = self.log_handler(log_files, end_time, start=start) messages = await self.loop.run_in_executor(
None,
functools.partial(
self.log_handler,
log_files,
end_time,
start=start,
),
)
if user: if user:
messages = [message for message in messages if str(user.id) in message] messages = [message for message in messages if str(user.id) in message]

View file

@ -7,6 +7,7 @@ from aitextgen import aitextgen
from typing import Literal from typing import Literal
from datetime import datetime, timedelta from datetime import datetime, timedelta
import asyncio, os, time, random import asyncio, os, time, random
import functools
class Chatbot(commands.Cog): class Chatbot(commands.Cog):
@ -27,7 +28,7 @@ class Chatbot(commands.Cog):
"dead_revive_time": 3000, "dead_revive_time": 3000,
} }
default_channel = {"autoreply": False, "randomness": 0.25, "timeout": 1500} default_channel = {"autoreply": False, "randomness": 0.25, "timeout": 1500}
default_global = {"use_gpu": False, "autoboot": False} default_global = {"use_gpu": False, "autoboot": False, "total_response_time": 0, "num_responses": 0}
self.config.register_guild(**default_guild) self.config.register_guild(**default_guild)
self.config.register_channel(**default_channel) self.config.register_channel(**default_channel)
@ -41,6 +42,8 @@ class Chatbot(commands.Cog):
self.history = {} self.history = {}
# when generating for a channel, ignore new messages # when generating for a channel, ignore new messages
self.channel_lock = [] self.channel_lock = []
# stat tracking
self.stats = {"total_response_time": 0, "num_responses": 0}
self.special_tokens = { self.special_tokens = {
"end_convo": "<end_convo>", "end_convo": "<end_convo>",
"start_convo": "<start_convo>", "start_convo": "<start_convo>",
@ -67,7 +70,14 @@ class Chatbot(commands.Cog):
if await self.config.autoboot(): if await self.config.autoboot():
await self.load_model() await self.load_model()
self.stats["total_response_time"] = await self.config.total_response_time()
self.stats["num_responses"] = await self.config.num_responses()
while True: while True:
if self.model is None:
await asyncio.sleep(60)
continue
for guild in self.bot.guilds: for guild in self.bot.guilds:
if await self.bot.cog_disabled_in_guild(self, guild): if await self.bot.cog_disabled_in_guild(self, guild):
continue continue
@ -96,6 +106,9 @@ class Chatbot(commands.Cog):
output = self.get_ai_response(context, max_len, temp) output = self.get_ai_response(context, max_len, temp)
await channel.send(output) await channel.send(output)
# save stats off
await self.config.total_response_time.set(self.stats["total_response_time"])
await self.config.num_responses.set(self.stats["num_responses"])
await asyncio.sleep(60) await asyncio.sleep(60)
def cog_unload(self): def cog_unload(self):
@ -116,6 +129,17 @@ class Chatbot(commands.Cog):
await message.edit(content=content) await message.edit(content=content)
@commands.command(name="aistats")
async def ai_stats(self, ctx):
"""
See some stats on my chatbot!
"""
if self.stats["num_responses"] <= 0:
return await ctx.maybe_send_embed("I haven't responded to anyone yet!")
avg_response = self.stats["total_response_time"] / self.stats["num_responses"]
await ctx.maybe_send_embed(f"**Average response time:** {avg_response:.2f} seconds")
@commands.group(name="ai") @commands.group(name="ai")
@commands.guild_only() @commands.guild_only()
@checks.admin() @checks.admin()
@ -142,6 +166,11 @@ class Chatbot(commands.Cog):
""" """
Load the model and set it to load on startup Load the model and set it to load on startup
""" """
if not lets_boot:
await self.config.autoboot.set(False)
await ctx.tick()
return
await self.load_model() await self.load_model()
await self.config.autoboot.set(True) await self.config.autoboot.set(True)
@ -419,13 +448,49 @@ class Chatbot(commands.Cog):
data = [l for i, l in enumerate(data) if i not in to_delete] data = [l for i, l in enumerate(data) if i not in to_delete]
with open(os.path.join(cog_data_path(cog_instance=self), f"{ctx.guild.id}-cleaned.txt"), "w") as f: save_file_name = os.path.join(cog_data_path(cog_instance=self), f"{ctx.guild.id}-cleaned.txt")
with open(save_file_name, "w") as f:
f.write("\n".join(data)) f.write("\n".join(data))
try: try:
await status_msg.edit(content=info("Done. Saved to the cog's data path.")) await status_msg.edit(content=info(f"Done. Saved to the cog's data path as {ctx.guild.id}-cleaned.txt"))
except: except:
await ctx.send(info("Done. Saved to the cog's data path.")) await ctx.send(info(f"Done. Saved to the cog's data path as {ctx.guild.id}-cleaned.txt"))
@ai.command(name="train", usage="<name of data file> <steps to train (should leave default)>")
@checks.is_owner()
async def ai_train(self, ctx, data_file: str, num_steps: int = 50000):
"""
Train the chatbot model, will use loaded model or it will a create a new one if none is loaded
Data file should be in this cog's data directory.
**MAKE SURE TO HAVE __TENSORFLOW__ INSTALLED BEFORE TRAINING!**
**WARNING** this will overwrite the current model if loaded!
**WARNING** this will use a lot of resources! Make sure you have a lot of memory and a GPU, set the gpu option before training!
"""
if self.model is None:
self.model = aitextgen(tf_gpt2="124M", to_gpu=(await self.config.use_gpu()))
await ctx.send(info("Starting training, see console for training output."))
# finetune
await self.loop.run_in_executor(
None,
functools.partial(
self.model.train,
os.path.join(cog_data_path(cog_instance=self), data_file),
output_dir=cog_data_path(cog_instance=self),
line_by_line=False,
from_cache=False,
num_steps=num_steps,
generate_every=num_steps,
save_every=1000,
save_gdrive=False,
learning_rate=1e-3,
batch_size=1,
),
)
def process_input(self, message: str) -> str: def process_input(self, message: str) -> str:
""" """
@ -437,7 +502,6 @@ class Chatbot(commands.Cog):
# Remove bot's @s from input # Remove bot's @s from input
processed_input = message.replace(("<@!" + str(self.bot.user.id) + ">"), "").strip() processed_input = message.replace(("<@!" + str(self.bot.user.id) + ">"), "").strip()
processed_input = message.replace(str(self.bot.user), "").strip() processed_input = message.replace(str(self.bot.user), "").strip()
print(processed_input)
# strip spaces at beginning of text # strip spaces at beginning of text
processed_input = "\n".join([s.strip() for s in processed_input.split("\n")]) processed_input = "\n".join([s.strip() for s in processed_input.split("\n")])
@ -460,8 +524,8 @@ class Chatbot(commands.Cog):
numtokens = len(self.model.tokenizer(message)["input_ids"]) numtokens = len(self.model.tokenizer(message)["input_ids"])
output = "" output = ""
i = 0 # in case of inf loop, three tries to generate a non-empty messages TODO: make configurable i = 0 # in case of inf loop, two tries to generate a non-empty messages TODO: make configurable
while output == "" and i < 3: while output == "" and i < 2:
text = self.model.generate( text = self.model.generate(
max_length=numtokens + 70 + 5 * max_len, max_length=numtokens + 70 + 5 * max_len,
prompt=message + "\n", prompt=message + "\n",
@ -506,21 +570,13 @@ class Chatbot(commands.Cog):
except: except:
pass pass
print(self.history[message.channel])
@commands.Cog.listener() @commands.Cog.listener()
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
if await self.bot.cog_disabled_in_guild(self, message.guild): if await self.bot.cog_disabled_in_guild(self, message.guild):
return return
if not self.model and (await self.config.autoboot()): if not self.model:
await self.bot.send_to_owners(
error(
"Your model for cog `chatbot` could not be found. Make sure to have two files, `pytorch_model.bin` and `config.json` in the cog's data directory."
)
)
return return
start = time.time()
author = message.author author = message.author
guild = message.guild guild = message.guild
channel = message.channel channel = message.channel
@ -534,7 +590,7 @@ class Chatbot(commands.Cog):
self.history[channel] = [] self.history[channel] = []
self.history[channel].append(message) self.history[channel].append(message)
print(self.channel_lock)
if channel in self.channel_lock: if channel in self.channel_lock:
return return
@ -565,7 +621,6 @@ class Chatbot(commands.Cog):
if ran_chat: if ran_chat:
self.talking_channels[channel] = message.created_at self.talking_channels[channel] = message.created_at
print(f"checks: {time.time() - start}")
start = time.time() start = time.time()
self.channel_lock.append(channel) self.channel_lock.append(channel)
async with channel.typing(): async with channel.typing():
@ -574,8 +629,6 @@ class Chatbot(commands.Cog):
max_len = await self.config.guild(guild).max_len() max_len = await self.config.guild(guild).max_len()
temp = await self.config.guild(guild).temp() temp = await self.config.guild(guild).temp()
print(f"history: {time.time() - start}")
start = time.time()
context = "" context = ""
# remove old messages # remove old messages
self.history[channel] = self.history[channel][-1 * history_len :] self.history[channel] = self.history[channel][-1 * history_len :]
@ -590,18 +643,14 @@ class Chatbot(commands.Cog):
if not context: if not context:
return return
print(f"process context: {time.time() - start}")
start = time.time()
response = await self.loop.run_in_executor(None, self.get_ai_response, context, max_len, temp) response = await self.loop.run_in_executor(None, self.get_ai_response, context, max_len, temp)
self.stats["total_response_time"] += time.time() - start
print(f"response: {time.time() - start}") self.stats["num_responses"] += 1
try: try:
self.channel_lock.remove(channel) self.channel_lock.remove(channel)
except ValueError: except ValueError:
pass pass
print(self.channel_lock)
return await message.reply(response, mention_author=False) return await message.reply(response, mention_author=False)
async def red_delete_data_for_user( async def red_delete_data_for_user(