|
|
|
|
|
|
|
|
|
import argparse |
|
from collections import OrderedDict |
|
import json |
|
import numpy as np |
|
import os |
|
import pandas as pd |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.backends.cudnn as cudnn |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms._transforms_video as transforms_video |
|
from sklearn.metrics import confusion_matrix |
|
|
|
from lavila.data import datasets |
|
from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop |
|
from lavila.models import models |
|
from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) |
|
from lavila.models.utils import inflate_positional_embeds |
|
from lavila.utils.config import load_cfg |
|
from lavila.utils.evaluation_charades import charades_map |
|
from lavila.utils.evaluation import get_mean_accuracy |
|
from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG) |
|
|
|
|
|
class VideoModel(nn.Module): |
|
""" Base model for video understanding based on LaViLa architecture. """ |
|
def __init__(self, config): |
|
""" Initializes the model. |
|
Parameters: |
|
config: config file |
|
""" |
|
super(VideoModel, self).__init__() |
|
self.cfg = load_cfg(config) |
|
self.model = self.build_model() |
|
self.tokenizer = self.get_tokenizer() |
|
self.templates = ['{}'] |
|
self.dataset = self.cfg['data']['dataset'] |
|
self.eval() |
|
|
|
def build_model(self): |
|
cfg = self.cfg |
|
if cfg['model'].get('pretrain', False): |
|
ckpt_path = cfg['model']['pretrain'] |
|
else: |
|
raise Exception('no checkpoint found') |
|
ckpt = torch.load(ckpt_path, map_location='cpu') |
|
|
|
state_dict = OrderedDict() |
|
for k, v in ckpt['state_dict'].items(): |
|
state_dict[k.replace('module.', '')] = v |
|
|
|
old_args = vars(ckpt['args']) |
|
arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE') |
|
self.arch = arch |
|
cfg['model']['arch'] = arch |
|
cfg['model']['norm_embed'] = old_args.get('norm_embed', True) |
|
print("=> creating model: {}".format(arch)) |
|
model = getattr(models, arch)( |
|
pretrained=old_args.get('load_visual_pretrained', None), |
|
pretrained2d=old_args.get('load_visual_pretrained', None) is not None, |
|
text_use_cls_token=old_args.get('use_cls_token', False), |
|
project_embed_dim=old_args.get('project_embed_dim', 256), |
|
timesformer_gated_xattn=False, |
|
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), |
|
model_cfg=cfg['model'] |
|
) |
|
model.logit_scale.requires_grad = False |
|
|
|
if torch.cuda.is_available(): |
|
model.cuda() |
|
|
|
if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True): |
|
|
|
print('=> inflating PE in models due to different frame numbers') |
|
state_dict = inflate_positional_embeds( |
|
model.state_dict(), state_dict, |
|
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), |
|
load_temporal_fix='bilinear', |
|
) |
|
model.load_state_dict(state_dict, strict=True) |
|
print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) |
|
|
|
return model |
|
|
|
def eval(self): |
|
cudnn.benchmark = True |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.model.eval() |
|
|
|
def get_tokenizer(self): |
|
arch = self.arch |
|
if arch.endswith('DISTILBERT_BASE'): |
|
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') |
|
elif arch.endswith('BERT_BASE'): |
|
tokenizer = MyBertTokenizer('bert-base-uncased') |
|
elif arch.endswith('BERT_LARGE'): |
|
tokenizer = MyBertTokenizer('bert-large-uncased') |
|
elif arch.endswith('GPT2'): |
|
tokenizer = MyGPT2Tokenizer('gpt2') |
|
elif arch.endswith('GPT2_MEDIUM'): |
|
tokenizer = MyGPT2Tokenizer('gpt2-medium') |
|
elif arch.endswith('GPT2_LARGE'): |
|
tokenizer = MyGPT2Tokenizer('gpt2-large') |
|
elif arch.endswith('GPT2_XL'): |
|
tokenizer = MyGPT2Tokenizer('gpt2-xl') |
|
else: |
|
print("Using SimpleTokenizer because of model '{}'. " |
|
"Please check if this is what you want".format(arch)) |
|
tokenizer = SimpleTokenizer() |
|
|
|
return tokenizer |
|
|
|
|
|
class VideoCLSModel(VideoModel): |
|
""" Video model for video classification tasks (Charades-Ego, EGTEA). """ |
|
def __init__(self, config): |
|
super(VideoCLSModel, self).__init__(config) |
|
self.labels, self.mapping_vn2act = self.gen_label_map() |
|
self.text_features = self.get_text_features() |
|
|
|
def gen_label_map(self): |
|
labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json') |
|
if os.path.isfile(labelmap): |
|
print(f"=> Loading label maps from {labelmap}") |
|
meta = json.load(open(labelmap, 'r')) |
|
labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act'] |
|
else: |
|
from lavila.utils.preprocess import generate_label_map |
|
labels, mapping_vn2act = generate_label_map(self.dataset) |
|
meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act} |
|
meta_dir = f'meta/{self.dataset}' |
|
if not os.path.exists(meta_dir): |
|
os.makedirs(meta_dir) |
|
json.dump(meta, open(f'{meta_dir}/label_map.json', 'w')) |
|
print(f"=> Label map is generated and saved to {meta_dir}/label_map.json") |
|
|
|
return labels, mapping_vn2act |
|
|
|
def load_data(self, idx=None): |
|
print(f"=> Creating dataset") |
|
cfg, dataset = self.cfg, self.dataset |
|
data_cfg = cfg['data'] |
|
crop_size = 224 if '336PX' not in self.arch else 336 |
|
val_transform = transforms.Compose([ |
|
Permute([3, 0, 1, 2]), |
|
transforms.Resize(crop_size), |
|
transforms.CenterCrop(crop_size), |
|
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]), |
|
]) |
|
|
|
if idx is None: |
|
metadata_val = data_cfg['metadata_val'] |
|
else: |
|
metadata_val = data_cfg['metadata_val'].format(idx) |
|
if dataset in ['charades_ego', 'egtea']: |
|
val_dataset = datasets.VideoClassyDataset( |
|
dataset, data_cfg['root'], metadata_val, |
|
transform=val_transform, is_training=False, |
|
label_mapping=self.mapping_vn2act, is_trimmed=False, |
|
num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], |
|
sparse_sample=data_cfg['sparse_sample'] |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
val_loader = torch.utils.data.DataLoader( |
|
val_dataset, batch_size=8, shuffle=False, |
|
num_workers=4, pin_memory=True, sampler=None, drop_last=False |
|
) |
|
|
|
return val_loader |
|
|
|
@torch.no_grad() |
|
def get_text_features(self): |
|
print('=> Extracting text features') |
|
text_features = [] |
|
for label in self.labels: |
|
if isinstance(label, list): |
|
texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label] |
|
else: |
|
texts = [tmpl.format(label) for tmpl in self.templates] |
|
texts = self.tokenizer(texts) |
|
if isinstance(texts, tuple): |
|
|
|
texts, masks = texts |
|
texts = texts.cuda(non_blocking=True) |
|
masks = masks.cuda(non_blocking=True) |
|
else: |
|
texts = texts.cuda(non_blocking=True) |
|
masks = None |
|
texts = texts.view(-1, 77).contiguous() |
|
masks = masks.view(-1, 77).contiguous() if masks is not None else None |
|
if masks is not None: |
|
class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks) |
|
else: |
|
class_embeddings, _ = self.model.encode_text(texts) |
|
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) |
|
class_embeddings = class_embeddings.mean(dim=0) |
|
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) |
|
|
|
text_features.append(class_embeddings) |
|
text_features = torch.stack(text_features, dim=0) |
|
|
|
return text_features |
|
|
|
@torch.no_grad() |
|
def forward(self, idx=None): |
|
print('=> Start forwarding') |
|
val_loader = self.load_data(idx) |
|
all_outputs = [] |
|
all_targets = [] |
|
for i, values in enumerate(val_loader): |
|
images = values[0] |
|
target = values[1] |
|
|
|
images = images.cuda(non_blocking=True) |
|
target = target.cuda(non_blocking=True) |
|
|
|
|
|
image_features, _ = self.model.encode_image(images) |
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
|
logits_per_image = image_features @ self.text_features.t() |
|
logits_per_image = torch.softmax(logits_per_image, dim=1) |
|
|
|
all_outputs.append(logits_per_image.cpu()) |
|
all_targets.append(target.cpu()) |
|
|
|
all_outputs = torch.cat(all_outputs) |
|
all_targets = torch.cat(all_targets) |
|
|
|
return all_outputs, all_targets |
|
|
|
@torch.no_grad() |
|
def predict(self, idx=0): |
|
all_outputs, all_targets = self.forward(idx) |
|
preds, targets = all_outputs.numpy(), all_targets.numpy() |
|
sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.055)[0][0] |
|
|
|
df = pd.DataFrame(self.labels) |
|
pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist() |
|
gt_action = df.iloc[np.where(targets[0])[0]].values.tolist() |
|
pred_action = sorted([x[0] for x in pred_action]) |
|
gt_action = sorted([x[0] for x in gt_action]) |
|
return pred_action, gt_action |
|
|
|
@torch.no_grad() |
|
def evaluate(self): |
|
all_outputs, all_targets = self.forward() |
|
preds, targets = all_outputs.numpy(), all_targets.numpy() |
|
if self.dataset == 'charades_ego': |
|
m_ap, _, m_aps = charades_map(preds, targets) |
|
print('mAP = {:.3f}'.format(m_ap)) |
|
elif self.dataset == 'egtea': |
|
cm = confusion_matrix(targets, preds.argmax(axis=1)) |
|
mean_class_acc, acc = get_mean_accuracy(cm) |
|
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
class VideoMIRModel(VideoModel): |
|
""" Video model for video multi-instance retrieval tasks (EK100_MIR). """ |
|
def __init__(self, config): |
|
super(VideoMIRModel, self).__init__(config) |
|
self.narrations = pd.read_csv(self.cfg['data']['narrations']).values[:, 1] |
|
self.text_features = self.get_text_features() |
|
self.video_samples = pd.read_csv('meta/ek100_mir/sel_t2v.csv').values[:, 0] |
|
|
|
def load_data(self, idx=None, t2v=False): |
|
print(f"=> Creating dataset") |
|
cfg, dataset = self.cfg, self.dataset |
|
data_cfg = cfg['data'] |
|
crop_size = 224 if '336PX' not in self.arch else 336 |
|
val_transform = transforms.Compose([ |
|
Permute([3, 0, 1, 2]), |
|
transforms.Resize(crop_size), |
|
transforms.CenterCrop(crop_size), |
|
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]), |
|
]) |
|
|
|
if dataset == 'ek100_mir': |
|
if t2v: |
|
metadata_val = 'meta/ek100_mir/sel_t2v.csv' |
|
self.relevancy_mat_v2t = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_v2t')) |
|
self.relevancy_mat_t2v = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_t2v')) |
|
val_dataset = datasets.VideoCaptionDatasetCLIP( |
|
'ek100_mir_demo', data_cfg['root'], metadata_val, val_transform, |
|
is_training=False, tokenizer=self.tokenizer, |
|
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'] |
|
) |
|
elif idx is None: |
|
metadata_val = data_cfg['metadata_val'] |
|
val_dataset = datasets.get_dataset(val_transform, self.tokenizer, cfg, is_training=False) |
|
else: |
|
metadata_val = data_cfg['metadata_val'].format(idx) |
|
self.relevancy_mat_v2t = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_v2t')) |
|
self.relevancy_mat_t2v = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_t2v')) |
|
val_dataset = datasets.VideoCaptionDatasetCLIP( |
|
'ek100_mir_demo', data_cfg['root'], metadata_val, val_transform, |
|
is_training=False, tokenizer=self.tokenizer, |
|
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'] |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
val_loader = torch.utils.data.DataLoader( |
|
val_dataset, batch_size=8, shuffle=False, |
|
num_workers=4, pin_memory=True, sampler=None, drop_last=False |
|
) |
|
|
|
return val_loader |
|
|
|
@torch.no_grad() |
|
def get_text_features(self): |
|
print('=> Extracting text features') |
|
text_features = [] |
|
for text in self.narrations: |
|
text = self.tokenizer(text) |
|
text = text.cuda(non_blocking=True) |
|
text = text.view(-1, 77).contiguous() |
|
text_embed, _ = self.model.encode_text(text) |
|
text_embed = F.normalize(text_embed, dim=-1).squeeze() |
|
text_features.append(text_embed) |
|
|
|
text_features = torch.stack(text_features, dim=0) |
|
|
|
return text_features |
|
|
|
@torch.no_grad() |
|
def forward_video(self, text_features=None, idx=None, t2v=False): |
|
print('=> Start forwarding') |
|
if t2v: |
|
val_loader = self.load_data(t2v=t2v) |
|
else: |
|
val_loader = self.load_data(idx=idx) |
|
all_outputs = [] |
|
for i, values in enumerate(val_loader): |
|
images = values[0].cuda(non_blocking=True) |
|
|
|
|
|
image_features, _ = self.model.encode_image(images) |
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
if t2v: |
|
all_outputs.append(image_features) |
|
else: |
|
|
|
logits_per_image = image_features @ text_features.t() |
|
logits_per_image = torch.softmax(logits_per_image, dim=1) |
|
all_outputs.append(logits_per_image.cpu()) |
|
|
|
all_outputs = torch.cat(all_outputs) |
|
if t2v: |
|
all_outputs = torch.softmax(text_features @ all_outputs.t(), dim=1).cpu() |
|
|
|
return all_outputs |
|
|
|
@torch.no_grad() |
|
def predict_v2t(self, idx=0, sid=0): |
|
all_outputs = self.forward_video(self.text_features, sid) |
|
preds = all_outputs.numpy() |
|
relevancy = self.relevancy_mat_v2t[idx] |
|
sel = 3 |
|
pred_action = self.narrations[(-preds[0]).argsort()[:sel]] |
|
gt_action = self.narrations[np.where(relevancy == 1)[0]] |
|
return pred_action, gt_action |
|
|
|
@torch.no_grad() |
|
def predict_t2v(self, idx=0, sid=0): |
|
text_features = self.text_features[sid].unsqueeze(0) |
|
all_outputs = self.forward_video(text_features, t2v=True) |
|
preds = all_outputs.numpy() |
|
relevancy = self.relevancy_mat_t2v[idx] |
|
sel = 1 |
|
pred_video = self.video_samples[(-preds[0]).argsort()[:sel]] |
|
gt_video = np.where(relevancy == 1)[0] |
|
return pred_video, gt_video |
|
|
|
@torch.no_grad() |
|
def evaluate(self): |
|
val_loader = self.load_data() |
|
cfg, dataset = self.cfg, self.dataset |
|
if self.dataset == 'ek100_mir': |
|
all_video_embed = [] |
|
all_text_embed = [] |
|
for i, inputs in enumerate(val_loader): |
|
inputs = [tensor.cuda(non_blocking=True) for tensor in inputs] |
|
relevancies = inputs.pop() |
|
|
|
|
|
outputs = self.model( |
|
*inputs, |
|
use_checkpoint=True, |
|
norm_embed=cfg['model']['norm_embed'] |
|
) |
|
|
|
image_features = outputs['image_embed'] |
|
text_features = outputs['text_embed'] |
|
all_video_embed.append(image_features.cpu().numpy()) |
|
all_text_embed.append(text_features.cpu().numpy()) |
|
|
|
all_text_embed = np.vstack(all_text_embed) |
|
all_video_embed = np.vstack(all_video_embed) |
|
similarity_matrix = np.matmul(all_video_embed, all_text_embed.T) |
|
similarity_matrix = (similarity_matrix + 1) / 2 |
|
video_id = pd.read_csv(cfg['data']['metadata'].replace('train', 'test')).values[:, 0] |
|
text_id = pd.read_csv(cfg['data']['metadata'].replace('train', 'test_sentence')).values[:, 0] |
|
indexes = [video_id.tolist().index(elem) for elem in text_id] |
|
similarity_matrix = similarity_matrix[:, indexes] |
|
print(similarity_matrix.shape) |
|
rel_matrix = pd.read_pickle( |
|
cfg['data']['relevancy_path'] |
|
) |
|
vis_map = calculate_mAP(similarity_matrix, rel_matrix) |
|
txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T) |
|
avg_map = (vis_map + txt_map) / 2 |
|
print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, avg_map)) |
|
vis_k_counts = calculate_k_counts(rel_matrix) |
|
txt_k_counts = calculate_k_counts(rel_matrix.T) |
|
vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts) |
|
txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts) |
|
vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG) |
|
txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG) |
|
avg_nDCG = (vis_nDCG + txt_nDCG) / 2 |
|
print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, avg_nDCG)) |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Ego-VPA inference', add_help=False) |
|
parser.add_argument('--dataset', |
|
default='charades_ego', |
|
type=str, help='charades_ego/ek100_mir') |
|
args = parser.parse_args() |
|
|
|
if args.dataset in ['charades_ego']: |
|
lavila = VideoCLSModel(f"configs/{args.dataset}/zeroshot.yml") |
|
egovpa = VideoCLSModel(f"configs/{args.dataset}/egovpa.yml") |
|
elif args.dataset == 'ek100_mir': |
|
lavila = VideoMIRModel(f"configs/{args.dataset}/zeroshot.yml") |
|
egovpa = VideoMIRModel(f"configs/{args.dataset}/egovpa.yml") |
|
else: |
|
raise NotImplementedError |
|
|
|
lavila.evaluate() |
|
egovpa.evaluate() |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|