|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
import tiktoken |
|
from huggingface_hub import hf_hub_download |
|
from transformer import GPT, GPTConfig |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
def load_model_from_hf(): |
|
|
|
model_id = "sudhakar272/transformer_model" |
|
checkpoint_path = hf_hub_download(repo_id=model_id, filename="transformer_model.pt") |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
config = checkpoint['config'] |
|
model = GPT(config) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
return model |
|
|
|
model = load_model_from_hf() |
|
|
|
|
|
|
|
model.train(False) |
|
|
|
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8): |
|
enc = tiktoken.get_encoding('gpt2') |
|
tokens = enc.encode(prompt) |
|
tokens = torch.tensor(tokens, dtype=torch.long) |
|
tokens = tokens.unsqueeze(0).repeat(num_samples, 1) |
|
tokens = tokens.to(device) |
|
|
|
with torch.no_grad(): |
|
for _ in range(max_length): |
|
if tokens.size(1) >= 1024: |
|
break |
|
|
|
logits = model(tokens)[0] |
|
logits = logits[:, -1, :] |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) |
|
ix = torch.multinomial(topk_probs, 1) |
|
next_token = torch.gather(topk_indices, -1, ix) |
|
|
|
tokens = torch.cat((tokens, next_token), dim=1) |
|
|
|
|
|
|
|
|
|
generated_texts = [] |
|
for i in range(num_samples): |
|
text = enc.decode(tokens[i].tolist()) |
|
generated_texts.append(text) |
|
|
|
return '\n\n---\n\n'.join(generated_texts) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Textbox(label="Prompt", value="Good night, good night! Parting is such sweet sorrow"), |
|
gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"), |
|
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"), |
|
], |
|
outputs=gr.Textbox(label="Generated Text"), |
|
title="Shakesphere Text Generator", |
|
description="Enter text for Shakesphere way of text and continue the same", |
|
examples=[ |
|
["To be, or not to be: that is the question.", 100, 1], |
|
["Love all, trust a few, do wrong to none.", 60, 2], |
|
["It's not enough to speak, but to speak true", 50, 3], |
|
["There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy.", 100, 1], |
|
["If you can look into the seeds of time, and say which grain will grow and which will not, speak then to me", 100, 1], |
|
["Love sought is good, but given unsought is better.", 100, 1], |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |