mirror of
https://github.com/brandons209/Red-bot-Cogs.git
synced 2024-05-06 13:32:35 +12:00
final updates to chatbot
This commit is contained in:
parent
f2bb5c99d4
commit
bc56b1b8be
|
@ -13,6 +13,7 @@ import os
|
|||
import asyncio
|
||||
import glob
|
||||
import io
|
||||
import functools
|
||||
|
||||
from typing import Literal
|
||||
|
||||
|
@ -111,6 +112,7 @@ class ActivityLogger(commands.Cog):
|
|||
}
|
||||
self.bot.remove_command("userinfo")
|
||||
self.load_task = asyncio.create_task(self.initialize())
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def cog_unload(self):
|
||||
self.lock = True
|
||||
|
@ -412,7 +414,15 @@ class ActivityLogger(commands.Cog):
|
|||
end_time = date
|
||||
|
||||
# get messages split by channel
|
||||
messages = self.log_handler(log_files, end_time, split_channels=True)
|
||||
messages = await self.loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self.log_handler,
|
||||
log_files,
|
||||
end_time,
|
||||
split_channels=True,
|
||||
),
|
||||
)
|
||||
|
||||
### set up data dictionary
|
||||
num_messages = {}
|
||||
|
@ -600,7 +610,15 @@ class ActivityLogger(commands.Cog):
|
|||
|
||||
await ctx.send(warning("**__Generating logs, please wait...__**"))
|
||||
# runs in descending order, with most recent log file first
|
||||
messages = self.log_handler(log_files, end_time, start=start)
|
||||
messages = await self.loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self.log_handler,
|
||||
log_files,
|
||||
end_time,
|
||||
start=start,
|
||||
),
|
||||
)
|
||||
|
||||
if user:
|
||||
messages = [message for message in messages if str(user.id) in message]
|
||||
|
|
|
@ -7,6 +7,7 @@ from aitextgen import aitextgen
|
|||
from typing import Literal
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio, os, time, random
|
||||
import functools
|
||||
|
||||
|
||||
class Chatbot(commands.Cog):
|
||||
|
@ -27,7 +28,7 @@ class Chatbot(commands.Cog):
|
|||
"dead_revive_time": 3000,
|
||||
}
|
||||
default_channel = {"autoreply": False, "randomness": 0.25, "timeout": 1500}
|
||||
default_global = {"use_gpu": False, "autoboot": False}
|
||||
default_global = {"use_gpu": False, "autoboot": False, "total_response_time": 0, "num_responses": 0}
|
||||
|
||||
self.config.register_guild(**default_guild)
|
||||
self.config.register_channel(**default_channel)
|
||||
|
@ -41,6 +42,8 @@ class Chatbot(commands.Cog):
|
|||
self.history = {}
|
||||
# when generating for a channel, ignore new messages
|
||||
self.channel_lock = []
|
||||
# stat tracking
|
||||
self.stats = {"total_response_time": 0, "num_responses": 0}
|
||||
self.special_tokens = {
|
||||
"end_convo": "<end_convo>",
|
||||
"start_convo": "<start_convo>",
|
||||
|
@ -67,7 +70,14 @@ class Chatbot(commands.Cog):
|
|||
if await self.config.autoboot():
|
||||
await self.load_model()
|
||||
|
||||
self.stats["total_response_time"] = await self.config.total_response_time()
|
||||
self.stats["num_responses"] = await self.config.num_responses()
|
||||
|
||||
while True:
|
||||
if self.model is None:
|
||||
await asyncio.sleep(60)
|
||||
continue
|
||||
|
||||
for guild in self.bot.guilds:
|
||||
if await self.bot.cog_disabled_in_guild(self, guild):
|
||||
continue
|
||||
|
@ -96,6 +106,9 @@ class Chatbot(commands.Cog):
|
|||
output = self.get_ai_response(context, max_len, temp)
|
||||
await channel.send(output)
|
||||
|
||||
# save stats off
|
||||
await self.config.total_response_time.set(self.stats["total_response_time"])
|
||||
await self.config.num_responses.set(self.stats["num_responses"])
|
||||
await asyncio.sleep(60)
|
||||
|
||||
def cog_unload(self):
|
||||
|
@ -116,6 +129,17 @@ class Chatbot(commands.Cog):
|
|||
|
||||
await message.edit(content=content)
|
||||
|
||||
@commands.command(name="aistats")
|
||||
async def ai_stats(self, ctx):
|
||||
"""
|
||||
See some stats on my chatbot!
|
||||
"""
|
||||
if self.stats["num_responses"] <= 0:
|
||||
return await ctx.maybe_send_embed("I haven't responded to anyone yet!")
|
||||
|
||||
avg_response = self.stats["total_response_time"] / self.stats["num_responses"]
|
||||
await ctx.maybe_send_embed(f"**Average response time:** {avg_response:.2f} seconds")
|
||||
|
||||
@commands.group(name="ai")
|
||||
@commands.guild_only()
|
||||
@checks.admin()
|
||||
|
@ -142,6 +166,11 @@ class Chatbot(commands.Cog):
|
|||
"""
|
||||
Load the model and set it to load on startup
|
||||
"""
|
||||
if not lets_boot:
|
||||
await self.config.autoboot.set(False)
|
||||
await ctx.tick()
|
||||
return
|
||||
|
||||
await self.load_model()
|
||||
await self.config.autoboot.set(True)
|
||||
|
||||
|
@ -419,13 +448,49 @@ class Chatbot(commands.Cog):
|
|||
|
||||
data = [l for i, l in enumerate(data) if i not in to_delete]
|
||||
|
||||
with open(os.path.join(cog_data_path(cog_instance=self), f"{ctx.guild.id}-cleaned.txt"), "w") as f:
|
||||
save_file_name = os.path.join(cog_data_path(cog_instance=self), f"{ctx.guild.id}-cleaned.txt")
|
||||
with open(save_file_name, "w") as f:
|
||||
f.write("\n".join(data))
|
||||
|
||||
try:
|
||||
await status_msg.edit(content=info("Done. Saved to the cog's data path."))
|
||||
await status_msg.edit(content=info(f"Done. Saved to the cog's data path as {ctx.guild.id}-cleaned.txt"))
|
||||
except:
|
||||
await ctx.send(info("Done. Saved to the cog's data path."))
|
||||
await ctx.send(info(f"Done. Saved to the cog's data path as {ctx.guild.id}-cleaned.txt"))
|
||||
|
||||
@ai.command(name="train", usage="<name of data file> <steps to train (should leave default)>")
|
||||
@checks.is_owner()
|
||||
async def ai_train(self, ctx, data_file: str, num_steps: int = 50000):
|
||||
"""
|
||||
Train the chatbot model, will use loaded model or it will a create a new one if none is loaded
|
||||
|
||||
Data file should be in this cog's data directory.
|
||||
|
||||
**MAKE SURE TO HAVE __TENSORFLOW__ INSTALLED BEFORE TRAINING!**
|
||||
|
||||
**WARNING** this will overwrite the current model if loaded!
|
||||
**WARNING** this will use a lot of resources! Make sure you have a lot of memory and a GPU, set the gpu option before training!
|
||||
"""
|
||||
if self.model is None:
|
||||
self.model = aitextgen(tf_gpt2="124M", to_gpu=(await self.config.use_gpu()))
|
||||
|
||||
await ctx.send(info("Starting training, see console for training output."))
|
||||
# finetune
|
||||
await self.loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self.model.train,
|
||||
os.path.join(cog_data_path(cog_instance=self), data_file),
|
||||
output_dir=cog_data_path(cog_instance=self),
|
||||
line_by_line=False,
|
||||
from_cache=False,
|
||||
num_steps=num_steps,
|
||||
generate_every=num_steps,
|
||||
save_every=1000,
|
||||
save_gdrive=False,
|
||||
learning_rate=1e-3,
|
||||
batch_size=1,
|
||||
),
|
||||
)
|
||||
|
||||
def process_input(self, message: str) -> str:
|
||||
"""
|
||||
|
@ -437,7 +502,6 @@ class Chatbot(commands.Cog):
|
|||
# Remove bot's @s from input
|
||||
processed_input = message.replace(("<@!" + str(self.bot.user.id) + ">"), "").strip()
|
||||
processed_input = message.replace(str(self.bot.user), "").strip()
|
||||
print(processed_input)
|
||||
|
||||
# strip spaces at beginning of text
|
||||
processed_input = "\n".join([s.strip() for s in processed_input.split("\n")])
|
||||
|
@ -460,8 +524,8 @@ class Chatbot(commands.Cog):
|
|||
numtokens = len(self.model.tokenizer(message)["input_ids"])
|
||||
|
||||
output = ""
|
||||
i = 0 # in case of inf loop, three tries to generate a non-empty messages TODO: make configurable
|
||||
while output == "" and i < 3:
|
||||
i = 0 # in case of inf loop, two tries to generate a non-empty messages TODO: make configurable
|
||||
while output == "" and i < 2:
|
||||
text = self.model.generate(
|
||||
max_length=numtokens + 70 + 5 * max_len,
|
||||
prompt=message + "\n",
|
||||
|
@ -506,21 +570,13 @@ class Chatbot(commands.Cog):
|
|||
except:
|
||||
pass
|
||||
|
||||
print(self.history[message.channel])
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message):
|
||||
if await self.bot.cog_disabled_in_guild(self, message.guild):
|
||||
return
|
||||
if not self.model and (await self.config.autoboot()):
|
||||
await self.bot.send_to_owners(
|
||||
error(
|
||||
"Your model for cog `chatbot` could not be found. Make sure to have two files, `pytorch_model.bin` and `config.json` in the cog's data directory."
|
||||
)
|
||||
)
|
||||
if not self.model:
|
||||
return
|
||||
|
||||
start = time.time()
|
||||
author = message.author
|
||||
guild = message.guild
|
||||
channel = message.channel
|
||||
|
@ -534,7 +590,7 @@ class Chatbot(commands.Cog):
|
|||
self.history[channel] = []
|
||||
|
||||
self.history[channel].append(message)
|
||||
print(self.channel_lock)
|
||||
|
||||
if channel in self.channel_lock:
|
||||
return
|
||||
|
||||
|
@ -565,7 +621,6 @@ class Chatbot(commands.Cog):
|
|||
if ran_chat:
|
||||
self.talking_channels[channel] = message.created_at
|
||||
|
||||
print(f"checks: {time.time() - start}")
|
||||
start = time.time()
|
||||
self.channel_lock.append(channel)
|
||||
async with channel.typing():
|
||||
|
@ -574,8 +629,6 @@ class Chatbot(commands.Cog):
|
|||
max_len = await self.config.guild(guild).max_len()
|
||||
temp = await self.config.guild(guild).temp()
|
||||
|
||||
print(f"history: {time.time() - start}")
|
||||
start = time.time()
|
||||
context = ""
|
||||
# remove old messages
|
||||
self.history[channel] = self.history[channel][-1 * history_len :]
|
||||
|
@ -590,18 +643,14 @@ class Chatbot(commands.Cog):
|
|||
if not context:
|
||||
return
|
||||
|
||||
print(f"process context: {time.time() - start}")
|
||||
start = time.time()
|
||||
|
||||
response = await self.loop.run_in_executor(None, self.get_ai_response, context, max_len, temp)
|
||||
|
||||
print(f"response: {time.time() - start}")
|
||||
|
||||
self.stats["total_response_time"] += time.time() - start
|
||||
self.stats["num_responses"] += 1
|
||||
try:
|
||||
self.channel_lock.remove(channel)
|
||||
except ValueError:
|
||||
pass
|
||||
print(self.channel_lock)
|
||||
|
||||
return await message.reply(response, mention_author=False)
|
||||
|
||||
async def red_delete_data_for_user(
|
||||
|
|
Loading…
Reference in a new issue