FlameF0X commited on
Commit
ea1d17e
·
verified ·
1 Parent(s): e09ff9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -62
app.py CHANGED
@@ -6,7 +6,8 @@ from safetensors.torch import load_file # Import safetensors for loading .safet
6
  import datetime
7
 
8
  # Model Constants
9
- MODEL_ID = "FlameF0X/Snowflake-G0-Release" # HF repo when published
 
10
  MAX_LENGTH = 384
11
  TEMPERATURE_MIN = 0.1
12
  TEMPERATURE_MAX = 2.0
@@ -68,46 +69,84 @@ css = """
68
  margin-top: 0;
69
  color: #66ffaa;
70
  }
 
 
 
 
 
 
71
  """
72
 
73
- # Helper functions to load model
74
- def load_model_and_tokenizer():
75
- global model, tokenizer, pipeline # Add this line
 
 
 
 
76
 
77
- # Load the tokenizer
78
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
79
 
80
- # Check if the pad_token is None, set it to eos_token if needed
81
- if tokenizer.pad_token is None:
82
- tokenizer.pad_token = tokenizer.eos_token
 
 
83
 
84
- # Check if the model uses safetensors or pytorch .bin model file
85
- model_file_path = os.path.join(MODEL_ID, "model.safetensors") # or model.bin if that's the case
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
 
87
  if os.path.exists(model_file_path):
88
- # Check if safetensors file exists
89
  print("Loading model from safetensors file...")
90
- model = load_file(model_file_path) # Safetensors loading
91
  else:
92
- # Load from standard .bin file
93
  print("Loading model from .bin file...")
94
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID,
95
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
96
- device_map="auto")
 
 
97
 
98
- # Initialize the generation pipeline
99
- pipeline = TextGenerationPipeline(
100
- model=model,
101
- tokenizer=tokenizer,
102
  return_full_text=False,
103
  max_length=MAX_LENGTH
104
  )
105
 
106
- return model, tokenizer, pipeline
107
 
108
  # Helper functions for generation
