let text generation run asyncrously

This commit is contained in:
brandons209 2019-01-20 02:15:38 -05:00
parent 52055a6474
commit c1a4404420

View file

@ -136,6 +136,27 @@ class ScriptCog:
self._write_config()
await self.bot.say("Script cooldown is now {}.".format(self.cooldown_limit))
async def get_model_output(self, num_words, temp, seed):
input_text = seed
for _ in range(num_words):
#tokenize text to ints
int_text = _tokenize_punctuation(input_text)
int_text = int_text.lower()
int_text = int_text.split()
try:
int_text = np.array([self.word_to_int[word] for word in int_text], dtype=np.int32)
except KeyError:
await self.bot.say("Sorry, that seed word is not in my vocabulary.\nPlease try an English word from the show.\n")
return None
#pad text if it is too short, pads with zeros at beginning of text, so shouldnt have too much noise added
int_text = pad_sequences([int_text], maxlen=self.sequence_length)
#predict next word:
prediction = self.model.predict(int_text, verbose=0)
output_word = self.int_to_word[_sample(prediction, temp=temp)]
#append to the result
input_text += ' ' + output_word
#convert tokenized punctuation and other characters back
return _untokenize_punctuation(input_text)
@commands.command(pass_context=True, no_pm=True)
async def genscript(self, ctx, num_words_to_generate : int = 100, variance : float = 0.5, seed : str = "pinkie pie::"):
@ -157,29 +178,13 @@ class ScriptCog:
variance = 0
await self.bot.say("Generating script, please wait...")
input_text = seed
for _ in range(num_words_to_generate):
#tokenize text to ints
int_text = _tokenize_punctuation(input_text)
int_text = int_text.lower()
int_text = int_text.split()
try:
int_text = np.array([self.word_to_int[word] for word in int_text], dtype=np.int32)
except KeyError:
await self.bot.say("Sorry, that seed word is not in my vocabulary.\nPlease try an English word from the show.\n")
return
#pad text if it is too short, pads with zeros at beginning of text, so shouldnt have too much noise added
int_text = pad_sequences([int_text], maxlen=self.sequence_length)
#predict next word:
prediction = self.model.predict(int_text, verbose=0)
output_word = self.int_to_word[_sample(prediction, temp=variance)]
#append to the result
input_text += ' ' + output_word
#convert tokenized punctuation and other characters back
result = _untokenize_punctuation(input_text)
await self.bot.say("------------------------")
await self.bot.say(result)
await self.bot.say("------------------------")
result = self.get_model_output(num_words_to_generate, variance, seed)
if result is not None:
await self.bot.say("------------------------")
await self.bot.say(result)
await self.bot.say("------------------------")
def setup(bot):
bot.add_cog(ScriptCog(bot))