babbleGPT / gradio_demo.py
shiffman's picture
Upload folder using huggingface_hub
9a570a0 verified
"""
Gradio interface for testing the trained nanoGPT model
"""
import os
import gradio as gr
import torch
import tiktoken
from model import GPTConfig, GPT
# Configuration
MODEL_DIR = "out-srs" # Change this to your model directory
DEVICE = "cpu" # Hugging Face Spaces uses CPU
MAX_TOKENS = 100
TEMPERATURE = 0.8
TOP_K = 200
def load_model():
"""Load the latest checkpoint from the model directory"""
print(f"Loading model from {MODEL_DIR}...")
# Use a specific checkpoint that we know is complete
ckpt_path = os.path.join(MODEL_DIR, 'ckpt_001000.pt')
print(f"Loading checkpoint: {ckpt_path}")
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location="cpu")
# Create model
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
# Load weights
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to("cpu")
print(f"Model loaded successfully! (iteration {checkpoint['iter_num']})")
return model
def load_tokenizer():
"""Load the tokenizer"""
# Check if there's a meta.pkl file for custom tokenizer
meta_path = os.path.join('data', 'srs', 'meta.pkl')
if os.path.exists(meta_path):
import pickle
print(f"Loading tokenizer from {meta_path}")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
else:
print("Using GPT-2 tokenizer")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
return encode, decode
# Load model and tokenizer once at startup
print("Initializing model...")
model = load_model()
encode, decode = load_tokenizer()
print("Ready!")
def generate_text(prompt, max_tokens, temperature, top_k):
"""Generate text from the model"""
try:
# Encode the prompt
start_ids = encode(prompt)
x = torch.tensor(start_ids, dtype=torch.long, device="cpu")[None, ...]
# Generate
with torch.no_grad():
y = model.generate(x, max_tokens, temperature=temperature, top_k=top_k)
generated = decode(y[0].tolist())
return generated
except Exception as e:
return f"Error generating text: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="SRS Conversational Model") as demo:
gr.Markdown("# SRS Conversational Model")
gr.Markdown("This model was trained on conversational data. Enter a prompt to see how it continues the conversation!")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here (e.g., 'Hello, how are you?')",
lines=3
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10, maximum=200, value=MAX_TOKENS, step=10,
label="Max tokens to generate"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=2.0, value=TEMPERATURE, step=0.1,
label="Temperature (creativity)"
)
top_k_slider = gr.Slider(
minimum=1, maximum=500, value=TOP_K, step=10,
label="Top-k (diversity)"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10,
max_lines=15
)
# Examples
gr.Examples(
examples=[
["Hello, how are you?", 100, 0.8, 200],
["I think the wedding", 80, 0.7, 150],
["So anyway, let's talk about", 120, 0.9, 200],
["You know what's interesting", 100, 0.8, 200]
],
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider]
)
# Connect the generate button
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider],
outputs=output_text
)
if __name__ == "__main__":
print("Starting Gradio interface...")
print("Will be available at http://localhost:7860")
print("Use share=True for public link")
# Launch for Hugging Face Spaces
demo.launch()