import gradio as gr import torch import torch.nn.functional as F import tiktoken from huggingface_hub import hf_hub_download from transformer import GPT, GPTConfig # Import your model class # Load the model from Hugging Face Hub device = 'cuda' if torch.cuda.is_available() else 'cpu' def load_model_from_hf(): # Replace with your Hugging Face model ID (username/model-name) model_id = "sudhakar272/transformer_model" checkpoint_path = hf_hub_download(repo_id=model_id, filename="transformer_model.pt") checkpoint = torch.load(checkpoint_path, map_location=device) config = checkpoint['config'] model = GPT(config) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() # Set to evaluation mode # Disable gradient computation for param in model.parameters(): param.requires_grad = False return model model = load_model_from_hf() # Force model to stay in eval mode model.train(False) def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8): enc = tiktoken.get_encoding('gpt2') tokens = enc.encode(prompt) tokens = torch.tensor(tokens, dtype=torch.long) tokens = tokens.unsqueeze(0).repeat(num_samples, 1) tokens = tokens.to(device) with torch.no_grad(): for _ in range(max_length): if tokens.size(1) >= 1024: # GPT context length break logits = model(tokens)[0] logits = logits[:, -1, :] #logits = logits[:, -1, :] / temperature probs = F.softmax(logits, dim=-1) # Top-k sampling topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) ix = torch.multinomial(topk_probs, 1) next_token = torch.gather(topk_indices, -1, ix) tokens = torch.cat((tokens, next_token), dim=1) # Remove special token check entirely # Just generate for the specified length or until context limit generated_texts = [] for i in range(num_samples): text = enc.decode(tokens[i].tolist()) generated_texts.append(text) return '\n\n---\n\n'.join(generated_texts) # Create Gradio interface iface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Prompt", value="Good night, good night! Parting is such sweet sorrow"), gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"), gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"), ], outputs=gr.Textbox(label="Generated Text"), title="Shakesphere Text Generator", description="Enter text for Shakesphere way of text and continue the same", examples=[ ["To be, or not to be: that is the question.", 100, 1], ["Love all, trust a few, do wrong to none.", 60, 2], ["It's not enough to speak, but to speak true", 50, 3], ["There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy.", 100, 1], ["If you can look into the seeds of time, and say which grain will grow and which will not, speak then to me", 100, 1], ["Love sought is good, but given unsought is better.", 100, 1], ] ) if __name__ == "__main__": iface.launch()