Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoProcessor, AutoTokenizer | |
| from PIL import Image | |
| import cv2 | |
| import tempfile | |
| import os | |
| import subprocess | |
| # Load your custom VLM model from Hugging Face | |
| MODEL_ID = "enpeizhao/qwen2_5-3b-instruct-trl-sft-vlm-odd-12-nf4-merged" | |
| BASE_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(): | |
| """Load the model and processor from Hugging Face""" | |
| try: | |
| model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(device) | |
| processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True) | |
| return model, processor, tokenizer | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None, None, None | |
| # Load model at startup | |
| model, processor, tokenizer = load_model() | |
| def convert_video_format(input_path, output_path): | |
| """Convert video to MP4 format using ffmpeg""" | |
| try: | |
| cmd = [ | |
| "ffmpeg", | |
| "-i", input_path, | |
| "-c:v", "libx264", | |
| "-c:a", "aac", | |
| "-strict", "experimental", | |
| "-preset", "fast", | |
| "-y", # Overwrite output file | |
| output_path | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| print(f"FFmpeg error: {result.stderr}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"Error converting video: {e}") | |
| return False | |
| def extract_frames(video_path, max_frames=10): | |
| """Extract frames from video""" | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return [] | |
| frames = [] | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| # Calculate frame indices to sample | |
| if total_frames <= max_frames: | |
| frame_indices = range(total_frames) | |
| else: | |
| frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)] | |
| for i in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| ret, frame = cap.read() | |
| if ret: | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| cap.release() | |
| return frames | |
| except Exception as e: | |
| print(f"Error extracting frames: {e}") | |
| return [] | |
| def process_video_frames(video_path, prompt): | |
| """ | |
| Process video frames with your VLM model | |
| """ | |
| if model is None or processor is None or tokenizer is None: | |
| return "Model not loaded properly" | |
| try: | |
| # Extract frames from video | |
| frames = extract_frames(video_path, max_frames=8) | |
| if not frames: | |
| return "No frames extracted from video" | |
| # Prepare conversation messages | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "video", "video": frames}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| # Process inputs (this is model-specific) | |
| try: | |
| # Try Qwen-VL style processing first | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=text, videos=frames, return_tensors="pt") | |
| inputs = inputs.to(device) | |
| # Generate response | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=512) | |
| # Decode output | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return response | |
| except Exception as e: | |
| # Fallback to simpler processing | |
| print(f"Qwen-VL style processing failed: {e}") | |
| # Process first frame with text prompt | |
| first_frame = frames[0] | |
| inputs = processor(text=prompt, videos=[first_frame], return_tensors="pt").to(device) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=100) | |
| # Decode output | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return f"[Processed first frame only] {response}" | |
| except Exception as e: | |
| return f"Error processing video: {str(e)}" | |
| def process_media(media, prompt): | |
| """ | |
| 通用处理函数,支持图片(PIL.Image)或视频(文件路径) | |
| """ | |
| if model is None or processor is None or tokenizer is None: | |
| return "Model not loaded properly" | |
| # 判断输入类型 | |
| if isinstance(media, Image.Image): | |
| # 单张图片 | |
| frames = [media] | |
| elif isinstance(media, str) and os.path.exists(media): | |
| # 视频路径,提取帧 | |
| frames = extract_frames(media, max_frames=8) | |
| if not frames: | |
| return "No frames extracted from video" | |
| else: | |
| return "Unsupported media type" | |
| # 构造消息 | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "video", "video": frames}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| try: | |
| # Qwen-VL风格处理 | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=text, videos=frames, return_tensors="pt") | |
| inputs = inputs.to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=512) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return response | |
| except Exception as e: | |
| print(f"Qwen-VL style processing failed: {e}") | |
| first_frame = frames[0] | |
| try: | |
| inputs = processor(text=prompt, videos=[first_frame], return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=100) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return f"[Processed first frame only] {response}" | |
| except Exception as e2: | |
| return f"Error processing media: {str(e2)}" | |
| def video_qa(video, prompt): | |
| """Main function for Gradio interface""" | |
| if video is None: | |
| return "Please upload a video" | |
| if not prompt: | |
| return "Please enter a question" | |
| try: | |
| # Create temporary files | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input: | |
| input_path = tmp_input.name | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output: | |
| output_path = tmp_output.name | |
| try: | |
| # Save uploaded video | |
| with open(input_path, "wb") as f: | |
| with open(video, "rb") as uploaded_file: | |
| f.write(uploaded_file.read()) | |
| # Convert video to compatible format | |
| if not convert_video_format(input_path, output_path): | |
| # If conversion fails, try to use original | |
| output_path = input_path | |
| # Process video with model | |
| result = process_video_frames(output_path, prompt) | |
| return result | |
| finally: | |
| # Clean up temporary files | |
| for path in [input_path, output_path]: | |
| if os.path.exists(path): | |
| os.unlink(path) | |
| except Exception as e: | |
| return f"Error processing video: {str(e)}" | |
| def media_qa(media, prompt): | |
| """Gradio接口主函数,支持图片或视频""" | |
| if media is None: | |
| return "Please upload an image or video" | |
| if not prompt: | |
| return "Please enter a question" | |
| # 判断是否为视频文件路径 | |
| if isinstance(media, str) and os.path.exists(media): | |
| # 视频处理流程(与原video_qa一致) | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input: | |
| input_path = tmp_input.name | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output: | |
| output_path = tmp_output.name | |
| try: | |
| with open(input_path, "wb") as f: | |
| with open(media, "rb") as uploaded_file: | |
| f.write(uploaded_file.read()) | |
| if not convert_video_format(input_path, output_path): | |
| output_path = input_path | |
| result = process_media(output_path, prompt) | |
| return result | |
| finally: | |
| for path in [input_path, output_path]: | |
| if os.path.exists(path): | |
| os.unlink(path) | |
| except Exception as e: | |
| return f"Error processing video: {str(e)}" | |
| else: | |
| # 图片直接处理 | |
| try: | |
| return process_media(media, prompt) | |
| except Exception as e: | |
| return f"Error processing image: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image/Video Question Answering with Custom VLM") | |
| gr.Markdown(f"Model: {MODEL_ID}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| media_input = gr.File(label="Upload Image or Video", file_types=["image", "video"], interactive=True) | |
| text_input = gr.Textbox(label="Question", placeholder="What is happening in this image or video?") | |
| submit_btn = gr.Button("Process") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Answer", lines=10) | |
| gr.Examples( | |
| examples=[ | |
| [None, "Describe what you see in the image or video"], | |
| [None, "What objects are present in the scene?"] | |
| ], | |
| inputs=[media_input, text_input], | |
| outputs=output_text | |
| ) | |
| submit_btn.click( | |
| fn=media_qa, | |
| inputs=[media_input, text_input], | |
| outputs=output_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |