sudhakar272's picture
Update app.py
ded2cdd verified
import gradio as gr
import torch
import torch.nn.functional as F
import tiktoken
from huggingface_hub import hf_hub_download
from transformer import GPT, GPTConfig # Import your model class
# Load the model from Hugging Face Hub
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model_from_hf():
# Replace with your Hugging Face model ID (username/model-name)
model_id = "sudhakar272/transformer_model"
checkpoint_path = hf_hub_download(repo_id=model_id, filename="transformer_model.pt")
checkpoint = torch.load(checkpoint_path, map_location=device)
config = checkpoint['config']
model = GPT(config)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval() # Set to evaluation mode
# Disable gradient computation
for param in model.parameters():
param.requires_grad = False
return model
model = load_model_from_hf()
# Force model to stay in eval mode
model.train(False)
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(prompt)
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
tokens = tokens.to(device)
with torch.no_grad():
for _ in range(max_length):
if tokens.size(1) >= 1024: # GPT context length
break
logits = model(tokens)[0]
logits = logits[:, -1, :]
#logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
# Top-k sampling
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
ix = torch.multinomial(topk_probs, 1)
next_token = torch.gather(topk_indices, -1, ix)
tokens = torch.cat((tokens, next_token), dim=1)
# Remove special token check entirely
# Just generate for the specified length or until context limit
generated_texts = []
for i in range(num_samples):
text = enc.decode(tokens[i].tolist())
generated_texts.append(text)
return '\n\n---\n\n'.join(generated_texts)
# Create Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", value="Good night, good night! Parting is such sweet sorrow"),
gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
],
outputs=gr.Textbox(label="Generated Text"),
title="Shakesphere Text Generator",
description="Enter text for Shakesphere way of text and continue the same",
examples=[
["To be, or not to be: that is the question.", 100, 1],
["Love all, trust a few, do wrong to none.", 60, 2],
["It's not enough to speak, but to speak true", 50, 3],
["There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy.", 100, 1],
["If you can look into the seeds of time, and say which grain will grow and which will not, speak then to me", 100, 1],
["Love sought is good, but given unsought is better.", 100, 1],
]
)
if __name__ == "__main__":
iface.launch()