Yukang commited on
Commit
fee541b
·
verified ·
1 Parent(s): 17fa0d6

Update modeling_vila.py

Browse files
Files changed (1) hide show
  1. modeling_vila.py +7 -4
modeling_vila.py CHANGED
@@ -739,15 +739,18 @@ class VILAForCausalLM(VILAPretrainedModel):
739
  self.encoders[name].pool_sizes[0][0] = 4 * round_up_to_bucket(num_video_frames / 256)
740
 
741
  if num_video_frames > 512:
742
- media_split = []
743
- frames_split = 4
744
  for video in media[name]:
745
- media_split += video.tensor_split(frames_split, dim=0)
 
 
 
746
  embeds_split = []
747
  for video in media_split:
748
  embeds_split += self.encoders[name]([video], media_config[name])
 
749
  embeds_merged = [
750
- torch.cat(embeds_split[i * frames_split: (i + 1) * frames_split], dim=0)
751
  for i in range(len(media[name]))
752
  ]
753
  embeds[name] = deque(embeds_merged)
 
739
  self.encoders[name].pool_sizes[0][0] = 4 * round_up_to_bucket(num_video_frames / 256)
740
 
741
  if num_video_frames > 512:
742
+ media_split, num_splits = [], []
 
743
  for video in media[name]:
744
+ video_split = video.split(512, dim=0)
745
+ media_split.extend(video_split)
746
+ num_splits.append(len(video_split))
747
+
748
  embeds_split = []
749
  for video in media_split:
750
  embeds_split += self.encoders[name]([video], media_config[name])
751
+
752
  embeds_merged = [
753
+ torch.cat(embeds_split[i * num_splits[i]: (i + 1) * num_splits[i]], dim=0)
754
  for i in range(len(media[name]))
755
  ]
756
  embeds[name] = deque(embeds_merged)