Spaces:
Running
Running
################################################## PACKAGES ############################################################ | |
################################################# PACKAGES ############################################################# | |
# PyTorch for deep learning operations | |
import torch | |
import torch.nn as nn | |
# PyTorch data loading and utilities | |
import torch.multiprocessing | |
# Additional PyTorch modules and libraries | |
import numpy as np | |
# Hugging Face Transformers library for BERT models | |
from transformers import BertModel, BertTokenizer, AutoImageProcessor, VideoMAEModel | |
# Visualization and progress tracking | |
from datasets import load_dataset | |
import av # pip install av | |
# Additional utility for iterating over combinations | |
import pandas as pd | |
from configs import CFG | |
from text_image import OneEncoder as TextImageEncoder | |
def read_video_pyav(container, indices): | |
""" | |
Decode the video with PyAV decoder. | |
Args: | |
container (`av.container.input.InputContainer`): PyAV container. | |
indices (`List[int]`): List of frame indices to decode. | |
Returns: | |
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). | |
""" | |
frames = [] | |
container.seek(0) | |
start_index = indices[0] | |
end_index = indices[-1] | |
for i, frame in enumerate(container.decode(video=0)): | |
if i > end_index: | |
break | |
if i >= start_index and i in indices: | |
frames.append(frame) | |
return np.stack([x.to_ndarray(format="rgb24") for x in frames]) | |
def sample_frame_indices(clip_len, frame_sample_rate, seg_len): | |
""" | |
Sample a given number of frame indices from the video. | |
Args: | |
clip_len (`int`): Total number of frames to sample. | |
frame_sample_rate (`int`): Sample every n-th frame. | |
seg_len (`int`): Maximum allowed index of sample's last frame. | |
Returns: | |
indices (`List[int]`): List of sampled frame indices | |
""" | |
converted_len = int(clip_len * frame_sample_rate) | |
end_idx = np.random.randint(converted_len, seg_len) | |
start_idx = end_idx - converted_len | |
indices = np.linspace(start_idx, end_idx, num=clip_len) | |
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) | |
return indices | |
class AlignmentLayer(nn.Module): | |
def __init__(self, input_dim=768, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
super(AlignmentLayer, self).__init__(*args, **kwargs) | |
# Attributes | |
self.input_dim = input_dim | |
self.projection_dim = projection_dim | |
self.dropout_rate = dropout_rate | |
# Layers | |
self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim) | |
self.gelu = nn.GELU() | |
self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim) | |
self.dropout = nn.Dropout(self.dropout_rate) | |
self.normalization_layer = nn.LayerNorm(self.projection_dim) | |
def forward(self, inputs): | |
x = inputs | |
x = self.linear_layer1(x) | |
x = self.gelu(x) | |
x = self.linear_layer2(x) | |
x = self.dropout(x) | |
x = self.normalization_layer(x) | |
return x | |
def __call__(self, inputs): | |
return self.forward(inputs) | |
class VideoEncoder(nn.Module): | |
def __init__(self, model_name=CFG.video_name, projection_dim=CFG.projection_dim, | |
trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
super(VideoEncoder, self).__init__(*args, **kwargs) | |
# Attributes | |
self.model_name = model_name | |
self.projection_dim = projection_dim | |
self.dropout_rate = dropout_rate | |
self.trainable = trainable | |
# Models | |
self.pretrained_encoder = VideoMAEModel.from_pretrained(self.model_name) | |
self.alignment_layer = AlignmentLayer( | |
input_dim=self.pretrained_encoder.config.hidden_size, | |
projection_dim=self.projection_dim, | |
dropout_rate=self.dropout_rate) | |
# Freeze VideoMAE | |
for parameter in self.pretrained_encoder.parameters(): | |
parameter.requires_grad = self.trainable | |
def forward(self, inputs): | |
x = self.pretrained_encoder(inputs).last_hidden_state | |
x = self.alignment_layer(x) | |
return x | |
def __call__(self, inputs): | |
return self.forward(inputs) | |
class ModalityTokenEncoder(nn.Module): | |
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs): | |
super(ModalityTokenEncoder, self).__init__(*args, **kwargs) | |
# Attributes | |
self.projection_dim = projection_dim | |
self.device = device | |
self.token_size = token_size | |
# Models | |
video_variance = torch.rand(1) * 0.5 + 0.1 | |
self.video_token = nn.Parameter(torch.normal(mean=0, std=video_variance.item(), | |
size=(self.token_size, self.projection_dim)).to(self.device)) | |
def forward(self): | |
return self.video_token | |
def __call__(self): | |
return self.forward() | |
class OneEncoder(nn.Module): | |
def __init__(self, device='cpu', modality_token_encoder=ModalityTokenEncoder(), checkpoint="bilalfaye/OneEncoder-text-image", | |
video_processor=AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base"), | |
video_encoder=VideoEncoder(), *args, **kwargs): | |
super(OneEncoder, self).__init__(*args, **kwargs) | |
self.device = device | |
self.checkpoint = checkpoint | |
self.modality_token_encoder = modality_token_encoder | |
self.modality_token_encoder.device = self.device | |
self.text_image_encoder = TextImageEncoder(device=self.device) | |
self.text_image_encoder.from_pretrained(self.checkpoint) | |
self.video_processor = video_processor | |
self.video_encoder = video_encoder | |
self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device)) | |
# Freeze | |
for parameter in self.text_image_encoder.parameters(): | |
parameter.requires_grad = False | |
def load_video(cls, video_path): | |
container = av.open(video_path) | |
return container | |
def read_video_pyav(cls, container, indices): | |
""" | |
Decode the video with PyAV decoder. | |
Args: | |
container (`av.container.input.InputContainer`): PyAV container. | |
indices (`List[int]`): List of frame indices to decode. | |
Returns: | |
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). | |
""" | |
frames = [] | |
container.seek(0) | |
start_index = indices[0] | |
end_index = indices[-1] | |
for i, frame in enumerate(container.decode(video=0)): | |
if i > end_index: | |
break | |
if i >= start_index and i in indices: | |
frames.append(frame) | |
return np.stack([x.to_ndarray(format="rgb24") for x in frames]) | |
def sample_frame_indices(cls, clip_len, frame_sample_rate, seg_len): | |
""" | |
Sample a given number of frame indices from the video. | |
Args: | |
clip_len (`int`): Total number of frames to sample. | |
frame_sample_rate (`int`): Sample every n-th frame. | |
seg_len (`int`): Maximum allowed index of sample's last frame. | |
Returns: | |
indices (`List[int]`): List of sampled frame indices | |
""" | |
converted_len = int(clip_len * frame_sample_rate) | |
end_idx = np.random.randint(converted_len, seg_len) | |
start_idx = end_idx - converted_len | |
indices = np.linspace(start_idx, end_idx, num=clip_len) | |
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) | |
return indices | |
def encode_video(self, videos): | |
""" | |
:param videos: torch.Size([batch, 16, 3, 224, 224]) | |
:return: torch.Size([batch, 1568, 768]) | |
""" | |
video_features = self.video_encoder(videos.to(self.device)) | |
modality_token_features = self.modality_token_encoder() | |
outputs = self.text_image_encoder.universal_projection_encoder([video_features, modality_token_features]).last_hidden_state | |
return outputs | |