|
|
|
""" |
|
This script is a simple text generator using the SmollmV2 model. |
|
It uses Gradio to create a web interface for generating text. |
|
""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
from transformers import GPT2Tokenizer |
|
import spaces |
|
import os |
|
from pathlib import Path |
|
import warnings |
|
|
|
|
|
from smollmv2 import SmollmV2 |
|
from config import SmollmConfig, DataConfig |
|
from smollv2_lightning import LitSmollmv2 |
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
warnings.filterwarnings('ignore', category=UserWarning) |
|
|
|
def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"): |
|
""" |
|
Combine split model parts into a single checkpoint file |
|
""" |
|
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True) |
|
|
|
|
|
if os.path.exists(output_file): |
|
print(f"Model already combined at: {output_file}") |
|
return output_file |
|
|
|
|
|
if not os.path.exists(model_dir): |
|
raise FileNotFoundError(f"Model directory {model_dir} not found") |
|
|
|
|
|
parts = sorted(Path(model_dir).glob("last.ckpt.part_*")) |
|
if not parts: |
|
raise FileNotFoundError("No model parts found") |
|
|
|
print("Combining model parts...") |
|
with open(output_file, 'wb') as outfile: |
|
for part in parts: |
|
print(f"Processing part: {part}") |
|
with open(part, 'rb') as infile: |
|
outfile.write(infile.read()) |
|
|
|
print(f"Model combined successfully: {output_file}") |
|
return output_file |
|
|
|
def load_model(): |
|
""" |
|
Load the SmollmV2 model and tokenizer. |
|
""" |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
checkpoint_path = "last.ckpt" |
|
|
|
if not os.path.exists(checkpoint_path): |
|
raise FileNotFoundError( |
|
f"Model checkpoint {checkpoint_path} not found. " |
|
"Please ensure the model checkpoint file 'last.ckpt' is present in the root directory." |
|
) |
|
|
|
try: |
|
|
|
model = LitSmollmv2.load_from_checkpoint( |
|
checkpoint_path, |
|
model_config=SmollmConfig, |
|
strict=False |
|
) |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
return model, tokenizer, device |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model: {str(e)}") |
|
|
|
|
|
model, tokenizer, device = load_model() |
|
|
|
@spaces.GPU(enable_queue=True) |
|
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9): |
|
""" |
|
Generate text using the SmollmV2 model. |
|
:param prompt: The initial text prompt to start the generation from. |
|
:param num_tokens: The number of tokens to generate. |
|
:param temperature: The temperature parameter for controlling randomness. |
|
:param top_p: The top-p parameter for nucleus sampling |
|
:return: The generated text. |
|
""" |
|
try: |
|
|
|
num_tokens = min(num_tokens, SmollmConfig.block_size) |
|
|
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.inference_mode(): |
|
for _ in range(num_tokens): |
|
|
|
with torch.autocast(device_type=device, dtype=torch.float16): |
|
outputs = model(input_ids) |
|
logits = outputs[0] if isinstance(outputs, tuple) else outputs |
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
if top_p > 0: |
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
cumsum_probs = torch.cumsum(sorted_probs, dim=-1) |
|
sorted_indices_to_keep = cumsum_probs <= top_p |
|
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone() |
|
sorted_indices_to_keep[..., 0] = 1 |
|
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep) |
|
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs)) |
|
probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=-1) |
|
|
|
|
|
if next_token.item() == tokenizer.eos_token_id: |
|
break |
|
|
|
|
|
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
return generated_text |
|
|
|
except Exception as e: |
|
return f"Error during text generation: {str(e)}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Textbox(label="Enter your prompt", value="Once upon a time"), |
|
gr.Slider(minimum=1, maximum=SmollmConfig.block_size//2, value=100, step=1, label="Number of tokens to generate"), |
|
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)") |
|
], |
|
outputs=gr.Textbox(label="Generated Text"), |
|
title="SmoLLMv2 Text Generator", |
|
description="Generate text using the SmoLLMv2-135M model", |
|
allow_flagging="never", |
|
cache_examples=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |