FlameF0X commited on
Commit
3670892
·
verified ·
1 Parent(s): a985b65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +312 -44
app.py CHANGED
@@ -1,56 +1,324 @@
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from modeling_snowflake import Snowflake4CausalLM
4
  import torch
 
 
5
 
6
- # Load tokenizer and model
7
- MODEL_NAME = "FlameF0X/Snowflake-G0-stable"
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- MODEL_NAME,
11
- torch_dtype=torch.float16 # Use half precision for memory efficiency
12
- )
13
- model.eval()
14
- model.to("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
- # --- Inference Function ---
17
- def generate_text(prompt, max_length=50):
18
- """
19
- Generate text based on the input prompt using the trained model.
20
- """
21
- # Tokenize the input prompt
22
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=384)
23
- input_ids = inputs["input_ids"].to(model.device)
24
- attention_mask = inputs["attention_mask"].to(model.device)
25
-
26
- # Generate output tokens
27
- with torch.no_grad():
28
- outputs = model.generate(
29
- input_ids=input_ids,
30
- attention_mask=attention_mask,
31
- max_length=max_length,
32
- pad_token_id=tokenizer.eos_token_id # Use EOS token for padding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Decode the generated tokens
36
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
- return generated_text
38
 
39
- # --- Gradio Interface ---
40
- with gr.Blocks() as demo:
41
- gr.Markdown("# Snowflake-G0-stable")
42
- gr.Markdown("")
43
 
44
- with gr.Row():
45
- input_prompt = gr.Textbox(label="Input Prompt", placeholder="Enter your text here...")
46
- output_text = gr.Textbox(label="Generated Text")
 
 
 
 
 
47
 
48
- submit_button = gr.Button("Generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- def on_submit(prompt):
51
- return generate_text(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- submit_button.click(on_submit, inputs=input_prompt, outputs=output_text)
 
54
 
55
  # Launch the app
56
- demo.launch()
 
 
1
+ import os
2
  import gradio as gr
 
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
5
+ import datetime
6
 
7
+ # Model Constants
8
+ MODEL_ID = "Snowflake-G0-Release" # Replace with actual HF repo when published
9
+ MAX_LENGTH = 384
10
+ TEMPERATURE_MIN = 0.1
11
+ TEMPERATURE_MAX = 2.0
12
+ TEMPERATURE_DEFAULT = 0.7
13
+ TOP_P_MIN = 0.1
14
+ TOP_P_MAX = 1.0
15
+ TOP_P_DEFAULT = 0.9
16
+ TOP_K_MIN = 1
17
+ TOP_K_MAX = 100
18
+ TOP_K_DEFAULT = 40
19
+ MAX_NEW_TOKENS_MIN = 16
20
+ MAX_NEW_TOKENS_MAX = 1024
21
+ MAX_NEW_TOKENS_DEFAULT = 256
22
+
23
+ # CSS for the app
24
+ css = """
25
+ .gradio-container {
26
+ background-color: #f0f8ff !important;
27
+ }
28
+ .header {
29
+ background-color: #e6f2ff;
30
+ padding: 20px;
31
+ margin-bottom: 20px;
32
+ border-radius: 10px;
33
+ text-align: center;
34
+ }
35
+ .header h1 {
36
+ color: #0066cc;
37
+ margin-bottom: 10px;
38
+ }
39
+ .snowflake-icon {
40
+ font-size: 24px;
41
+ margin-right: 10px;
42
+ }
43
+ .footer {
44
+ text-align: center;
45
+ margin-top: 20px;
46
+ font-size: 0.9em;
47
+ color: #666;
48
+ }
49
+ .parameter-section {
50
+ background-color: #e6f7ff;
51
+ padding: 15px;
52
+ border-radius: 8px;
53
+ margin-bottom: 15px;
54
+ }
55
+ .parameter-section h3 {
56
+ margin-top: 0;
57
+ color: #0066cc;
58
+ }
59
+ .example-section {
60
+ background-color: #e6fffa;
61
+ padding: 15px;
62
+ border-radius: 8px;
63
+ margin-bottom: 15px;
64
+ }
65
+ .example-section h3 {
66
+ margin-top: 0;
67
+ color: #00997a;
68
+ }
69
+ """
70
+
71
+ # Helper functions
72
+ def load_model_and_tokenizer():
73
+ # Load tokenizer
74
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
75
+ if tokenizer.pad_token is None:
76
+ tokenizer.pad_token = tokenizer.eos_token
77
+
78
+ # Load model with optimizations
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ MODEL_ID,
81
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
82
+ device_map="auto"
83
+ )
84
+
85
+ # Create pipeline
86
+ pipeline = TextGenerationPipeline(
87
+ model=model,
88
+ tokenizer=tokenizer,
89
+ return_full_text=False,
90
+ max_length=MAX_LENGTH
91
+ )
92
+
93
+ return model, tokenizer, pipeline
94
+
95
+ def generate_text(
96
+ prompt,
97
+ temperature=TEMPERATURE_DEFAULT,
98
+ top_p=TOP_P_DEFAULT,
99
+ top_k=TOP_K_DEFAULT,
100
+ max_new_tokens=MAX_NEW_TOKENS_DEFAULT,
101
+ history=None
102
+ ):
103
+ if history is None:
104
+ history = []
105
+
106
+ # Add current prompt to history
107
+ history.append({"role": "user", "content": prompt})
108
+
109
+ try:
110
+ # Generate response
111
+ outputs = pipeline(
112
+ prompt,
113
+ do_sample=temperature > 0,
114
+ temperature=temperature,
115
+ top_p=top_p,
116
+ top_k=top_k,
117
+ max_new_tokens=max_new_tokens,
118
+ pad_token_id=tokenizer.pad_token_id,
119
+ num_return_sequences=1
120
  )
121
+
122
+ response = outputs[0]["generated_text"]
123
+
124
+ # Add model response to history
125
+ history.append({"role": "assistant", "content": response})
126
+
127
+ # Format chat history for display
128
+ formatted_history = []
129
+ for entry in history:
130
+ role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: "
131
+ formatted_history.append(f"{role_prefix}{entry['content']}")
132
+
133
+ return response, history, "\n\n".join(formatted_history)
134
+
135
+ except Exception as e:
136
+ error_msg = f"Error generating response: {str(e)}"
137
+ history.append({"role": "assistant", "content": f"[ERROR] {error_msg}"})
138
+ return error_msg, history, str(history)
139
 
140
+ def clear_conversation():
141
+ return "", [], ""
 
142
 
143
+ def apply_preset_example(example, history):
144
+ return example, history
 
 
145
 
146
+ # Example prompts
147
+ examples = [
148
+ "Write a short story about a snowflake that comes to life.",
149
+ "Explain the concept of artificial neural networks to a 10-year-old.",
150
+ "What are some interesting applications of natural language processing?",
151
+ "Write a haiku about programming.",
152
+ "Create a dialogue between two AI researchers discussing the future of language models."
153
+ ]
154
 
155
+ # Main function
156
+ def create_demo():
157
+ with gr.Blocks(css=css) as demo:
158
+ # Header
159
+ gr.HTML("""
160
+ <div class="header">
161
+ <h1><span class="snowflake-icon">❄️</span> Snowflake-G0-Release Demo</h1>
162
+ <p>Experience the capabilities of the Snowflake-G0-Release language model</p>
163
+ </div>
164
+ """)
165
+
166
+ # Model info
167
+ with gr.Accordion("About Snowflake-G0-Release", open=False):
168
+ gr.Markdown("""
169
+ ## Snowflake-G0-Release
170
+
171
+ This is the initial release of the Snowflake series language models, trained on the DialogMLM-50K dataset with optimized memory usage.
172
+
173
+ ### Model details
174
+ - Architecture: SnowflakeCore
175
+ - Hidden size: 384
176
+ - Number of attention heads: 6
177
+ - Number of layers: 4
178
+ - Feed-forward dimension: 768
179
+ - Maximum sequence length: 384
180
+ - Vocabulary size: 30522 (BERT tokenizer)
181
+
182
+ ### Key Features
183
+ - Efficient memory usage
184
+ - Fused QKV projection for faster inference
185
+ - Pre-norm architecture for stable training
186
+ - Compatible with HuggingFace Transformers
187
+ """)
188
+
189
+ # Chat interface
190
+ with gr.Column():
191
+ chat_history_display = gr.Textbox(
192
+ value="",
193
+ label="Conversation History",
194
+ lines=10,
195
+ max_lines=30,
196
+ interactive=False
197
+ )
198
+
199
+ # Invisible state variables
200
+ history_state = gr.State([])
201
+
202
+ # Input and output
203
+ with gr.Row():
204
+ with gr.Column(scale=4):
205
+ prompt = gr.Textbox(
206
+ placeholder="Type your message here...",
207
+ label="Your Input",
208
+ lines=2
209
+ )
210
+ with gr.Column(scale=1):
211
+ submit_btn = gr.Button("Send", variant="primary")
212
+ clear_btn = gr.Button("Clear Conversation")
213
+
214
+ response_output = gr.Textbox(
215
+ value="",
216
+ label="Model Response",
217
+ lines=5,
218
+ max_lines=10,
219
+ interactive=False
220
+ )
221
+
222
+ # Advanced parameters
223
+ with gr.Accordion("Generation Parameters", open=False):
224
+ with gr.Column(elem_classes="parameter-section"):
225
+ with gr.Row():
226
+ with gr.Column():
227
+ temperature = gr.Slider(
228
+ minimum=TEMPERATURE_MIN,
229
+ maximum=TEMPERATURE_MAX,
230
+ value=TEMPERATURE_DEFAULT,
231
+ step=0.05,
232
+ label="Temperature",
233
+ info="Higher = more creative, Lower = more deterministic"
234
+ )
235
+
236
+ top_p = gr.Slider(
237
+ minimum=TOP_P_MIN,
238
+ maximum=TOP_P_MAX,
239
+ value=TOP_P_DEFAULT,
240
+ step=0.05,
241
+ label="Top-p (nucleus sampling)",
242
+ info="Controls diversity via cumulative probability"
243
+ )
244
+
245
+ with gr.Column():
246
+ top_k = gr.Slider(
247
+ minimum=TOP_K_MIN,
248
+ maximum=TOP_K_MAX,
249
+ value=TOP_K_DEFAULT,
250
+ step=1,
251
+ label="Top-k",
252
+ info="Limits word choice to top k options"
253
+ )
254
+
255
+ max_new_tokens = gr.Slider(
256
+ minimum=MAX_NEW_TOKENS_MIN,
257
+ maximum=MAX_NEW_TOKENS_MAX,
258
+ value=MAX_NEW_TOKENS_DEFAULT,
259
+ step=8,
260
+ label="Maximum New Tokens",
261
+ info="Controls the length of generated response"
262
+ )
263
+
264
+ # Examples
265
+ with gr.Accordion("Example Prompts", open=True):
266
+ with gr.Column(elem_classes="example-section"):
267
+ example_btn = gr.Examples(
268
+ examples=examples,
269
+ inputs=prompt,
270
+ label="Click on an example to try it",
271
+ examples_per_page=5
272
+ )
273
+
274
+ # Footer
275
+ gr.HTML(f"""
276
+ <div class="footer">
277
+ <p>Snowflake-G0-Release Demo • Created with Gradio • {datetime.datetime.now().year}</p>
278
+ </div>
279
+ """)
280
+
281
+ # Set up interactions
282
+ submit_btn.click(
283
+ fn=generate_text,
284
+ inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
285
+ outputs=[response_output, history_state, chat_history_display]
286
+ )
287
+
288
+ prompt.submit(
289
+ fn=generate_text,
290
+ inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
291
+ outputs=[response_output, history_state, chat_history_display]
292
+ )
293
+
294
+ clear_btn.click(
295
+ fn=clear_conversation,
296
+ inputs=[],
297
+ outputs=[prompt, history_state, chat_history_display]
298
+ )
299
+
300
+ return demo
301
 
302
+ # Load model and tokenizer
303
+ print("Loading Snowflake-G0-Release model and tokenizer...")
304
+ try:
305
+ model, tokenizer, pipeline = load_model_and_tokenizer()
306
+ print("Model loaded successfully!")
307
+ except Exception as e:
308
+ print(f"Error loading model: {str(e)}")
309
+ # Create a simple error demo if model fails to load
310
+ with gr.Blocks(css=css) as error_demo:
311
+ gr.HTML(f"""
312
+ <div class="header" style="background-color: #ffebee;">
313
+ <h1><span class="snowflake-icon">⚠️</span> Error Loading Model</h1>
314
+ <p>There was a problem loading the Snowflake-G0-Release model: {str(e)}</p>
315
+ </div>
316
+ """)
317
+ demo = error_demo
318
 
319
+ # Create and launch the demo
320
+ demo = create_demo()
321
 
322
  # Launch the app
323
+ if __name__ == "__main__":
324
+ demo.launch()