enpeizhao commited on
Commit
6f0970c
·
1 Parent(s): 205e158
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -152,11 +152,13 @@ def process_media(media, prompt):
152
  if isinstance(media, Image.Image):
153
  # Single image
154
  frames = [media]
 
155
  elif isinstance(media, str) and os.path.exists(media):
156
  # Video path, extract frames
157
  frames = extract_frames(media, max_frames=8)
158
  if not frames:
159
  return "No frames extracted from video"
 
160
  else:
161
  return "Unsupported media type"
162
 
@@ -174,7 +176,10 @@ def process_media(media, prompt):
174
  try:
175
  # Qwen-VL style processing
176
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
177
- inputs = processor(text=text, videos=frames, return_tensors="pt")
 
 
 
178
  inputs = inputs.to(device)
179
  with torch.no_grad():
180
  generated_ids = model.generate(**inputs, max_new_tokens=512)
@@ -187,7 +192,10 @@ def process_media(media, prompt):
187
  print(f"Qwen-VL style processing failed: {e}")
188
  first_frame = frames[0]
189
  try:
190
- inputs = processor(text=prompt, videos=[first_frame], return_tensors="pt").to(device)
 
 
 
191
  with torch.no_grad():
192
  outputs = model.generate(**inputs, max_new_tokens=100)
193
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
152
  if isinstance(media, Image.Image):
153
  # Single image
154
  frames = [media]
155
+ is_image = True
156
  elif isinstance(media, str) and os.path.exists(media):
157
  # Video path, extract frames
158
  frames = extract_frames(media, max_frames=8)
159
  if not frames:
160
  return "No frames extracted from video"
161
+ is_image = False
162
  else:
163
  return "Unsupported media type"
164
 
 
176
  try:
177
  # Qwen-VL style processing
178
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
179
+ if is_image:
180
+ inputs = processor(text=text, images=frames, return_tensors="pt")
181
+ else:
182
+ inputs = processor(text=text, videos=frames, return_tensors="pt")
183
  inputs = inputs.to(device)
184
  with torch.no_grad():
185
  generated_ids = model.generate(**inputs, max_new_tokens=512)
 
192
  print(f"Qwen-VL style processing failed: {e}")
193
  first_frame = frames[0]
194
  try:
195
+ if is_image:
196
+ inputs = processor(text=prompt, images=[first_frame], return_tensors="pt").to(device)
197
+ else:
198
+ inputs = processor(text=prompt, videos=[first_frame], return_tensors="pt").to(device)
199
  with torch.no_grad():
200
  outputs = model.generate(**inputs, max_new_tokens=100)
201
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)