add progress bar to correlation commands

This commit is contained in:
Brandon 2022-11-02 03:06:41 -04:00
parent bf08c3e7ee
commit 27db92ea4a

View file

@ -1927,7 +1927,6 @@ class ActivityLogger(commands.Cog):
log_files = glob.glob(os.path.join(PATH, str(guild.id), "*.log"))
log_files = [log for log in log_files if "guild" not in log]
async with ctx.channel.typing():
# get messages split by channel
messages = await self.loop.run_in_executor(
None,
@ -1939,7 +1938,10 @@ class ActivityLogger(commands.Cog):
),
)
def process_messages():
async def process_messages():
progress_msg_str = "Processed {}/{} channels."
progress_msg = await ctx.send(progress_msg_str.format(0, len(messages)))
progress_index = 0
for ch_id, data in messages.items():
channel = guild.get_channel(ch_id)
# channel may be deleted, but still want to include message data
@ -1991,10 +1993,11 @@ class ActivityLogger(commands.Cog):
del joined_at[user]
except IndexError:
pass
except KeyError: # not sure why this happens... TODO figure it out
except KeyError: # happens if user rejoins after running this command
pass
except ValueError:
pass
await asyncio.sleep(0)
else:
to_delete = []
for message in data:
@ -2003,9 +2006,11 @@ class ActivityLogger(commands.Cog):
to_delete.append(message)
elif " deleted message from " in message:
to_delete.append(message)
await asyncio.sleep(0)
for msg in to_delete:
data.remove(msg)
await asyncio.sleep(0)
for i, message in enumerate(data):
try:
@ -2035,7 +2040,7 @@ class ActivityLogger(commands.Cog):
continue
except IndexError:
pass
except KeyError: # not sure why this happens... TODO figure it out
except KeyError: # happens if user rejoins after running this command
pass
except ValueError:
pass
@ -2060,13 +2065,17 @@ class ActivityLogger(commands.Cog):
adj_matrix[members[user1], members[user2]] += corr_weights["messages"][j - i]
except IndexError:
pass
except KeyError: # happens if user rejoins after running this command
pass
await asyncio.sleep(0)
await self.loop.run_in_executor(
None,
functools.partial(
process_messages,
),
)
progress_index += 1
try:
await progress_msg.edit(content=progress_msg_str.format(progress_index, len(messages)))
except:
progress_msg = await ctx.send(progress_msg_str.format(0, len(messages)))
await process_messages()
# define table save paths
table_save_path = str(PATH / f"plot_data_{ctx.message.id}")
@ -2120,7 +2129,6 @@ class ActivityLogger(commands.Cog):
log_files = glob.glob(os.path.join(PATH, str(guild.id), "*.log"))
log_files = [log for log in log_files if "guild" not in log]
async with ctx.channel.typing():
# get messages split by channel
messages = await self.loop.run_in_executor(
None,
@ -2132,7 +2140,10 @@ class ActivityLogger(commands.Cog):
),
)
def process_messages():
async def process_messages():
progress_msg_str = "Processed {}/{} channels."
progress_msg = await ctx.send(progress_msg_str.format(0, len(messages)))
progress_index = 0
for ch_id, data in messages.items():
channel = guild.get_channel(ch_id)
# channel may be deleted, but still want to include message data
@ -2191,10 +2202,11 @@ class ActivityLogger(commands.Cog):
del joined_at[user]
except IndexError:
pass
except KeyError: # not sure why this happens... TODO figure it out
except KeyError: # happens if user rejoins after running this command
pass
except ValueError:
pass
await asyncio.sleep(0)
else:
to_delete = []
for message in data:
@ -2203,9 +2215,11 @@ class ActivityLogger(commands.Cog):
to_delete.append(message)
elif " deleted message from " in message:
to_delete.append(message)
await asyncio.sleep(0)
for msg in to_delete:
data.remove(msg)
await asyncio.sleep(0)
for i, message in enumerate(data):
user1_id = int(message.split("(id:")[1].split(")")[0])
@ -2230,11 +2244,11 @@ class ActivityLogger(commands.Cog):
continue
except IndexError:
pass
except KeyError: # not sure why this happens... TODO figure it out
except KeyError: # happens if user rejoins after running this command
pass
except ValueError:
pass
await asyncio.sleep(0)
# get messages around current message and add weights
for j in range(max(i - 5, 0), i):
try:
@ -2258,20 +2272,22 @@ class ActivityLogger(commands.Cog):
adj_matrix[members[user1], members[user2]] += corr_weights["messages"][j - i]
except IndexError:
pass
except KeyError: # happens if user rejoins after running this command
pass
await asyncio.sleep(0)
await self.loop.run_in_executor(
None,
functools.partial(
process_messages,
),
)
progress_index += 1
try:
await progress_msg.edit(content=progress_msg_str.format(progress_index, len(messages)))
except:
progress_msg = await ctx.send(progress_msg_str.format(0, len(messages)))
await process_messages()
member_names = [m.name for m in members.keys()]
adj_matrix = pd.DataFrame(data=adj_matrix, index=member_names, columns=member_names)
adj_matrix_voice = pd.DataFrame(data=adj_matrix_voice, index=member_names, columns=member_names)
adj_matrix_all = (
adj_matrix + adj_matrix_voice
) # have to add first otherwise tables dont line up for addition
adj_matrix_all = adj_matrix + adj_matrix_voice # have to add first otherwise tables dont line up for addition
# drop users who do not correlate to anyone else
for column in adj_matrix.columns: