import os import torch import gradio as gr from transformers import AutoTokenizer from config import SmolLM2Config from model import SmolLM2Lightning def load_model(checkpoint_path): """Load the trained model from checkpoint""" try: config = SmolLM2Config("config.yaml") model = SmolLM2Lightning.load_from_checkpoint(checkpoint_path, config=config) model.eval() if torch.cuda.is_available(): model = model.cuda() print(f"Model loaded on GPU: {torch.cuda.get_device_name()}") else: print("Model loaded on CPU") return model except Exception as e: print(f"Error loading model: {str(e)}") return None def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50): """Generate text from prompt""" try: if model is None: return "Model not loaded. Please check if checkpoint exists." inputs = model.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=model.config.model.max_position_embeddings, padding=True ) if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_length, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=model.tokenizer.pad_token_id, bos_token_id=model.tokenizer.bos_token_id, eos_token_id=model.tokenizer.eos_token_id ) return model.tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"Error generating text: {str(e)}" # Load the model print("Loading model...") checkpoint_path = "checkpoints/smol-lm2-final.ckpt" if not os.path.exists(checkpoint_path): print(f"Warning: Checkpoint not found at {checkpoint_path}") print("Please train the model first or specify correct checkpoint path") model = None else: model = load_model(checkpoint_path) # Create Gradio interface demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."), gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"), gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"), gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k") ], outputs=gr.Textbox(label="Generated Text"), title="SmolLM2 Text Generation", description="Enter a prompt and adjust generation parameters to create text with SmolLM2", examples=[ ["Explain what machine learning is:", 100, 0.7, 0.9, 50], ["Once upon a time", 150, 0.8, 0.9, 40], ["The best way to learn programming is", 120, 0.7, 0.9, 50] ] ) if __name__ == "__main__": print("Starting Gradio interface...") # Simple launch configuration demo.launch( server_port=7860, share=True )