|
--- |
|
license: apache-2.0 |
|
--- |
|
### Inference Code |
|
```Python |
|
import numpy as np |
|
import pickle |
|
from keras.preprocessing.sequence import pad_sequences |
|
from keras.models import load_model |
|
|
|
def predict_word(seed_text: str, tokenizer, model, next_words: int = 2) -> str: |
|
for _ in range(next_words): |
|
token_list = tokenizer.texts_to_sequences([seed_text])[0] |
|
token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre') |
|
predicted = np.argmax(model.predict(token_list), axis=-1) |
|
output_word = "" |
|
for word, index in tokenizer.word_index.items(): |
|
if index == predicted: |
|
output_word = word |
|
break |
|
seed_text += " " + output_word |
|
return seed_text |
|
``` |