prithivMLmods commited on
Commit
c11a2d2
·
verified ·
1 Parent(s): 5b2ec8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -128
app.py CHANGED
@@ -4,9 +4,7 @@ import uuid
4
  import json
5
  import time
6
  import asyncio
7
- import tempfile
8
  from threading import Thread
9
- import base64
10
 
11
  import gradio as gr
12
  import spaces
@@ -14,7 +12,6 @@ import torch
14
  import numpy as np
15
  from PIL import Image
16
  import edge_tts
17
- import trimesh
18
 
19
  from transformers import (
20
  AutoModelForCausalLM,
@@ -24,85 +21,8 @@ from transformers import (
24
  AutoProcessor,
25
  )
26
  from transformers.image_utils import load_image
27
-
28
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
29
- from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
30
- from diffusers.utils import export_to_ply
31
-
32
-
33
- MAX_SEED = np.iinfo(np.int32).max
34
-
35
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
- return seed
39
 
40
- def glb_to_data_url(glb_path: str) -> str:
41
- """
42
- Reads a GLB file from disk and returns a data URL with a base64 encoded representation.
43
- This data URL can be used as the `src` for an HTML <model-viewer> tag.
44
- """
45
- with open(glb_path, "rb") as f:
46
- data = f.read()
47
- b64_data = base64.b64encode(data).decode("utf-8")
48
- return f"data:model/gltf-binary;base64,{b64_data}"
49
-
50
- class Model:
51
- def __init__(self):
52
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
- self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
54
- self.pipe.to(self.device)
55
- # Ensure the text encoder is in half precision to avoid dtype mismatches.
56
- if torch.cuda.is_available():
57
- try:
58
- self.pipe.text_encoder = self.pipe.text_encoder.half()
59
- except AttributeError:
60
- pass
61
-
62
- self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
63
- self.pipe_img.to(self.device)
64
- # Use getattr with a default value to avoid AttributeError if text_encoder is missing.
65
- if torch.cuda.is_available():
66
- text_encoder_img = getattr(self.pipe_img, "text_encoder", None)
67
- if text_encoder_img is not None:
68
- self.pipe_img.text_encoder = text_encoder_img.half()
69
-
70
- def to_glb(self, ply_path: str) -> str:
71
- mesh = trimesh.load(ply_path)
72
- # Rotate the mesh for proper orientation
73
- rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
74
- mesh.apply_transform(rot)
75
- rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
76
- mesh.apply_transform(rot)
77
- mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
78
- mesh.export(mesh_path.name, file_type="glb")
79
- return mesh_path.name
80
-
81
- def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str:
82
- generator = torch.Generator(device=self.device).manual_seed(seed)
83
- images = self.pipe(
84
- prompt,
85
- generator=generator,
86
- guidance_scale=guidance_scale,
87
- num_inference_steps=num_steps,
88
- output_type="mesh",
89
- ).images
90
- ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
91
- export_to_ply(images[0], ply_path.name)
92
- return self.to_glb(ply_path.name)
93
-
94
- def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str:
95
- generator = torch.Generator(device=self.device).manual_seed(seed)
96
- images = self.pipe_img(
97
- image,
98
- generator=generator,
99
- guidance_scale=guidance_scale,
100
- num_inference_steps=num_steps,
101
- output_type="mesh",
102
- ).images
103
- ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
104
- export_to_ply(images[0], ply_path.name)
105
- return self.to_glb(ply_path.name)
106
 
107
  DESCRIPTION = """
108
  # QwQ Edge 💬
@@ -128,6 +48,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
128
 
129
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
130
 
 
131
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
132
  tokenizer = AutoTokenizer.from_pretrained(model_id)
133
  model = AutoModelForCausalLM.from_pretrained(
@@ -137,13 +58,11 @@ model = AutoModelForCausalLM.from_pretrained(
137
  )
138
  model.eval()
139
 
140
- # Voices for text-to-speech
141
  TTS_VOICES = [
142
  "en-US-JennyNeural", # @tts1
143
  "en-US-GuyNeural", # @tts2
144
  ]
145
 
146
- # Load multimodal processor and model (e.g. for OCR and image processing)
147
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
148
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
149
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -169,12 +88,14 @@ def clean_chat_history(chat_history):
169
  cleaned.append(msg)
170
  return cleaned
171
 
 
172
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
173
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
174
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
175
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
176
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
177
 
 
178
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
179
  MODEL_ID_SD,
180
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -183,21 +104,31 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
183
  ).to(device)
184
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
185
 
 
186
  if torch.cuda.is_available():
187
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
188
 
 
189
  if USE_TORCH_COMPILE:
190
  sd_pipe.compile()
191
 
 
192
  if ENABLE_CPU_OFFLOAD:
193
  sd_pipe.enable_model_cpu_offload()
194
 
 
 
195
  def save_image(img: Image.Image) -> str:
196
  """Save a PIL image with a unique filename and return the path."""
197
  unique_name = str(uuid.uuid4()) + ".png"
198
  img.save(unique_name)
199
  return unique_name
200
 
 
 
 
 
 
201
  @spaces.GPU(duration=60, enable_queue=True)
202
  def generate_image_fn(
203
  prompt: str,
@@ -237,6 +168,7 @@ def generate_image_fn(
237
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
238
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
239
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
240
  if device.type == "cuda":
241
  with torch.autocast("cuda", dtype=torch.float16):
242
  outputs = sd_pipe(**batch_options)
@@ -246,23 +178,6 @@ def generate_image_fn(
246
  image_paths = [save_image(img) for img in images]
247
  return image_paths, seed
248
 
249
- @spaces.GPU(duration=120, enable_queue=True)
250
- def generate_3d_fn(
251
- prompt: str,
252
- seed: int = 1,
253
- guidance_scale: float = 15.0,
254
- num_steps: int = 64,
255
- randomize_seed: bool = False,
256
- ):
257
- """
258
- Generate a 3D model from text using the ShapE pipeline.
259
- Returns a tuple of (glb_file_path, used_seed).
260
- """
261
- seed = int(randomize_seed_fn(seed, randomize_seed))
262
- model3d = Model()
263
- glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
264
- return glb_path, seed
265
-
266
  @spaces.GPU
267
  def generate(
268
  input_dict: dict,
@@ -274,39 +189,16 @@ def generate(
274
  repetition_penalty: float = 1.2,
275
  ):
276
  """
277
- Generates chatbot responses with support for multimodal input, TTS, image generation,
278
- and 3D model generation.
279
-
280
  Special commands:
281
  - "@tts1" or "@tts2": triggers text-to-speech.
282
  - "@image": triggers image generation using the SDXL pipeline.
283
- - "@3d": triggers 3D model generation using the ShapE pipeline.
284
  """
285
  text = input_dict["text"]
286
  files = input_dict.get("files", [])
287
 
288
- # --- 3D Generation branch ---
289
- if text.strip().lower().startswith("@3d"):
290
- prompt = text[len("@3d"):].strip()
291
- yield "Generating 3D model..."
292
- glb_path, used_seed = generate_3d_fn(
293
- prompt=prompt,
294
- seed=1,
295
- guidance_scale=15.0,
296
- num_steps=64,
297
- randomize_seed=True,
298
- )
299
- # Convert the GLB file to a base64 data URL and embed it in an HTML <model-viewer> tag.
300
- data_url = glb_to_data_url(glb_path)
301
- html_output = f'''
302
- <model-viewer src="{data_url}" alt="3D Model" auto-rotate camera-controls style="width: 100%; height: 400px;"></model-viewer>
303
- <script type="module" src="https://unpkg.com/@google/model-viewer/dist/model-viewer.min.js"></script>
304
- '''
305
- yield gr.HTML(html_output)
306
- return
307
-
308
- # --- Image Generation branch ---
309
  if text.strip().lower().startswith("@image"):
 
310
  prompt = text[len("@image"):].strip()
311
  yield "Generating image..."
312
  image_paths, used_seed = generate_image_fn(
@@ -322,10 +214,10 @@ def generate(
322
  use_resolution_binning=True,
323
  num_images=1,
324
  )
 
325
  yield gr.Image(image_paths[0])
326
- return
327
 
328
- # --- Text and TTS branch ---
329
  tts_prefix = "@tts"
330
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
331
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -333,9 +225,11 @@ def generate(
333
  if is_tts and voice_index:
334
  voice = TTS_VOICES[voice_index - 1]
335
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
336
  conversation = [{"role": "user", "content": text}]
337
  else:
338
  voice = None
 
339
  text = text.replace(tts_prefix, "").strip()
340
  conversation = clean_chat_history(chat_history)
341
  conversation.append({"role": "user", "content": text})
@@ -369,6 +263,7 @@ def generate(
369
  time.sleep(0.01)
370
  yield buffer
371
  else:
 
372
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
373
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
374
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -397,6 +292,7 @@ def generate(
397
  final_response = "".join(outputs)
398
  yield final_response
399
 
 
400
  if is_tts and voice:
401
  output_file = asyncio.run(text_to_speech(final_response, voice))
402
  yield gr.Audio(output_file, autoplay=True)
@@ -412,11 +308,12 @@ demo = gr.ChatInterface(
412
  ],
413
  examples=[
414
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
415
- ["@3d A birthday cupcake with cherry"],
416
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
417
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
418
  ["Write a Python function to check if a number is prime."],
419
  ["@tts2 What causes rainbows to form?"],
 
420
  ],
421
  cache_examples=False,
422
  type="messages",
@@ -429,4 +326,5 @@ demo = gr.ChatInterface(
429
  )
430
 
431
  if __name__ == "__main__":
432
- demo.queue(max_size=30).launch(share=True)
 
 
4
  import json
5
  import time
6
  import asyncio
 
7
  from threading import Thread
 
8
 
9
  import gradio as gr
10
  import spaces
 
12
  import numpy as np
13
  from PIL import Image
14
  import edge_tts
 
15
 
16
  from transformers import (
17
  AutoModelForCausalLM,
 
21
  AutoProcessor,
22
  )
23
  from transformers.image_utils import load_image
 
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
 
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
+ # Load text-only model and tokenizer
52
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
53
  tokenizer = AutoTokenizer.from_pretrained(model_id)
54
  model = AutoModelForCausalLM.from_pretrained(
 
58
  )
59
  model.eval()
60
 
 
61
  TTS_VOICES = [
62
  "en-US-JennyNeural", # @tts1
63
  "en-US-GuyNeural", # @tts2
64
  ]
65
 
 
66
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
67
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
68
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
88
  cleaned.append(msg)
89
  return cleaned
90
 
91
+ # Environment variables and parameters for Stable Diffusion XL
92
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
93
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
94
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
95
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
96
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
97
 
98
+ # Load the SDXL pipeline
99
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
100
  MODEL_ID_SD,
101
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
104
  ).to(device)
105
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
106
 
107
+ # Ensure that the text encoder is in half-precision if using CUDA.
108
  if torch.cuda.is_available():
109
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
110
 
111
+ # Optional: compile the model for speedup if enabled
112
  if USE_TORCH_COMPILE:
113
  sd_pipe.compile()
114
 
115
+ # Optional: offload parts of the model to CPU if needed
116
  if ENABLE_CPU_OFFLOAD:
117
  sd_pipe.enable_model_cpu_offload()
118
 
119
+ MAX_SEED = np.iinfo(np.int32).max
120
+
121
  def save_image(img: Image.Image) -> str:
122
  """Save a PIL image with a unique filename and return the path."""
123
  unique_name = str(uuid.uuid4()) + ".png"
124
  img.save(unique_name)
125
  return unique_name
126
 
127
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
128
+ if randomize_seed:
129
+ seed = random.randint(0, MAX_SEED)
130
+ return seed
131
+
132
  @spaces.GPU(duration=60, enable_queue=True)
133
  def generate_image_fn(
134
  prompt: str,
 
168
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
169
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
170
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
171
+ # Wrap the pipeline call in autocast if using CUDA
172
  if device.type == "cuda":
173
  with torch.autocast("cuda", dtype=torch.float16):
174
  outputs = sd_pipe(**batch_options)
 
178
  image_paths = [save_image(img) for img in images]
179
  return image_paths, seed
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  @spaces.GPU
182
  def generate(
183
  input_dict: dict,
 
189
  repetition_penalty: float = 1.2,
190
  ):
191
  """
192
+ Generates chatbot responses with support for multimodal input, TTS, and image generation.
 
 
193
  Special commands:
194
  - "@tts1" or "@tts2": triggers text-to-speech.
195
  - "@image": triggers image generation using the SDXL pipeline.
 
196
  """
197
  text = input_dict["text"]
198
  files = input_dict.get("files", [])
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if text.strip().lower().startswith("@image"):
201
+ # Remove the "@image" tag and use the rest as prompt
202
  prompt = text[len("@image"):].strip()
203
  yield "Generating image..."
204
  image_paths, used_seed = generate_image_fn(
 
214
  use_resolution_binning=True,
215
  num_images=1,
216
  )
217
+ # Yield the generated image so that the chat interface displays it.
218
  yield gr.Image(image_paths[0])
219
+ return # Exit early
220
 
 
221
  tts_prefix = "@tts"
222
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
223
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
225
  if is_tts and voice_index:
226
  voice = TTS_VOICES[voice_index - 1]
227
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
228
+ # Clear previous chat history for a fresh TTS request.
229
  conversation = [{"role": "user", "content": text}]
230
  else:
231
  voice = None
232
+ # Remove any stray @tts tags and build the conversation history.
233
  text = text.replace(tts_prefix, "").strip()
234
  conversation = clean_chat_history(chat_history)
235
  conversation.append({"role": "user", "content": text})
 
263
  time.sleep(0.01)
264
  yield buffer
265
  else:
266
+
267
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
268
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
269
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
292
  final_response = "".join(outputs)
293
  yield final_response
294
 
295
+ # If TTS was requested, convert the final response to speech.
296
  if is_tts and voice:
297
  output_file = asyncio.run(text_to_speech(final_response, voice))
298
  yield gr.Audio(output_file, autoplay=True)
 
308
  ],
309
  examples=[
310
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
311
+ [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
312
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
313
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
314
  ["Write a Python function to check if a number is prime."],
315
  ["@tts2 What causes rainbows to form?"],
316
+
317
  ],
318
  cache_examples=False,
319
  type="messages",
 
326
  )
327
 
328
  if __name__ == "__main__":
329
+ # To create a public link, set share=True in launch().
330
+ demo.queue(max_size=20).launch(share=True)