prithivMLmods commited on
Commit
d144c92
·
verified ·
1 Parent(s): b55b5cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -14
app.py CHANGED
@@ -14,11 +14,6 @@ import numpy as np
14
  from PIL import Image
15
  import edge_tts
16
 
17
- import sambanova_gradio
18
- # Load the reasoning model from sambanova_gradio.
19
- # This returns a callable interface for inference.
20
- reasoning_model = gr.load("DeepSeek-R1-Distill-Llama-70B", src=sambanova_gradio.registry, accept_token=True)
21
-
22
  from transformers import (
23
  AutoModelForCausalLM,
24
  AutoTokenizer,
@@ -29,6 +24,9 @@ from transformers import (
29
  from transformers.image_utils import load_image
30
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
31
 
 
 
 
32
  MAX_MAX_NEW_TOKENS = 2048
33
  DEFAULT_MAX_NEW_TOKENS = 1024
34
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -56,6 +54,38 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
56
  torch_dtype=torch.float16
57
  ).to("cuda").eval()
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
60
  communicate = edge_tts.Communicate(text, voice)
61
  await communicate.save(output_file)
@@ -94,6 +124,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
94
 
95
  dtype = torch.float16 if device.type == "cuda" else torch.float32
96
 
 
 
 
97
  if torch.cuda.is_available():
98
  # Lightning 5 model
99
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -179,6 +212,9 @@ def save_image(img: Image.Image) -> str:
179
  img.save(unique_name)
180
  return unique_name
181
 
 
 
 
182
  @spaces.GPU
183
  def generate(
184
  input_dict: dict,
@@ -193,6 +229,7 @@ def generate(
193
  files = input_dict.get("files", [])
194
 
195
  lower_text = text.lower().strip()
 
196
  # Check if the prompt is an image generation command using model flags.
197
  if (lower_text.startswith("@lightningv5") or
198
  lower_text.startswith("@lightningv4") or
@@ -245,17 +282,20 @@ def generate(
245
  yield gr.Image(image_path)
246
  return
247
 
248
- # New reasoning branch.
249
- elif lower_text.startswith("@reasoning"):
250
- # Remove the reasoning flag and clean the prompt.
251
- prompt_clean = re.sub(r"@reasoning", "", text, flags=re.IGNORECASE).strip().strip('"')
252
- yield "Processing reasoning request..."
253
- # Call the reasoning model (this call might be synchronous; adjust if needed).
254
- reasoning_response = reasoning_model(prompt_clean)
255
- yield reasoning_response
 
256
  return
257
 
 
258
  # Otherwise, handle text/chat (and TTS) generation.
 
259
  tts_prefix = "@tts"
260
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
261
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -360,7 +400,7 @@ demo = gr.ChatInterface(
360
  ['@turbov3 "Abstract art, colorful and vibrant"'],
361
  ["Write a Python function to check if a number is prime."],
362
  ["@tts2 What causes rainbows to form?"],
363
- ["@reasoning Explain the significance of Gödel's incompleteness theorems."],
364
  ],
365
  cache_examples=False,
366
  type="messages",
 
14
  from PIL import Image
15
  import edge_tts
16
 
 
 
 
 
 
17
  from transformers import (
18
  AutoModelForCausalLM,
19
  AutoTokenizer,
 
24
  from transformers.image_utils import load_image
25
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
+ # -----------------------------
28
+ # Existing global variables and model setup
29
+ # -----------------------------
30
  MAX_MAX_NEW_TOKENS = 2048
31
  DEFAULT_MAX_NEW_TOKENS = 1024
32
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
54
  torch_dtype=torch.float16
55
  ).to("cuda").eval()
56
 
57
+ # -----------------------------
58
+ # New reasoning feature setup
59
+ # -----------------------------
60
+ from openai import OpenAI
61
+
62
+ api_key = os.getenv("SAMBANOVA_API_KEY")
63
+ client_reasoning = OpenAI(
64
+ base_url="https://api.sambanova.ai/v1/",
65
+ api_key=api_key,
66
+ )
67
+
68
+ def reasoning_predict(message, history):
69
+ """
70
+ This function appends the user's reasoning request to the history,
71
+ then streams the response from the Sambanova API using the model
72
+ 'DeepSeek-R1-Distill-Llama-70B'.
73
+ """
74
+ history.append({"role": "user", "content": message})
75
+ stream = client_reasoning.chat.completions.create(
76
+ messages=history,
77
+ model="DeepSeek-R1-Distill-Llama-70B",
78
+ stream=True,
79
+ )
80
+ chunks = []
81
+ for chunk in stream:
82
+ # Accumulate streamed content and yield the current full response
83
+ chunks.append(chunk.choices[0].delta.content or "")
84
+ yield "".join(chunks)
85
+
86
+ # -----------------------------
87
+ # Utility Functions and Checks
88
+ # -----------------------------
89
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
90
  communicate = edge_tts.Communicate(text, voice)
91
  await communicate.save(output_file)
 
124
 
125
  dtype = torch.float16 if device.type == "cuda" else torch.float32
126
 
127
+ # -----------------------------
128
+ # Image Generation Models Setup
129
+ # -----------------------------
130
  if torch.cuda.is_available():
131
  # Lightning 5 model
132
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
212
  img.save(unique_name)
213
  return unique_name
214
 
215
+ # -----------------------------
216
+ # Main Generation Function with Reasoning Integration
217
+ # -----------------------------
218
  @spaces.GPU
219
  def generate(
220
  input_dict: dict,
 
229
  files = input_dict.get("files", [])
230
 
231
  lower_text = text.lower().strip()
232
+
233
  # Check if the prompt is an image generation command using model flags.
234
  if (lower_text.startswith("@lightningv5") or
235
  lower_text.startswith("@lightningv4") or
 
282
  yield gr.Image(image_path)
283
  return
284
 
285
+ # -----------------------------
286
+ # NEW: Reasoning Branch
287
+ # -----------------------------
288
+ if lower_text.startswith("@reasoning"):
289
+ reasoning_text = text.replace("@reasoning", "").strip()
290
+ reasoning_history = clean_chat_history(chat_history)
291
+ yield "Reasoning..."
292
+ for response in reasoning_predict(reasoning_text, reasoning_history):
293
+ yield response
294
  return
295
 
296
+ # -----------------------------
297
  # Otherwise, handle text/chat (and TTS) generation.
298
+ # -----------------------------
299
  tts_prefix = "@tts"
300
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
301
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
400
  ['@turbov3 "Abstract art, colorful and vibrant"'],
401
  ["Write a Python function to check if a number is prime."],
402
  ["@tts2 What causes rainbows to form?"],
403
+ ["@reasoning How does quantum entanglement work and what are its implications?"],
404
  ],
405
  cache_examples=False,
406
  type="messages",