final updates to chatbot

This commit is contained in:
Brandon 2022-04-29 18:03:21 -04:00
parent f2bb5c99d4
commit bc56b1b8be
2 changed files with 96 additions and 29 deletions

View file

@ -13,6 +13,7 @@ import os
import asyncio
import glob
import io
import functools
from typing import Literal
@ -111,6 +112,7 @@ class ActivityLogger(commands.Cog):
}
self.bot.remove_command("userinfo")
self.load_task = asyncio.create_task(self.initialize())
self.loop = asyncio.get_event_loop()
def cog_unload(self):
self.lock = True
@ -412,7 +414,15 @@ class ActivityLogger(commands.Cog):
end_time = date
# get messages split by channel
messages = self.log_handler(log_files, end_time, split_channels=True)
messages = await self.loop.run_in_executor(
None,
functools.partial(
self.log_handler,
log_files,
end_time,
split_channels=True,
),
)
### set up data dictionary
num_messages = {}
@ -600,7 +610,15 @@ class ActivityLogger(commands.Cog):
await ctx.send(warning("**__Generating logs, please wait...__**"))
# runs in descending order, with most recent log file first
messages = self.log_handler(log_files, end_time, start=start)
messages = await self.loop.run_in_executor(
None,
functools.partial(
self.log_handler,
log_files,
end_time,
start=start,
),
)
if user:
messages = [message for message in messages if str(user.id) in message]

View file

