File size: 4,858 Bytes
9a570a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Gradio interface for testing the trained nanoGPT model
"""
import os
import gradio as gr
import torch
import tiktoken
from model import GPTConfig, GPT

# Configuration
MODEL_DIR = "out-srs"  # Change this to your model directory
DEVICE = "cpu"  # Hugging Face Spaces uses CPU
MAX_TOKENS = 100
TEMPERATURE = 0.8
TOP_K = 200

def load_model():
    """Load the latest checkpoint from the model directory"""
    print(f"Loading model from {MODEL_DIR}...")
    
    # Use a specific checkpoint that we know is complete
    ckpt_path = os.path.join(MODEL_DIR, 'ckpt_001000.pt')
    
    print(f"Loading checkpoint: {ckpt_path}")
    
    # Load checkpoint
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    
    # Create model
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    
    # Load weights
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    
    model.load_state_dict(state_dict)
    model.eval()
    model.to("cpu")
    
    print(f"Model loaded successfully! (iteration {checkpoint['iter_num']})")
    return model

def load_tokenizer():
    """Load the tokenizer"""
    # Check if there's a meta.pkl file for custom tokenizer
    meta_path = os.path.join('data', 'srs', 'meta.pkl')
    if os.path.exists(meta_path):
        import pickle
        print(f"Loading tokenizer from {meta_path}")
        with open(meta_path, 'rb') as f:
            meta = pickle.load(f)
        stoi, itos = meta['stoi'], meta['itos']
        encode = lambda s: [stoi[c] for c in s]
        decode = lambda l: ''.join([itos[i] for i in l])
    else:
        print("Using GPT-2 tokenizer")
        enc = tiktoken.get_encoding("gpt2")
        encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
        decode = lambda l: enc.decode(l)
    
    return encode, decode

# Load model and tokenizer once at startup
print("Initializing model...")
model = load_model()
encode, decode = load_tokenizer()
print("Ready!")

def generate_text(prompt, max_tokens, temperature, top_k):
    """Generate text from the model"""
    try:
        # Encode the prompt
        start_ids = encode(prompt)
        x = torch.tensor(start_ids, dtype=torch.long, device="cpu")[None, ...]
        
        # Generate
        with torch.no_grad():
            y = model.generate(x, max_tokens, temperature=temperature, top_k=top_k)
            generated = decode(y[0].tolist())
        
        return generated
    
    except Exception as e:
        return f"Error generating text: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="SRS Conversational Model") as demo:
    gr.Markdown("# SRS Conversational Model")
    gr.Markdown("This model was trained on conversational data. Enter a prompt to see how it continues the conversation!")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Prompt", 
                placeholder="Enter your prompt here (e.g., 'Hello, how are you?')",
                lines=3
            )
            
            with gr.Row():
                max_tokens_slider = gr.Slider(
                    minimum=10, maximum=200, value=MAX_TOKENS, step=10,
                    label="Max tokens to generate"
                )
                temperature_slider = gr.Slider(
                    minimum=0.1, maximum=2.0, value=TEMPERATURE, step=0.1,
                    label="Temperature (creativity)"
                )
                top_k_slider = gr.Slider(
                    minimum=1, maximum=500, value=TOP_K, step=10,
                    label="Top-k (diversity)"
                )
            
            generate_btn = gr.Button("Generate", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(
                label="Generated Text",
                lines=10,
                max_lines=15
            )
    
    # Examples
    gr.Examples(
        examples=[
            ["Hello, how are you?", 100, 0.8, 200],
            ["I think the wedding", 80, 0.7, 150],
            ["So anyway, let's talk about", 120, 0.9, 200],
            ["You know what's interesting", 100, 0.8, 200]
        ],
        inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider]
    )
    
    # Connect the generate button
    generate_btn.click(
        fn=generate_text,
        inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider],
        outputs=output_text
    )

if __name__ == "__main__":
    print("Starting Gradio interface...")
    print("Will be available at http://localhost:7860")
    print("Use share=True for public link")
    
    # Launch for Hugging Face Spaces
    demo.launch()