enpeizhao's picture
fix video inference and add images
b651759
raw
history blame
10.8 kB
# app.py
import gradio as gr
import torch
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from PIL import Image
import cv2
import tempfile
import os
import subprocess
# Load your custom VLM model from Hugging Face
MODEL_ID = "enpeizhao/qwen2_5-3b-instruct-trl-sft-vlm-odd-12-nf4-merged"
BASE_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
"""Load the model and processor from Hugging Face"""
try:
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
return model, processor, tokenizer
except Exception as e:
print(f"Error loading model: {e}")
return None, None, None
# Load model at startup
model, processor, tokenizer = load_model()
def convert_video_format(input_path, output_path):
"""Convert video to MP4 format using ffmpeg"""
try:
cmd = [
"ffmpeg",
"-i", input_path,
"-c:v", "libx264",
"-c:a", "aac",
"-strict", "experimental",
"-preset", "fast",
"-y", # Overwrite output file
output_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"FFmpeg error: {result.stderr}")
return False
return True
except Exception as e:
print(f"Error converting video: {e}")
return False
def extract_frames(video_path, max_frames=10):
"""Extract frames from video"""
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return []
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
# Calculate frame indices to sample
if total_frames <= max_frames:
frame_indices = range(total_frames)
else:
frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)]
for i in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
cap.release()
return frames
except Exception as e:
print(f"Error extracting frames: {e}")
return []
def process_video_frames(video_path, prompt):
"""
Process video frames with your VLM model
"""
if model is None or processor is None or tokenizer is None:
return "Model not loaded properly"
try:
# Extract frames from video
frames = extract_frames(video_path, max_frames=8)
if not frames:
return "No frames extracted from video"
# Prepare conversation messages
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": frames},
{"type": "text", "text": prompt},
],
}
]
# Process inputs (this is model-specific)
try:
# Try Qwen-VL style processing first
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, videos=frames, return_tensors="pt")
inputs = inputs.to(device)
# Generate response
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=512)
# Decode output
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
except Exception as e:
# Fallback to simpler processing
print(f"Qwen-VL style processing failed: {e}")
# Process first frame with text prompt
first_frame = frames[0]
inputs = processor(text=prompt, videos=[first_frame], return_tensors="pt").to(device)
# Generate response
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=100)
# Decode output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return f"[Processed first frame only] {response}"
except Exception as e:
return f"Error processing video: {str(e)}"
def process_media(media, prompt):
"""
通用处理函数,支持图片(PIL.Image)或视频(文件路径)
"""
if model is None or processor is None or tokenizer is None:
return "Model not loaded properly"
# 判断输入类型
if isinstance(media, Image.Image):
# 单张图片
frames = [media]
elif isinstance(media, str) and os.path.exists(media):
# 视频路径,提取帧
frames = extract_frames(media, max_frames=8)
if not frames:
return "No frames extracted from video"
else:
return "Unsupported media type"
# 构造消息
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": frames},
{"type": "text", "text": prompt},
],
}
]
try:
# Qwen-VL风格处理
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, videos=frames, return_tensors="pt")
inputs = inputs.to(device)
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
except Exception as e:
print(f"Qwen-VL style processing failed: {e}")
first_frame = frames[0]
try:
inputs = processor(text=prompt, videos=[first_frame], return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return f"[Processed first frame only] {response}"
except Exception as e2:
return f"Error processing media: {str(e2)}"
def video_qa(video, prompt):
"""Main function for Gradio interface"""
if video is None:
return "Please upload a video"
if not prompt:
return "Please enter a question"
try:
# Create temporary files
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input:
input_path = tmp_input.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output:
output_path = tmp_output.name
try:
# Save uploaded video
with open(input_path, "wb") as f:
with open(video, "rb") as uploaded_file:
f.write(uploaded_file.read())
# Convert video to compatible format
if not convert_video_format(input_path, output_path):
# If conversion fails, try to use original
output_path = input_path
# Process video with model
result = process_video_frames(output_path, prompt)
return result
finally:
# Clean up temporary files
for path in [input_path, output_path]:
if os.path.exists(path):
os.unlink(path)
except Exception as e:
return f"Error processing video: {str(e)}"
def media_qa(media, prompt):
"""Gradio接口主函数,支持图片或视频"""
if media is None:
return "Please upload an image or video"
if not prompt:
return "Please enter a question"
# 判断是否为视频文件路径
if isinstance(media, str) and os.path.exists(media):
# 视频处理流程(与原video_qa一致)
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input:
input_path = tmp_input.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output:
output_path = tmp_output.name
try:
with open(input_path, "wb") as f:
with open(media, "rb") as uploaded_file:
f.write(uploaded_file.read())
if not convert_video_format(input_path, output_path):
output_path = input_path
result = process_media(output_path, prompt)
return result
finally:
for path in [input_path, output_path]:
if os.path.exists(path):
os.unlink(path)
except Exception as e:
return f"Error processing video: {str(e)}"
else:
# 图片直接处理
try:
return process_media(media, prompt)
except Exception as e:
return f"Error processing image: {str(e)}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image/Video Question Answering with Custom VLM")
gr.Markdown(f"Model: {MODEL_ID}")
with gr.Row():
with gr.Column():
media_input = gr.File(label="Upload Image or Video", file_types=["image", "video"], interactive=True)
text_input = gr.Textbox(label="Question", placeholder="What is happening in this image or video?")
submit_btn = gr.Button("Process")
with gr.Column():
output_text = gr.Textbox(label="Answer", lines=10)
gr.Examples(
examples=[
[None, "Describe what you see in the image or video"],
[None, "What objects are present in the scene?"]
],
inputs=[media_input, text_input],
outputs=output_text
)
submit_btn.click(
fn=media_qa,
inputs=[media_input, text_input],
outputs=output_text
)
if __name__ == "__main__":
demo.launch()