prithivMLmods commited on
Commit
83a0174
·
verified ·
1 Parent(s): f74b154

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -23
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- from collections.abc import Iterator
3
  from threading import Thread
4
  import gradio as gr
5
  import spaces
@@ -35,7 +34,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
35
 
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
- # Load the text-only model and tokenizer
39
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
40
  tokenizer = AutoTokenizer.from_pretrained(model_id)
41
  model = AutoModelForCausalLM.from_pretrained(
@@ -54,7 +53,7 @@ TTS_VOICES = [
54
  "en-US-JasonNeural", # @tts6
55
  ]
56
 
57
- # Load the multimodal (OCR) model and processor
58
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
59
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
60
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -69,6 +68,18 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
69
  await communicate.save(output_file)
70
  return output_file
71
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  @spaces.GPU
73
  def generate(
74
  input_dict: dict,
@@ -80,14 +91,14 @@ def generate(
80
  repetition_penalty: float = 1.2,
81
  ):
82
  """
83
- Generates chatbot response and handles TTS requests with multimodal input support.
84
- If the query starts with a TTS command (e.g. '@tts1'), the chat history is cleared
85
- to avoid non-text responses (like Audio) interfering with template rendering.
86
  """
87
  text = input_dict["text"]
88
  files = input_dict.get("files", [])
89
 
90
- # Check if input includes image(s)
91
  if len(files) > 1:
92
  images = [load_image(image) for image in files]
93
  elif len(files) == 1:
@@ -95,33 +106,35 @@ def generate(
95
  else:
96
  images = []
97
 
98
- # Check if the message is for TTS
99
  tts_prefix = "@tts"
100
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 7))
101
  voice_index = next((i for i in range(1, 7) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
102
-
103
  if is_tts and voice_index:
104
  voice = TTS_VOICES[voice_index - 1]
105
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
106
- # Clear conversation history to avoid issues with non-text outputs.
107
  conversation = [{"role": "user", "content": text}]
108
  else:
109
  voice = None
110
  text = text.replace(tts_prefix, "").strip()
111
- conversation = [*chat_history, {"role": "user", "content": text}]
 
 
112
 
113
- # If there are images, process multimodal input
114
  if images:
115
- messages = [
116
- {"role": "user", "content": [
 
117
  *[{"type": "image", "image": image} for image in images],
118
  {"type": "text", "text": text},
119
- ]}
120
- ]
121
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
122
  inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
123
 
124
- # Handle generation for multimodal input using model_m
125
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
126
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
127
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
@@ -134,9 +147,8 @@ def generate(
134
  buffer = buffer.replace("<|im_end|>", "")
135
  time.sleep(0.01)
136
  yield buffer
137
-
138
  else:
139
- # Process text-only input using model
140
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
141
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
142
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -155,7 +167,7 @@ def generate(
155
  num_beams=1,
156
  repetition_penalty=repetition_penalty,
157
  )
158
- t = Thread(target=model.generate, kwargs=generate_kwargs)
159
  t.start()
160
 
161
  outputs = []
@@ -164,11 +176,9 @@ def generate(
164
  yield "".join(outputs)
165
 
166
  final_response = "".join(outputs)
167
-
168
- # Yield text response first.
169
  yield final_response
170
 
171
- # If TTS was requested, yield audio output separately.
172
  if is_tts and voice:
173
  output_file = asyncio.run(text_to_speech(final_response, voice))
174
  yield gr.Audio(output_file, autoplay=True)
 
1
  import os
 
2
  from threading import Thread
3
  import gradio as gr
4
  import spaces
 
34
 
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
 
37
+ # Text-only model and tokenizer
38
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
39
  tokenizer = AutoTokenizer.from_pretrained(model_id)
40
  model = AutoModelForCausalLM.from_pretrained(
 
53
  "en-US-JasonNeural", # @tts6
54
  ]
55
 
56
+ # Multimodal (OCR) model and processor
57
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
58
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
68
  await communicate.save(output_file)
69
  return output_file
70
 
71
+ def clean_chat_history(chat_history):
72
+ """
73
+ Filter out any entries whose content is not a string.
74
+ This avoids non-text objects (like tuples or Audio) from being concatenated.
75
+ """
76
+ cleaned = []
77
+ for msg in chat_history:
78
+ # Only keep dict messages that have a string 'content'
79
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
80
+ cleaned.append(msg)
81
+ return cleaned
82
+
83
  @spaces.GPU
84
  def generate(
85
  input_dict: dict,
 
91
  repetition_penalty: float = 1.2,
92
  ):
93
  """
94
+ Generates a chatbot response and handles TTS requests with multimodal input support.
95
+ If the user’s query begins with an @tts command, previous chat history is ignored
96
+ (clearing any non-text outputs). Otherwise, the chat history is cleaned to include only text.
97
  """
98
  text = input_dict["text"]
99
  files = input_dict.get("files", [])
100
 
101
+ # Determine if images are provided
102
  if len(files) > 1:
103
  images = [load_image(image) for image in files]
104
  elif len(files) == 1:
 
106
  else:
107
  images = []
108
 
109
+ # Check for TTS prefix
110
  tts_prefix = "@tts"
111
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 7))
112
  voice_index = next((i for i in range(1, 7) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
113
+
114
  if is_tts and voice_index:
115
  voice = TTS_VOICES[voice_index - 1]
116
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
117
+ # Clear any previous chat history when using TTS to avoid type errors
118
  conversation = [{"role": "user", "content": text}]
119
  else:
120
  voice = None
121
  text = text.replace(tts_prefix, "").strip()
122
+ # Clean the chat history to include only messages with string content
123
+ conversation = clean_chat_history(chat_history)
124
+ conversation.append({"role": "user", "content": text})
125
 
126
+ # Multimodal branch if images are provided
127
  if images:
128
+ messages = [{
129
+ "role": "user",
130
+ "content": [
131
  *[{"type": "image", "image": image} for image in images],
132
  {"type": "text", "text": text},
133
+ ]
134
+ }]
135
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
136
  inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
137
 
 
138
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
139
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
140
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
 
147
  buffer = buffer.replace("<|im_end|>", "")
148
  time.sleep(0.01)
149
  yield buffer
 
150
  else:
151
+ # Text-only branch using the text model
152
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
153
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
154
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
167
  num_beams=1,
168
  repetition_penalty=repetition_penalty,
169
  )
170
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
171
  t.start()
172
 
173
  outputs = []
 
176
  yield "".join(outputs)
177
 
178
  final_response = "".join(outputs)
179
+ # Yield text response first
 
180
  yield final_response
181
 
 
182
  if is_tts and voice:
183
  output_file = asyncio.run(text_to_speech(final_response, voice))
184
  yield gr.Audio(output_file, autoplay=True)