FlameF0X commited on
Commit
b7255ad
·
verified ·
1 Parent(s): 1337e63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -36
app.py CHANGED
@@ -2,19 +2,9 @@ import os
2
  import torch
3
  import gradio as gr
4
  import datetime
5
- from spaces import GPU
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
7
  from safetensors.torch import load_file
8
 
9
- import spaces
10
-
11
- @spaces.GPU
12
- def use_gpu():
13
- import torch
14
- print("Torch CUDA available:", torch.cuda.is_available())
15
- return {"cuda_available": torch.cuda.is_available()}
16
-
17
-
18
  # Constants
19
  MODEL_CONFIG = {
20
  "G0-Release": "FlameF0X/Snowflake-G0-Release",
@@ -28,13 +18,11 @@ TOP_P_DEFAULT = 0.9
28
  TOP_K_DEFAULT = 40
29
  MAX_NEW_TOKENS_DEFAULT = 256
30
 
31
- # UI parameter bounds
32
  TEMPERATURE_MIN, TEMPERATURE_MAX = 0.1, 2.0
33
  TOP_P_MIN, TOP_P_MAX = 0.1, 1.0
34
  TOP_K_MIN, TOP_K_MAX = 1, 100
35
  MAX_NEW_TOKENS_MIN, MAX_NEW_TOKENS_MAX = 16, 1024
36
 
37
- # Styling
38
  css = """
39
  .gradio-container { background-color: #1e1e2f !important; color: #e0e0e0 !important; }
40
  .header { background-color: #2b2b3c; padding: 20px; margin-bottom: 20px; border-radius: 10px; text-align: center; }
@@ -48,7 +36,6 @@ css = """
48
  .model-select { background-color: #2a2a4a; padding: 10px; border-radius: 8px; margin-bottom: 15px; }
49
  """
50
 
51
- # Model registry
52
  model_registry = {}
53
 
54
  def load_all_models():
@@ -64,10 +51,11 @@ def load_all_models():
64
  model = load_file(safetensor_path)
65
  else:
66
  print("Loading from Hugging Face or .bin...")
 
67
  model = AutoModelForCausalLM.from_pretrained(
68
  model_id,
69
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
70
- device_map="auto"
71
  )
72
 
73
  pipeline = TextGenerationPipeline(
@@ -83,7 +71,7 @@ def generate_text(prompt, model_version, temperature, top_p, top_k, max_new_toke
83
  if history is None:
84
  history = []
85
  history.append({"role": "user", "content": prompt})
86
-
87
  try:
88
  if model_version not in model_registry:
89
  raise ValueError(f"Model '{model_version}' not found.")
@@ -110,7 +98,7 @@ def generate_text(prompt, model_version, temperature, top_p, top_k, max_new_toke
110
  formatted_history.append(f"{prefix}{entry['content']}")
111
 
112
  return response, history, "\n\n".join(formatted_history)
113
-
114
  except Exception as e:
115
  error_msg = f"Error generating response: {str(e)}"
116
  history.append({"role": "assistant", "content": f"[ERROR] {error_msg}", "model": model_version})
@@ -121,14 +109,13 @@ def clear_conversation():
121
 
122
  def create_demo():
123
  with gr.Blocks(css=css) as demo:
124
- # Header
125
  gr.HTML("""
126
  <div class="header">
127
  <h1><span class="snowflake-icon">❄️</span> Snowflake Models Demo</h1>
128
  <p>Experience the capabilities of the Snowflake series language models</p>
129
  </div>
130
  """)
131
-
132
  with gr.Column():
133
  with gr.Row(elem_classes="model-select"):
134
  model_version = gr.Radio(
@@ -137,21 +124,21 @@ def create_demo():
137
  label="Select Model Version",
138
  info="Choose which Snowflake model to use"
139
  )
140
-
141
  chat_history_display = gr.Textbox(
142
- value="",
143
- label="Conversation History",
144
- lines=10,
145
  max_lines=30,
146
  interactive=False
147
  )
148
-
149
  history_state = gr.State([])
150
-
151
  with gr.Row():
152
  with gr.Column(scale=4):
153
  prompt = gr.Textbox(
154
- placeholder="Type your message here...",
155
  label="Your Input",
156
  lines=2
157
  )
@@ -160,14 +147,13 @@ def create_demo():
160
  clear_btn = gr.Button("Clear Conversation")
161
 
162
  response_output = gr.Textbox(
163
- value="",
164
- label="Model Response",
165
  lines=5,
166
  max_lines=10,
167
  interactive=False
168
  )
169
-
170
- # Generation Parameters
171
  with gr.Accordion("Generation Parameters", open=False):
172
  with gr.Column(elem_classes="parameter-section"):
173
  with gr.Row():
@@ -193,8 +179,7 @@ def create_demo():
193
  value=MAX_NEW_TOKENS_DEFAULT, step=8,
194
  label="Maximum New Tokens"
195
  )
196
-
197
- # Example prompts
198
  examples = [
199
  "Write a short story about a snowflake that comes to life.",
200
  "Explain the concept of artificial neural networks to a 10-year-old.",
@@ -211,14 +196,13 @@ def create_demo():
211
  label="Click on an example to try it",
212
  examples_per_page=5
213
  )
214
-
215
  gr.HTML(f"""
216
  <div class="footer">
217
  <p>Snowflake Models Demo • Created with Gradio • {datetime.datetime.now().year}</p>
218
  </div>
219
  """)
220
-
221
- # Interactions
222
  submit_btn.click(
223
  fn=generate_text,
224
  inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state],
@@ -253,6 +237,5 @@ except Exception as e:
253
  </div>
254
  """)
255
 
256
- # Run app
257
  if __name__ == "__main__":
258
  demo.launch()
 
2
  import torch
3
  import gradio as gr
4
  import datetime
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
6
  from safetensors.torch import load_file
7
 
 
 
 
 
 
 
 
 
 
8
  # Constants
9
  MODEL_CONFIG = {
10
  "G0-Release": "FlameF0X/Snowflake-G0-Release",
 
18
  TOP_K_DEFAULT = 40
19
  MAX_NEW_TOKENS_DEFAULT = 256
20
 
 
21
  TEMPERATURE_MIN, TEMPERATURE_MAX = 0.1, 2.0
22
  TOP_P_MIN, TOP_P_MAX = 0.1, 1.0
23
  TOP_K_MIN, TOP_K_MAX = 1, 100
24
  MAX_NEW_TOKENS_MIN, MAX_NEW_TOKENS_MAX = 16, 1024
25
 
 
26
  css = """
27
  .gradio-container { background-color: #1e1e2f !important; color: #e0e0e0 !important; }
28
  .header { background-color: #2b2b3c; padding: 20px; margin-bottom: 20px; border-radius: 10px; text-align: center; }
 
36
  .model-select { background-color: #2a2a4a; padding: 10px; border-radius: 8px; margin-bottom: 15px; }
37
  """
38
 
 
39
  model_registry = {}
40
 
41
  def load_all_models():
 
51
  model = load_file(safetensor_path)
52
  else:
53
  print("Loading from Hugging Face or .bin...")
54
+ # Key fix: no device_map, load on CPU only
55
  model = AutoModelForCausalLM.from_pretrained(
56
  model_id,
57
+ torch_dtype=torch.float32,
58
+ device_map=None
59
  )
60
 
61
  pipeline = TextGenerationPipeline(
 
71
  if history is None:
72
  history = []
73
  history.append({"role": "user", "content": prompt})
74
+
75
  try:
76
  if model_version not in model_registry:
77
  raise ValueError(f"Model '{model_version}' not found.")
 
98
  formatted_history.append(f"{prefix}{entry['content']}")
99
 
100
  return response, history, "\n\n".join(formatted_history)
101
+
102
  except Exception as e:
103
  error_msg = f"Error generating response: {str(e)}"
104
  history.append({"role": "assistant", "content": f"[ERROR] {error_msg}", "model": model_version})
 
109
 
110
  def create_demo():
111
  with gr.Blocks(css=css) as demo:
 
112
  gr.HTML("""
113
  <div class="header">
114
  <h1><span class="snowflake-icon">❄️</span> Snowflake Models Demo</h1>
115
  <p>Experience the capabilities of the Snowflake series language models</p>
116
  </div>
117
  """)
118
+
119
  with gr.Column():
120
  with gr.Row(elem_classes="model-select"):
121
  model_version = gr.Radio(
 
124
  label="Select Model Version",
125
  info="Choose which Snowflake model to use"
126
  )
127
+
128
  chat_history_display = gr.Textbox(
129
+ value="",
130
+ label="Conversation History",
131
+ lines=10,
132
  max_lines=30,
133
  interactive=False
134
  )
135
+
136
  history_state = gr.State([])
137
+
138
  with gr.Row():
139
  with gr.Column(scale=4):
140
  prompt = gr.Textbox(
141
+ placeholder="Type your message here...",
142
  label="Your Input",
143
  lines=2
144
  )
 
147
  clear_btn = gr.Button("Clear Conversation")
148
 
149
  response_output = gr.Textbox(
150
+ value="",
151
+ label="Model Response",
152
  lines=5,
153
  max_lines=10,
154
  interactive=False
155
  )
156
+
 
157
  with gr.Accordion("Generation Parameters", open=False):
158
  with gr.Column(elem_classes="parameter-section"):
159
  with gr.Row():
 
179
  value=MAX_NEW_TOKENS_DEFAULT, step=8,
180
  label="Maximum New Tokens"
181
  )
182
+
 
183
  examples = [
184
  "Write a short story about a snowflake that comes to life.",
185
  "Explain the concept of artificial neural networks to a 10-year-old.",
 
196
  label="Click on an example to try it",
197
  examples_per_page=5
198
  )
199
+
200
  gr.HTML(f"""
201
  <div class="footer">
202
  <p>Snowflake Models Demo • Created with Gradio • {datetime.datetime.now().year}</p>
203
  </div>
204
  """)
205
+
 
206
  submit_btn.click(
207
  fn=generate_text,
208
  inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state],
 
237
  </div>
238
  """)
239
 
 
240
  if __name__ == "__main__":
241
  demo.launch()