File size: 6,368 Bytes
f42f624 baa4d1d f42f624 baa4d1d f42f624 ea06c18 baa4d1d f42f624 ea06c18 baa4d1d f42f624 baa4d1d f42f624 1cb4d80 f42f624 dbdeb7e f42f624 baa4d1d f42f624 baa4d1d f42f624 baa4d1d f42f624 baa4d1d f42f624 baa4d1d f42f624 dbdeb7e f42f624 dbdeb7e f42f624 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
#! /usr/bin/env python3
"""
This script is a simple text generator using the SmollmV2 model.
It uses Gradio to create a web interface for generating text.
"""
# Third-Party Imports
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import GPT2Tokenizer
import spaces
import os
from pathlib import Path
import warnings
# Local imports
from smollmv2 import SmollmV2
from config import SmollmConfig, DataConfig
from smollv2_lightning import LitSmollmv2
# Configure PyTorch to handle the device properties issue
torch._dynamo.config.suppress_errors = True
warnings.filterwarnings('ignore', category=UserWarning)
def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
"""
Combine split model parts into a single checkpoint file
"""
# Create checkpoints directory if it doesn't exist
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# Check if combined model already exists
if os.path.exists(output_file):
print(f"Model already combined at: {output_file}")
return output_file
# Ensure the model parts exist
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model directory {model_dir} not found")
# Combine the parts
parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
if not parts:
raise FileNotFoundError("No model parts found")
print("Combining model parts...")
with open(output_file, 'wb') as outfile:
for part in parts:
print(f"Processing part: {part}")
with open(part, 'rb') as infile:
outfile.write(infile.read())
print(f"Model combined successfully: {output_file}")
return output_file
def load_model():
"""
Load the SmollmV2 model and tokenizer.
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model directly from checkpoint
checkpoint_path = "last.ckpt"
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Model checkpoint {checkpoint_path} not found. "
"Please ensure the model checkpoint file 'last.ckpt' is present in the root directory."
)
try:
# Load the model from checkpoint using Lightning module
model = LitSmollmv2.load_from_checkpoint(
checkpoint_path,
model_config=SmollmConfig,
strict=False
)
model.to(device)
model.eval()
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer, device
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
# Load the model globally
model, tokenizer, device = load_model()
@spaces.GPU(enable_queue=True)
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
"""
Generate text using the SmollmV2 model.
:param prompt: The initial text prompt to start the generation from.
:param num_tokens: The number of tokens to generate.
:param temperature: The temperature parameter for controlling randomness.
:param top_p: The top-p parameter for nucleus sampling
:return: The generated text.
"""
try:
# Ensure num_tokens doesn't exceed model's block size
num_tokens = min(num_tokens, SmollmConfig.block_size)
# Tokenize input prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate tokens one at a time
with torch.inference_mode(): # Use inference_mode instead of no_grad
for _ in range(num_tokens):
# Get the model's predictions
with torch.autocast(device_type=device, dtype=torch.float16): # Changed to float16
outputs = model(input_ids)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
# Get the next token probabilities
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
# Apply top-p sampling
if top_p > 0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_keep = cumsum_probs <= top_p
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
sorted_indices_to_keep[..., 0] = 1
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1)
# Append to input_ids
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Stop if we generate an EOS token
if next_token.item() == tokenizer.eos_token_id:
break
# Decode and return the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text
except Exception as e:
return f"Error during text generation: {str(e)}"
# Create the Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Enter your prompt", value="Once upon a time"),
gr.Slider(minimum=1, maximum=SmollmConfig.block_size//2, value=100, step=1, label="Number of tokens to generate"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
],
outputs=gr.Textbox(label="Generated Text"),
title="SmoLLMv2 Text Generator",
description="Generate text using the SmoLLMv2-135M model",
allow_flagging="never",
cache_examples=True
)
if __name__ == "__main__":
demo.launch() |