prithivMLmods commited on
Commit
fca22b9
·
verified ·
1 Parent(s): 853899c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -63
app.py CHANGED
@@ -1,23 +1,48 @@
1
  import os
2
- import time
3
  from threading import Thread
4
  import gradio as gr
 
5
  import torch
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
- from transformers.image_utils import load_image
8
  import edge_tts
9
  import asyncio
10
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
 
 
 
 
 
 
 
11
 
12
- # Load models
13
- MODEL_ID = "prithivMLmods/FastThink-0.5B-Tiny"
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16).eval()
 
16
 
17
- # For multimodal OCR processing
18
- OCR_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
19
- ocr_processor = AutoProcessor.from_pretrained(OCR_MODEL_ID, trust_remote_code=True)
20
- ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(OCR_MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to("cuda").eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  TTS_VOICES = [
23
  "en-US-JennyNeural", # @tts1
@@ -30,7 +55,14 @@ TTS_VOICES = [
30
  "en-US-TonyNeural", # @tts8
31
  ]
32
 
33
- # Handle text-to-speech conversion
 
 
 
 
 
 
 
34
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
35
  """Convert text to speech using Edge TTS and save as MP3"""
36
  communicate = edge_tts.Communicate(text, voice)
@@ -39,25 +71,27 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
39
 
40
  @spaces.GPU
41
  def generate(
42
- input_dict,
43
- history,
44
- max_new_tokens: int = 1024,
45
- temperature: float = 0.6,
46
- top_p: float = 0.9,
47
- top_k: int = 50,
48
- repetition_penalty: float = 1.2
49
  ):
50
- """Generates chatbot response and handles TTS requests with multimodal support"""
51
- text = input_dict.get("text", "")
52
  files = input_dict.get("files", [])
53
-
54
- # Handle multimodal OCR processing
55
- if files:
56
  images = [load_image(image) for image in files]
 
 
57
  else:
58
  images = []
59
-
60
- # Check if the message is TTS request
61
  tts_prefix = "@tts"
62
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 9))
63
  voice_index = next((i for i in range(1, 9) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -69,29 +103,36 @@ def generate(
69
  voice = None
70
  text = text.replace(tts_prefix, "").strip()
71
 
72
- # If images are provided, combine image and text for the prompt
 
73
  if images:
74
- # Prepare images as part of the conversation
75
  messages = [
76
- {
77
- "role": "user",
78
- "content": [
79
- *[{"type": "image", "image": image} for image in images],
80
- {"type": "text", "text": text},
81
- ],
82
- }
83
  ]
84
- prompt = ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
85
- inputs = ocr_processor(
86
- text=[prompt],
87
- images=images,
88
- return_tensors="pt",
89
- padding=True,
90
- ).to("cuda")
 
 
 
 
 
 
 
 
 
 
91
 
92
  else:
93
- # Normal text-only input
94
- conversation = [*history, {"role": "user", "content": text}]
95
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
96
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
97
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -110,38 +151,32 @@ def generate(
110
  num_beams=1,
111
  repetition_penalty=repetition_penalty,
112
  )
113
-
114
- # Start generation in a separate thread
115
  t = Thread(target=model.generate, kwargs=generate_kwargs)
116
  t.start()
117
 
118
- # Collect generated text
119
  outputs = []
120
  for text in streamer:
121
  outputs.append(text)
122
  yield "".join(outputs)
 
123
  final_response = "".join(outputs)
124
 
125
- # Handle text-to-speech
126
- if is_tts and voice:
127
- output_file = asyncio.run(text_to_speech(final_response, voice))
128
- yield gr.Audio(output_file, autoplay=True) # Return playable audio
129
- else:
130
- yield final_response # Return text response
131
 
132
- # Gradio Interface
133
- demo = gr.Interface(
134
  fn=generate,
135
- inputs=[
136
- gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), # Multimodal input
137
- gr.Textbox(label="Chat History", value="", placeholder="Previous conversation history"),
138
  gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
139
  gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
140
  gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
141
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
142
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
143
  ],
144
- outputs=["text", "audio"],
145
  examples=[
146
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
147
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
@@ -150,11 +185,15 @@ demo = gr.Interface(
150
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
151
  ["@tts5 What is the capital of France?"],
152
  ],
153
- stop_btn="Stop Generation",
154
- description="QwQ Edge: A Chatbot with Text-to-Speech and Multimodal Support",
 
155
  css=css,
156
  fill_height=True,
 
 
 
157
  )
158
 
159
  if __name__ == "__main__":
160
- demo.launch()
 
1
  import os
2
+ from collections.abc import Iterator
3
  from threading import Thread
4
  import gradio as gr
5
+ import spaces
6
  import torch
 
 
7
  import edge_tts
8
  import asyncio
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
11
+ from transformers.image_utils import load_image
12
+ import time
13
+
14
+ DESCRIPTION = """
15
+ # QwQ Edge 💬
16
+ """
17
 
18
+ css = '''
19
+ h1 {
20
+ text-align: center;
21
+ display: block;
22
+ }
23
 
24
+ #duplicate-button {
25
+ margin: auto;
26
+ color: #fff;
27
+ background: #1565c0;
28
+ border-radius: 100vh;
29
+ }
30
+ '''
31
+
32
+ MAX_MAX_NEW_TOKENS = 2048
33
+ DEFAULT_MAX_NEW_TOKENS = 1024
34
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
35
+
36
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
+
38
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
39
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ model_id,
42
+ device_map="auto",
43
+ torch_dtype=torch.bfloat16,
44
+ )
45
+ model.eval()
46
 
47
  TTS_VOICES = [
48
  "en-US-JennyNeural", # @tts1
 
55
  "en-US-TonyNeural", # @tts8
56
  ]
57
 
58
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
59
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
60
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
61
+ MODEL_ID,
62
+ trust_remote_code=True,
63
+ torch_dtype=torch.float16
64
+ ).to("auto").eval()
65
+
66
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
67
  """Convert text to speech using Edge TTS and save as MP3"""
68
  communicate = edge_tts.Communicate(text, voice)
 
71
 
72
  @spaces.GPU
73
  def generate(
74
+ input_dict: dict,
75
+ chat_history: list[dict],
76
+ max_new_tokens: int = 1024,
77
+ temperature: float = 0.6,
78
+ top_p: float = 0.9,
79
+ top_k: int = 50,
80
+ repetition_penalty: float = 1.2,
81
  ):
82
+ """Generates chatbot response and handles TTS requests with multimodal input support"""
83
+ text = input_dict["text"]
84
  files = input_dict.get("files", [])
85
+
86
+ # Check if input includes image(s)
87
+ if len(files) > 1:
88
  images = [load_image(image) for image in files]
89
+ elif len(files) == 1:
90
+ images = [load_image(files[0])]
91
  else:
92
  images = []
93
+
94
+ # Check if message is for TTS
95
  tts_prefix = "@tts"
96
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 9))
97
  voice_index = next((i for i in range(1, 9) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
103
  voice = None
104
  text = text.replace(tts_prefix, "").strip()
105
 
106
+ conversation = [*chat_history, {"role": "user", "content": text}]
107
+
108
  if images:
109
+ # Process multimodal input
110
  messages = [
111
+ {"role": "user", "content": [
112
+ *[{"type": "image", "image": image} for image in images],
113
+ {"type": "text", "text": text},
114
+ ]}
 
 
 
115
  ]
116
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
117
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
118
+
119
+ # Handle generation for multimodal input
120
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
121
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
122
+
123
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
124
+ thread.start()
125
+
126
+ buffer = ""
127
+ yield "Thinking..."
128
+ for new_text in streamer:
129
+ buffer += new_text
130
+ buffer = buffer.replace("<|im_end|>", "")
131
+ time.sleep(0.01)
132
+ yield buffer
133
 
134
  else:
135
+ # Process text-only input
 
136
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
137
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
138
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
151
  num_beams=1,
152
  repetition_penalty=repetition_penalty,
153
  )
 
 
154
  t = Thread(target=model.generate, kwargs=generate_kwargs)
155
  t.start()
156
 
 
157
  outputs = []
158
  for text in streamer:
159
  outputs.append(text)
160
  yield "".join(outputs)
161
+
162
  final_response = "".join(outputs)
163
 
164
+ if is_tts and voice:
165
+ output_file = asyncio.run(text_to_speech(final_response, voice))
166
+ yield gr.Audio(output_file, autoplay=True) # Return playable audio
167
+ else:
168
+ yield final_response # Return text response
 
169
 
170
+
171
+ demo = gr.ChatInterface(
172
  fn=generate,
173
+ additional_inputs=[
 
 
174
  gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
175
  gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
176
  gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
177
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
178
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
179
  ],
 
180
  examples=[
181
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
182
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
 
185
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
186
  ["@tts5 What is the capital of France?"],
187
  ],
188
+ cache_examples=False,
189
+ type="messages",
190
+ description=DESCRIPTION,
191
  css=css,
192
  fill_height=True,
193
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
194
+ stop_btn="Stop Generation",
195
+ multimodal=True,
196
  )
197
 
198
  if __name__ == "__main__":
199
+ demo.queue(max_size=20).launch()