item-generator / modeling.py
bjorn-hommel's picture
init commit
773685b
import os
import logging
import streamlit as st
from transformers import pipeline
from transformers import GPT2Tokenizer, GPT2LMHeadModel
def load_model():
keys = ['generator']
if any(st.session_state.get(key) is None for key in keys):
with st.spinner('Loading the model might take a couple of seconds...'):
try:
if os.environ.get('remote_model_path'):
model_path = os.environ.get('remote_model_path')
else:
model_path = os.getenv('model_path')
st.session_state.generator = pipeline(task='text-generation', model=model_path, tokenizer=model_path)
logging.info('Loaded models and tokenizer!')
except Exception as e:
logging.error(f'Error while loading models/tokenizer: {e}')
def generate_items(constructs, prefix='', **kwargs):
with st.spinner(f'Generating item(s) for `{constructs}`...'):
construct_sep = '#'
item_sep = '@'
constructs = constructs if isinstance(constructs, list) else [constructs]
encoded_constructs = construct_sep + construct_sep.join([x.lower() for x in constructs])
encoded_prompt = f'{encoded_constructs}{item_sep}{prefix}'
outputs = st.session_state.generator(encoded_prompt, **kwargs)
truncate_str = f'{encoded_constructs}{item_sep}'
item_stems = []
for output in outputs:
item_stems.append(output['generated_text'].replace(truncate_str, ''))
return item_stems
def get_next_tokens(prefix, breadth=5):
# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Encode the prefix
inputs = tokenizer(prefix, return_tensors='pt')
# Get the model's predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Only consider the last token for next token predictions
last_token_logits = logits[:, -1, :]
# Get the indices of the top 'breadth' possible next tokens
top_tokens = torch.topk(last_token_logits, breadth, dim=1).indices.tolist()[0]
# Decode the token IDs to tokens
next_tokens = [tokenizer.decode([token_id]) for token_id in top_tokens]
return next_tokens