add progress bar to correlation commands

This commit is contained in:
Brandon 2022-11-02 03:06:41 -04:00
parent bf08c3e7ee
commit 27db92ea4a

View file

@ -1927,7 +1927,6 @@ class ActivityLogger(commands.Cog):
log_files = glob.glob(os.path.join(PATH, str(guild.id), "*.log")) log_files = glob.glob(os.path.join(PATH, str(guild.id), "*.log"))
log_files = [log for log in log_files if "guild" not in log] log_files = [log for log in log_files if "guild" not in log]
async with ctx.channel.typing():
# get messages split by channel # get messages split by channel
messages = await self.loop.run_in_executor( messages = await self.loop.run_in_executor(
None, None,
@ -1939,7 +1938,10 @@ class ActivityLogger(commands.Cog):
), ),
) )
def process_messages(): async def process_messages():
progress_msg_str = "Processed {}/{} channels."
progress_msg = await ctx.send(progress_msg_str.format(0, len(messages)))
progress_index = 0
for ch_id, data in messages.items(): for ch_id, data in messages.items():
channel = guild.get_channel(ch_id) channel = guild.get_channel(ch_id)
# channel may be deleted, but still want to include message data # channel may be deleted, but still want to include message data
@ -1991,10 +1993,11 @@ class ActivityLogger(commands.Cog):
del joined_at[user] del joined_at[user]
except IndexError: except IndexError:
pass pass
except KeyError: # not sure why this happens... TODO figure it out except KeyError: # happens if user rejoins after running this command
pass pass
except ValueError: except ValueError:
pass pass
await asyncio.sleep(0)
else: else:
to_delete = [] to_delete = []
for message in data: for message in data:
@ -2003,9 +2006,11 @@ class ActivityLogger(commands.Cog):
to_delete.append(message) to_delete.append(message)
elif " deleted message from " in message: elif " deleted message from " in message:
to_delete.append(message) to_delete.append(message)
await asyncio.sleep(0)
for msg in to_delete: for msg in to_delete:
data.remove(msg) data.remove(msg)
await asyncio.sleep(0)
for i, message in enumerate(data): for i, message in enumerate(data):
try: try:
@ -2035,7 +2040,7 @@ class ActivityLogger(commands.Cog):
continue continue
except IndexError: except IndexError:
pass pass
except KeyError: # not sure why this happens... TODO figure it out except KeyError: # happens if user rejoins after running this command
pass pass
except ValueError: except ValueError:
pass pass
@ -2060,13 +2065,17 @@ class ActivityLogger(commands.Cog):
adj_matrix[members[user1], members[user2]] += corr_weights["messages"][j - i] adj_matrix[members[user1], members[user2]] += corr_weights["messages"][j - i]
except IndexError: except IndexError:
pass pass
except KeyError: # happens if user rejoins after running this command
pass
await asyncio.sleep(0)
await self.loop.run_in_executor( progress_index += 1
None, try:
functools.partial( await progress_msg.edit(content=progress_msg_str.format(progress_index, len(messages)))
process_messages, except:
), progress_msg = await ctx.send(progress_msg_str.format(0, len(messages)))
)
await process_messages()
# define table save paths # define table save paths
table_save_path = str(PATH / f"plot_data_{ctx.message.id}") table_save_path = str(PATH / f"plot_data_{ctx.message.id}")
@ -2120,7 +2129,6 @@ class ActivityLogger(commands.Cog):
log_files = glob.glob(os.path.join(PATH, str(guild.id), "*.log")) log_files = glob.glob(os.path.join(PATH, str(guild.id), "*.log"))
log_files = [log for log in log_files if "guild" not in log] log_files = [log for log in log_files if "guild" not in log]
async with ctx.channel.typing():
# get messages split by channel # get messages split by channel
messages = await self.loop.run_in_executor( messages = await self.loop.run_in_executor(
None, None,
@ -2132,7 +2140,10 @@ class ActivityLogger(commands.Cog):
), ),
) )
def process_messages(): async def process_messages():
progress_msg_str = "Processed {}/{} channels."
progress_msg = await ctx.send(progress_msg_str.format(0, len(messages)))
progress_index = 0
for ch_id, data in messages.items(): for ch_id, data in messages.items():
channel = guild.get_channel(ch_id) channel = guild.get_channel(ch_id)
# channel may be deleted, but still want to include message data # channel may be deleted, but still want to include message data
@ -2191,10 +2202,11 @@ class ActivityLogger(commands.Cog):
del joined_at[user] del joined_at[user]
except IndexError: except IndexError:
pass pass
except KeyError: # not sure why this happens... TODO figure it out except KeyError: # happens if user rejoins after running this command
pass pass
except ValueError: except ValueError:
pass pass
await asyncio.sleep(0)
else: else:
to_delete = [] to_delete = []
for message in data: for message in data:
@ -2203,9 +2215,11 @@ class ActivityLogger(commands.Cog):
to_delete.append(message) to_delete.append(message)
elif " deleted message from " in message: elif " deleted message from " in message:
to_delete.append(message) to_delete.append(message)
await asyncio.sleep(0)
for msg in to_delete: for msg in to_delete:
data.remove(msg) data.remove(msg)
await asyncio.sleep(0)
for i, message in enumerate(data): for i, message in enumerate(data):
user1_id = int(message.split("(id:")[1].split(")")[0]) user1_id = int(message.split("(id:")[1].split(")")[0])
@ -2230,11 +2244,11 @@ class ActivityLogger(commands.Cog):
continue continue
except IndexError: except IndexError:
pass pass
except KeyError: # not sure why this happens... TODO figure it out except KeyError: # happens if user rejoins after running this command
pass pass
except ValueError: except ValueError:
pass pass
await asyncio.sleep(0)
# get messages around current message and add weights # get messages around current message and add weights
for j in range(max(i - 5, 0), i): for j in range(max(i - 5, 0), i):
try: try:
@ -2258,20 +2272,22 @@ class ActivityLogger(commands.Cog):
adj_matrix[members[user1], members[user2]] += corr_weights["messages"][j - i] adj_matrix[members[user1], members[user2]] += corr_weights["messages"][j - i]
except IndexError: except IndexError:
pass pass
except KeyError: # happens if user rejoins after running this command
pass
await asyncio.sleep(0)
await self.loop.run_in_executor( progress_index += 1
None, try:
functools.partial( await progress_msg.edit(content=progress_msg_str.format(progress_index, len(messages)))
process_messages, except:
), progress_msg = await ctx.send(progress_msg_str.format(0, len(messages)))
)
await process_messages()
member_names = [m.name for m in members.keys()] member_names = [m.name for m in members.keys()]
adj_matrix = pd.DataFrame(data=adj_matrix, index=member_names, columns=member_names) adj_matrix = pd.DataFrame(data=adj_matrix, index=member_names, columns=member_names)
adj_matrix_voice = pd.DataFrame(data=adj_matrix_voice, index=member_names, columns=member_names) adj_matrix_voice = pd.DataFrame(data=adj_matrix_voice, index=member_names, columns=member_names)
adj_matrix_all = ( adj_matrix_all = adj_matrix + adj_matrix_voice # have to add first otherwise tables dont line up for addition
adj_matrix + adj_matrix_voice
) # have to add first otherwise tables dont line up for addition
# drop users who do not correlate to anyone else # drop users who do not correlate to anyone else
for column in adj_matrix.columns: for column in adj_matrix.columns: