Spaces:
Sleeping
Sleeping
File size: 10,799 Bytes
138df1c 98ba949 138df1c db29767 138df1c 98ba949 138df1c 98ba949 138df1c 98ba949 138df1c 98ba949 138df1c 98ba949 138df1c 98ba949 138df1c 98ba949 b651759 98ba949 138df1c b651759 138df1c 98ba949 138df1c b651759 138df1c b651759 98ba949 138df1c b651759 138df1c b651759 138df1c b651759 98ba949 138df1c b651759 138df1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
# app.py
import gradio as gr
import torch
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from PIL import Image
import cv2
import tempfile
import os
import subprocess
# Load your custom VLM model from Hugging Face
MODEL_ID = "enpeizhao/qwen2_5-3b-instruct-trl-sft-vlm-odd-12-nf4-merged"
BASE_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
"""Load the model and processor from Hugging Face"""
try:
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
return model, processor, tokenizer
except Exception as e:
print(f"Error loading model: {e}")
return None, None, None
# Load model at startup
model, processor, tokenizer = load_model()
def convert_video_format(input_path, output_path):
"""Convert video to MP4 format using ffmpeg"""
try:
cmd = [
"ffmpeg",
"-i", input_path,
"-c:v", "libx264",
"-c:a", "aac",
"-strict", "experimental",
"-preset", "fast",
"-y", # Overwrite output file
output_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"FFmpeg error: {result.stderr}")
return False
return True
except Exception as e:
print(f"Error converting video: {e}")
return False
def extract_frames(video_path, max_frames=10):
"""Extract frames from video"""
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return []
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
# Calculate frame indices to sample
if total_frames <= max_frames:
frame_indices = range(total_frames)
else:
frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)]
for i in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
cap.release()
return frames
except Exception as e:
print(f"Error extracting frames: {e}")
return []
def process_video_frames(video_path, prompt):
"""
Process video frames with your VLM model
"""
if model is None or processor is None or tokenizer is None:
return "Model not loaded properly"
try:
# Extract frames from video
frames = extract_frames(video_path, max_frames=8)
if not frames:
return "No frames extracted from video"
# Prepare conversation messages
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": frames},
{"type": "text", "text": prompt},
],
}
]
# Process inputs (this is model-specific)
try:
# Try Qwen-VL style processing first
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, videos=frames, return_tensors="pt")
inputs = inputs.to(device)
# Generate response
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=512)
# Decode output
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
except Exception as e:
# Fallback to simpler processing
print(f"Qwen-VL style processing failed: {e}")
# Process first frame with text prompt
first_frame = frames[0]
inputs = processor(text=prompt, videos=[first_frame], return_tensors="pt").to(device)
# Generate response
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=100)
# Decode output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return f"[Processed first frame only] {response}"
except Exception as e:
return f"Error processing video: {str(e)}"
def process_media(media, prompt):
"""
通用处理函数,支持图片(PIL.Image)或视频(文件路径)
"""
if model is None or processor is None or tokenizer is None:
return "Model not loaded properly"
# 判断输入类型
if isinstance(media, Image.Image):
# 单张图片
frames = [media]
elif isinstance(media, str) and os.path.exists(media):
# 视频路径,提取帧
frames = extract_frames(media, max_frames=8)
if not frames:
return "No frames extracted from video"
else:
return "Unsupported media type"
# 构造消息
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": frames},
{"type": "text", "text": prompt},
],
}
]
try:
# Qwen-VL风格处理
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, videos=frames, return_tensors="pt")
inputs = inputs.to(device)
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
except Exception as e:
print(f"Qwen-VL style processing failed: {e}")
first_frame = frames[0]
try:
inputs = processor(text=prompt, videos=[first_frame], return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return f"[Processed first frame only] {response}"
except Exception as e2:
return f"Error processing media: {str(e2)}"
def video_qa(video, prompt):
"""Main function for Gradio interface"""
if video is None:
return "Please upload a video"
if not prompt:
return "Please enter a question"
try:
# Create temporary files
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input:
input_path = tmp_input.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output:
output_path = tmp_output.name
try:
# Save uploaded video
with open(input_path, "wb") as f:
with open(video, "rb") as uploaded_file:
f.write(uploaded_file.read())
# Convert video to compatible format
if not convert_video_format(input_path, output_path):
# If conversion fails, try to use original
output_path = input_path
# Process video with model
result = process_video_frames(output_path, prompt)
return result
finally:
# Clean up temporary files
for path in [input_path, output_path]:
if os.path.exists(path):
os.unlink(path)
except Exception as e:
return f"Error processing video: {str(e)}"
def media_qa(media, prompt):
"""Gradio接口主函数,支持图片或视频"""
if media is None:
return "Please upload an image or video"
if not prompt:
return "Please enter a question"
# 判断是否为视频文件路径
if isinstance(media, str) and os.path.exists(media):
# 视频处理流程(与原video_qa一致)
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input:
input_path = tmp_input.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output:
output_path = tmp_output.name
try:
with open(input_path, "wb") as f:
with open(media, "rb") as uploaded_file:
f.write(uploaded_file.read())
if not convert_video_format(input_path, output_path):
output_path = input_path
result = process_media(output_path, prompt)
return result
finally:
for path in [input_path, output_path]:
if os.path.exists(path):
os.unlink(path)
except Exception as e:
return f"Error processing video: {str(e)}"
else:
# 图片直接处理
try:
return process_media(media, prompt)
except Exception as e:
return f"Error processing image: {str(e)}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image/Video Question Answering with Custom VLM")
gr.Markdown(f"Model: {MODEL_ID}")
with gr.Row():
with gr.Column():
media_input = gr.File(label="Upload Image or Video", file_types=["image", "video"], interactive=True)
text_input = gr.Textbox(label="Question", placeholder="What is happening in this image or video?")
submit_btn = gr.Button("Process")
with gr.Column():
output_text = gr.Textbox(label="Answer", lines=10)
gr.Examples(
examples=[
[None, "Describe what you see in the image or video"],
[None, "What objects are present in the scene?"]
],
inputs=[media_input, text_input],
outputs=output_text
)
submit_btn.click(
fn=media_qa,
inputs=[media_input, text_input],
outputs=output_text
)
if __name__ == "__main__":
demo.launch() |