Spaces:
Sleeping
Sleeping
File size: 4,858 Bytes
9a570a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
Gradio interface for testing the trained nanoGPT model
"""
import os
import gradio as gr
import torch
import tiktoken
from model import GPTConfig, GPT
# Configuration
MODEL_DIR = "out-srs" # Change this to your model directory
DEVICE = "cpu" # Hugging Face Spaces uses CPU
MAX_TOKENS = 100
TEMPERATURE = 0.8
TOP_K = 200
def load_model():
"""Load the latest checkpoint from the model directory"""
print(f"Loading model from {MODEL_DIR}...")
# Use a specific checkpoint that we know is complete
ckpt_path = os.path.join(MODEL_DIR, 'ckpt_001000.pt')
print(f"Loading checkpoint: {ckpt_path}")
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location="cpu")
# Create model
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
# Load weights
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to("cpu")
print(f"Model loaded successfully! (iteration {checkpoint['iter_num']})")
return model
def load_tokenizer():
"""Load the tokenizer"""
# Check if there's a meta.pkl file for custom tokenizer
meta_path = os.path.join('data', 'srs', 'meta.pkl')
if os.path.exists(meta_path):
import pickle
print(f"Loading tokenizer from {meta_path}")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
else:
print("Using GPT-2 tokenizer")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
return encode, decode
# Load model and tokenizer once at startup
print("Initializing model...")
model = load_model()
encode, decode = load_tokenizer()
print("Ready!")
def generate_text(prompt, max_tokens, temperature, top_k):
"""Generate text from the model"""
try:
# Encode the prompt
start_ids = encode(prompt)
x = torch.tensor(start_ids, dtype=torch.long, device="cpu")[None, ...]
# Generate
with torch.no_grad():
y = model.generate(x, max_tokens, temperature=temperature, top_k=top_k)
generated = decode(y[0].tolist())
return generated
except Exception as e:
return f"Error generating text: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="SRS Conversational Model") as demo:
gr.Markdown("# SRS Conversational Model")
gr.Markdown("This model was trained on conversational data. Enter a prompt to see how it continues the conversation!")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here (e.g., 'Hello, how are you?')",
lines=3
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10, maximum=200, value=MAX_TOKENS, step=10,
label="Max tokens to generate"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=2.0, value=TEMPERATURE, step=0.1,
label="Temperature (creativity)"
)
top_k_slider = gr.Slider(
minimum=1, maximum=500, value=TOP_K, step=10,
label="Top-k (diversity)"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10,
max_lines=15
)
# Examples
gr.Examples(
examples=[
["Hello, how are you?", 100, 0.8, 200],
["I think the wedding", 80, 0.7, 150],
["So anyway, let's talk about", 120, 0.9, 200],
["You know what's interesting", 100, 0.8, 200]
],
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider]
)
# Connect the generate button
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider],
outputs=output_text
)
if __name__ == "__main__":
print("Starting Gradio interface...")
print("Will be available at http://localhost:7860")
print("Use share=True for public link")
# Launch for Hugging Face Spaces
demo.launch() |