File size: 11,114 Bytes
3670892
dfdff51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650d6db
dfdff51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f82987b
dfdff51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f82987b
dfdff51
 
f82987b
dfdff51
 
f82987b
dfdff51
 
 
 
 
 
 
 
f82987b
dfdff51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f82987b
dfdff51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e16642
dfdff51
 
73ad22a
dfdff51
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
from safetensors.torch import load_file  # Import safetensors for loading .safetensors models
import datetime

# Model Constants
MODEL_ID = "FlameF0X/Snowflake-G0-Release"  # HF repo when published
MAX_LENGTH = 384
TEMPERATURE_MIN = 0.1
TEMPERATURE_MAX = 2.0
TEMPERATURE_DEFAULT = 0.7
TOP_P_MIN = 0.1
TOP_P_MAX = 1.0
TOP_P_DEFAULT = 0.9
TOP_K_MIN = 1
TOP_K_MAX = 100
TOP_K_DEFAULT = 40
MAX_NEW_TOKENS_MIN = 16
MAX_NEW_TOKENS_MAX = 1024
MAX_NEW_TOKENS_DEFAULT = 256

# CSS for the app
css = """
.gradio-container {
    background-color: #1e1e2f !important;
    color: #e0e0e0 !important;
}
.header {
    background-color: #2b2b3c;
    padding: 20px;
    margin-bottom: 20px;
    border-radius: 10px;
    text-align: center;
}
.header h1 {
    color: #66ccff;
    margin-bottom: 10px;
}
.snowflake-icon {
    font-size: 24px;
    margin-right: 10px;
}
.footer {
    text-align: center;
    margin-top: 20px;
    font-size: 0.9em;
    color: #999;
}
.parameter-section {
    background-color: #2a2a3a;
    padding: 15px;
    border-radius: 8px;
    margin-bottom: 15px;
}
.parameter-section h3 {
    margin-top: 0;
    color: #66ccff;
}
.example-section {
    background-color: #223344;
    padding: 15px;
    border-radius: 8px;
    margin-bottom: 15px;
}
.example-section h3 {
    margin-top: 0;
    color: #66ffaa;
}
"""

# Helper functions to load model
def load_model_and_tokenizer():
    global model, tokenizer, pipeline  # Add this line

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    
    # Check if the pad_token is None, set it to eos_token if needed
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Check if the model uses safetensors or pytorch .bin model file
    model_file_path = os.path.join(MODEL_ID, "model.safetensors")  # or model.bin if that's the case

    if os.path.exists(model_file_path):
        # Check if safetensors file exists
        print("Loading model from safetensors file...")
        model = load_file(model_file_path)  # Safetensors loading
    else:
        # Load from standard .bin file
        print("Loading model from .bin file...")
        model = AutoModelForCausalLM.from_pretrained(MODEL_ID, 
                                                     torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                                                     device_map="auto")
    
    # Initialize the generation pipeline
    pipeline = TextGenerationPipeline(
        model=model,
        tokenizer=tokenizer,
        return_full_text=False,
        max_length=MAX_LENGTH
    )
    
    return model, tokenizer, pipeline

# Helper functions for generation
def generate_text(
    prompt, 
    temperature=TEMPERATURE_DEFAULT, 
    top_p=TOP_P_DEFAULT, 
    top_k=TOP_K_DEFAULT, 
    max_new_tokens=MAX_NEW_TOKENS_DEFAULT,
    history=None
):
    if history is None:
        history = []
    
    # Add current prompt to history
    history.append({"role": "user", "content": prompt})
    
    try:
        # Generate response
        outputs = pipeline(
            prompt,
            do_sample=temperature > 0,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            num_return_sequences=1
        )
        
        response = outputs[0]["generated_text"]
        
        # Add model response to history
        history.append({"role": "assistant", "content": response})
        
        # Format chat history for display
        formatted_history = []
        for entry in history:
            role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: "
            formatted_history.append(f"{role_prefix}{entry['content']}")
        
        return response, history, "\n\n".join(formatted_history)
    
    except Exception as e:
        error_msg = f"Error generating response: {str(e)}"
        history.append({"role": "assistant", "content": f"[ERROR] {error_msg}"})
        return error_msg, history, str(history)

def clear_conversation():
    return "", [], ""

def apply_preset_example(example, history):
    return example, history

# Example prompts
examples = [
    "Write a short story about a snowflake that comes to life.",
    "Explain the concept of artificial neural networks to a 10-year-old.",
    "What are some interesting applications of natural language processing?",
    "Write a haiku about programming.",
    "Create a dialogue between two AI researchers discussing the future of language models."
]

