Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -14,11 +14,6 @@ import numpy as np
|
|
14 |
from PIL import Image
|
15 |
import edge_tts
|
16 |
|
17 |
-
import sambanova_gradio
|
18 |
-
# Load the reasoning model from sambanova_gradio.
|
19 |
-
# This returns a callable interface for inference.
|
20 |
-
reasoning_model = gr.load("DeepSeek-R1-Distill-Llama-70B", src=sambanova_gradio.registry, accept_token=True)
|
21 |
-
|
22 |
from transformers import (
|
23 |
AutoModelForCausalLM,
|
24 |
AutoTokenizer,
|
@@ -29,6 +24,9 @@ from transformers import (
|
|
29 |
from transformers.image_utils import load_image
|
30 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
31 |
|
|
|
|
|
|
|
32 |
MAX_MAX_NEW_TOKENS = 2048
|
33 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
34 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
@@ -56,6 +54,38 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
56 |
torch_dtype=torch.float16
|
57 |
).to("cuda").eval()
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
60 |
communicate = edge_tts.Communicate(text, voice)
|
61 |
await communicate.save(output_file)
|
@@ -94,6 +124,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
|
94 |
|
95 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
96 |
|
|
|
|
|
|
|
97 |
if torch.cuda.is_available():
|
98 |
# Lightning 5 model
|
99 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
@@ -179,6 +212,9 @@ def save_image(img: Image.Image) -> str:
|
|
179 |
img.save(unique_name)
|
180 |
return unique_name
|
181 |
|
|
|
|
|
|
|
182 |
@spaces.GPU
|
183 |
def generate(
|
184 |
input_dict: dict,
|
@@ -193,6 +229,7 @@ def generate(
|
|
193 |
files = input_dict.get("files", [])
|
194 |
|
195 |
lower_text = text.lower().strip()
|
|
|
196 |
# Check if the prompt is an image generation command using model flags.
|
197 |
if (lower_text.startswith("@lightningv5") or
|
198 |
lower_text.startswith("@lightningv4") or
|
@@ -245,17 +282,20 @@ def generate(
|
|
245 |
yield gr.Image(image_path)
|
246 |
return
|
247 |
|
248 |
-
#
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
256 |
return
|
257 |
|
|
|
258 |
# Otherwise, handle text/chat (and TTS) generation.
|
|
|
259 |
tts_prefix = "@tts"
|
260 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
261 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
@@ -360,7 +400,7 @@ demo = gr.ChatInterface(
|
|
360 |
['@turbov3 "Abstract art, colorful and vibrant"'],
|
361 |
["Write a Python function to check if a number is prime."],
|
362 |
["@tts2 What causes rainbows to form?"],
|
363 |
-
["@reasoning
|
364 |
],
|
365 |
cache_examples=False,
|
366 |
type="messages",
|
|
|
14 |
from PIL import Image
|
15 |
import edge_tts
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
from transformers import (
|
18 |
AutoModelForCausalLM,
|
19 |
AutoTokenizer,
|
|
|
24 |
from transformers.image_utils import load_image
|
25 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
26 |
|
27 |
+
# -----------------------------
|
28 |
+
# Existing global variables and model setup
|
29 |
+
# -----------------------------
|
30 |
MAX_MAX_NEW_TOKENS = 2048
|
31 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
32 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
|
54 |
torch_dtype=torch.float16
|
55 |
).to("cuda").eval()
|
56 |
|
57 |
+
# -----------------------------
|
58 |
+
# New reasoning feature setup
|
59 |
+
# -----------------------------
|
60 |
+
from openai import OpenAI
|
61 |
+
|
62 |
+
api_key = os.getenv("SAMBANOVA_API_KEY")
|
63 |
+
client_reasoning = OpenAI(
|
64 |
+
base_url="https://api.sambanova.ai/v1/",
|
65 |
+
api_key=api_key,
|
66 |
+
)
|
67 |
+
|
68 |
+
def reasoning_predict(message, history):
|
69 |
+
"""
|
70 |
+
This function appends the user's reasoning request to the history,
|
71 |
+
then streams the response from the Sambanova API using the model
|
72 |
+
'DeepSeek-R1-Distill-Llama-70B'.
|
73 |
+
"""
|
74 |
+
history.append({"role": "user", "content": message})
|
75 |
+
stream = client_reasoning.chat.completions.create(
|
76 |
+
messages=history,
|
77 |
+
model="DeepSeek-R1-Distill-Llama-70B",
|
78 |
+
stream=True,
|
79 |
+
)
|
80 |
+
chunks = []
|
81 |
+
for chunk in stream:
|
82 |
+
# Accumulate streamed content and yield the current full response
|
83 |
+
chunks.append(chunk.choices[0].delta.content or "")
|
84 |
+
yield "".join(chunks)
|
85 |
+
|
86 |
+
# -----------------------------
|
87 |
+
# Utility Functions and Checks
|
88 |
+
# -----------------------------
|
89 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
90 |
communicate = edge_tts.Communicate(text, voice)
|
91 |
await communicate.save(output_file)
|
|
|
124 |
|
125 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
126 |
|
127 |
+
# -----------------------------
|
128 |
+
# Image Generation Models Setup
|
129 |
+
# -----------------------------
|
130 |
if torch.cuda.is_available():
|
131 |
# Lightning 5 model
|
132 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
212 |
img.save(unique_name)
|
213 |
return unique_name
|
214 |
|
215 |
+
# -----------------------------
|
216 |
+
# Main Generation Function with Reasoning Integration
|
217 |
+
# -----------------------------
|
218 |
@spaces.GPU
|
219 |
def generate(
|
220 |
input_dict: dict,
|
|
|
229 |
files = input_dict.get("files", [])
|
230 |
|
231 |
lower_text = text.lower().strip()
|
232 |
+
|
233 |
# Check if the prompt is an image generation command using model flags.
|
234 |
if (lower_text.startswith("@lightningv5") or
|
235 |
lower_text.startswith("@lightningv4") or
|
|
|
282 |
yield gr.Image(image_path)
|
283 |
return
|
284 |
|
285 |
+
# -----------------------------
|
286 |
+
# NEW: Reasoning Branch
|
287 |
+
# -----------------------------
|
288 |
+
if lower_text.startswith("@reasoning"):
|
289 |
+
reasoning_text = text.replace("@reasoning", "").strip()
|
290 |
+
reasoning_history = clean_chat_history(chat_history)
|
291 |
+
yield "Reasoning..."
|
292 |
+
for response in reasoning_predict(reasoning_text, reasoning_history):
|
293 |
+
yield response
|
294 |
return
|
295 |
|
296 |
+
# -----------------------------
|
297 |
# Otherwise, handle text/chat (and TTS) generation.
|
298 |
+
# -----------------------------
|
299 |
tts_prefix = "@tts"
|
300 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
301 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
|
|
400 |
['@turbov3 "Abstract art, colorful and vibrant"'],
|
401 |
["Write a Python function to check if a number is prime."],
|
402 |
["@tts2 What causes rainbows to form?"],
|
403 |
+
["@reasoning How does quantum entanglement work and what are its implications?"],
|
404 |
],
|
405 |
cache_examples=False,
|
406 |
type="messages",
|