File size: 6,368 Bytes
f42f624
 
 
 
 
 
 
 
 
 
 
 
 
baa4d1d
f42f624
 
 
 
 
 
baa4d1d
 
 
f42f624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea06c18
baa4d1d
f42f624
ea06c18
 
 
 
 
 
baa4d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f42f624
baa4d1d
 
f42f624
1cb4d80
 
f42f624
 
 
 
 
dbdeb7e
 
 
 
 
f42f624
baa4d1d
 
 
f42f624
baa4d1d
 
f42f624
baa4d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f42f624
baa4d1d
 
 
f42f624
baa4d1d
 
f42f624
 
 
 
 
 
dbdeb7e
f42f624
 
 
 
dbdeb7e
 
f42f624
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#! /usr/bin/env python3
"""
This script is a simple text generator using the SmollmV2 model.
It uses Gradio to create a web interface for generating text.
"""
# Third-Party Imports
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import GPT2Tokenizer
import spaces
import os
from pathlib import Path
import warnings

# Local imports
from smollmv2 import SmollmV2
from config import SmollmConfig, DataConfig
from smollv2_lightning import LitSmollmv2

# Configure PyTorch to handle the device properties issue
torch._dynamo.config.suppress_errors = True
warnings.filterwarnings('ignore', category=UserWarning)

def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
    """
    Combine split model parts into a single checkpoint file
    """
    # Create checkpoints directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Check if combined model already exists
    if os.path.exists(output_file):
        print(f"Model already combined at: {output_file}")
        return output_file
    
    # Ensure the model parts exist
    if not os.path.exists(model_dir):
        raise FileNotFoundError(f"Model directory {model_dir} not found")
    
    # Combine the parts
    parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
    if not parts:
        raise FileNotFoundError("No model parts found")
    
    print("Combining model parts...")
    with open(output_file, 'wb') as outfile:
        for part in parts:
            print(f"Processing part: {part}")
            with open(part, 'rb') as infile:
                outfile.write(infile.read())
    
    print(f"Model combined successfully: {output_file}")
    return output_file

def load_model():
    """
    Load the SmollmV2 model and tokenizer.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load model directly from checkpoint
    checkpoint_path = "last.ckpt"
    
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(
            f"Model checkpoint {checkpoint_path} not found. "
            "Please ensure the model checkpoint file 'last.ckpt' is present in the root directory."
        )
    
    try:
        # Load the model from checkpoint using Lightning module
        model = LitSmollmv2.load_from_checkpoint(
            checkpoint_path,
            model_config=SmollmConfig,
            strict=False
        )
        
        model.to(device)
        model.eval()
        
        # Initialize tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
        tokenizer.pad_token = tokenizer.eos_token
        
        return model, tokenizer, device
    
    except Exception as e:
        raise RuntimeError(f"Error loading model: {str(e)}")

# Load the model globally
model, tokenizer, device = load_model()

@spaces.GPU(enable_queue=True)
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
    """
    Generate text using the SmollmV2 model.
    :param prompt: The initial text prompt to start the generation from.
    :param num_tokens: The number of tokens to generate.
    :param temperature: The temperature parameter for controlling randomness.
    :param top_p: The top-p parameter for nucleus sampling
    :return: The generated text.
    """
    try:
        # Ensure num_tokens doesn't exceed model's block size
        num_tokens = min(num_tokens, SmollmConfig.block_size)
        
        # Tokenize input prompt
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
        
        # Generate tokens one at a time
        with torch.inference_mode():  # Use inference_mode instead of no_grad
            for _ in range(num_tokens):
                # Get the model's predictions
                with torch.autocast(device_type=device, dtype=torch.float16):  # Changed to float16
                    outputs = model(input_ids)
                    logits = outputs[0] if isinstance(outputs, tuple) else outputs
                
                # Get the next token probabilities
                logits = logits[:, -1, :] / temperature
                probs = F.softmax(logits, dim=-1)
                
                # Apply top-p sampling
                if top_p > 0:
                    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
                    sorted_indices_to_keep = cumsum_probs <= top_p
                    sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
                    sorted_indices_to_keep[..., 0] = 1
                    indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
                    probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
                    probs = probs / probs.sum(dim=-1, keepdim=True)
                
                # Sample next token
                next_token = torch.multinomial(probs, num_samples=1)
                
                # Append to input_ids
                input_ids = torch.cat([input_ids, next_token], dim=-1)
                
                # Stop if we generate an EOS token
                if next_token.item() == tokenizer.eos_token_id:
                    break
        
        # Decode and return the generated text
        generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return generated_text
    
    except Exception as e:
        return f"Error during text generation: {str(e)}"

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Enter your prompt", value="Once upon a time"),
        gr.Slider(minimum=1, maximum=SmollmConfig.block_size//2, value=100, step=1, label="Number of tokens to generate"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="SmoLLMv2 Text Generator",
    description="Generate text using the SmoLLMv2-135M model",
    allow_flagging="never",
    cache_examples=True
)

if __name__ == "__main__":
    demo.launch()