109
  def generate_text(
110
- prompt,
 
111
  temperature=TEMPERATURE_DEFAULT,
112
  top_p=TOP_P_DEFAULT,
113
  top_k=TOP_K_DEFAULT,
@@ -121,6 +160,16 @@ def generate_text(
121
  history.append({"role": "user", "content": prompt})
122
 
123
  try:
 
 
 
 
 
 
 
 
 
 
124
  # Generate response
125
  outputs = pipeline(
126
  prompt,
@@ -136,19 +185,23 @@ def generate_text(
136
  response = outputs[0]["generated_text"]
137
 
138
  # Add model response to history
139
- history.append({"role": "assistant", "content": response})
140
 
141
  # Format chat history for display
142
  formatted_history = []
143
  for entry in history:
144
- role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: "
 
 
 
 
145
  formatted_history.append(f"{role_prefix}{entry['content']}")
146
 
147
  return response, history, "\n\n".join(formatted_history)
148
 
149
  except Exception as e:
150
  error_msg = f"Error generating response: {str(e)}"
151
- history.append({"role": "assistant", "content": f"[ERROR] {error_msg}"})
152
  return error_msg, history, str(history)
153
 
154
  def clear_conversation():
@@ -172,36 +225,22 @@ def create_demo():
172
  # Header
173
  gr.HTML("""
174
  <div class="header">
175
- <h1><span class="snowflake-icon">❄️</span> Snowflake-G0-Release Demo</h1>
176
- <p>Experience the capabilities of the Snowflake-G0-Release language model</p>
177
  </div>
178
  """)
179
 
180
- # Model info
181
- with gr.Accordion("About Snowflake-G0-Release", open=False):
182
- gr.Markdown("""
183
- ## Snowflake-G0-Release
184
-
185
- This is the initial release of the Snowflake series language models, trained on the DialogMLM-50K dataset with optimized memory usage.
186
-
187
- ### Model details
188
- - Architecture: SnowflakeCore
189
- - Hidden size: 384
190
- - Number of attention heads: 6
191
- - Number of layers: 4
192
- - Feed-forward dimension: 768
193
- - Maximum sequence length: 384
194
- - Vocabulary size: 30522 (BERT tokenizer)
195
-
196
- ### Key Features
197
- - Efficient memory usage
198
- - Fused QKV projection for faster inference
199
- - Pre-norm architecture for stable training
200
- - Compatible with HuggingFace Transformers
201
- """)
202
-
203
  # Chat interface
204
  with gr.Column():
 
 
 
 
 
 
 
 
 
205
  chat_history_display = gr.Textbox(
206
  value="",
207
  label="Conversation History",
@@ -288,20 +327,20 @@ def create_demo():
288
  # Footer
289
  gr.HTML(f"""
290
  <div class="footer">
291
- <p>Snowflake-G0-Release Demo • Created with Gradio • {datetime.datetime.now().year}</p>
292
  </div>
293
  """)
294
 
295
  # Set up interactions
296
  submit_btn.click(
297
  fn=generate_text,
298
- inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
299
  outputs=[response_output, history_state, chat_history_display]
300
  )
301
 
302
  prompt.submit(
303
  fn=generate_text,
304
- inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
305
  outputs=[response_output, history_state, chat_history_display]
306
  )
307
 
@@ -313,19 +352,19 @@ def create_demo():
313
 
314
  return demo
315
 
316
- # Load model and tokenizer
317
- print("Loading Snowflake-G0-Release model and tokenizer...")
318
  try:
319
- model, tokenizer, pipeline = load_model_and_tokenizer()
320
- print("Model loaded successfully!")
321
  except Exception as e:
322
- print(f"Error loading model: {str(e)}")
323
- # Create a simple error demo if model fails to load
324
  with gr.Blocks(css=css) as error_demo:
325
  gr.HTML(f"""
326
  <div class="header" style="background-color: #ffebee;">
327
- <h1><span class="snowflake-icon">⚠️</span> Error Loading Model</h1>
328
- <p>There was a problem loading the Snowflake-G0-Release model: {str(e)}</p>
329
  </div>
330
  """)
331
  demo = error_demo
 
6
  import datetime
7
 
8
  # Model Constants
9
+ MODEL_ID_V1 = "FlameF0X/Snowflake-G0-Release"
10
+ MODEL_ID_V2 = "FlameF0X/Snowflake-G0-Release-2"
11
  MAX_LENGTH = 384
12
  TEMPERATURE_MIN = 0.1
13
  TEMPERATURE_MAX = 2.0
 
69
  margin-top: 0;
70
  color: #66ffaa;
71
  }
72
+ .model-select {
73
+ background-color: #2a2a4a;
74
+ padding: 10px;
75
+ border-radius: 8px;
76
+ margin-bottom: 15px;
77
+ }
78
  """
79
 
80
+ # Global variables for models and tokenizers
81
+ model_v1 = None
82
+ tokenizer_v1 = None
83
+ pipeline_v1 = None
84
+ model_v2 = None
85
+ tokenizer_v2 = None
86
+ pipeline_v2 = None
87
 
88
+ # Helper functions to load models
89
+ def load_models_and_tokenizers():
90
+ global model_v1, tokenizer_v1, pipeline_v1, model_v2, tokenizer_v2, pipeline_v2
91
 
92
+ # Load the first model
93
+ print(f"Loading model from {MODEL_ID_V1}...")
94
+ tokenizer_v1 = AutoTokenizer.from_pretrained(MODEL_ID_V1)
95
+ if tokenizer_v1.pad_token is None:
96
+ tokenizer_v1.pad_token = tokenizer_v1.eos_token
97
 
98
+ model_file_path = os.path.join(MODEL_ID_V1, "model.safetensors")
99
+
100
+ if os.path.exists(model_file_path):
101
+ print("Loading model from safetensors file...")
102
+ model_v1 = load_file(model_file_path)
103
+ else:
104
+ print("Loading model from .bin file...")
105
+ model_v1 = AutoModelForCausalLM.from_pretrained(
106
+ MODEL_ID_V1,
107
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
108
+ device_map="auto"
109
+ )
110
+
111
+ pipeline_v1 = TextGenerationPipeline(
112
+ model=model_v1,
113
+ tokenizer=tokenizer_v1,
114
+ return_full_text=False,
115
+ max_length=MAX_LENGTH
116
+ )
117
+
118
+ # Load the second model
119
+ print(f"Loading model from {MODEL_ID_V2}...")
120
+ tokenizer_v2 = AutoTokenizer.from_pretrained(MODEL_ID_V2)
121
+ if tokenizer_v2.pad_token is None:
122
+ tokenizer_v2.pad_token = tokenizer_v2.eos_token
123
 
124
+ model_file_path = os.path.join(MODEL_ID_V2, "model.safetensors")
125
+
126
  if os.path.exists(model_file_path):
 
127
  print("Loading model from safetensors file...")
128
+ model_v2 = load_file(model_file_path)
129
  else:
 
130
  print("Loading model from .bin file...")
131
+ model_v2 = AutoModelForCausalLM.from_pretrained(
132
+ MODEL_ID_V2,
133
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
134
+ device_map="auto"
135
+ )
136
 
137
+ pipeline_v2 = TextGenerationPipeline(
138
+ model=model_v2,
139
+ tokenizer=tokenizer_v2,
 
140
  return_full_text=False,
141
  max_length=MAX_LENGTH
142
  )
143
 
144
+ return (model_v1, tokenizer_v1, pipeline_v1), (model_v2, tokenizer_v2, pipeline_v2)
145
 
146
  # Helper functions for generation
147
  def generate_text(
148
+ prompt,
149
+ model_version,
150
  temperature=TEMPERATURE_DEFAULT,
151
  top_p=TOP_P_DEFAULT,
152
  top_k=TOP_K_DEFAULT,
 
160
  history.append({"role": "user", "content": prompt})
161
 
162
  try:
163
+ # Select the appropriate pipeline based on model version
164
+ if model_version == "G0-Release":
165
+ pipeline = pipeline_v1
166
+ tokenizer = tokenizer_v1
167
+ model_name = "Snowflake-G0-Release"
168
+ else: # "G0-Release-2"
169
+ pipeline = pipeline_v2
170
+ tokenizer = tokenizer_v2
171
+ model_name = "Snowflake-G0-Release-2"
172
+
173
  # Generate response
174
  outputs = pipeline(
175
  prompt,
 
185
  response = outputs[0]["generated_text"]
186
 
187
  # Add model response to history
188
+ history.append({"role": "assistant", "content": response, "model": model_name})
189
 
190
  # Format chat history for display
191
  formatted_history = []
192
  for entry in history:
193
+ if entry["role"] == "user":
194
+ role_prefix = "👤 User: "
195
+ else:
196
+ model_indicator = f"[{entry.get('model', 'Snowflake')}]"
197
+ role_prefix = f"❄️ {model_indicator}: "
198
  formatted_history.append(f"{role_prefix}{entry['content']}")
199
 
200
  return response, history, "\n\n".join(formatted_history)
201
 
202
  except Exception as e:
203
  error_msg = f"Error generating response: {str(e)}"
204
+ history.append({"role": "assistant", "content": f"[ERROR] {error_msg}", "model": model_version})
205
  return error_msg, history, str(history)
206
 
207
  def clear_conversation():
 
225
  # Header
226
  gr.HTML("""
227
  <div class="header">
228
+ <h1><span class="snowflake-icon">❄️</span> Snowflake Models Demo</h1>
229
+ <p>Experience the capabilities of the Snowflake series language models</p>
230
  </div>
231
  """)
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  # Chat interface
234
  with gr.Column():
235
+ # Model selection
236
+ with gr.Row(elem_classes="model-select"):
237
+ model_version = gr.Radio(
238
+ ["G0-Release", "G0-Release-2"],
239
+ label="Select Model Version",
240
+ value="G0-Release-2",
241
+ info="Choose which Snowflake model to use"
242
+ )
243
+
244
  chat_history_display = gr.Textbox(
245
  value="",
246
  label="Conversation History",
 
327
  # Footer
328
  gr.HTML(f"""
329
  <div class="footer">
330
+ <p>Snowflake Models Demo • Created with Gradio • {datetime.datetime.now().year}</p>
331
  </div>
332
  """)
333
 
334
  # Set up interactions
335
  submit_btn.click(
336
  fn=generate_text,
337
+ inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state],
338
  outputs=[response_output, history_state, chat_history_display]
339
  )
340
 
341
  prompt.submit(
342
  fn=generate_text,
343
+ inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state],
344
  outputs=[response_output, history_state, chat_history_display]
345
  )
346
 
 
352
 
353
  return demo
354
 
355
+ # Load models and tokenizers
356
+ print("Loading Snowflake models and tokenizers...")
357
  try:
358
+ (model_v1, tokenizer_v1, pipeline_v1), (model_v2, tokenizer_v2, pipeline_v2) = load_models_and_tokenizers()
359
+ print("Models loaded successfully!")
360
  except Exception as e:
361
+ print(f"Error loading models: {str(e)}")
362
+ # Create a simple error demo if models fail to load
363
  with gr.Blocks(css=css) as error_demo:
364
  gr.HTML(f"""
365
  <div class="header" style="background-color: #ffebee;">
366
+ <h1><span class="snowflake-icon">⚠️</span> Error Loading Models</h1>
367
+ <p>There was a problem loading the Snowflake models: {str(e)}</p>
368
  </div>
369
  """)
370
  demo = error_demo