Update app.py
Browse files
app.py
CHANGED
@@ -1,24 +1,11 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# Hugging Face Space app.py for Snowflake-G0-Release demo
|
3 |
-
|
4 |
import os
|
5 |
-
import sys
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
|
9 |
import datetime
|
10 |
-
import logging
|
11 |
-
|
12 |
-
# Configure logging
|
13 |
-
logging.basicConfig(
|
14 |
-
level=logging.INFO,
|
15 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
16 |
-
handlers=[logging.StreamHandler(sys.stdout)]
|
17 |
-
)
|
18 |
-
logger = logging.getLogger("snowflake-demo")
|
19 |
|
20 |
# Model Constants
|
21 |
-
MODEL_ID =
|
22 |
MAX_LENGTH = 384
|
23 |
TEMPERATURE_MIN = 0.1
|
24 |
TEMPERATURE_MAX = 2.0
|
@@ -30,7 +17,7 @@ TOP_K_MIN = 1
|
|
30 |
TOP_K_MAX = 100
|
31 |
TOP_K_DEFAULT = 40
|
32 |
MAX_NEW_TOKENS_MIN = 16
|
33 |
-
MAX_NEW_TOKENS_MAX =
|
34 |
MAX_NEW_TOKENS_DEFAULT = 256
|
35 |
|
36 |
# CSS for the app
|
@@ -81,44 +68,29 @@ css = """
|
|
81 |
}
|
82 |
"""
|
83 |
|
84 |
-
# Global variables for model, tokenizer, and pipeline
|
85 |
-
model = None
|
86 |
-
tokenizer = None
|
87 |
-
pipeline = None
|
88 |
-
|
89 |
# Helper functions
|
90 |
def load_model_and_tokenizer():
|
91 |
-
global model, tokenizer, pipeline
|
92 |
-
|
93 |
-
logger.info(f"Loading model and tokenizer from: {MODEL_ID}")
|
94 |
-
|
95 |
# Load tokenizer
|
96 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
97 |
if tokenizer.pad_token is None:
|
98 |
tokenizer.pad_token = tokenizer.eos_token
|
99 |
|
100 |
-
# Determine device and precision
|
101 |
-
device_map = "auto"
|
102 |
-
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
103 |
-
logger.info(f"Using device: {device_map}, dtype: {dtype}")
|
104 |
-
|
105 |
# Load model with optimizations
|
106 |
model = AutoModelForCausalLM.from_pretrained(
|
107 |
MODEL_ID,
|
108 |
-
torch_dtype=
|
109 |
-
device_map=
|
110 |
-
low_cpu_mem_usage=True
|
111 |
)
|
112 |
|
113 |
# Create pipeline
|
114 |
pipeline = TextGenerationPipeline(
|
115 |
model=model,
|
116 |
tokenizer=tokenizer,
|
117 |
-
return_full_text=False
|
|
|
118 |
)
|
119 |
|
120 |
-
|
121 |
-
return True
|
122 |
|
123 |
def generate_text(
|
124 |
prompt,
|
@@ -131,16 +103,10 @@ def generate_text(
|
|
131 |
if history is None:
|
132 |
history = []
|
133 |
|
134 |
-
if not prompt.strip():
|
135 |
-
return "", history, "\n\n".join([f"{'👤 User: ' if h['role'] == 'user' else '❄️ Snowflake: '}{h['content']}" for h in history])
|
136 |
-
|
137 |
# Add current prompt to history
|
138 |
history.append({"role": "user", "content": prompt})
|
139 |
|
140 |
try:
|
141 |
-
logger.info(f"Generating text for prompt: {prompt[:50]}...")
|
142 |
-
logger.info(f"Parameters: temp={temperature}, top_p={top_p}, top_k={top_k}, max_tokens={max_new_tokens}")
|
143 |
-
|
144 |
# Generate response
|
145 |
outputs = pipeline(
|
146 |
prompt,
|
@@ -154,7 +120,6 @@ def generate_text(
|
|
154 |
)
|
155 |
|
156 |
response = outputs[0]["generated_text"]
|
157 |
-
logger.info(f"Generated response: {response[:50]}...")
|
158 |
|
159 |
# Add model response to history
|
160 |
history.append({"role": "assistant", "content": response})
|
@@ -165,23 +130,19 @@ def generate_text(
|
|
165 |
role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: "
|
166 |
formatted_history.append(f"{role_prefix}{entry['content']}")
|
167 |
|
168 |
-
return
|
169 |
|
170 |
except Exception as e:
|
171 |
-
logger.error(f"Error generating response: {str(e)}", exc_info=True)
|
172 |
error_msg = f"Error generating response: {str(e)}"
|
173 |
history.append({"role": "assistant", "content": f"[ERROR] {error_msg}"})
|
174 |
-
|
175 |
-
formatted_history = []
|
176 |
-
for entry in history:
|
177 |
-
role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: "
|
178 |
-
formatted_history.append(f"{role_prefix}{entry['content']}")
|
179 |
-
|
180 |
-
return "", history, "\n\n".join(formatted_history)
|
181 |
|
182 |
def clear_conversation():
|
183 |
return "", [], ""
|
184 |
|
|
|
|
|
|
|
185 |
# Example prompts
|
186 |
examples = [
|
187 |
"Write a short story about a snowflake that comes to life.",
|
@@ -193,7 +154,7 @@ examples = [
|
|
193 |
|
194 |
# Main function
|
195 |
def create_demo():
|
196 |
-
with gr.Blocks(css=css
|
197 |
# Header
|
198 |
gr.HTML("""
|
199 |
<div class="header">
|
@@ -229,8 +190,8 @@ def create_demo():
|
|
229 |
with gr.Column():
|
230 |
chat_history_display = gr.Textbox(
|
231 |
value="",
|
232 |
-
label="Conversation",
|
233 |
-
lines=
|
234 |
max_lines=30,
|
235 |
interactive=False
|
236 |
)
|
@@ -247,9 +208,16 @@ def create_demo():
|
|
247 |
lines=2
|
248 |
)
|
249 |
with gr.Column(scale=1):
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
# Advanced parameters
|
255 |
with gr.Accordion("Generation Parameters", open=False):
|
@@ -299,7 +267,8 @@ def create_demo():
|
|
299 |
example_btn = gr.Examples(
|
300 |
examples=examples,
|
301 |
inputs=prompt,
|
302 |
-
label="Click on an example to try it"
|
|
|
303 |
)
|
304 |
|
305 |
# Footer
|
@@ -313,13 +282,13 @@ def create_demo():
|
|
313 |
submit_btn.click(
|
314 |
fn=generate_text,
|
315 |
inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
|
316 |
-
outputs=[
|
317 |
)
|
318 |
|
319 |
prompt.submit(
|
320 |
fn=generate_text,
|
321 |
inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
|
322 |
-
outputs=[
|
323 |
)
|
324 |
|
325 |
clear_btn.click(
|
@@ -330,13 +299,13 @@ def create_demo():
|
|
330 |
|
331 |
return demo
|
332 |
|
333 |
-
# Load model and
|
|
|
334 |
try:
|
335 |
-
|
336 |
-
|
337 |
-
demo = create_demo()
|
338 |
except Exception as e:
|
339 |
-
|
340 |
# Create a simple error demo if model fails to load
|
341 |
with gr.Blocks(css=css) as error_demo:
|
342 |
gr.HTML(f"""
|
@@ -347,6 +316,9 @@ except Exception as e:
|
|
347 |
""")
|
348 |
demo = error_demo
|
349 |
|
|
|
|
|
|
|
350 |
# Launch the app
|
351 |
if __name__ == "__main__":
|
352 |
demo.launch()
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
|
5 |
import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Model Constants
|
8 |
+
MODEL_ID = "FlameF0X/Snowflake-G0-Release" # Replace with actual HF repo when published
|
9 |
MAX_LENGTH = 384
|
10 |
TEMPERATURE_MIN = 0.1
|
11 |
TEMPERATURE_MAX = 2.0
|
|
|
17 |
TOP_K_MAX = 100
|
18 |
TOP_K_DEFAULT = 40
|
19 |
MAX_NEW_TOKENS_MIN = 16
|
20 |
+
MAX_NEW_TOKENS_MAX = 1024
|
21 |
MAX_NEW_TOKENS_DEFAULT = 256
|
22 |
|
23 |
# CSS for the app
|
|
|
68 |
}
|
69 |
"""
|
70 |
|
|
|
|
|
|
|
|
|
|
|
71 |
# Helper functions
|
72 |
def load_model_and_tokenizer():
|
|
|
|
|
|
|
|
|
73 |
# Load tokenizer
|
74 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
75 |
if tokenizer.pad_token is None:
|
76 |
tokenizer.pad_token = tokenizer.eos_token
|
77 |
|
|
|
|
|
|
|
|
|
|
|
78 |
# Load model with optimizations
|
79 |
model = AutoModelForCausalLM.from_pretrained(
|
80 |
MODEL_ID,
|
81 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
82 |
+
device_map="auto"
|
|
|
83 |
)
|
84 |
|
85 |
# Create pipeline
|
86 |
pipeline = TextGenerationPipeline(
|
87 |
model=model,
|
88 |
tokenizer=tokenizer,
|
89 |
+
return_full_text=False,
|
90 |
+
max_length=MAX_LENGTH
|
91 |
)
|
92 |
|
93 |
+
return model, tokenizer, pipeline
|
|
|
94 |
|
95 |
def generate_text(
|
96 |
prompt,
|
|
|
103 |
if history is None:
|
104 |
history = []
|
105 |
|
|
|
|
|
|
|
106 |
# Add current prompt to history
|
107 |
history.append({"role": "user", "content": prompt})
|
108 |
|
109 |
try:
|
|
|
|
|
|
|
110 |
# Generate response
|
111 |
outputs = pipeline(
|
112 |
prompt,
|
|
|
120 |
)
|
121 |
|
122 |
response = outputs[0]["generated_text"]
|
|
|
123 |
|
124 |
# Add model response to history
|
125 |
history.append({"role": "assistant", "content": response})
|
|
|
130 |
role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: "
|
131 |
formatted_history.append(f"{role_prefix}{entry['content']}")
|
132 |
|
133 |
+
return response, history, "\n\n".join(formatted_history)
|
134 |
|
135 |
except Exception as e:
|
|
|
136 |
error_msg = f"Error generating response: {str(e)}"
|
137 |
history.append({"role": "assistant", "content": f"[ERROR] {error_msg}"})
|
138 |
+
return error_msg, history, str(history)
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
def clear_conversation():
|
141 |
return "", [], ""
|
142 |
|
143 |
+
def apply_preset_example(example, history):
|
144 |
+
return example, history
|
145 |
+
|
146 |
# Example prompts
|
147 |
examples = [
|
148 |
"Write a short story about a snowflake that comes to life.",
|
|
|
154 |
|
155 |
# Main function
|
156 |
def create_demo():
|
157 |
+
with gr.Blocks(css=css) as demo:
|
158 |
# Header
|
159 |
gr.HTML("""
|
160 |
<div class="header">
|
|
|
190 |
with gr.Column():
|
191 |
chat_history_display = gr.Textbox(
|
192 |
value="",
|
193 |
+
label="Conversation History",
|
194 |
+
lines=10,
|
195 |
max_lines=30,
|
196 |
interactive=False
|
197 |
)
|
|
|
208 |
lines=2
|
209 |
)
|
210 |
with gr.Column(scale=1):
|
211 |
+
submit_btn = gr.Button("Send", variant="primary")
|
212 |
+
clear_btn = gr.Button("Clear Conversation")
|
213 |
+
|
214 |
+
response_output = gr.Textbox(
|
215 |
+
value="",
|
216 |
+
label="Model Response",
|
217 |
+
lines=5,
|
218 |
+
max_lines=10,
|
219 |
+
interactive=False
|
220 |
+
)
|
221 |
|
222 |
# Advanced parameters
|
223 |
with gr.Accordion("Generation Parameters", open=False):
|
|
|
267 |
example_btn = gr.Examples(
|
268 |
examples=examples,
|
269 |
inputs=prompt,
|
270 |
+
label="Click on an example to try it",
|
271 |
+
examples_per_page=5
|
272 |
)
|
273 |
|
274 |
# Footer
|
|
|
282 |
submit_btn.click(
|
283 |
fn=generate_text,
|
284 |
inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
|
285 |
+
outputs=[response_output, history_state, chat_history_display]
|
286 |
)
|
287 |
|
288 |
prompt.submit(
|
289 |
fn=generate_text,
|
290 |
inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
|
291 |
+
outputs=[response_output, history_state, chat_history_display]
|
292 |
)
|
293 |
|
294 |
clear_btn.click(
|
|
|
299 |
|
300 |
return demo
|
301 |
|
302 |
+
# Load model and tokenizer
|
303 |
+
print("Loading Snowflake-G0-Release model and tokenizer...")
|
304 |
try:
|
305 |
+
model, tokenizer, pipeline = load_model_and_tokenizer()
|
306 |
+
print("Model loaded successfully!")
|
|
|
307 |
except Exception as e:
|
308 |
+
print(f"Error loading model: {str(e)}")
|
309 |
# Create a simple error demo if model fails to load
|
310 |
with gr.Blocks(css=css) as error_demo:
|
311 |
gr.HTML(f"""
|
|
|
316 |
""")
|
317 |
demo = error_demo
|
318 |
|
319 |
+
# Create and launch the demo
|
320 |
+
demo = create_demo()
|
321 |
+
|
322 |
# Launch the app
|
323 |
if __name__ == "__main__":
|
324 |
demo.launch()
|