File size: 2,367 Bytes
773685b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import logging
import streamlit as st
from transformers import pipeline
from transformers import GPT2Tokenizer, GPT2LMHeadModel
def load_model():
keys = ['generator']
if any(st.session_state.get(key) is None for key in keys):
with st.spinner('Loading the model might take a couple of seconds...'):
try:
if os.environ.get('remote_model_path'):
model_path = os.environ.get('remote_model_path')
else:
model_path = os.getenv('model_path')
st.session_state.generator = pipeline(task='text-generation', model=model_path, tokenizer=model_path)
logging.info('Loaded models and tokenizer!')
except Exception as e:
logging.error(f'Error while loading models/tokenizer: {e}')
def generate_items(constructs, prefix='', **kwargs):
with st.spinner(f'Generating item(s) for `{constructs}`...'):
construct_sep = '#'
item_sep = '@'
constructs = constructs if isinstance(constructs, list) else [constructs]
encoded_constructs = construct_sep + construct_sep.join([x.lower() for x in constructs])
encoded_prompt = f'{encoded_constructs}{item_sep}{prefix}'
outputs = st.session_state.generator(encoded_prompt, **kwargs)
truncate_str = f'{encoded_constructs}{item_sep}'
item_stems = []
for output in outputs:
item_stems.append(output['generated_text'].replace(truncate_str, ''))
return item_stems
def get_next_tokens(prefix, breadth=5):
# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Encode the prefix
inputs = tokenizer(prefix, return_tensors='pt')
# Get the model's predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Only consider the last token for next token predictions
last_token_logits = logits[:, -1, :]
# Get the indices of the top 'breadth' possible next tokens
top_tokens = torch.topk(last_token_logits, breadth, dim=1).indices.tolist()[0]
# Decode the token IDs to tokens
next_tokens = [tokenizer.decode([token_id]) for token_id in top_tokens]
return next_tokens |