File size: 3,349 Bytes
92cb323
 
 
 
 
 
 
 
1f75912
92cb323
1f75912
ded2cdd
5fd73e8
92cb323
90bc4d0
1f75912
92cb323
 
 
 
 
 
 
 
 
 
 
 
 
1f75912
92cb323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37be743
92cb323
 
 
 
7992b22
 
92cb323
37be743
 
 
 
 
 
92cb323
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
import torch.nn.functional as F
import tiktoken
from huggingface_hub import hf_hub_download
from transformer import GPT, GPTConfig  # Import your model class

# Load the model from Hugging Face Hub
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model_from_hf():
    # Replace with your Hugging Face model ID (username/model-name)
    model_id = "sudhakar272/transformer_model"
    checkpoint_path = hf_hub_download(repo_id=model_id, filename="transformer_model.pt")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = checkpoint['config']
    model = GPT(config)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()  # Set to evaluation mode
    
    # Disable gradient computation
    for param in model.parameters():
        param.requires_grad = False
        
    return model

model = load_model_from_hf()


# Force model to stay in eval mode
model.train(False)

def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
    enc = tiktoken.get_encoding('gpt2')
    tokens = enc.encode(prompt)
    tokens = torch.tensor(tokens, dtype=torch.long)
    tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
    tokens = tokens.to(device)
    
    with torch.no_grad():
        for _ in range(max_length):
            if tokens.size(1) >= 1024:  # GPT context length
                break
                
            logits = model(tokens)[0]
            logits = logits[:, -1, :]
            #logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            
            # Top-k sampling
            topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
            ix = torch.multinomial(topk_probs, 1)
            next_token = torch.gather(topk_indices, -1, ix)
            
            tokens = torch.cat((tokens, next_token), dim=1)
            
            # Remove special token check entirely
            # Just generate for the specified length or until context limit
    
    generated_texts = []
    for i in range(num_samples):
        text = enc.decode(tokens[i].tolist())
        generated_texts.append(text)
    
    return '\n\n---\n\n'.join(generated_texts)

# Create Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", value="Good night, good night! Parting is such sweet sorrow"),
        gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
        gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="Shakesphere Text Generator",
    description="Enter text for Shakesphere way of text and continue the same",
    examples=[
        ["To be, or not to be: that is the question.", 100, 1],
        ["Love all, trust a few, do wrong to none.", 60, 2],
        ["It's not enough to speak, but to speak true", 50, 3],
        ["There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy.", 100, 1],
        ["If you can look into the seeds of time, and say which grain will grow and which will not, speak then to me", 100, 1],
        ["Love sought is good, but given unsought is better.", 100, 1],
    ]
)

if __name__ == "__main__":
    iface.launch()