
128 lines
5.2 KiB
Raw Normal View History

2019-01-20 13:34:03 +13:00
import discord
from discord.ext import commands
from keras.models import load_model
import numpy as np
import os
import pickle
#loads dictionary from file
def _load_dict(path):
with open(path, 'rb') as file:
dict = pickle.load(file)
return dict
#dictionaries for tokenizing puncuation and converting it back
punctuation_to_tokens = {'!':' ||exclaimation_mark|| ', ',':' ||comma|| ', '"':' ||quotation_mark|| ',
';':' ||semicolon|| ', '.':' ||period|| ', '?':' ||question_mark|| ', '(':' ||left_parentheses|| ',
')':' ||right_parentheses|| ', '--':' ||dash|| ', '\n':' ||return|| ', ':':' ||colon|| '}
tokens_to_punctuation = {token.strip(): punc for punc, token in punctuation_to_tokens.items()}
#for all of the puncuation in replace_list, convert it to tokens
def _tokenize_punctuation(text):
replace_list = ['.', ',', '!', '"', ';', '?', '(', ')', '--', '\n', ':']
for char in replace_list:
text = text.replace(char, punctuation_to_tokens[char])
return text
#convert tokens back to puncuation
def _untokenize_punctuation(text):
replace_list = ['||period||', '||comma||', '||exclaimation_mark||', '||quotation_mark||',
'||semicolon||', '||question_mark||', '||left_parentheses||', '||right_parentheses||',
'||dash||', '||return||', '||colon||']
for char in replace_list:
if char == '||left_parentheses||':#added this since left parentheses had an extra space
text = text.replace(' ' + char + ' ', tokens_to_punctuation[char])
text = text.replace(' ' + char, tokens_to_punctuation[char])
return text
helper function that instead of just doing argmax for prediction, actually taking a sample of top possible words
takes a tempature which defines how many predictions to consider. lower means the word picked will be closer to the highest predicted word.
def _sample(prediction, temp=0):
if temp <= 0:
return np.argmax(prediction)
prediction = prediction[0]
prediction = np.asarray(prediction).astype('float64')
prediction = np.log(prediction) / temp
expo_prediction = np.exp(prediction)
prediction = expo_prediction / np.sum(expo_prediction)
probabilities = np.random.multinomial(1, prediction, 1)
return np.argmax(probabilities)
"""This cog generates scripts based on imported model, I used a keras model. """
class ScriptCog:
def __init__(self, bot): = bot
os.makedirs("data/scriptcog/", exist_ok=True)
os.makedirs("data/scriptcog/dicts", exist_ok=True)
self.model_path = "data/scriptcog/model.h5"
self.dict_path = "data/scriptcog/dicts/"
self.model = load_model(self.model_path)
self.word_limit = 100
self.word_to_int = _load_dict(self.dict_path + 'word_to_int.pkl')
self.int_to_word = _load_dict(self.dict_path + 'int_to_word.pkl')
self.sequence_length = _load_dict(self.dict_path + 'sequence_length.pkl')
self.word_to_int = None
self.int_to_word = None
self.sequence_length = None
async def setwordlimit(self, ctx, num_words : int = 100):
if ctx.invoked_subcommand is None:
await"Usage: setwordlimit limit")
self.word_limit = num_words
await"Maximum number of words is now {}".format(self.word_limit))
2019-01-20 14:35:16 +13:00
async def genscript(self, ctx, num_words : int = 100, temp : float = 0.5, seed : string = "pinkie pie::"):
2019-01-20 13:34:03 +13:00
if ctx.invoked_subcommand is None:
await"Usage: genscript num_words randomness(between 0 and 1)")
if num_words > self.word_limit:
await"Please keep script sizes to {} words or less.".format(self.word_limit))
if temp > 1.0:
temp = 1.0
elif temp < 0:
temp = 0
await"Generating script, please wait...")
input_text = seed
for _ in range(num_words):
#tokenize text to ints
int_text = _tokenize_punctuation(input_text)
int_text = int_text.lower()
int_text = int_text.split()
int_text = np.array([self.word_to_int[word] for word in int_text], dtype=np.int32)
except KeyError:
await"Sorry, that seed word is not in my vocabulary.\nPlease try an English word from the show.\n")
#pad text if it is too short, pads with zeros at beginning of text, so shouldnt have too much noise added
int_text = pad_sequences([int_text], maxlen=self.sequence_length)
#predict next word:
prediction = self.model.predict(int_text, verbose=0)
output_word = self.int_to_word[_sample(prediction, temp=temp)]
#append to the result
input_text += ' ' + output_word
#convert tokenized punctuation and other characters back
result = _untokenize_punctuation(input_text)
def setup(bot):