mirror of
https://github.com/brandons209/Red-bot-Cogs.git
synced 2024-05-14 17:33:07 +12:00
initial upload
This commit is contained in:
parent
4f26c44a64
commit
c217b02d80
4
__init__.py
Normal file
4
__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from .script import ScriptCog
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(ScriptCog(bot))
|
10
info.json
Normal file
10
info.json
Normal file
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"AUTHOR" : "Brandon Silva",
|
||||
"INSTALL_MSG" : "Thanks for installing my cog!",
|
||||
"NAME" : "ScriptCog",
|
||||
"SHORT" : "Generates text from keras model.",
|
||||
"DESCRIPTION" : "Uses a keras model to generate text.",
|
||||
"TAGS" : ["mlp"],
|
||||
"REQUIREMENTS" : ["tensorflow", "keras"],
|
||||
"HIDDEN" : false
|
||||
}
|
2
requirements.txt
Normal file
2
requirements.txt
Normal file
|
@ -0,0 +1,2 @@
|
|||
tensorflow==1.5
|
||||
keras
|
127
script.py
Normal file
127
script.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import discord
|
||||
from discord.ext import commands
|
||||
from keras.models import load_model
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
#loads dictionary from file
|
||||
def _load_dict(path):
|
||||
with open(path, 'rb') as file:
|
||||
dict = pickle.load(file)
|
||||
return dict
|
||||
|
||||
#dictionaries for tokenizing puncuation and converting it back
|
||||
punctuation_to_tokens = {'!':' ||exclaimation_mark|| ', ',':' ||comma|| ', '"':' ||quotation_mark|| ',
|
||||
';':' ||semicolon|| ', '.':' ||period|| ', '?':' ||question_mark|| ', '(':' ||left_parentheses|| ',
|
||||
')':' ||right_parentheses|| ', '--':' ||dash|| ', '\n':' ||return|| ', ':':' ||colon|| '}
|
||||
|
||||
tokens_to_punctuation = {token.strip(): punc for punc, token in punctuation_to_tokens.items()}
|
||||
|
||||
#for all of the puncuation in replace_list, convert it to tokens
|
||||
def _tokenize_punctuation(text):
|
||||
replace_list = ['.', ',', '!', '"', ';', '?', '(', ')', '--', '\n', ':']
|
||||
for char in replace_list:
|
||||
text = text.replace(char, punctuation_to_tokens[char])
|
||||
return text
|
||||
|
||||
#convert tokens back to puncuation
|
||||
def _untokenize_punctuation(text):
|
||||
replace_list = ['||period||', '||comma||', '||exclaimation_mark||', '||quotation_mark||',
|
||||
'||semicolon||', '||question_mark||', '||left_parentheses||', '||right_parentheses||',
|
||||
'||dash||', '||return||', '||colon||']
|
||||
for char in replace_list:
|
||||
if char == '||left_parentheses||':#added this since left parentheses had an extra space
|
||||
text = text.replace(' ' + char + ' ', tokens_to_punctuation[char])
|
||||
text = text.replace(' ' + char, tokens_to_punctuation[char])
|
||||
return text
|
||||
|
||||
"""
|
||||
helper function that instead of just doing argmax for prediction, actually taking a sample of top possible words
|
||||
takes a tempature which defines how many predictions to consider. lower means the word picked will be closer to the highest predicted word.
|
||||
"""
|
||||
def _sample(prediction, temp=0):
|
||||
if temp <= 0:
|
||||
return np.argmax(prediction)
|
||||
prediction = prediction[0]
|
||||
prediction = np.asarray(prediction).astype('float64')
|
||||
prediction = np.log(prediction) / temp
|
||||
expo_prediction = np.exp(prediction)
|
||||
prediction = expo_prediction / np.sum(expo_prediction)
|
||||
probabilities = np.random.multinomial(1, prediction, 1)
|
||||
return np.argmax(probabilities)
|
||||
|
||||
|
||||
"""This cog generates scripts based on imported model, I used a keras model. """
|
||||
class ScriptCog:
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
os.makedirs("data/scriptcog/", exist_ok=True)
|
||||
os.makedirs("data/scriptcog/dicts", exist_ok=True)
|
||||
self.model_path = "data/scriptcog/model.h5"
|
||||
self.dict_path = "data/scriptcog/dicts/"
|
||||
|
||||
self.model = load_model(self.model_path)
|
||||
self.word_limit = 100
|
||||
|
||||
try:
|
||||
self.word_to_int = _load_dict(self.dict_path + 'word_to_int.pkl')
|
||||
self.int_to_word = _load_dict(self.dict_path + 'int_to_word.pkl')
|
||||
self.sequence_length = _load_dict(self.dict_path + 'sequence_length.pkl')
|
||||
except:
|
||||
self.word_to_int = None
|
||||
self.int_to_word = None
|
||||
self.sequence_length = None
|
||||
|
||||
@commands.command(pass_context=True)
|
||||
async def setwordlimit(self, ctx, num_words : int = 100):
|
||||
if ctx.invoked_subcommand is None:
|
||||
await self.bot.say("Usage: setwordlimit limit")
|
||||
return
|
||||
self.word_limit = num_words
|
||||
await self.bot.say("Maximum number of words is now {}".format(self.word_limit))
|
||||
|
||||
@commands.command(pass_context=True)
|
||||
async def genscript(self, ctx, num_words : int = 100, temp : double = 0.5 ):
|
||||
if ctx.invoked_subcommand is None:
|
||||
await self.bot.say("Usage: genscript num_words randomness(between 0 and 1)")
|
||||
return
|
||||
if num_words > self.word_limit:
|
||||
await self.bot.say("Please keep script sizes to {} words or less.".format(self.word_limit))
|
||||
return
|
||||
|
||||
if temp > 1.0:
|
||||
temp = 1.0
|
||||
elif temp < 0:
|
||||
temp = 0
|
||||
|
||||
await self.bot.say("Generating script, please wait...")
|
||||
input_text = seed
|
||||
for _ in range(num_words):
|
||||
#tokenize text to ints
|
||||
int_text = _tokenize_punctuation(input_text)
|
||||
int_text = int_text.lower()
|
||||
int_text = int_text.split()
|
||||
try:
|
||||
int_text = np.array([self.word_to_int[word] for word in int_text], dtype=np.int32)
|
||||
except KeyError:
|
||||
await self.bot.say("Sorry, that seed word is not in my vocabulary.\nPlease try an English word from the show.\n")
|
||||
return
|
||||
#pad text if it is too short, pads with zeros at beginning of text, so shouldnt have too much noise added
|
||||
int_text = pad_sequences([int_text], maxlen=self.sequence_length)
|
||||
#predict next word:
|
||||
prediction = self.model.predict(int_text, verbose=0)
|
||||
output_word = self.int_to_word[_sample(prediction, temp=temp)]
|
||||
#append to the result
|
||||
input_text += ' ' + output_word
|
||||
#convert tokenized punctuation and other characters back
|
||||
result = _untokenize_punctuation(input_text)
|
||||
|
||||
await self.bot.say(result)
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(ScriptCog(bot))
|
Loading…
Reference in a new issue