Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoModel, AutoTokenizer | |
import numpy as np | |
import tempfile | |
import os | |
from decord import VideoReader, cpu | |
from scipy.spatial import cKDTree | |
import math | |
import warnings | |
import spaces | |
warnings.filterwarnings("ignore") | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
def load_model(): | |
"""Load the MiniCPM-V-4.5 model and tokenizer""" | |
global model, tokenizer | |
if model is None: | |
print("Loading MiniCPM-V-4.5 model...") | |
model = AutoModel.from_pretrained( | |
'openbmb/MiniCPM-V-4_5', | |
trust_remote_code=True, | |
attn_implementation='sdpa', | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
model = model.eval() | |
tokenizer = AutoTokenizer.from_pretrained( | |
'openbmb/MiniCPM-V-4_5', | |
trust_remote_code=True | |
) | |
print("Model loaded successfully!") | |
return model, tokenizer | |
def map_to_nearest_scale(values, scale): | |
"""Map values to nearest scale for temporal IDs""" | |
tree = cKDTree(np.asarray(scale)[:, None]) | |
_, indices = tree.query(np.asarray(values)[:, None]) | |
return np.asarray(scale)[indices] | |
def group_array(arr, size): | |
"""Group array into chunks of specified size""" | |
return [arr[i:i+size] for i in range(0, len(arr), size)] | |
def uniform_sample(l, n): | |
"""Uniformly sample n items from list l""" | |
gap = len(l) / n | |
idxs = [int(i * gap + gap / 2) for i in range(n)] | |
return [l[i] for i in idxs] | |
def encode_video(video_path, choose_fps=3, max_frames=180, max_packing=3, time_scale=0.1): | |
"""Encode video frames with temporal IDs for the model""" | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
fps = vr.get_avg_fps() | |
video_duration = len(vr) / fps | |
if choose_fps * int(video_duration) <= max_frames: | |
packing_nums = 1 | |
choose_frames = round(min(choose_fps, round(fps)) * min(max_frames, video_duration)) | |
else: | |
packing_nums = math.ceil(video_duration * choose_fps / max_frames) | |
if packing_nums <= max_packing: | |
choose_frames = round(video_duration * choose_fps) | |
else: | |
choose_frames = round(max_frames * max_packing) | |
packing_nums = max_packing | |
frame_idx = [i for i in range(0, len(vr))] | |
frame_idx = np.array(uniform_sample(frame_idx, choose_frames)) | |
print(f'Video duration: {video_duration:.2f}s, frames: {len(frame_idx)}, packing: {packing_nums}') | |
frames = vr.get_batch(frame_idx).asnumpy() | |
frame_idx_ts = frame_idx / fps | |
scale = np.arange(0, video_duration, time_scale) | |
frame_ts_id = map_to_nearest_scale(frame_idx_ts, scale) / time_scale | |
frame_ts_id = frame_ts_id.astype(np.int32) | |
frames = [Image.fromarray(v.astype('uint8')).convert('RGB') for v in frames] | |
frame_ts_id_group = group_array(frame_ts_id, packing_nums) | |
return frames, frame_ts_id_group | |
def process_input( | |
file_input, | |
user_prompt, | |
system_prompt, | |
fps, | |
context_size, | |
temperature, | |
enable_thinking | |
): | |
"""Process user input and generate response""" | |
try: | |
# Load model if not already loaded | |
model, tokenizer = load_model() | |
if file_input is None: | |
return "Please upload an image or video file." | |
# Determine if input is image or video | |
file_path = file_input | |
file_ext = os.path.splitext(file_path)[1].lower() | |
is_video = file_ext in ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v'] | |
# Prepare messages | |
msgs = [] | |
# Add system prompt if provided | |
if system_prompt and system_prompt.strip(): | |
msgs.append({'role': 'system', 'content': system_prompt.strip()}) | |
if is_video: | |
# Process video | |
frames, frame_ts_id_group = encode_video(file_path, choose_fps=fps) | |
msgs.append({'role': 'user', 'content': frames + [user_prompt]}) | |
# Generate response for video | |
answer = model.chat( | |
msgs=msgs, | |
tokenizer=tokenizer, | |
use_image_id=False, | |
max_slice_nums=1, | |
temporal_ids=frame_ts_id_group, | |
enable_thinking=enable_thinking, | |
max_new_tokens=context_size, | |
temperature=temperature | |
) | |
else: | |
# Process image | |
image = Image.open(file_path).convert('RGB') | |
msgs.append({'role': 'user', 'content': [image, user_prompt]}) | |
# Generate response for image | |
answer = model.chat( | |
msgs=msgs, | |
tokenizer=tokenizer, | |
enable_thinking=enable_thinking, | |
max_new_tokens=context_size, | |
temperature=temperature | |
) | |
return answer | |
except Exception as e: | |
return f"Error processing input: {str(e)}" | |
def create_interface(): | |
"""Create and configure Gradio interface""" | |
with gr.Blocks(title="MiniCPM-V-4.5 Multimodal Chat") as iface: | |
gr.Markdown(""" | |
# MiniCPM-V-4.5 Multimodal Chat | |
A powerful 8B parameter multimodal model that can understand images and videos with GPT-4V level performance. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# File input | |
file_input = gr.File( | |
label="Upload Image or Video", | |
file_types=["image", "video"] | |
) | |
# Video FPS setting | |
fps_slider = gr.Slider( | |
minimum=1, | |
maximum=30, | |
value=5, | |
step=1, | |
label="Video FPS" | |
) | |
# Context size | |
context_size = gr.Slider( | |
minimum=512, | |
maximum=4096, | |
value=2048, | |
step=256, | |
label="Max Output Tokens" | |
) | |
# Temperature | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.6, | |
step=0.1, | |
label="Temperature" | |
) | |
# Thinking mode | |
enable_thinking = gr.Checkbox( | |
label="Enable Deep Thinking", | |
value=False | |
) | |
with gr.Column(scale=2): | |
# System prompt | |
system_prompt = gr.Textbox( | |
label="System Prompt (Optional)", | |
placeholder="Enter system instructions here...", | |
lines=3 | |
) | |
# User prompt | |
user_prompt = gr.Textbox( | |
label="Your Question", | |
placeholder="Describe what you see in the image/video, or ask a specific question...", | |
lines=4 | |
) | |
# Submit button | |
submit_btn = gr.Button("Generate Response", variant="primary") | |
# Output | |
output = gr.Textbox( | |
label="Model Response", | |
lines=15 | |
) | |
# Event handlers | |
submit_btn.click( | |
fn=process_input, | |
inputs=[ | |
file_input, | |
user_prompt, | |
system_prompt, | |
fps_slider, | |
context_size, | |
temperature, | |
enable_thinking | |
], | |
outputs=output | |
) | |
user_prompt.submit( | |
fn=process_input, | |
inputs=[ | |
file_input, | |
user_prompt, | |
system_prompt, | |
fps_slider, | |
context_size, | |
temperature, | |
enable_thinking | |
], | |
outputs=output | |
) | |
return iface | |
if __name__ == "__main__": | |
# Create and launch interface | |
demo = create_interface() | |
demo.launch(share=True) |