FlameF0X commited on
Commit
1467791
·
verified ·
1 Parent(s): a7311db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -231
app.py CHANGED
@@ -1,177 +1,85 @@
1
  import os
2
  import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
5
- from safetensors.torch import load_file # Import safetensors for loading .safetensors models
6
  import datetime
 
 
 
 
 
 
 
 
 
7
 
8
- # Model Constants
9
- MODEL_ID_V1 = "FlameF0X/Snowflake-G0-Release"
10
- MODEL_ID_V2 = "FlameF0X/Snowflake-G0-Release-2"
11
- MODEL_ID_V3 = "FlameF0X/Snowflake-G0-Release-2.5"
12
  MAX_LENGTH = 384
13
- TEMPERATURE_MIN = 0.1
14
- TEMPERATURE_MAX = 2.0
15
  TEMPERATURE_DEFAULT = 0.7
16
- TOP_P_MIN = 0.1
17
- TOP_P_MAX = 1.0
18
  TOP_P_DEFAULT = 0.9
19
- TOP_K_MIN = 1
20
- TOP_K_MAX = 100
21
  TOP_K_DEFAULT = 40
22
- MAX_NEW_TOKENS_MIN = 16
23
- MAX_NEW_TOKENS_MAX = 1024
24
  MAX_NEW_TOKENS_DEFAULT = 256
25
 
26
- # CSS for the app
 
 
 
 
 
 
27
  css = """
28
- .gradio-container {
29
- background-color: #1e1e2f !important;
30
- color: #e0e0e0 !important;
31
- }
32
- .header {
33
- background-color: #2b2b3c;
34
- padding: 20px;
35
- margin-bottom: 20px;
36
- border-radius: 10px;
37
- text-align: center;
38
- }
39
- .header h1 {
40
- color: #66ccff;
41
- margin-bottom: 10px;
42
- }
43
- .snowflake-icon {
44
- font-size: 24px;
45
- margin-right: 10px;
46
- }
47
- .footer {
48
- text-align: center;
49
- margin-top: 20px;
50
- font-size: 0.9em;
51
- color: #999;
52
- }
53
- .parameter-section {
54
- background-color: #2a2a3a;
55
- padding: 15px;
56
- border-radius: 8px;
57
- margin-bottom: 15px;
58
- }
59
- .parameter-section h3 {
60
- margin-top: 0;
61
- color: #66ccff;
62
- }
63
- .example-section {
64
- background-color: #223344;
65
- padding: 15px;
66
- border-radius: 8px;
67
- margin-bottom: 15px;
68
- }
69
- .example-section h3 {
70
- margin-top: 0;
71
- color: #66ffaa;
72
- }
73
- .model-select {
74
- background-color: #2a2a4a;
75
- padding: 10px;
76
- border-radius: 8px;
77
- margin-bottom: 15px;
78
- }
79
  """
80
 
81
- # Global variables for models and tokenizers
82
- model_v1 = None
83
- tokenizer_v1 = None
84
- pipeline_v1 = None
85
- model_v2 = None
86
- tokenizer_v2 = None
87
- pipeline_v2 = None
88
 
89
- # Helper functions to load models
90
- def load_models_and_tokenizers():
91
- global model_v1, tokenizer_v1, pipeline_v1, model_v2, tokenizer_v2, pipeline_v2
92
-
93
- # Load the first model
94
- print(f"Loading model from {MODEL_ID_V1}...")
95
- tokenizer_v1 = AutoTokenizer.from_pretrained(MODEL_ID_V1)
96
- if tokenizer_v1.pad_token is None:
97
- tokenizer_v1.pad_token = tokenizer_v1.eos_token
98
 
99
- model_file_path = os.path.join(MODEL_ID_V1, "model.safetensors")
100
-
101
- if os.path.exists(model_file_path):
102
- print("Loading model from safetensors file...")
103
- model_v1 = load_file(model_file_path)
104
- else:
105
- print("Loading model from .bin file...")
106
- model_v1 = AutoModelForCausalLM.from_pretrained(
107
- MODEL_ID_V1,
108
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
109
- device_map="auto"
110
- )
111
-
112
- pipeline_v1 = TextGenerationPipeline(
113
- model=model_v1,
114
- tokenizer=tokenizer_v1,
115
- return_full_text=False,
116
- max_length=MAX_LENGTH
117
- )
118
-
119
- # Load the second model
120
- print(f"Loading model from {MODEL_ID_V2}...")
121
- tokenizer_v2 = AutoTokenizer.from_pretrained(MODEL_ID_V2)
122
- if tokenizer_v2.pad_token is None:
123
- tokenizer_v2.pad_token = tokenizer_v2.eos_token
124
 