@ -7,6 +7,7 @@ from aitextgen import aitextgen
from typing import Literal
from datetime import datetime, timedelta
import asyncio, os, time, random
import functools
class Chatbot(commands.Cog):
@ -27,7 +28,7 @@ class Chatbot(commands.Cog):
"dead_revive_time": 3000,
}
default_channel = {"autoreply": False, "randomness": 0.25, "timeout": 1500}
default_global = {"use_gpu": False, "autoboot": False}
default_global = {"use_gpu": False, "autoboot": False, "total_response_time": 0, "num_responses": 0}
self.config.register_guild(**default_guild)
self.config.register_channel(**default_channel)
@ -41,6 +42,8 @@ class Chatbot(commands.Cog):
self.history = {}
# when generating for a channel, ignore new messages
self.channel_lock = []
# stat tracking
self.stats = {"total_response_time": 0, "num_responses": 0}
self.special_tokens = {
"end_convo": "<end_convo>",
"start_convo": "<start_convo>",
@ -67,7 +70,14 @@ class Chatbot(commands.Cog):
if await self.config.autoboot():
await self.load_model()
self.stats["total_response_time"] = await self.config.total_response_time()
self.stats["num_responses"] = await self.config.num_responses()
while True:
if self.model is None:
await asyncio.sleep(60)
continue
for guild in self.bot.guilds:
if await self.bot.cog_disabled_in_guild(self, guild):
continue
@ -96,6 +106,9 @@ class Chatbot(commands.Cog):
output = self.get_ai_response(context, max_len, temp)
await channel.send(output)
# save stats off
await self.config.total_response_time.set(self.stats["total_response_time"])
await self.config.num_responses.set(self.stats["num_responses"])
await asyncio.sleep(60)
def cog_unload(self):
@ -116,6 +129,17 @@ class Chatbot(commands.Cog):
await message.edit(content=content)
@commands.command(name="aistats")
async def ai_stats(self, ctx):
"""
See some stats on my chatbot!
"""
if self.stats["num_responses"] <= 0:
return await ctx.maybe_send_embed("I haven't responded to anyone yet!")
avg_response = self.stats["total_response_time"] / self.stats["num_responses"]
await ctx.maybe_send_embed(f"**Average response time:** {avg_response:.2f} seconds")
@commands.group(name="ai")
@commands.guild_only()
@checks.admin()
@ -142,6 +166,11 @@ class Chatbot(commands.Cog):
"""
Load the model and set it to load on startup
"""
if not lets_boot:
await self.config.autoboot.set(False)
await ctx.tick()
return
await self.load_model()
await self.config.autoboot.set(True)
@ -419,13 +448,49 @@ class Chatbot(commands.Cog):
data = [l for i, l in enumerate(data) if i not in to_delete]
with open(os.path.join(cog_data_path(cog_instance=self), f"{ctx.guild.id}-cleaned.txt"), "w") as f:
save_file_name = os.path.join(cog_data_path(cog_instance=self), f"{ctx.guild.id}-cleaned.txt")
with open(save_file_name, "w") as f:
f.write("\n".join(data))
try:
await status_msg.edit(content=info("Done. Saved to the cog's data path."))
await status_msg.edit(content=info(f"Done. Saved to the cog's data path as {ctx.guild.id}-cleaned.txt"))
except:
await ctx.send(info("Done. Saved to the cog's data path."))
await ctx.send(info(f"Done. Saved to the cog's data path as {ctx.guild.id}-cleaned.txt"))
@ai.command(name="train", usage="<name of data file> <steps to train (should leave default)>")
@checks.is_owner()
async def ai_train(self, ctx, data_file: str, num_steps: int = 50000):
"""
Train the chatbot model, will use loaded model or it will a create a new one if none is loaded
Data file should be in this cog's data directory.
**MAKE SURE TO HAVE __TENSORFLOW__ INSTALLED BEFORE TRAINING!**
**WARNING** this will overwrite the current model if loaded!
**WARNING** this will use a lot of resources! Make sure you have a lot of memory and a GPU, set the gpu option before training!
"""
if self.model is None:
self.model = aitextgen(tf_gpt2="124M", to_gpu=(await self.config.use_gpu()))
await ctx.send(info("Starting training, see console for training output."))
# finetune
await self.loop.run_in_executor(
None,
functools.partial(
self.model.train,
os.path.join(cog_data_path(cog_instance=self), data_file),
output_dir=cog_data_path(cog_instance=self),
line_by_line=False,
from_cache=False,
num_steps=num_steps,
generate_every=num_steps,
save_every=1000,
save_gdrive=False,
learning_rate=1e-3,
batch_size=1,
),
)
def process_input(self, message: str) -> str:
"""
@ -437,7 +502,6 @@ class Chatbot(commands.Cog):
# Remove bot's @s from input
processed_input = message.replace(("<@!" + str(self.bot.user.id) + ">"), "").strip()
processed_input = message.replace(str(self.bot.user), "").strip()
print(processed_input)
# strip spaces at beginning of text
processed_input = "\n".join([s.strip() for s in processed_input.split("\n")])
@ -460,8 +524,8 @@ class Chatbot(commands.Cog):
numtokens = len(self.model.tokenizer(message)["input_ids"])
output = ""
i = 0 # in case of inf loop, three tries to generate a non-empty messages TODO: make configurable
while output == "" and i < 3:
i = 0 # in case of inf loop, two tries to generate a non-empty messages TODO: make configurable
while output == "" and i < 2:
text = self.model.generate(
max_length=numtokens + 70 + 5 * max_len,
prompt=message + "\n",
@ -506,21 +570,13 @@ class Chatbot(commands.Cog):
except:
pass
print(self.history[message.channel])
@commands.Cog.listener()
async def on_message(self, message: discord.Message):
if await self.bot.cog_disabled_in_guild(self, message.guild):
return
if not self.model and (await self.config.autoboot()):
await self.bot.send_to_owners(
error(
"Your model for cog `chatbot` could not be found. Make sure to have two files, `pytorch_model.bin` and `config.json` in the cog's data directory."
)
)
if not self.model:
return
start = time.time()
author = message.author
guild = message.guild
channel = message.channel
@ -534,7 +590,7 @@ class Chatbot(commands.Cog):
self.history[channel] = []
self.history[channel].append(message)
print(self.channel_lock)
if channel in self.channel_lock:
return
@ -565,7 +621,6 @@ class Chatbot(commands.Cog):
if ran_chat:
self.talking_channels[channel] = message.created_at
print(f"checks: {time.time() - start}")
start = time.time()
self.channel_lock.append(channel)
async with channel.typing():
@ -574,8 +629,6 @@ class Chatbot(commands.Cog):
max_len = await self.config.guild(guild).max_len()
temp = await self.config.guild(guild).temp()
print(f"history: {time.time() - start}")
start = time.time()
context = ""
# remove old messages
self.history[channel] = self.history[channel][-1 * history_len :]
@ -590,18 +643,14 @@ class Chatbot(commands.Cog):
if not context:
return
print(f"process context: {time.time() - start}")
start = time.time()
response = await self.loop.run_in_executor(None, self.get_ai_response, context, max_len, temp)
print(f"response: {time.time() - start}")
self.stats["total_response_time"] += time.time() - start
self.stats["num_responses"] += 1
try:
self.channel_lock.remove(channel)
except ValueError:
pass
print(self.channel_lock)
return await message.reply(response, mention_author=False)
async def red_delete_data_for_user(