import os import random import uuid import json import time import asyncio import tempfile from threading import Thread import base64 import shutil import gradio as gr import spaces import torch import numpy as np from PIL import Image import edge_tts import trimesh from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, Qwen2VLForConditionalGeneration, AutoProcessor, ) from transformers.image_utils import load_image from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler from diffusers import ShapEImg2ImgPipeline, ShapEPipeline from diffusers.utils import export_to_ply # ----------------------------------------------------------------------------- # Global constants and helper functions # ----------------------------------------------------------------------------- MAX_SEED = np.iinfo(np.int32).max def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def glb_to_data_url(glb_path: str) -> str: """ Reads a GLB file from disk and returns a data URL with a base64 encoded representation. (Not used in this method.) """ with open(glb_path, "rb") as f: data = f.read() b64_data = base64.b64encode(data).decode("utf-8") return f"data:model/gltf-binary;base64,{b64_data}" # ----------------------------------------------------------------------------- # Model class for Text-to-3D Generation (ShapE) # ----------------------------------------------------------------------------- class Model: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16) self.pipe.to(self.device) # Ensure the text encoder is in half precision to avoid dtype mismatches. if torch.cuda.is_available(): try: self.pipe.text_encoder = self.pipe.text_encoder.half() except AttributeError: pass self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16) self.pipe_img.to(self.device) # Use getattr with a default value to avoid AttributeError if text_encoder is missing. if torch.cuda.is_available(): text_encoder_img = getattr(self.pipe_img, "text_encoder", None) if text_encoder_img is not None: self.pipe_img.text_encoder = text_encoder_img.half() def to_glb(self, ply_path: str) -> str: mesh = trimesh.load(ply_path) # Rotate the mesh for proper orientation rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0]) mesh.apply_transform(rot) rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0]) mesh.apply_transform(rot) mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) mesh.export(mesh_path.name, file_type="glb") return mesh_path.name def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str: generator = torch.Generator(device=self.device).manual_seed(seed) images = self.pipe( prompt, generator=generator, guidance_scale=guidance_scale, num_inference_steps=num_steps, output_type="mesh", ).images ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b") export_to_ply(images[0], ply_path.name) return self.to_glb(ply_path.name) def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str: generator = torch.Generator(device=self.device).manual_seed(seed) images = self.pipe_img( image, generator=generator, guidance_scale=guidance_scale, num_inference_steps=num_steps, output_type="mesh", ).images ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b") export_to_ply(images[0], ply_path.name) return self.to_glb(ply_path.name) # ----------------------------------------------------------------------------- # Gradio UI configuration # ----------------------------------------------------------------------------- DESCRIPTION = """ # QwQ Edge 💬 """ css = ''' h1 { text-align: center; display: block; } #duplicate-button { margin: auto; color: #fff; background: #1565c0; border-radius: 100vh; } ''' MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ----------------------------------------------------------------------------- # Load Models and Pipelines for Chat, Image, and Multimodal Processing # ----------------------------------------------------------------------------- # Load the text-only model and tokenizer (for pure text chat) model_id = "prithivMLmods/FastThink-0.5B-Tiny" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.eval() # Voices for text-to-speech TTS_VOICES = [ "en-US-JennyNeural", # @tts1 "en-US-GuyNeural", # @tts2 ] # Load multimodal processor and model (e.g. for OCR and image processing) MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model_m = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16 ).to("cuda").eval() # ----------------------------------------------------------------------------- # Asynchronous text-to-speech # ----------------------------------------------------------------------------- async def text_to_speech(text: str, voice: str, output_file="output.mp3"): """Convert text to speech using Edge TTS and save as MP3""" communicate = edge_tts.Communicate(text, voice) await communicate.save(output_file) return output_file # ----------------------------------------------------------------------------- # Utility function to clean conversation history # ----------------------------------------------------------------------------- def clean_chat_history(chat_history): """ Filter out any chat entries whose "content" is not a string. This helps prevent errors when concatenating previous messages. """ cleaned = [] for msg in chat_history: if isinstance(msg, dict) and isinstance(msg.get("content"), str): cleaned.append(msg) return cleaned # ----------------------------------------------------------------------------- # Stable Diffusion XL Pipeline for Image Generation # ----------------------------------------------------------------------------- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation sd_pipe = StableDiffusionXLPipeline.from_pretrained( MODEL_ID_SD, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_safetensors=True, add_watermarker=False, ).to(device) sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config) if torch.cuda.is_available(): sd_pipe.text_encoder = sd_pipe.text_encoder.half() if USE_TORCH_COMPILE: sd_pipe.compile() if ENABLE_CPU_OFFLOAD: sd_pipe.enable_model_cpu_offload() def save_image(img: Image.Image) -> str: """Save a PIL image with a unique filename and return the path.""" unique_name = str(uuid.uuid4()) + ".png" img.save(unique_name) return unique_name @spaces.GPU(duration=60, enable_queue=True) def generate_image_fn( prompt: str, negative_prompt: str = "", use_negative_prompt: bool = False, seed: int = 1, width: int = 1024, height: int = 1024, guidance_scale: float = 3, num_inference_steps: int = 25, randomize_seed: bool = False, use_resolution_binning: bool = True, num_images: int = 1, progress=gr.Progress(track_tqdm=True), ): """Generate images using the SDXL pipeline.""" seed = int(randomize_seed_fn(seed, randomize_seed)) generator = torch.Generator(device=device).manual_seed(seed) options = { "prompt": [prompt] * num_images, "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None, "width": width, "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "generator": generator, "output_type": "pil", } if use_resolution_binning: options["use_resolution_binning"] = True images = [] # Process in batches for i in range(0, num_images, BATCH_SIZE): batch_options = options.copy() batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE] if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None: batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE] if device.type == "cuda": with torch.autocast("cuda", dtype=torch.float16): outputs = sd_pipe(**batch_options) else: outputs = sd_pipe(**batch_options) images.extend(outputs.images) image_paths = [save_image(img) for img in images] return image_paths, seed # ----------------------------------------------------------------------------- # Text-to-3D Generation using the ShapE Pipeline # ----------------------------------------------------------------------------- @spaces.GPU(duration=120, enable_queue=True) def generate_3d_fn( prompt: str, seed: int = 1, guidance_scale: float = 15.0, num_steps: int = 64, randomize_seed: bool = False, ): """ Generate a 3D model from text using the ShapE pipeline. Returns a tuple of (glb_file_path, used_seed). """ seed = int(randomize_seed_fn(seed, randomize_seed)) model3d = Model() glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps) return glb_path, seed # ----------------------------------------------------------------------------- # Chat Generation Function with support for @tts, @image, and @3d commands # ----------------------------------------------------------------------------- @spaces.GPU def generate( input_dict: dict, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ): """ Generates chatbot responses with support for multimodal input, TTS, image generation, and 3D model generation. Special commands: - "@tts1" or "@tts2": triggers text-to-speech. - "@image": triggers image generation using the SDXL pipeline. - "@3d": triggers 3D model generation using the ShapE pipeline. """ text = input_dict["text"] files = input_dict.get("files", []) # --- 3D Generation branch --- if text.strip().lower().startswith("@3d"): prompt = text[len("@3d"):].strip() yield "Hold tight, generating a 3D mesh GLB file....." glb_path, used_seed = generate_3d_fn( prompt=prompt, seed=1, guidance_scale=15.0, num_steps=64, randomize_seed=True, ) # Copy the GLB file to a static folder. static_folder = os.path.join(os.getcwd(), "static") if not os.path.exists(static_folder): os.makedirs(static_folder) new_filename = f"mesh_{uuid.uuid4()}.glb" new_filepath = os.path.join(static_folder, new_filename) shutil.copy(glb_path, new_filepath) #yield gr.File(new_filepath) # After generating the GLB file... yield gr.Model3D(new_filepath, clear_color=[1, 1, 1, 1]) return # --- Image Generation branch --- if text.strip().lower().startswith("@image"): prompt = text[len("@image"):].strip() yield "Generating image..." image_paths, used_seed = generate_image_fn( prompt=prompt, negative_prompt="", use_negative_prompt=False, seed=1, width=1024, height=1024, guidance_scale=3, num_inference_steps=25, randomize_seed=True, use_resolution_binning=True, num_images=1, ) yield gr.Image(image_paths[0]) return # --- Text and TTS branch --- tts_prefix = "@tts" is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3)) voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None) if is_tts and voice_index: voice = TTS_VOICES[voice_index - 1] text = text.replace(f"{tts_prefix}{voice_index}", "").strip() conversation = [{"role": "user", "content": text}] else: voice = None text = text.replace(tts_prefix, "").strip() conversation = clean_chat_history(chat_history) conversation.append({"role": "user", "content": text}) if files: if len(files) > 1: images = [load_image(image) for image in files] elif len(files) == 1: images = [load_image(files[0])] else: images = [] messages = [{ "role": "user", "content": [ *[{"type": "image", "image": image} for image in images], {"type": "text", "text": text}, ] }] prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model_m.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield "Thinking..." for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer else: input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "num_beams": 1, "repetition_penalty": repetition_penalty, } t = Thread(target=model.generate, kwargs=generation_kwargs) t.start() outputs = [] for new_text in streamer: outputs.append(new_text) yield "".join(outputs) final_response = "".join(outputs) yield final_response if is_tts and voice: output_file = asyncio.run(text_to_speech(final_response, voice)) yield gr.Audio(output_file, autoplay=True) # ----------------------------------------------------------------------------- # Gradio Chat Interface Setup and Launch # ----------------------------------------------------------------------------- demo = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS), gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6), gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9), gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50), gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2), ], examples=[ ["@tts1 Who is Nikola Tesla, and why did he die?"], ["@3d A birthday cupcake with cherry"], [{"text": "summarize the letter", "files": ["examples/1.png"]}], ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"], ["Write a Python function to check if a number is prime."], ["@tts2 What causes rainbows to form?"], ], cache_examples=False, type="messages", description=DESCRIPTION, css=css, fill_height=True, textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, ) # Ensure the static folder exists if not os.path.exists("static"): os.makedirs("static") # Mount the static folder onto the FastAPI app so that GLB files are served. from fastapi.staticfiles import StaticFiles demo.app.mount("/static", StaticFiles(directory="static"), name="static") if __name__ == "__main__": # Launch without the unsupported static_dirs parameter. demo.queue(max_size=20).launch(share=True)