""" Gradio interface for testing the trained nanoGPT model """ import os import gradio as gr import torch import tiktoken from model import GPTConfig, GPT # Configuration MODEL_DIR = "out-srs" # Change this to your model directory DEVICE = "cpu" # Hugging Face Spaces uses CPU MAX_TOKENS = 100 TEMPERATURE = 0.8 TOP_K = 200 def load_model(): """Load the latest checkpoint from the model directory""" print(f"Loading model from {MODEL_DIR}...") # Use a specific checkpoint that we know is complete ckpt_path = os.path.join(MODEL_DIR, 'ckpt_001000.pt') print(f"Loading checkpoint: {ckpt_path}") # Load checkpoint checkpoint = torch.load(ckpt_path, map_location="cpu") # Create model gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) # Load weights state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.eval() model.to("cpu") print(f"Model loaded successfully! (iteration {checkpoint['iter_num']})") return model def load_tokenizer(): """Load the tokenizer""" # Check if there's a meta.pkl file for custom tokenizer meta_path = os.path.join('data', 'srs', 'meta.pkl') if os.path.exists(meta_path): import pickle print(f"Loading tokenizer from {meta_path}") with open(meta_path, 'rb') as f: meta = pickle.load(f) stoi, itos = meta['stoi'], meta['itos'] encode = lambda s: [stoi[c] for c in s] decode = lambda l: ''.join([itos[i] for i in l]) else: print("Using GPT-2 tokenizer") enc = tiktoken.get_encoding("gpt2") encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) decode = lambda l: enc.decode(l) return encode, decode # Load model and tokenizer once at startup print("Initializing model...") model = load_model() encode, decode = load_tokenizer() print("Ready!") def generate_text(prompt, max_tokens, temperature, top_k): """Generate text from the model""" try: # Encode the prompt start_ids = encode(prompt) x = torch.tensor(start_ids, dtype=torch.long, device="cpu")[None, ...] # Generate with torch.no_grad(): y = model.generate(x, max_tokens, temperature=temperature, top_k=top_k) generated = decode(y[0].tolist()) return generated except Exception as e: return f"Error generating text: {str(e)}" # Create Gradio interface with gr.Blocks(title="SRS Conversational Model") as demo: gr.Markdown("# SRS Conversational Model") gr.Markdown("This model was trained on conversational data. Enter a prompt to see how it continues the conversation!") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt here (e.g., 'Hello, how are you?')", lines=3 ) with gr.Row(): max_tokens_slider = gr.Slider( minimum=10, maximum=200, value=MAX_TOKENS, step=10, label="Max tokens to generate" ) temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=TEMPERATURE, step=0.1, label="Temperature (creativity)" ) top_k_slider = gr.Slider( minimum=1, maximum=500, value=TOP_K, step=10, label="Top-k (diversity)" ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Generated Text", lines=10, max_lines=15 ) # Examples gr.Examples( examples=[ ["Hello, how are you?", 100, 0.8, 200], ["I think the wedding", 80, 0.7, 150], ["So anyway, let's talk about", 120, 0.9, 200], ["You know what's interesting", 100, 0.8, 200] ], inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider] ) # Connect the generate button generate_btn.click( fn=generate_text, inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider], outputs=output_text ) if __name__ == "__main__": print("Starting Gradio interface...") print("Will be available at http://localhost:7860") print("Use share=True for public link") # Launch for Hugging Face Spaces demo.launch()