LongVILA-R1-7B / media.py
Yukang's picture
Update media.py
17fa0d6 verified
import glob
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
import cv2
import numpy as np
import PIL
import PIL.Image
import requests
from transformers import PretrainedConfig
MEDIA_TOKENS = {
"image": "<image>",
"video": "<vila/video>",
}
class Media:
pass
class File(Media):
def __init__(self, path: str) -> None:
self.path = path
class Image(File):
pass
class Video(File):
pass
def make_list(obj: Any) -> List:
return obj if isinstance(obj, list) else [obj]
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
if isinstance(image, Image):
if image.path.startswith("http://") or image.path.startswith("https://"):
image = PIL.Image.open(requests.get(image.path, stream=True).raw)
else:
image = PIL.Image.open(image.path)
return image
def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
# Load video frames from a directory
if os.path.isdir(video_path):
frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
return [PIL.Image.open(frame_paths[index]) for index in indices]
# Load video frames from a video file
vidcap = cv2.VideoCapture(video_path)
# Find the last frame as frame count might not be accurate
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
while frame_count > 0:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
if vidcap.grab():
break
frame_count -= 1
else:
raise ValueError(f"Video '{video_path}' has no frames.")
# Extract frames uniformly
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
frames = {}
for index in indices:
if index in frames:
continue
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
success, frame = vidcap.read()
if not success:
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames[index] = PIL.Image.fromarray(frame)
return [frames[index] for index in indices if index in frames]
def _load_video_with_fps(video_path: str, *, num_frames: int, fps: float) -> List[PIL.Image.Image]:
# Load video frames from a directory
if os.path.isdir(video_path):
frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
indices = np.round(np.linspace(0, len(frame_paths) - 1, min(num_frames, len(frame_paths)))).astype(int)
return [PIL.Image.open(frame_paths[index]) for index in indices]
# Load video frames from a video file
vidcap = cv2.VideoCapture(video_path)
if not vidcap.isOpened():
raise ValueError(f"Cannot open video file: {video_path}")
orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
# Estimate video duration in seconds
duration_sec = frame_count / orig_fps if orig_fps > 0 else 0
if duration_sec == 0:
raise ValueError(f"Video '{video_path}' seems to be empty or corrupted.")
# Compute total frames to sample based on desired fps
sampled_frame_count = min(((int(duration_sec * fps) + 127) // 128) * 128, ((num_frames + 127) // 128) * 128)
# Compute which frame indices to sample
indices = np.linspace(0, frame_count - 1, sampled_frame_count).astype(int)
frames = {}
for index in indices:
if index in frames:
continue
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
success, frame = vidcap.read()
if not success:
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames[index] = PIL.Image.fromarray(frame)
vidcap.release()
return [frames[index] for index in indices if index in frames]
def _extract_video(video, config: PretrainedConfig) -> List[PIL.Image.Image]:
num_frames = config.num_video_frames
video_path = video.path if isinstance(video, Video) else video["path"]
if getattr(config, "fps") > 0:
frames = _load_video_with_fps(video_path, num_frames=num_frames, fps=config.fps)
else:
frames = _load_video(video_path, num_frames=num_frames)
return frames
def extract_media(
messages: List[Dict[str, Any]],
config: Optional[PretrainedConfig] = None,
draft: bool = False,
) -> Dict[str, List[Any]]:
media = defaultdict(list)
for message in messages:
text = ""
for part in make_list(message["value"]):
if isinstance(part, str):
for token in MEDIA_TOKENS.values():
if token in part:
print(f"Media token '{token}' found in text: '{part}'. Removed.")
part = part.replace(token, "").strip()
text += part
elif isinstance(part, (Image, PIL.Image.Image)):
if draft:
media["image"].append(part)
else:
media["image"].append(_extract_image(part))
text += MEDIA_TOKENS["image"]
elif isinstance(part, dict) or isinstance(part, Video):
if draft:
media["video"].append(part)
else:
media["video"].append(_extract_video(part, config))
text += MEDIA_TOKENS["video"]
else:
raise ValueError(f"Unsupported prompt part type: {type(part)}")
message["value"] = text
if MEDIA_TOKENS["video"] in messages[0]["value"]:
messages[0]["value"] = "<vila/video>" + messages[0]["value"].replace("<vila/video>", "")
return media