|
import os |
|
import logging |
|
import streamlit as st |
|
from transformers import pipeline |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel |
|
|
|
def load_model(): |
|
|
|
keys = ['generator'] |
|
|
|
if any(st.session_state.get(key) is None for key in keys): |
|
|
|
with st.spinner('Loading the model might take a couple of seconds...'): |
|
try: |
|
if os.environ.get('remote_model_path'): |
|
model_path = os.environ.get('remote_model_path') |
|
else: |
|
model_path = os.getenv('model_path') |
|
|
|
st.session_state.generator = pipeline(task='text-generation', model=model_path, tokenizer=model_path) |
|
|
|
logging.info('Loaded models and tokenizer!') |
|
|
|
except Exception as e: |
|
logging.error(f'Error while loading models/tokenizer: {e}') |
|
|
|
def generate_items(constructs, prefix='', **kwargs): |
|
|
|
with st.spinner(f'Generating item(s) for `{constructs}`...'): |
|
construct_sep = '#' |
|
item_sep = '@' |
|
|
|
constructs = constructs if isinstance(constructs, list) else [constructs] |
|
encoded_constructs = construct_sep + construct_sep.join([x.lower() for x in constructs]) |
|
encoded_prompt = f'{encoded_constructs}{item_sep}{prefix}' |
|
|
|
outputs = st.session_state.generator(encoded_prompt, **kwargs) |
|
truncate_str = f'{encoded_constructs}{item_sep}' |
|
|
|
item_stems = [] |
|
for output in outputs: |
|
item_stems.append(output['generated_text'].replace(truncate_str, '')) |
|
|
|
return item_stems |
|
|
|
def get_next_tokens(prefix, breadth=5): |
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
model = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
|
|
inputs = tokenizer(prefix, return_tensors='pt') |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
|
|
last_token_logits = logits[:, -1, :] |
|
|
|
|
|
top_tokens = torch.topk(last_token_logits, breadth, dim=1).indices.tolist()[0] |
|
|
|
|
|
next_tokens = [tokenizer.decode([token_id]) for token_id in top_tokens] |
|
|
|
return next_tokens |