prithivMLmods commited on
Commit
a85c4cf
·
verified ·
1 Parent(s): eab6c4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -329
app.py CHANGED
@@ -1,330 +1,331 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
-
9
- import gradio as gr
10
- import spaces
11
- import torch
12
- import numpy as np
13
- from PIL import Image
14
- import edge_tts
15
-
16
- from transformers import (
17
- AutoModelForCausalLM,
18
- AutoTokenizer,
19
- TextIteratorStreamer,
20
- Qwen2VLForConditionalGeneration,
21
- AutoProcessor,
22
- )
23
- from transformers.image_utils import load_image
24
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
-
26
-
27
- DESCRIPTION = """
28
- # QwQ Edge 💬
29
- """
30
-
31
- css = '''
32
- h1 {
33
- text-align: center;
34
- display: block;
35
- }
36
-
37
- #duplicate-button {
38
- margin: auto;
39
- color: #fff;
40
- background: #1565c0;
41
- border-radius: 100vh;
42
- }
43
- '''
44
-
45
- MAX_MAX_NEW_TOKENS = 2048
46
- DEFAULT_MAX_NEW_TOKENS = 1024
47
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
48
-
49
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
-
51
- # Load text-only model and tokenizer
52
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
53
- tokenizer = AutoTokenizer.from_pretrained(model_id)
54
- model = AutoModelForCausalLM.from_pretrained(
55
- model_id,
56
- device_map="auto",
57
- torch_dtype=torch.bfloat16,
58
- )
59
- model.eval()
60
-
61
- TTS_VOICES = [
62
- "en-US-JennyNeural", # @tts1
63
- "en-US-GuyNeural", # @tts2
64
- ]
65
-
66
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
67
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
68
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
69
- MODEL_ID,
70
- trust_remote_code=True,
71
- torch_dtype=torch.float16
72
- ).to("cuda").eval()
73
-
74
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
75
- """Convert text to speech using Edge TTS and save as MP3"""
76
- communicate = edge_tts.Communicate(text, voice)
77
- await communicate.save(output_file)
78
- return output_file
79
-
80
- def clean_chat_history(chat_history):
81
- """
82
- Filter out any chat entries whose "content" is not a string.
83
- This helps prevent errors when concatenating previous messages.
84
- """
85
- cleaned = []
86
- for msg in chat_history:
87
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
88
- cleaned.append(msg)
89
- return cleaned
90
-
91
- # Environment variables and parameters for Stable Diffusion XL
92
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
93
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
94
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
95
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
96
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
97
-
98
- # Load the SDXL pipeline
99
- sd_pipe = StableDiffusionXLPipeline.from_pretrained(
100
- MODEL_ID_SD,
101
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
102
- use_safetensors=True,
103
- add_watermarker=False,
104
- ).to(device)
105
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
106
-
107
- # Ensure that the text encoder is in half-precision if using CUDA.
108
- if torch.cuda.is_available():
109
- sd_pipe.text_encoder = sd_pipe.text_encoder.half()
110
-
111
- # Optional: compile the model for speedup if enabled
112
- if USE_TORCH_COMPILE:
113
- sd_pipe.compile()
114
-
115
- # Optional: offload parts of the model to CPU if needed
116
- if ENABLE_CPU_OFFLOAD:
117
- sd_pipe.enable_model_cpu_offload()
118
-
119
- MAX_SEED = np.iinfo(np.int32).max
120
-
121
- def save_image(img: Image.Image) -> str:
122
- """Save a PIL image with a unique filename and return the path."""
123
- unique_name = str(uuid.uuid4()) + ".png"
124
- img.save(unique_name)
125
- return unique_name
126
-
127
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
128
- if randomize_seed:
129
- seed = random.randint(0, MAX_SEED)
130
- return seed
131
-
132
- @spaces.GPU(duration=60, enable_queue=True)
133
- def generate_image_fn(
134
- prompt: str,
135
- negative_prompt: str = "",
136
- use_negative_prompt: bool = False,
137
- seed: int = 1,
138
- width: int = 1024,
139
- height: int = 1024,
140
- guidance_scale: float = 3,
141
- num_inference_steps: int = 25,
142
- randomize_seed: bool = False,
143
- use_resolution_binning: bool = True,
144
- num_images: int = 1,
145
- progress=gr.Progress(track_tqdm=True),
146
- ):
147
- """Generate images using the SDXL pipeline."""
148
- seed = int(randomize_seed_fn(seed, randomize_seed))
149
- generator = torch.Generator(device=device).manual_seed(seed)
150
-
151
- options = {
152
- "prompt": [prompt] * num_images,
153
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
154
- "width": width,
155
- "height": height,
156
- "guidance_scale": guidance_scale,
157
- "num_inference_steps": num_inference_steps,
158
- "generator": generator,
159
- "output_type": "pil",
160
- }
161
- if use_resolution_binning:
162
- options["use_resolution_binning"] = True
163
-
164
- images = []
165
- # Process in batches
166
- for i in range(0, num_images, BATCH_SIZE):
167
- batch_options = options.copy()
168
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
169
- if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
170
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
171
- # Wrap the pipeline call in autocast if using CUDA
172
- if device.type == "cuda":
173
- with torch.autocast("cuda", dtype=torch.float16):
174
- outputs = sd_pipe(**batch_options)
175
- else:
176
- outputs = sd_pipe(**batch_options)
177
- images.extend(outputs.images)
178
- image_paths = [save_image(img) for img in images]
179
- return image_paths, seed
180
-
181
- @spaces.GPU
182
- def generate(
183
- input_dict: dict,
184
- chat_history: list[dict],
185
- max_new_tokens: int = 1024,
186
- temperature: float = 0.6,
187
- top_p: float = 0.9,
188
- top_k: int = 50,
189
- repetition_penalty: float = 1.2,
190
- ):
191
- """
192
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
193
- Special commands:
194
- - "@tts1" or "@tts2": triggers text-to-speech.
195
- - "@image": triggers image generation using the SDXL pipeline.
196
- """
197
- text = input_dict["text"]
198
- files = input_dict.get("files", [])
199
-
200
- if text.strip().lower().startswith("@image"):
201
- # Remove the "@image" tag and use the rest as prompt
202
- prompt = text[len("@image"):].strip()
203
- yield "Generating image..."
204
- image_paths, used_seed = generate_image_fn(
205
- prompt=prompt,
206
- negative_prompt="",
207
- use_negative_prompt=False,
208
- seed=1,
209
- width=1024,
210
- height=1024,
211
- guidance_scale=3,
212
- num_inference_steps=25,
213
- randomize_seed=True,
214
- use_resolution_binning=True,
215
- num_images=1,
216
- )
217
- # Yield the generated image so that the chat interface displays it.
218
- yield gr.Image(image_paths[0])
219
- return # Exit early
220
-
221
- tts_prefix = "@tts"
222
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
223
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
224
-
225
- if is_tts and voice_index:
226
- voice = TTS_VOICES[voice_index - 1]
227
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
228
- # Clear previous chat history for a fresh TTS request.
229
- conversation = [{"role": "user", "content": text}]
230
- else:
231
- voice = None
232
- # Remove any stray @tts tags and build the conversation history.
233
- text = text.replace(tts_prefix, "").strip()
234
- conversation = clean_chat_history(chat_history)
235
- conversation.append({"role": "user", "content": text})
236
-
237
- if files:
238
- if len(files) > 1:
239
- images = [load_image(image) for image in files]
240
- elif len(files) == 1:
241
- images = [load_image(files[0])]
242
- else:
243
- images = []
244
- messages = [{
245
- "role": "user",
246
- "content": [
247
- *[{"type": "image", "image": image} for image in images],
248
- {"type": "text", "text": text},
249
- ]
250
- }]
251
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
252
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
253
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
254
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
255
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
256
- thread.start()
257
-
258
- buffer = ""
259
- yield "Thinking..."
260
- for new_text in streamer:
261
- buffer += new_text
262
- buffer = buffer.replace("<|im_end|>", "")
263
- time.sleep(0.01)
264
- yield buffer
265
- else:
266
-
267
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
268
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
269
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
270
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
271
- input_ids = input_ids.to(model.device)
272
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
273
- generation_kwargs = {
274
- "input_ids": input_ids,
275
- "streamer": streamer,
276
- "max_new_tokens": max_new_tokens,
277
- "do_sample": True,
278
- "top_p": top_p,
279
- "top_k": top_k,
280
- "temperature": temperature,
281
- "num_beams": 1,
282
- "repetition_penalty": repetition_penalty,
283
- }
284
- t = Thread(target=model.generate, kwargs=generation_kwargs)
285
- t.start()
286
-
287
- outputs = []
288
- for new_text in streamer:
289
- outputs.append(new_text)
290
- yield "".join(outputs)
291
-
292
- final_response = "".join(outputs)
293
- yield final_response
294
-
295
- # If TTS was requested, convert the final response to speech.
296
- if is_tts and voice:
297
- output_file = asyncio.run(text_to_speech(final_response, voice))
298
- yield gr.Audio(output_file, autoplay=True)
299
-
300
- demo = gr.ChatInterface(
301
- fn=generate,
302
- additional_inputs=[
303
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
304
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
305
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
306
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
307
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
308
- ],
309
- examples=[
310
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
311
- [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
312
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
313
- ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
314
- ["Write a Python function to check if a number is prime."],
315
- ["@tts2 What causes rainbows to form?"],
316
-
317
- ],
318
- cache_examples=False,
319
- type="messages",
320
- description=DESCRIPTION,
321
- css=css,
322
- fill_height=True,
323
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
324
- stop_btn="Stop Generation",
325
- multimodal=True,
326
- )
327
-
328
- if __name__ == "__main__":
329
- # To create a public link, set share=True in launch().
 
330
  demo.queue(max_size=20).launch(share=True)
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
+ import spaces
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import edge_tts
15
+
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ TextIteratorStreamer,
20
+ Qwen2VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ )
23
+ from transformers.image_utils import load_image
24
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
+
26
+
27
+ DESCRIPTION = """
28
+ # QwQ Edge 💬
29
+ """
30
+
31
+ css = '''
32
+ h1 {
33
+ text-align: center;
34
+ display: block;
35
+ }
36
+
37
+ #duplicate-button {
38
+ margin: auto;
39
+ color: #fff;
40
+ background: #1565c0;
41
+ border-radius: 100vh;
42
+ }
43
+ '''
44
+
45
+ MAX_MAX_NEW_TOKENS = 2048
46
+ DEFAULT_MAX_NEW_TOKENS = 1024
47
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
48
+
49
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
+
51
+ # Load text-only model and tokenizer
52
+ #model_id = "prithivMLmods/FastThink-0.5B-Tiny"
53
+ model_id = "prithivMLmods/SmolLM2_135M_Grpo_Gsm8k"
54
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ model_id,
57
+ device_map="auto",
58
+ torch_dtype=torch.bfloat16,
59
+ )
60
+ model.eval()
61
+
62
+ TTS_VOICES = [
63
+ "en-US-JennyNeural", # @tts1
64
+ "en-US-GuyNeural", # @tts2
65
+ ]
66
+
67
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
68
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
69
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
70
+ MODEL_ID,
71
+ trust_remote_code=True,
72
+ torch_dtype=torch.float16
73
+ ).to("cuda").eval()
74
+
75
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
76
+ """Convert text to speech using Edge TTS and save as MP3"""
77
+ communicate = edge_tts.Communicate(text, voice)
78
+ await communicate.save(output_file)
79
+ return output_file
80
+
81
+ def clean_chat_history(chat_history):
82
+ """
83
+ Filter out any chat entries whose "content" is not a string.
84
+ This helps prevent errors when concatenating previous messages.
85
+ """
86
+ cleaned = []
87
+ for msg in chat_history:
88
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
89
+ cleaned.append(msg)
90
+ return cleaned
91
+
92
+ # Environment variables and parameters for Stable Diffusion XL
93
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
94
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
95
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
96
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
97
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
98
+
99
+ # Load the SDXL pipeline
100
+ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
101
+ MODEL_ID_SD,
102
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
103
+ use_safetensors=True,
104
+ add_watermarker=False,
105
+ ).to(device)
106
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
107
+
108
+ # Ensure that the text encoder is in half-precision if using CUDA.
109
+ if torch.cuda.is_available():
110
+ sd_pipe.text_encoder = sd_pipe.text_encoder.half()
111
+
112
+ # Optional: compile the model for speedup if enabled
113
+ if USE_TORCH_COMPILE:
114
+ sd_pipe.compile()
115
+
116
+ # Optional: offload parts of the model to CPU if needed
117
+ if ENABLE_CPU_OFFLOAD:
118
+ sd_pipe.enable_model_cpu_offload()
119
+
120
+ MAX_SEED = np.iinfo(np.int32).max
121
+
122
+ def save_image(img: Image.Image) -> str:
123
+ """Save a PIL image with a unique filename and return the path."""
124
+ unique_name = str(uuid.uuid4()) + ".png"
125
+ img.save(unique_name)
126
+ return unique_name
127
+
128
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
129
+ if randomize_seed:
130
+ seed = random.randint(0, MAX_SEED)
131
+ return seed
132
+
133
+ @spaces.GPU(duration=60, enable_queue=True)
134
+ def generate_image_fn(
135
+ prompt: str,
136
+ negative_prompt: str = "",
137
+ use_negative_prompt: bool = False,
138
+ seed: int = 1,
139
+ width: int = 1024,
140
+ height: int = 1024,
141
+ guidance_scale: float = 3,
142
+ num_inference_steps: int = 25,
143
+ randomize_seed: bool = False,
144
+ use_resolution_binning: bool = True,
145
+ num_images: int = 1,
146
+ progress=gr.Progress(track_tqdm=True),
147
+ ):
148
+ """Generate images using the SDXL pipeline."""
149
+ seed = int(randomize_seed_fn(seed, randomize_seed))
150
+ generator = torch.Generator(device=device).manual_seed(seed)
151
+
152
+ options = {
153
+ "prompt": [prompt] * num_images,
154
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
155
+ "width": width,
156
+ "height": height,
157
+ "guidance_scale": guidance_scale,
158
+ "num_inference_steps": num_inference_steps,
159
+ "generator": generator,
160
+ "output_type": "pil",
161
+ }
162
+ if use_resolution_binning:
163
+ options["use_resolution_binning"] = True
164
+
165
+ images = []
166
+ # Process in batches
167
+ for i in range(0, num_images, BATCH_SIZE):
168
+ batch_options = options.copy()
169
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
170
+ if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
171
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
172
+ # Wrap the pipeline call in autocast if using CUDA
173
+ if device.type == "cuda":
174
+ with torch.autocast("cuda", dtype=torch.float16):
175
+ outputs = sd_pipe(**batch_options)
176
+ else:
177
+ outputs = sd_pipe(**batch_options)
178
+ images.extend(outputs.images)
179
+ image_paths = [save_image(img) for img in images]
180
+ return image_paths, seed
181
+
182
+ @spaces.GPU
183
+ def generate(
184
+ input_dict: dict,
185
+ chat_history: list[dict],
186
+ max_new_tokens: int = 1024,
187
+ temperature: float = 0.6,
188
+ top_p: float = 0.9,
189
+ top_k: int = 50,
190
+ repetition_penalty: float = 1.2,
191
+ ):
192
+ """
193
+ Generates chatbot responses with support for multimodal input, TTS, and image generation.
194
+ Special commands:
195
+ - "@tts1" or "@tts2": triggers text-to-speech.
196
+ - "@image": triggers image generation using the SDXL pipeline.
197
+ """
198
+ text = input_dict["text"]
199
+ files = input_dict.get("files", [])
200
+
201
+ if text.strip().lower().startswith("@image"):
202
+ # Remove the "@image" tag and use the rest as prompt
203
+ prompt = text[len("@image"):].strip()
204
+ yield "Generating image..."
205
+ image_paths, used_seed = generate_image_fn(
206
+ prompt=prompt,
207
+ negative_prompt="",
208
+ use_negative_prompt=False,
209
+ seed=1,
210
+ width=1024,
211
+ height=1024,
212
+ guidance_scale=3,
213
+ num_inference_steps=25,
214
+ randomize_seed=True,
215
+ use_resolution_binning=True,
216
+ num_images=1,
217
+ )
218
+ # Yield the generated image so that the chat interface displays it.
219
+ yield gr.Image(image_paths[0])
220
+ return # Exit early
221
+
222
+ tts_prefix = "@tts"
223
+ is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
224
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
225
+
226
+ if is_tts and voice_index:
227
+ voice = TTS_VOICES[voice_index - 1]
228
+ text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
229
+ # Clear previous chat history for a fresh TTS request.
230
+ conversation = [{"role": "user", "content": text}]
231
+ else:
232
+ voice = None
233
+ # Remove any stray @tts tags and build the conversation history.
234
+ text = text.replace(tts_prefix, "").strip()
235
+ conversation = clean_chat_history(chat_history)
236
+ conversation.append({"role": "user", "content": text})
237
+
238
+ if files:
239
+ if len(files) > 1:
240
+ images = [load_image(image) for image in files]
241
+ elif len(files) == 1:
242
+ images = [load_image(files[0])]
243
+ else:
244
+ images = []
245
+ messages = [{
246
+ "role": "user",
247
+ "content": [
248
+ *[{"type": "image", "image": image} for image in images],
249
+ {"type": "text", "text": text},
250
+ ]
251
+ }]
252
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
253
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
254
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
255
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
256
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
257
+ thread.start()
258
+
259
+ buffer = ""
260
+ yield "Thinking..."
261
+ for new_text in streamer:
262
+ buffer += new_text
263
+ buffer = buffer.replace("<|im_end|>", "")
264
+ time.sleep(0.01)
265
+ yield buffer
266
+ else:
267
+
268
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
269
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
270
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
271
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
272
+ input_ids = input_ids.to(model.device)
273
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
274
+ generation_kwargs = {
275
+ "input_ids": input_ids,
276
+ "streamer": streamer,
277
+ "max_new_tokens": max_new_tokens,
278
+ "do_sample": True,
279
+ "top_p": top_p,
280
+ "top_k": top_k,
281
+ "temperature": temperature,
282
+ "num_beams": 1,
283
+ "repetition_penalty": repetition_penalty,
284
+ }
285
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
286
+ t.start()
287
+
288
+ outputs = []
289
+ for new_text in streamer:
290
+ outputs.append(new_text)
291
+ yield "".join(outputs)
292
+
293
+ final_response = "".join(outputs)
294
+ yield final_response
295
+
296
+ # If TTS was requested, convert the final response to speech.
297
+ if is_tts and voice:
298
+ output_file = asyncio.run(text_to_speech(final_response, voice))
299
+ yield gr.Audio(output_file, autoplay=True)
300
+
301
+ demo = gr.ChatInterface(
302
+ fn=generate,
303
+ additional_inputs=[
304
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
305
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
306
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
307
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
308
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
309
+ ],
310
+ examples=[
311
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
312
+ [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
313
+ [{"text": "summarize the letter", "files": ["examples/1.png"]}],
314
+ ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
315
+ ["Write a Python function to check if a number is prime."],
316
+ ["@tts2 What causes rainbows to form?"],
317
+
318
+ ],
319
+ cache_examples=False,
320
+ type="messages",
321
+ description=DESCRIPTION,
322
+ css=css,
323
+ fill_height=True,
324
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
325
+ stop_btn="Stop Generation",
326
+ multimodal=True,
327
+ )
328
+
329
+ if __name__ == "__main__":
330
+ # To create a public link, set share=True in launch().
331
  demo.queue(max_size=20).launch(share=True)