Update modeling_vila.py
Browse files- modeling_vila.py +7 -4
modeling_vila.py
CHANGED
@@ -739,15 +739,18 @@ class VILAForCausalLM(VILAPretrainedModel):
|
|
739 |
self.encoders[name].pool_sizes[0][0] = 4 * round_up_to_bucket(num_video_frames / 256)
|
740 |
|
741 |
if num_video_frames > 512:
|
742 |
-
media_split = []
|
743 |
-
frames_split = 4
|
744 |
for video in media[name]:
|
745 |
-
|
|
|
|
|
|
|
746 |
embeds_split = []
|
747 |
for video in media_split:
|
748 |
embeds_split += self.encoders[name]([video], media_config[name])
|
|
|
749 |
embeds_merged = [
|
750 |
-
torch.cat(embeds_split[i *
|
751 |
for i in range(len(media[name]))
|
752 |
]
|
753 |
embeds[name] = deque(embeds_merged)
|
|
|
739 |
self.encoders[name].pool_sizes[0][0] = 4 * round_up_to_bucket(num_video_frames / 256)
|
740 |
|
741 |
if num_video_frames > 512:
|
742 |
+
media_split, num_splits = [], []
|
|
|
743 |
for video in media[name]:
|
744 |
+
video_split = video.split(512, dim=0)
|
745 |
+
media_split.extend(video_split)
|
746 |
+
num_splits.append(len(video_split))
|
747 |
+
|
748 |
embeds_split = []
|
749 |
for video in media_split:
|
750 |
embeds_split += self.encoders[name]([video], media_config[name])
|
751 |
+
|
752 |
embeds_merged = [
|
753 |
+
torch.cat(embeds_split[i * num_splits[i]: (i + 1) * num_splits[i]], dim=0)
|
754 |
for i in range(len(media[name]))
|
755 |
]
|
756 |
embeds[name] = deque(embeds_merged)
|