SmoLLMv2 / app.py
Shilpaj's picture
Docs: Updated README
dbdeb7e verified
#! /usr/bin/env python3
"""
This script is a simple text generator using the SmollmV2 model.
It uses Gradio to create a web interface for generating text.
"""
# Third-Party Imports
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import GPT2Tokenizer
import spaces
import os
from pathlib import Path
import warnings
# Local imports
from smollmv2 import SmollmV2
from config import SmollmConfig, DataConfig
from smollv2_lightning import LitSmollmv2
# Configure PyTorch to handle the device properties issue
torch._dynamo.config.suppress_errors = True
warnings.filterwarnings('ignore', category=UserWarning)
def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
"""
Combine split model parts into a single checkpoint file
"""
# Create checkpoints directory if it doesn't exist
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# Check if combined model already exists
if os.path.exists(output_file):
print(f"Model already combined at: {output_file}")
return output_file
# Ensure the model parts exist
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model directory {model_dir} not found")
# Combine the parts
parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
if not parts:
raise FileNotFoundError("No model parts found")
print("Combining model parts...")
with open(output_file, 'wb') as outfile:
for part in parts:
print(f"Processing part: {part}")
with open(part, 'rb') as infile:
outfile.write(infile.read())
print(f"Model combined successfully: {output_file}")
return output_file
def load_model():
"""
Load the SmollmV2 model and tokenizer.
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model directly from checkpoint
checkpoint_path = "last.ckpt"
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Model checkpoint {checkpoint_path} not found. "
"Please ensure the model checkpoint file 'last.ckpt' is present in the root directory."
)
try:
# Load the model from checkpoint using Lightning module
model = LitSmollmv2.load_from_checkpoint(
checkpoint_path,
model_config=SmollmConfig,
strict=False
)
model.to(device)
model.eval()
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer, device
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
# Load the model globally
model, tokenizer, device = load_model()
@spaces.GPU(enable_queue=True)
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
"""
Generate text using the SmollmV2 model.
:param prompt: The initial text prompt to start the generation from.
:param num_tokens: The number of tokens to generate.
:param temperature: The temperature parameter for controlling randomness.
:param top_p: The top-p parameter for nucleus sampling
:return: The generated text.
"""
try:
# Ensure num_tokens doesn't exceed model's block size
num_tokens = min(num_tokens, SmollmConfig.block_size)
# Tokenize input prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate tokens one at a time
with torch.inference_mode(): # Use inference_mode instead of no_grad
for _ in range(num_tokens):
# Get the model's predictions
with torch.autocast(device_type=device, dtype=torch.float16): # Changed to float16
outputs = model(input_ids)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
# Get the next token probabilities
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
# Apply top-p sampling
if top_p > 0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_keep = cumsum_probs <= top_p
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
sorted_indices_to_keep[..., 0] = 1
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1)
# Append to input_ids
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Stop if we generate an EOS token
if next_token.item() == tokenizer.eos_token_id:
break
# Decode and return the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text
except Exception as e:
return f"Error during text generation: {str(e)}"
# Create the Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Enter your prompt", value="Once upon a time"),
gr.Slider(minimum=1, maximum=SmollmConfig.block_size//2, value=100, step=1, label="Number of tokens to generate"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
],
outputs=gr.Textbox(label="Generated Text"),
title="SmoLLMv2 Text Generator",
description="Generate text using the SmoLLMv2-135M model",
allow_flagging="never",
cache_examples=True
)
if __name__ == "__main__":
demo.launch()