diff --git a/activitylog/activitylog.py b/activitylog/activitylog.py index f3d12d6..451632e 100644 --- a/activitylog/activitylog.py +++ b/activitylog/activitylog.py @@ -13,6 +13,7 @@ import os import asyncio import glob import io +import functools from typing import Literal @@ -111,6 +112,7 @@ class ActivityLogger(commands.Cog): } self.bot.remove_command("userinfo") self.load_task = asyncio.create_task(self.initialize()) + self.loop = asyncio.get_event_loop() def cog_unload(self): self.lock = True @@ -412,7 +414,15 @@ class ActivityLogger(commands.Cog): end_time = date # get messages split by channel - messages = self.log_handler(log_files, end_time, split_channels=True) + messages = await self.loop.run_in_executor( + None, + functools.partial( + self.log_handler, + log_files, + end_time, + split_channels=True, + ), + ) ### set up data dictionary num_messages = {} @@ -600,7 +610,15 @@ class ActivityLogger(commands.Cog): await ctx.send(warning("**__Generating logs, please wait...__**")) # runs in descending order, with most recent log file first - messages = self.log_handler(log_files, end_time, start=start) + messages = await self.loop.run_in_executor( + None, + functools.partial( + self.log_handler, + log_files, + end_time, + start=start, + ), + ) if user: messages = [message for message in messages if str(user.id) in message] diff --git a/chatbot/chatbot.py b/chatbot/chatbot.py index bc59646..e5919cc 100644 --- a/chatbot/chatbot.py +++ b/chatbot/chatbot.py @@ -7,6 +7,7 @@ from aitextgen import aitextgen from typing import Literal from datetime import datetime, timedelta import asyncio, os, time, random +import functools class Chatbot(commands.Cog): @@ -27,7 +28,7 @@ class Chatbot(commands.Cog): "dead_revive_time": 3000, } default_channel = {"autoreply": False, "randomness": 0.25, "timeout": 1500} - default_global = {"use_gpu": False, "autoboot": False} + default_global = {"use_gpu": False, "autoboot": False, "total_response_time": 0, "num_responses": 0} self.config.register_guild(**default_guild) self.config.register_channel(**default_channel) @@ -41,6 +42,8 @@ class Chatbot(commands.Cog): self.history = {} # when generating for a channel, ignore new messages self.channel_lock = [] + # stat tracking + self.stats = {"total_response_time": 0, "num_responses": 0} self.special_tokens = { "end_convo": "", "start_convo": "", @@ -67,7 +70,14 @@ class Chatbot(commands.Cog): if await self.config.autoboot(): await self.load_model() + self.stats["total_response_time"] = await self.config.total_response_time() + self.stats["num_responses"] = await self.config.num_responses() + while True: + if self.model is None: + await asyncio.sleep(60) + continue + for guild in self.bot.guilds: if await self.bot.cog_disabled_in_guild(self, guild): continue @@ -96,6 +106,9 @@ class Chatbot(commands.Cog): output = self.get_ai_response(context, max_len, temp) await channel.send(output) + # save stats off + await self.config.total_response_time.set(self.stats["total_response_time"]) + await self.config.num_responses.set(self.stats["num_responses"]) await asyncio.sleep(60) def cog_unload(self): @@ -116,6 +129,17 @@ class Chatbot(commands.Cog): await message.edit(content=content) + @commands.command(name="aistats") + async def ai_stats(self, ctx): + """ + See some stats on my chatbot! + """ + if self.stats["num_responses"] <= 0: + return await ctx.maybe_send_embed("I haven't responded to anyone yet!") + + avg_response = self.stats["total_response_time"] / self.stats["num_responses"] + await ctx.maybe_send_embed(f"**Average response time:** {avg_response:.2f} seconds") + @commands.group(name="ai") @commands.guild_only() @checks.admin() @@ -142,6 +166,11 @@ class Chatbot(commands.Cog): """ Load the model and set it to load on startup """ + if not lets_boot: + await self.config.autoboot.set(False) + await ctx.tick() + return + await self.load_model() await self.config.autoboot.set(True) @@ -419,13 +448,49 @@ class Chatbot(commands.Cog): data = [l for i, l in enumerate(data) if i not in to_delete] - with open(os.path.join(cog_data_path(cog_instance=self), f"{ctx.guild.id}-cleaned.txt"), "w") as f: + save_file_name = os.path.join(cog_data_path(cog_instance=self), f"{ctx.guild.id}-cleaned.txt") + with open(save_file_name, "w") as f: f.write("\n".join(data)) try: - await status_msg.edit(content=info("Done. Saved to the cog's data path.")) + await status_msg.edit(content=info(f"Done. Saved to the cog's data path as {ctx.guild.id}-cleaned.txt")) except: - await ctx.send(info("Done. Saved to the cog's data path.")) + await ctx.send(info(f"Done. Saved to the cog's data path as {ctx.guild.id}-cleaned.txt")) + + @ai.command(name="train", usage=" ") + @checks.is_owner() + async def ai_train(self, ctx, data_file: str, num_steps: int = 50000): + """ + Train the chatbot model, will use loaded model or it will a create a new one if none is loaded + + Data file should be in this cog's data directory. + + **MAKE SURE TO HAVE __TENSORFLOW__ INSTALLED BEFORE TRAINING!** + + **WARNING** this will overwrite the current model if loaded! + **WARNING** this will use a lot of resources! Make sure you have a lot of memory and a GPU, set the gpu option before training! + """ + if self.model is None: + self.model = aitextgen(tf_gpt2="124M", to_gpu=(await self.config.use_gpu())) + + await ctx.send(info("Starting training, see console for training output.")) + # finetune + await self.loop.run_in_executor( + None, + functools.partial( + self.model.train, + os.path.join(cog_data_path(cog_instance=self), data_file), + output_dir=cog_data_path(cog_instance=self), + line_by_line=False, + from_cache=False, + num_steps=num_steps, + generate_every=num_steps, + save_every=1000, + save_gdrive=False, + learning_rate=1e-3, + batch_size=1, + ), + ) def process_input(self, message: str) -> str: """ @@ -437,7 +502,6 @@ class Chatbot(commands.Cog): # Remove bot's @s from input processed_input = message.replace(("<@!" + str(self.bot.user.id) + ">"), "").strip() processed_input = message.replace(str(self.bot.user), "").strip() - print(processed_input) # strip spaces at beginning of text processed_input = "\n".join([s.strip() for s in processed_input.split("\n")]) @@ -460,8 +524,8 @@ class Chatbot(commands.Cog): numtokens = len(self.model.tokenizer(message)["input_ids"]) output = "" - i = 0 # in case of inf loop, three tries to generate a non-empty messages TODO: make configurable - while output == "" and i < 3: + i = 0 # in case of inf loop, two tries to generate a non-empty messages TODO: make configurable + while output == "" and i < 2: text = self.model.generate( max_length=numtokens + 70 + 5 * max_len, prompt=message + "\n", @@ -506,21 +570,13 @@ class Chatbot(commands.Cog): except: pass - print(self.history[message.channel]) - @commands.Cog.listener() async def on_message(self, message: discord.Message): if await self.bot.cog_disabled_in_guild(self, message.guild): return - if not self.model and (await self.config.autoboot()): - await self.bot.send_to_owners( - error( - "Your model for cog `chatbot` could not be found. Make sure to have two files, `pytorch_model.bin` and `config.json` in the cog's data directory." - ) - ) + if not self.model: return - start = time.time() author = message.author guild = message.guild channel = message.channel @@ -534,7 +590,7 @@ class Chatbot(commands.Cog): self.history[channel] = [] self.history[channel].append(message) - print(self.channel_lock) + if channel in self.channel_lock: return @@ -565,7 +621,6 @@ class Chatbot(commands.Cog): if ran_chat: self.talking_channels[channel] = message.created_at - print(f"checks: {time.time() - start}") start = time.time() self.channel_lock.append(channel) async with channel.typing(): @@ -574,8 +629,6 @@ class Chatbot(commands.Cog): max_len = await self.config.guild(guild).max_len() temp = await self.config.guild(guild).temp() - print(f"history: {time.time() - start}") - start = time.time() context = "" # remove old messages self.history[channel] = self.history[channel][-1 * history_len :] @@ -590,18 +643,14 @@ class Chatbot(commands.Cog): if not context: return - print(f"process context: {time.time() - start}") - start = time.time() - response = await self.loop.run_in_executor(None, self.get_ai_response, context, max_len, temp) - - print(f"response: {time.time() - start}") - + self.stats["total_response_time"] += time.time() - start + self.stats["num_responses"] += 1 try: self.channel_lock.remove(channel) except ValueError: pass - print(self.channel_lock) + return await message.reply(response, mention_author=False) async def red_delete_data_for_user(