Upload 4 files
Browse files- config.json +2 -2
- media.py +47 -6
- media_encoder.py +2 -6
- modeling_vila.py +31 -1
config.json
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
],
|
| 7 |
"chat_template": null,
|
| 8 |
"drop_path_rate": 0.0,
|
| 9 |
-
"fps":
|
| 10 |
"hidden_size": 3584,
|
| 11 |
"image_aspect_ratio": "resize",
|
| 12 |
"image_encoder": {
|
|
@@ -177,7 +177,7 @@
|
|
| 177 |
"model_name_or_path": "./LongVILA-R1-7B",
|
| 178 |
"model_type": "vila",
|
| 179 |
"num_time_tokens": 0,
|
| 180 |
-
"num_video_frames":
|
| 181 |
"resume_path": "./LongVILA-R1-7B",
|
| 182 |
"s2": false,
|
| 183 |
"s2_max_split_size": 336,
|
|
|
|
| 6 |
],
|
| 7 |
"chat_template": null,
|
| 8 |
"drop_path_rate": 0.0,
|
| 9 |
+
"fps": 2.0,
|
| 10 |
"hidden_size": 3584,
|
| 11 |
"image_aspect_ratio": "resize",
|
| 12 |
"image_encoder": {
|
|
|
|
| 177 |
"model_name_or_path": "./LongVILA-R1-7B",
|
| 178 |
"model_type": "vila",
|
| 179 |
"num_time_tokens": 0,
|
| 180 |
+
"num_video_frames": 2048,
|
| 181 |
"resume_path": "./LongVILA-R1-7B",
|
| 182 |
"s2": false,
|
| 183 |
"s2_max_split_size": 336,
|
media.py
CHANGED
|
@@ -10,11 +10,6 @@ import PIL.Image
|
|
| 10 |
import requests
|
| 11 |
from transformers import PretrainedConfig
|
| 12 |
|
| 13 |
-
# from llava.constants import MEDIA_TOKENS
|
| 14 |
-
# from llava.media import Image, Video
|
| 15 |
-
# from llava.utils import make_list
|
| 16 |
-
# from llava.utils.logging import logger
|
| 17 |
-
|
| 18 |
MEDIA_TOKENS = {
|
| 19 |
"image": "<image>",
|
| 20 |
"video": "<vila/video>",
|
|
@@ -86,11 +81,57 @@ def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
|
|
| 86 |
frames[index] = PIL.Image.fromarray(frame)
|
| 87 |
return [frames[index] for index in indices if index in frames]
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
def _extract_video(video, config: PretrainedConfig) -> List[PIL.Image.Image]:
|
| 91 |
num_frames = config.num_video_frames
|
| 92 |
video_path = video.path if isinstance(video, Video) else video["path"]
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
return frames
|
| 95 |
|
| 96 |
|
|
|
|
| 10 |
import requests
|
| 11 |
from transformers import PretrainedConfig
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
MEDIA_TOKENS = {
|
| 14 |
"image": "<image>",
|
| 15 |
"video": "<vila/video>",
|
|
|
|
| 81 |
frames[index] = PIL.Image.fromarray(frame)
|
| 82 |
return [frames[index] for index in indices if index in frames]
|
| 83 |
|
| 84 |
+
def _load_video_with_fps(video_path: str, *, num_frames: int, fps: float) -> List[PIL.Image.Image]:
|
| 85 |
+
# Load video frames from a directory
|
| 86 |
+
if os.path.isdir(video_path):
|
| 87 |
+
frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
|
| 88 |
+
indices = np.round(np.linspace(0, len(frame_paths) - 1, min(num_frames, len(frame_paths)))).astype(int)
|
| 89 |
+
return [PIL.Image.open(frame_paths[index]) for index in indices]
|
| 90 |
+
|
| 91 |
+
# Load video frames from a video file
|
| 92 |
+
vidcap = cv2.VideoCapture(video_path)
|
| 93 |
+
if not vidcap.isOpened():
|
| 94 |
+
raise ValueError(f"Cannot open video file: {video_path}")
|
| 95 |
+
|
| 96 |
+
orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 97 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 98 |
+
|
| 99 |
+
# Estimate video duration in seconds
|
| 100 |
+
duration_sec = frame_count / orig_fps if orig_fps > 0 else 0
|
| 101 |
+
|
| 102 |
+
if duration_sec == 0:
|
| 103 |
+
raise ValueError(f"Video '{video_path}' seems to be empty or corrupted.")
|
| 104 |
+
|
| 105 |
+
# Compute total frames to sample based on desired fps
|
| 106 |
+
sampled_frame_count = int(duration_sec * fps)
|
| 107 |
+
sampled_frame_count = ((sampled_frame_count + 127) // 128) * 128
|
| 108 |
+
sampled_frame_count = min(sampled_frame_count, num_frames)
|
| 109 |
+
|
| 110 |
+
# Compute which frame indices to sample
|
| 111 |
+
indices = np.linspace(0, frame_count - 1, sampled_frame_count).astype(int)
|
| 112 |
+
frames = {}
|
| 113 |
+
for index in indices:
|
| 114 |
+
if index in frames:
|
| 115 |
+
continue
|
| 116 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
|
| 117 |
+
success, frame = vidcap.read()
|
| 118 |
+
if not success:
|
| 119 |
+
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
|
| 120 |
+
continue
|
| 121 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 122 |
+
frames[index] = PIL.Image.fromarray(frame)
|
| 123 |
+
|
| 124 |
+
vidcap.release()
|
| 125 |
+
return [frames[index] for index in indices if index in frames]
|
| 126 |
+
|
| 127 |
|
| 128 |
def _extract_video(video, config: PretrainedConfig) -> List[PIL.Image.Image]:
|
| 129 |
num_frames = config.num_video_frames
|
| 130 |
video_path = video.path if isinstance(video, Video) else video["path"]
|
| 131 |
+
if getattr(config, "fps") > 0:
|
| 132 |
+
frames = _load_video_with_fps(video_path, num_frames=num_frames, fps=config.fps)
|
| 133 |
+
else:
|
| 134 |
+
frames = _load_video(video_path, num_frames=num_frames)
|
| 135 |
return frames
|
| 136 |
|
| 137 |
|
media_encoder.py
CHANGED
|
@@ -101,22 +101,18 @@ class BasicVideoEncoder(BaseEncoder):
|
|
| 101 |
return [process_features(f) for f in features]
|
| 102 |
|
| 103 |
def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
|
| 104 |
-
|
| 105 |
-
return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
|
| 106 |
-
else:
|
| 107 |
-
return x.narrow(dim, start=0, length=1)
|
| 108 |
|
| 109 |
class TSPVideoEncoder(BasicVideoEncoder):
|
| 110 |
def __init__(
|
| 111 |
self,
|
| 112 |
parent: torch.nn.Module,
|
| 113 |
-
#pool_sizes: List[Tuple[int, int, int]],
|
| 114 |
start_tokens: Optional[str] = None,
|
| 115 |
end_tokens: Optional[str] = "\n",
|
| 116 |
sep_tokens: Optional[str] = None,
|
| 117 |
) -> None:
|
| 118 |
super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
|
| 119 |
-
self.pool_sizes = [[8, 1, 1]]
|
| 120 |
self.sep_tokens = sep_tokens
|
| 121 |
|
| 122 |
def _process_features(
|
|
|
|
| 101 |
return [process_features(f) for f in features]
|
| 102 |
|
| 103 |
def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
|
| 104 |
+
return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
class TSPVideoEncoder(BasicVideoEncoder):
|
| 107 |
def __init__(
|
| 108 |
self,
|
| 109 |
parent: torch.nn.Module,
|
|
|
|
| 110 |
start_tokens: Optional[str] = None,
|
| 111 |
end_tokens: Optional[str] = "\n",
|
| 112 |
sep_tokens: Optional[str] = None,
|
| 113 |
) -> None:
|
| 114 |
super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
|
| 115 |
+
self.pool_sizes = [[8, 1, 1]]
|
| 116 |
self.sep_tokens = sep_tokens
|
| 117 |
|
| 118 |
def _process_features(
|
modeling_vila.py
CHANGED
|
@@ -725,7 +725,37 @@ class VILAForCausalLM(VILAPretrainedModel):
|
|
| 725 |
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
|
| 726 |
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
|
| 727 |
continue
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
return embeds
|
| 730 |
|
| 731 |
def __truncate_sequence(
|
|
|
|
| 725 |
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
|
| 726 |
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
|
| 727 |
continue
|
| 728 |
+
def round_up_to_bucket(x: int) -> int:
|
| 729 |
+
bucket = 1
|
| 730 |
+
total = 8
|
| 731 |
+
while bucket < total:
|
| 732 |
+
if x <= bucket:
|
| 733 |
+
return bucket
|
| 734 |
+
bucket *= 2
|
| 735 |
+
return total
|
| 736 |
+
|
| 737 |
+
if "video" in name:
|
| 738 |
+
num_video_frames = max([video.shape[0] for video in media[name]])
|
| 739 |
+
if isinstance(self.encoders[name], TSPVideoEncoder):
|
| 740 |
+
self.encoders[name].pool_sizes[0][0] = 4 * round_up_to_bucket(num_video_frames / 256)
|
| 741 |
+
|
| 742 |
+
if num_video_frames > 512:
|
| 743 |
+
media_split = []
|
| 744 |
+
frames_split = 4
|
| 745 |
+
for video in media[name]:
|
| 746 |
+
media_split += video.tensor_split(frames_split, dim=0)
|
| 747 |
+
embeds_split = []
|
| 748 |
+
for video in media_split:
|
| 749 |
+
embeds_split += self.encoders[name]([video], media_config[name])
|
| 750 |
+
embeds_merged = [
|
| 751 |
+
torch.cat(embeds_split[i * frames_split: (i + 1) * frames_split], dim=0)
|
| 752 |
+
for i in range(len(media[name]))
|
| 753 |
+
]
|
| 754 |
+
embeds[name] = deque(embeds_merged)
|
| 755 |
+
else:
|
| 756 |
+
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
|
| 757 |
+
else:
|
| 758 |
+
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
|
| 759 |
return embeds
|
| 760 |
|
| 761 |
def __truncate_sequence(
|