#! /usr/bin/env python3 """ This script is a simple text generator using the SmollmV2 model. It uses Gradio to create a web interface for generating text. """ # Third-Party Imports import torch import torch.nn.functional as F import gradio as gr from transformers import GPT2Tokenizer import spaces import os from pathlib import Path import warnings # Local imports from smollmv2 import SmollmV2 from config import SmollmConfig, DataConfig from smollv2_lightning import LitSmollmv2 # Configure PyTorch to handle the device properties issue torch._dynamo.config.suppress_errors = True warnings.filterwarnings('ignore', category=UserWarning) def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"): """ Combine split model parts into a single checkpoint file """ # Create checkpoints directory if it doesn't exist os.makedirs(os.path.dirname(output_file), exist_ok=True) # Check if combined model already exists if os.path.exists(output_file): print(f"Model already combined at: {output_file}") return output_file # Ensure the model parts exist if not os.path.exists(model_dir): raise FileNotFoundError(f"Model directory {model_dir} not found") # Combine the parts parts = sorted(Path(model_dir).glob("last.ckpt.part_*")) if not parts: raise FileNotFoundError("No model parts found") print("Combining model parts...") with open(output_file, 'wb') as outfile: for part in parts: print(f"Processing part: {part}") with open(part, 'rb') as infile: outfile.write(infile.read()) print(f"Model combined successfully: {output_file}") return output_file def load_model(): """ Load the SmollmV2 model and tokenizer. """ device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load model directly from checkpoint checkpoint_path = "last.ckpt" if not os.path.exists(checkpoint_path): raise FileNotFoundError( f"Model checkpoint {checkpoint_path} not found. " "Please ensure the model checkpoint file 'last.ckpt' is present in the root directory." ) try: # Load the model from checkpoint using Lightning module model = LitSmollmv2.load_from_checkpoint( checkpoint_path, model_config=SmollmConfig, strict=False ) model.to(device) model.eval() # Initialize tokenizer tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path) tokenizer.pad_token = tokenizer.eos_token return model, tokenizer, device except Exception as e: raise RuntimeError(f"Error loading model: {str(e)}") # Load the model globally model, tokenizer, device = load_model() @spaces.GPU(enable_queue=True) def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9): """ Generate text using the SmollmV2 model. :param prompt: The initial text prompt to start the generation from. :param num_tokens: The number of tokens to generate. :param temperature: The temperature parameter for controlling randomness. :param top_p: The top-p parameter for nucleus sampling :return: The generated text. """ try: # Ensure num_tokens doesn't exceed model's block size num_tokens = min(num_tokens, SmollmConfig.block_size) # Tokenize input prompt input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) # Generate tokens one at a time with torch.inference_mode(): # Use inference_mode instead of no_grad for _ in range(num_tokens): # Get the model's predictions with torch.autocast(device_type=device, dtype=torch.float16): # Changed to float16 outputs = model(input_ids) logits = outputs[0] if isinstance(outputs, tuple) else outputs # Get the next token probabilities logits = logits[:, -1, :] / temperature probs = F.softmax(logits, dim=-1) # Apply top-p sampling if top_p > 0: sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumsum_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_keep = cumsum_probs <= top_p sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone() sorted_indices_to_keep[..., 0] = 1 indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep) probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs)) probs = probs / probs.sum(dim=-1, keepdim=True) # Sample next token next_token = torch.multinomial(probs, num_samples=1) # Append to input_ids input_ids = torch.cat([input_ids, next_token], dim=-1) # Stop if we generate an EOS token if next_token.item() == tokenizer.eos_token_id: break # Decode and return the generated text generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) return generated_text except Exception as e: return f"Error during text generation: {str(e)}" # Create the Gradio interface demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Enter your prompt", value="Once upon a time"), gr.Slider(minimum=1, maximum=SmollmConfig.block_size//2, value=100, step=1, label="Number of tokens to generate"), gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)") ], outputs=gr.Textbox(label="Generated Text"), title="SmoLLMv2 Text Generator", description="Generate text using the SmoLLMv2-135M model", allow_flagging="never", cache_examples=True ) if __name__ == "__main__": demo.launch()