125
- model_file_path = os.path.join(MODEL_ID_V2, "model.safetensors")
126
-
127
- if os.path.exists(model_file_path):
128
- print("Loading model from safetensors file...")
129
- model_v2 = load_file(model_file_path)
130
- else:
131
- print("Loading model from .bin file...")
132
- model_v2 = AutoModelForCausalLM.from_pretrained(
133
- MODEL_ID_V2,
134
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
135
- device_map="auto"
136
  )
137
-
138
- pipeline_v2 = TextGenerationPipeline(
139
- model=model_v2,
140
- tokenizer=tokenizer_v2,
141
- return_full_text=False,
142
- max_length=MAX_LENGTH
143
- )
144
-
145
- return (model_v1, tokenizer_v1, pipeline_v1), (model_v2, tokenizer_v2, pipeline_v2)
146
 
147
- # Helper functions for generation
148
- def generate_text(
149
- prompt,
150
- model_version,
151
- temperature=TEMPERATURE_DEFAULT,
152
- top_p=TOP_P_DEFAULT,
153
- top_k=TOP_K_DEFAULT,
154
- max_new_tokens=MAX_NEW_TOKENS_DEFAULT,
155
- history=None
156
- ):
157
  if history is None:
158
  history = []
159
-
160
- # Add current prompt to history
161
  history.append({"role": "user", "content": prompt})
162
 
163
  try:
164
- # Select the appropriate pipeline based on model version
165
- if model_version == "G0-Release":
166
- pipeline = pipeline_v1
167
- tokenizer = tokenizer_v1
168
- model_name = "Snowflake-G0-Release"
169
- else: # "G0-Release-2"
170
- pipeline = pipeline_v2
171
- tokenizer = tokenizer_v2
172
- model_name = "Snowflake-G0-Release-2"
173
-
174
- # Generate response
175
  outputs = pipeline(
176
  prompt,
177
  do_sample=temperature > 0,
@@ -182,22 +90,15 @@ def generate_text(
182
  pad_token_id=tokenizer.pad_token_id,
183
  num_return_sequences=1
184
  )
185
-
186
  response = outputs[0]["generated_text"]
187
-
188
- # Add model response to history
189
- history.append({"role": "assistant", "content": response, "model": model_name})
190
-
191
- # Format chat history for display
192
  formatted_history = []
193
  for entry in history:
194
- if entry["role"] == "user":
195
- role_prefix = "👤 User: "
196
- else:
197
- model_indicator = f"[{entry.get('model', 'Snowflake')}]"
198
- role_prefix = f"❄️ {model_indicator}: "
199
- formatted_history.append(f"{role_prefix}{entry['content']}")
200
-
201
  return response, history, "\n\n".join(formatted_history)
202
 
203
  except Exception as e:
@@ -208,19 +109,6 @@ def generate_text(
208
  def clear_conversation():
209
  return "", [], ""
210
 
211
- def apply_preset_example(example, history):
212
- return example, history
213
-
214
- # Example prompts
215
- examples = [
216
- "Write a short story about a snowflake that comes to life.",
217
- "Explain the concept of artificial neural networks to a 10-year-old.",
218
- "What are some interesting applications of natural language processing?",
219
- "Write a haiku about programming.",
220
- "Create a dialogue between two AI researchers discussing the future of language models."
221
- ]
222
-
223
- # Main function
224
  def create_demo():
225
  with gr.Blocks(css=css) as demo:
226
  # Header
@@ -231,14 +119,12 @@ def create_demo():
231
  </div>
232
  """)
233
 
234
- # Chat interface
235
  with gr.Column():
236
- # Model selection
237
  with gr.Row(elem_classes="model-select"):
238
  model_version = gr.Radio(
239
- ["G0-Release", "G0-Release-2"],
 
240
  label="Select Model Version",
241
- value="G0-Release-2",
242
  info="Choose which Snowflake model to use"
243
  )
244
 
@@ -250,10 +136,8 @@ def create_demo():
250
  interactive=False
251
  )
252
 
253
- # Invisible state variables
254
  history_state = gr.State([])
255
 
256
- # Input and output
257
  with gr.Row():
258
  with gr.Column(scale=4):
259
  prompt = gr.Textbox(
@@ -264,7 +148,7 @@ def create_demo():
264
  with gr.Column(scale=1):
265
  submit_btn = gr.Button("Send", variant="primary")
266
  clear_btn = gr.Button("Clear Conversation")
267
-
268
  response_output = gr.Textbox(
269
  value="",
270
  label="Model Response",
@@ -273,106 +157,92 @@ def create_demo():
273
  interactive=False
274
  )
275
 
276
- # Advanced parameters
277
  with gr.Accordion("Generation Parameters", open=False):
278
  with gr.Column(elem_classes="parameter-section"):
279
  with gr.Row():
280
  with gr.Column():
281
  temperature = gr.Slider(
282
- minimum=TEMPERATURE_MIN,
283
- maximum=TEMPERATURE_MAX,
284
- value=TEMPERATURE_DEFAULT,
285
- step=0.05,
286
- label="Temperature",
287
- info="Higher = more creative, Lower = more deterministic"
288
  )
289
-
290
  top_p = gr.Slider(
291
- minimum=TOP_P_MIN,
292
- maximum=TOP_P_MAX,
293
- value=TOP_P_DEFAULT,
294
- step=0.05,
295
- label="Top-p (nucleus sampling)",
296
- info="Controls diversity via cumulative probability"
297
  )
298
-
299
  with gr.Column():
300
  top_k = gr.Slider(
301
- minimum=TOP_K_MIN,
302
- maximum=TOP_K_MAX,
303
- value=TOP_K_DEFAULT,
304
- step=1,
305
- label="Top-k",
306
- info="Limits word choice to top k options"
307
  )
308
-
309
  max_new_tokens = gr.Slider(
310
- minimum=MAX_NEW_TOKENS_MIN,
311
- maximum=MAX_NEW_TOKENS_MAX,
312
- value=MAX_NEW_TOKENS_DEFAULT,
313
- step=8,
314
- label="Maximum New Tokens",
315
- info="Controls the length of generated response"
316
  )
317
 
318
- # Examples
 
 
 
 
 
 
 
 
319
  with gr.Accordion("Example Prompts", open=True):
320
  with gr.Column(elem_classes="example-section"):
321
- example_btn = gr.Examples(
322
  examples=examples,
323
  inputs=prompt,
324
  label="Click on an example to try it",
325
  examples_per_page=5
326
  )
327
 
328
- # Footer
329
  gr.HTML(f"""
330
  <div class="footer">
331
  <p>Snowflake Models Demo • Created with Gradio • {datetime.datetime.now().year}</p>
332
  </div>
333
  """)
334
 
335
- # Set up interactions
336
  submit_btn.click(
337
  fn=generate_text,
338
  inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state],
339
  outputs=[response_output, history_state, chat_history_display]
340
  )
341
-
342
  prompt.submit(
343
  fn=generate_text,
344
  inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state],
345
  outputs=[response_output, history_state, chat_history_display]
346
  )
347
-
348
  clear_btn.click(
349
  fn=clear_conversation,
350
  inputs=[],
351
  outputs=[prompt, history_state, chat_history_display]
352
  )
353
-
354
  return demo
355
 
356
- # Load models and tokenizers
357
- print("Loading Snowflake models and tokenizers...")
358
  try:
359
- (model_v1, tokenizer_v1, pipeline_v1), (model_v2, tokenizer_v2, pipeline_v2) = load_models_and_tokenizers()
360
- print("Models loaded successfully!")
 
361
  except Exception as e:
362
- print(f"Error loading models: {str(e)}")
363
- # Create a simple error demo if models fail to load
364
- with gr.Blocks(css=css) as error_demo:
365
  gr.HTML(f"""
366
  <div class="header" style="background-color: #ffebee;">
367
  <h1><span class="snowflake-icon">⚠️</span> Error Loading Models</h1>
368
  <p>There was a problem loading the Snowflake models: {str(e)}</p>
369
  </div>
370
  """)
371
- demo = error_demo
372
-
373
- # Create and launch the demo
374
- demo = create_demo()
375
 
376
- # Launch the app
377
  if __name__ == "__main__":
378
- demo.launch()
 
1
  import os
2
  import torch
3
  import gradio as gr
 
 
4
  import datetime
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
6
+ from safetensors.torch import load_file
7
+
8
+ # Constants
9
+ MODEL_CONFIG = {
10
+ "G0-Release": "FlameF0X/Snowflake-G0-Release",
11
+ "G0-Release-2": "FlameF0X/Snowflake-G0-Release-2",
12
+ "G0-Release-2.5": "FlameF0X/Snowflake-G0-Release-2.5"
13
+ }
14
 
 
 
 
 
15
  MAX_LENGTH = 384
 
 
16
  TEMPERATURE_DEFAULT = 0.7
 
 
17
  TOP_P_DEFAULT = 0.9
 
 
18
  TOP_K_DEFAULT = 40
 
 
19
  MAX_NEW_TOKENS_DEFAULT = 256
20
 
21
+ # UI parameter bounds
22
+ TEMPERATURE_MIN, TEMPERATURE_MAX = 0.1, 2.0
23
+ TOP_P_MIN, TOP_P_MAX = 0.1, 1.0
24
+ TOP_K_MIN, TOP_K_MAX = 1, 100
25
+ MAX_NEW_TOKENS_MIN, MAX_NEW_TOKENS_MAX = 16, 1024
26
+
27
+ # Styling
28
  css = """
29
+ .gradio-container { background-color: #1e1e2f !important; color: #e0e0e0 !important; }
30
+ .header { background-color: #2b2b3c; padding: 20px; margin-bottom: 20px; border-radius: 10px; text-align: center; }
31
+ .header h1 { color: #66ccff; margin-bottom: 10px; }
32
+ .snowflake-icon { font-size: 24px; margin-right: 10px; }
33
+ .footer { text-align: center; margin-top: 20px; font-size: 0.9em; color: #999; }
34
+ .parameter-section { background-color: #2a2a3a; padding: 15px; border-radius: 8px; margin-bottom: 15px; }
35
+ .parameter-section h3 { margin-top: 0; color: #66ccff; }
36
+ .example-section { background-color: #223344; padding: 15px; border-radius: 8px; margin-bottom: 15px; }
37
+ .example-section h3 { margin-top: 0; color: #66ffaa; }
38
+ .model-select { background-color: #2a2a4a; padding: 10px; border-radius: 8px; margin-bottom: 15px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  """
40
 
41
+ # Model registry
42
+ model_registry = {}
 
 
 
 
 
43
 
44
+ def load_all_models():
45
+ for name, model_id in MODEL_CONFIG.items():
46
+ print(f"Loading model: {name} from {model_id}")
47
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
48
+ if tokenizer.pad_token is None:
49
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
50
 
51
+ safetensor_path = os.path.join(model_id, "model.safetensors")
52
+ if os.path.exists(safetensor_path):
53
+ print("Loading from safetensors...")
54
+ model = load_file(safetensor_path)
55
+ else:
56
+ print("Loading from Hugging Face or .bin...")
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_id,
59
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
60
+ device_map="auto"
61
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ pipeline = TextGenerationPipeline(
64
+ model=model,
65
+ tokenizer=tokenizer,
66
+ return_full_text=False,
67
+ max_length=MAX_LENGTH
 
 
 
 
 
 
68
  )
 
 
 
 
 
 
 
 
 
69
 
70
+ model_registry[name] = (model, tokenizer, pipeline)
71
+
72
+ def generate_text(prompt, model_version, temperature, top_p, top_k, max_new_tokens, history=None):
 
 
 
 
 
 
 
73
  if history is None:
74
  history = []
 
 
75
  history.append({"role": "user", "content": prompt})
76
 
77
  try:
78
+ if model_version not in model_registry:
79
+ raise ValueError(f"Model '{model_version}' not found.")
80
+
81
+ _, tokenizer, pipeline = model_registry[model_version]
82
+
 
 
 
 
 
 
83
  outputs = pipeline(
84
  prompt,
85
  do_sample=temperature > 0,
 
90
  pad_token_id=tokenizer.pad_token_id,
91
  num_return_sequences=1
92
  )
93
+
94
  response = outputs[0]["generated_text"]
95
+ history.append({"role": "assistant", "content": response, "model": model_version})
96
+
 
 
 
97
  formatted_history = []
98
  for entry in history:
99
+ prefix = "👤 User: " if entry["role"] == "user" else f"❄️ [{entry.get('model', 'Model')}]: "
100
+ formatted_history.append(f"{prefix}{entry['content']}")
101
+
 
 
 
 
102
  return response, history, "\n\n".join(formatted_history)
103
 
104
  except Exception as e:
 
109
  def clear_conversation():
110
  return "", [], ""
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def create_demo():
113
  with gr.Blocks(css=css) as demo:
114
  # Header
 
119
  </div>
120
  """)
121
 
 
122
  with gr.Column():
 
123
  with gr.Row(elem_classes="model-select"):
124
  model_version = gr.Radio(
125
+ choices=list(MODEL_CONFIG.keys()),
126
+ value=list(MODEL_CONFIG.keys())[0],
127
  label="Select Model Version",
 
128
  info="Choose which Snowflake model to use"
129
  )
130
 
 
136
  interactive=False
137
  )
138
 
 
139
  history_state = gr.State([])
140
 
 
141
  with gr.Row():
142
  with gr.Column(scale=4):
143
  prompt = gr.Textbox(
 
148
  with gr.Column(scale=1):
149
  submit_btn = gr.Button("Send", variant="primary")
150
  clear_btn = gr.Button("Clear Conversation")
151
+
152
  response_output = gr.Textbox(
153
  value="",
154
  label="Model Response",
 
157
  interactive=False
158
  )
159
 
160
+ # Generation Parameters
161
  with gr.Accordion("Generation Parameters", open=False):
162
  with gr.Column(elem_classes="parameter-section"):
163
  with gr.Row():
164
  with gr.Column():
165
  temperature = gr.Slider(
166
+ minimum=TEMPERATURE_MIN, maximum=TEMPERATURE_MAX,
167
+ value=TEMPERATURE_DEFAULT, step=0.05,
168
+ label="Temperature"
 
 
 
169
  )
 
170
  top_p = gr.Slider(
171
+ minimum=TOP_P_MIN, maximum=TOP_P_MAX,
172
+ value=TOP_P_DEFAULT, step=0.05,
173
+ label="Top-p"
 
 
 
174
  )
 
175
  with gr.Column():
176
  top_k = gr.Slider(
177
+ minimum=TOP_K_MIN, maximum=TOP_K_MAX,
178
+ value=TOP_K_DEFAULT, step=1,
179
+ label="Top-k"
 
 
 
180
  )
 
181
  max_new_tokens = gr.Slider(
182
+ minimum=MAX_NEW_TOKENS_MIN, maximum=MAX_NEW_TOKENS_MAX,
183
+ value=MAX_NEW_TOKENS_DEFAULT, step=8,
184
+ label="Maximum New Tokens"
 
 
 
185
  )
186
 
187
+ # Example prompts
188
+ examples = [
189
+ "Write a short story about a snowflake that comes to life.",
190
+ "Explain the concept of artificial neural networks to a 10-year-old.",
191
+ "What are some interesting applications of natural language processing?",
192
+ "Write a haiku about programming.",
193
+ "Create a dialogue between two AI researchers discussing the future of language models."
194
+ ]
195
+
196
  with gr.Accordion("Example Prompts", open=True):
197
  with gr.Column(elem_classes="example-section"):
198
+ gr.Examples(
199
  examples=examples,
200
  inputs=prompt,
201
  label="Click on an example to try it",
202
  examples_per_page=5
203
  )
204
 
 
205
  gr.HTML(f"""
206
  <div class="footer">
207
  <p>Snowflake Models Demo • Created with Gradio • {datetime.datetime.now().year}</p>
208
  </div>
209
  """)
210
 
211
+ # Interactions
212
  submit_btn.click(
213
  fn=generate_text,
214
  inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state],
215
  outputs=[response_output, history_state, chat_history_display]
216
  )
 
217
  prompt.submit(
218
  fn=generate_text,
219
  inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state],
220
  outputs=[response_output, history_state, chat_history_display]
221
  )
 
222
  clear_btn.click(
223
  fn=clear_conversation,
224
  inputs=[],
225
  outputs=[prompt, history_state, chat_history_display]
226
  )
227
+
228
  return demo
229
 
230
+ # Initialize
231
+ print("Loading Snowflake models...")
232
  try:
233
+ load_all_models()
234
+ print("All models loaded successfully!")
235
+ demo = create_demo()
236
  except Exception as e:
237
+ print(f"Error loading models: {e}")
238
+ with gr.Blocks(css=css) as demo:
 
239
  gr.HTML(f"""
240
  <div class="header" style="background-color: #ffebee;">
241
  <h1><span class="snowflake-icon">⚠️</span> Error Loading Models</h1>
242
  <p>There was a problem loading the Snowflake models: {str(e)}</p>
243
  </div>
244
  """)
 
 
 
 
245
 
246
+ # Run app
247
  if __name__ == "__main__":
248
+ demo.launch()