Yukang commited on
Commit
8f2be92
·
verified ·
1 Parent(s): 4590c99

Upload 4 files

Browse files
Files changed (4) hide show
  1. config.json +2 -2
  2. media.py +47 -6
  3. media_encoder.py +2 -6
  4. modeling_vila.py +31 -1
config.json CHANGED
@@ -6,7 +6,7 @@
6
  ],
7
  "chat_template": null,
8
  "drop_path_rate": 0.0,
9
- "fps": 0.0,
10
  "hidden_size": 3584,
11
  "image_aspect_ratio": "resize",
12
  "image_encoder": {
@@ -177,7 +177,7 @@
177
  "model_name_or_path": "./LongVILA-R1-7B",
178
  "model_type": "vila",
179
  "num_time_tokens": 0,
180
- "num_video_frames": 512,
181
  "resume_path": "./LongVILA-R1-7B",
182
  "s2": false,
183
  "s2_max_split_size": 336,
 
6
  ],
7
  "chat_template": null,
8
  "drop_path_rate": 0.0,
9
+ "fps": 2.0,
10
  "hidden_size": 3584,
11
  "image_aspect_ratio": "resize",
12
  "image_encoder": {
 
177
  "model_name_or_path": "./LongVILA-R1-7B",
178
  "model_type": "vila",
179
  "num_time_tokens": 0,
180
+ "num_video_frames": 2048,
181
  "resume_path": "./LongVILA-R1-7B",
182
  "s2": false,
183
  "s2_max_split_size": 336,
media.py CHANGED
@@ -10,11 +10,6 @@ import PIL.Image
10
  import requests
11
  from transformers import PretrainedConfig
12
 
13
- # from llava.constants import MEDIA_TOKENS
14
- # from llava.media import Image, Video
15
- # from llava.utils import make_list
16
- # from llava.utils.logging import logger
17
-
18
  MEDIA_TOKENS = {
19
  "image": "<image>",
20
  "video": "<vila/video>",
@@ -86,11 +81,57 @@ def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
86
  frames[index] = PIL.Image.fromarray(frame)
87
  return [frames[index] for index in indices if index in frames]
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def _extract_video(video, config: PretrainedConfig) -> List[PIL.Image.Image]:
91
  num_frames = config.num_video_frames
92
  video_path = video.path if isinstance(video, Video) else video["path"]
93
- frames = _load_video(video_path, num_frames=num_frames)
 
 
 
94
  return frames
95
 
96
 
 
10
  import requests
11
  from transformers import PretrainedConfig
12
 
 
 
 
 
 
13
  MEDIA_TOKENS = {
14
  "image": "<image>",
15
  "video": "<vila/video>",
 
81
  frames[index] = PIL.Image.fromarray(frame)
82
  return [frames[index] for index in indices if index in frames]
83
 
84
+ def _load_video_with_fps(video_path: str, *, num_frames: int, fps: float) -> List[PIL.Image.Image]:
85
+ # Load video frames from a directory
86
+ if os.path.isdir(video_path):
87
+ frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
88
+ indices = np.round(np.linspace(0, len(frame_paths) - 1, min(num_frames, len(frame_paths)))).astype(int)
89
+ return [PIL.Image.open(frame_paths[index]) for index in indices]
90
+
91
+ # Load video frames from a video file
92
+ vidcap = cv2.VideoCapture(video_path)
93
+ if not vidcap.isOpened():
94
+ raise ValueError(f"Cannot open video file: {video_path}")
95
+
96
+ orig_fps = vidcap.get(cv2.CAP_PROP_FPS)
97
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
98
+
99
+ # Estimate video duration in seconds
100
+ duration_sec = frame_count / orig_fps if orig_fps > 0 else 0
101
+
102
+ if duration_sec == 0:
103
+ raise ValueError(f"Video '{video_path}' seems to be empty or corrupted.")
104
+
105
+ # Compute total frames to sample based on desired fps
106
+ sampled_frame_count = int(duration_sec * fps)
107
+ sampled_frame_count = ((sampled_frame_count + 127) // 128) * 128
108
+ sampled_frame_count = min(sampled_frame_count, num_frames)
109
+
110
+ # Compute which frame indices to sample
111
+ indices = np.linspace(0, frame_count - 1, sampled_frame_count).astype(int)
112
+ frames = {}
113
+ for index in indices:
114
+ if index in frames:
115
+ continue
116
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
117
+ success, frame = vidcap.read()
118
+ if not success:
119
+ print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
120
+ continue
121
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
122
+ frames[index] = PIL.Image.fromarray(frame)
123
+
124
+ vidcap.release()
125
+ return [frames[index] for index in indices if index in frames]
126
+
127
 
128
  def _extract_video(video, config: PretrainedConfig) -> List[PIL.Image.Image]:
129
  num_frames = config.num_video_frames
130
  video_path = video.path if isinstance(video, Video) else video["path"]
131
+ if getattr(config, "fps") > 0:
132
+ frames = _load_video_with_fps(video_path, num_frames=num_frames, fps=config.fps)
133
+ else:
134
+ frames = _load_video(video_path, num_frames=num_frames)
135
  return frames
136
 
137
 
media_encoder.py CHANGED
@@ -101,22 +101,18 @@ class BasicVideoEncoder(BaseEncoder):
101
  return [process_features(f) for f in features]
102
 
103
  def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
104
- if x.shape[dim] % size == 0:
105
- return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
106
- else:
107
- return x.narrow(dim, start=0, length=1)
108
 
109
  class TSPVideoEncoder(BasicVideoEncoder):
110
  def __init__(
111
  self,
112
  parent: torch.nn.Module,
113
- #pool_sizes: List[Tuple[int, int, int]],
114
  start_tokens: Optional[str] = None,
115
  end_tokens: Optional[str] = "\n",
116
  sep_tokens: Optional[str] = None,
117
  ) -> None:
118
  super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
119
- self.pool_sizes = [[8, 1, 1]] #pool_sizes
120
  self.sep_tokens = sep_tokens
121
 
122
  def _process_features(
 
101
  return [process_features(f) for f in features]
102
 
103
  def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
104
+ return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
 
 
 
105
 
106
  class TSPVideoEncoder(BasicVideoEncoder):
107
  def __init__(
108
  self,
109
  parent: torch.nn.Module,
 
110
  start_tokens: Optional[str] = None,
111
  end_tokens: Optional[str] = "\n",
112
  sep_tokens: Optional[str] = None,
113
  ) -> None:
114
  super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
115
+ self.pool_sizes = [[8, 1, 1]]
116
  self.sep_tokens = sep_tokens
117
 
118
  def _process_features(
modeling_vila.py CHANGED
@@ -725,7 +725,37 @@ class VILAForCausalLM(VILAPretrainedModel):
725
  dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
726
  embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
727
  continue
728
- embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729
  return embeds
730
 
731
  def __truncate_sequence(
 
725
  dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
726
  embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
727
  continue
728
+ def round_up_to_bucket(x: int) -> int:
729
+ bucket = 1
730
+ total = 8
731
+ while bucket < total:
732
+ if x <= bucket:
733
+ return bucket
734
+ bucket *= 2
735
+ return total
736
+
737
+ if "video" in name:
738
+ num_video_frames = max([video.shape[0] for video in media[name]])
739
+ if isinstance(self.encoders[name], TSPVideoEncoder):
740
+ self.encoders[name].pool_sizes[0][0] = 4 * round_up_to_bucket(num_video_frames / 256)
741
+
742
+ if num_video_frames > 512:
743
+ media_split = []
744
+ frames_split = 4
745
+ for video in media[name]:
746
+ media_split += video.tensor_split(frames_split, dim=0)
747
+ embeds_split = []
748
+ for video in media_split:
749
+ embeds_split += self.encoders[name]([video], media_config[name])
750
+ embeds_merged = [
751
+ torch.cat(embeds_split[i * frames_split: (i + 1) * frames_split], dim=0)
752
+ for i in range(len(media[name]))
753
+ ]
754
+ embeds[name] = deque(embeds_merged)
755
+ else:
756
+ embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
757
+ else:
758
+ embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
759
  return embeds
760
 
761
  def __truncate_sequence(