# Main function
def create_demo():
    with gr.Blocks(css=css) as demo:
        # Header
        gr.HTML("""
        <div class="header">
            <h1><span class="snowflake-icon">❄️</span> Snowflake-G0-Release Demo</h1>
            <p>Experience the capabilities of the Snowflake-G0-Release language model</p>
        </div>
        """)
        
        # Model info
        with gr.Accordion("About Snowflake-G0-Release", open=False):
            gr.Markdown("""
            ## Snowflake-G0-Release
            
            This is the initial release of the Snowflake series language models, trained on the DialogMLM-50K dataset with optimized memory usage.
            
            ### Model details
            - Architecture: SnowflakeCore
            - Hidden size: 384
            - Number of attention heads: 6
            - Number of layers: 4
            - Feed-forward dimension: 768
            - Maximum sequence length: 384
            - Vocabulary size: 30522 (BERT tokenizer)
            
            ### Key Features
            - Efficient memory usage
            - Fused QKV projection for faster inference
            - Pre-norm architecture for stable training
            - Compatible with HuggingFace Transformers
            """)
        
        # Chat interface
        with gr.Column():
            chat_history_display = gr.Textbox(
                value="", 
                label="Conversation History", 
                lines=10, 
                max_lines=30,
                interactive=False
            )
            
            # Invisible state variables
            history_state = gr.State([])
            
            # Input and output
            with gr.Row():
                with gr.Column(scale=4):
                    prompt = gr.Textbox(
                        placeholder="Type your message here...", 
                        label="Your Input",
                        lines=2
                    )
                with gr.Column(scale=1):
                    submit_btn = gr.Button("Send", variant="primary")
                    clear_btn = gr.Button("Clear Conversation")
            
            response_output = gr.Textbox(
                value="", 
                label="Model Response", 
                lines=5,
                max_lines=10,
                interactive=False
            )
        
        # Advanced parameters
        with gr.Accordion("Generation Parameters", open=False):
            with gr.Column(elem_classes="parameter-section"):
                with gr.Row():
                    with gr.Column():
                        temperature = gr.Slider(
                            minimum=TEMPERATURE_MIN, 
                            maximum=TEMPERATURE_MAX,
                            value=TEMPERATURE_DEFAULT,
                            step=0.05,
                            label="Temperature",
                            info="Higher = more creative, Lower = more deterministic"
                        )
                        
                        top_p = gr.Slider(
                            minimum=TOP_P_MIN,
                            maximum=TOP_P_MAX,
                            value=TOP_P_DEFAULT,
                            step=0.05,
                            label="Top-p (nucleus sampling)",
                            info="Controls diversity via cumulative probability"
                        )
                    
                    with gr.Column():
                        top_k = gr.Slider(
                            minimum=TOP_K_MIN,
                            maximum=TOP_K_MAX,
                            value=TOP_K_DEFAULT,
                            step=1,
                            label="Top-k",
                            info="Limits word choice to top k options"
                        )
                        
                        max_new_tokens = gr.Slider(
                            minimum=MAX_NEW_TOKENS_MIN,
                            maximum=MAX_NEW_TOKENS_MAX,
                            value=MAX_NEW_TOKENS_DEFAULT,
                            step=8,
                            label="Maximum New Tokens",
                            info="Controls the length of generated response"
                        )
        
        # Examples
        with gr.Accordion("Example Prompts", open=True):
            with gr.Column(elem_classes="example-section"):
                example_btn = gr.Examples(
                    examples=examples,
                    inputs=prompt,
                    label="Click on an example to try it",
                    examples_per_page=5
                )
        
        # Footer
        gr.HTML(f"""
        <div class="footer">
            <p>Snowflake-G0-Release Demo • Created with Gradio • {datetime.datetime.now().year}</p>
        </div>
        """)
        
        # Set up interactions
        submit_btn.click(
            fn=generate_text,
            inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
            outputs=[response_output, history_state, chat_history_display]
        )
        
        prompt.submit(
            fn=generate_text,
            inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
            outputs=[response_output, history_state, chat_history_display]
        )
        
        clear_btn.click(
            fn=clear_conversation,
            inputs=[],
            outputs=[prompt, history_state, chat_history_display]
        )
        
    return demo

# Load model and tokenizer
print("Loading Snowflake-G0-Release model and tokenizer...")
try:
    model, tokenizer, pipeline = load_model_and_tokenizer()
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {str(e)}")
    # Create a simple error demo if model fails to load
    with gr.Blocks(css=css) as error_demo:
        gr.HTML(f"""
        <div class="header" style="background-color: #ffebee;">
            <h1><span class="snowflake-icon">⚠️</span> Error Loading Model</h1>
            <p>There was a problem loading the Snowflake-G0-Release model: {str(e)}</p>
        </div>
        """)
    demo = error_demo

# Create and launch the demo
demo = create_demo()

# Launch the app
if __name__ == "__main__":
    demo.launch()