mirror of
https://github.com/brandons209/Red-bot-Cogs.git
synced 2024-07-01 04:20:17 +12:00
let text generation run asyncrously
This commit is contained in:
parent
52055a6474
commit
c1a4404420
|
@ -136,6 +136,27 @@ class ScriptCog:
|
||||||
self._write_config()
|
self._write_config()
|
||||||
await self.bot.say("Script cooldown is now {}.".format(self.cooldown_limit))
|
await self.bot.say("Script cooldown is now {}.".format(self.cooldown_limit))
|
||||||
|
|
||||||
|
async def get_model_output(self, num_words, temp, seed):
|
||||||
|
input_text = seed
|
||||||
|
for _ in range(num_words):
|
||||||
|
#tokenize text to ints
|
||||||
|
int_text = _tokenize_punctuation(input_text)
|
||||||
|
int_text = int_text.lower()
|
||||||
|
int_text = int_text.split()
|
||||||
|
try:
|
||||||
|
int_text = np.array([self.word_to_int[word] for word in int_text], dtype=np.int32)
|
||||||
|
except KeyError:
|
||||||
|
await self.bot.say("Sorry, that seed word is not in my vocabulary.\nPlease try an English word from the show.\n")
|
||||||
|
return None
|
||||||
|
#pad text if it is too short, pads with zeros at beginning of text, so shouldnt have too much noise added
|
||||||
|
int_text = pad_sequences([int_text], maxlen=self.sequence_length)
|
||||||
|
#predict next word:
|
||||||
|
prediction = self.model.predict(int_text, verbose=0)
|
||||||
|
output_word = self.int_to_word[_sample(prediction, temp=temp)]
|
||||||
|
#append to the result
|
||||||
|
input_text += ' ' + output_word
|
||||||
|
#convert tokenized punctuation and other characters back
|
||||||
|
return _untokenize_punctuation(input_text)
|
||||||
|
|
||||||
@commands.command(pass_context=True, no_pm=True)
|
@commands.command(pass_context=True, no_pm=True)
|
||||||
async def genscript(self, ctx, num_words_to_generate : int = 100, variance : float = 0.5, seed : str = "pinkie pie::"):
|
async def genscript(self, ctx, num_words_to_generate : int = 100, variance : float = 0.5, seed : str = "pinkie pie::"):
|
||||||
|
@ -157,26 +178,10 @@ class ScriptCog:
|
||||||
variance = 0
|
variance = 0
|
||||||
|
|
||||||
await self.bot.say("Generating script, please wait...")
|
await self.bot.say("Generating script, please wait...")
|
||||||
input_text = seed
|
|
||||||
for _ in range(num_words_to_generate):
|
result = self.get_model_output(num_words_to_generate, variance, seed)
|
||||||
#tokenize text to ints
|
|
||||||
int_text = _tokenize_punctuation(input_text)
|
if result is not None:
|
||||||
int_text = int_text.lower()
|
|
||||||
int_text = int_text.split()
|
|
||||||
try:
|
|
||||||
int_text = np.array([self.word_to_int[word] for word in int_text], dtype=np.int32)
|
|
||||||
except KeyError:
|
|
||||||
await self.bot.say("Sorry, that seed word is not in my vocabulary.\nPlease try an English word from the show.\n")
|
|
||||||
return
|
|
||||||
#pad text if it is too short, pads with zeros at beginning of text, so shouldnt have too much noise added
|
|
||||||
int_text = pad_sequences([int_text], maxlen=self.sequence_length)
|
|
||||||
#predict next word:
|
|
||||||
prediction = self.model.predict(int_text, verbose=0)
|
|
||||||
output_word = self.int_to_word[_sample(prediction, temp=variance)]
|
|
||||||
#append to the result
|
|
||||||
input_text += ' ' + output_word
|
|
||||||
#convert tokenized punctuation and other characters back
|
|
||||||
result = _untokenize_punctuation(input_text)
|
|
||||||
await self.bot.say("------------------------")
|
await self.bot.say("------------------------")
|
||||||
await self.bot.say(result)
|
await self.bot.say(result)
|
||||||
await self.bot.say("------------------------")
|
await self.bot.say("------------------------")
|
||||||
|
|
Loading…
Reference in a new issue