EK100MIR / demo.py
gina9726's picture
Update demo.py
d32d7bb verified
### demo.py
# Define model classes for inference.
###
import argparse
from collections import OrderedDict
import json
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
from sklearn.metrics import confusion_matrix
from lavila.data import datasets
from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop
from lavila.models import models
from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer)
from lavila.models.utils import inflate_positional_embeds
from lavila.utils.config import load_cfg
from lavila.utils.evaluation_charades import charades_map
from lavila.utils.evaluation import get_mean_accuracy
from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG)
class VideoModel(nn.Module):
""" Base model for video understanding based on LaViLa architecture. """
def __init__(self, config):
""" Initializes the model.
Parameters:
config: config file
"""
super(VideoModel, self).__init__()
self.cfg = load_cfg(config)
self.model = self.build_model()
self.tokenizer = self.get_tokenizer()
self.templates = ['{}']
self.dataset = self.cfg['data']['dataset']
self.eval()
def build_model(self):
cfg = self.cfg
if cfg['model'].get('pretrain', False):
ckpt_path = cfg['model']['pretrain']
else:
raise Exception('no checkpoint found')
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
state_dict[k.replace('module.', '')] = v
old_args = vars(ckpt['args'])
arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE')
self.arch = arch
cfg['model']['arch'] = arch
cfg['model']['norm_embed'] = old_args.get('norm_embed', True)
print("=> creating model: {}".format(arch))
model = getattr(models, arch)(
pretrained=old_args.get('load_visual_pretrained', None),
pretrained2d=old_args.get('load_visual_pretrained', None) is not None,
text_use_cls_token=old_args.get('use_cls_token', False),
project_embed_dim=old_args.get('project_embed_dim', 256),
timesformer_gated_xattn=False,
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']),
model_cfg=cfg['model']
)
model.logit_scale.requires_grad = False
if torch.cuda.is_available():
model.cuda()
if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True):
# inflate weight
print('=> inflating PE in models due to different frame numbers')
state_dict = inflate_positional_embeds(
model.state_dict(), state_dict,
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']),
load_temporal_fix='bilinear',
)
model.load_state_dict(state_dict, strict=True)
print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch']))
return model
def eval(self):
cudnn.benchmark = True
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
def get_tokenizer(self):
arch = self.arch
if arch.endswith('DISTILBERT_BASE'):
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
elif arch.endswith('BERT_BASE'):
tokenizer = MyBertTokenizer('bert-base-uncased')
elif arch.endswith('BERT_LARGE'):
tokenizer = MyBertTokenizer('bert-large-uncased')
elif arch.endswith('GPT2'):
tokenizer = MyGPT2Tokenizer('gpt2')
elif arch.endswith('GPT2_MEDIUM'):
tokenizer = MyGPT2Tokenizer('gpt2-medium')
elif arch.endswith('GPT2_LARGE'):
tokenizer = MyGPT2Tokenizer('gpt2-large')
elif arch.endswith('GPT2_XL'):
tokenizer = MyGPT2Tokenizer('gpt2-xl')
else:
print("Using SimpleTokenizer because of model '{}'. "
"Please check if this is what you want".format(arch))
tokenizer = SimpleTokenizer()
return tokenizer
class VideoCLSModel(VideoModel):
""" Video model for video classification tasks (Charades-Ego, EGTEA). """
def __init__(self, config):
super(VideoCLSModel, self).__init__(config)
self.labels, self.mapping_vn2act = self.gen_label_map()
self.text_features = self.get_text_features()
def gen_label_map(self):
labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json')
if os.path.isfile(labelmap):
print(f"=> Loading label maps from {labelmap}")
meta = json.load(open(labelmap, 'r'))
labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act']
else:
from lavila.utils.preprocess import generate_label_map
labels, mapping_vn2act = generate_label_map(self.dataset)
meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act}
meta_dir = f'meta/{self.dataset}'
if not os.path.exists(meta_dir):
os.makedirs(meta_dir)
json.dump(meta, open(f'{meta_dir}/label_map.json', 'w'))
print(f"=> Label map is generated and saved to {meta_dir}/label_map.json")
return labels, mapping_vn2act
def load_data(self, idx=None):
print(f"=> Creating dataset")
cfg, dataset = self.cfg, self.dataset
data_cfg = cfg['data']
crop_size = 224 if '336PX' not in self.arch else 336
val_transform = transforms.Compose([
Permute([3, 0, 1, 2]), # T H W C -> C T H W
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]),
])
if idx is None:
metadata_val = data_cfg['metadata_val']
else:
metadata_val = data_cfg['metadata_val'].format(idx)
if dataset in ['charades_ego', 'egtea']:
val_dataset = datasets.VideoClassyDataset(
dataset, data_cfg['root'], metadata_val,
transform=val_transform, is_training=False,
label_mapping=self.mapping_vn2act, is_trimmed=False,
num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
sparse_sample=data_cfg['sparse_sample']
)
else:
raise NotImplementedError
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=8, shuffle=False,
num_workers=4, pin_memory=True, sampler=None, drop_last=False
)
return val_loader
@torch.no_grad()
def get_text_features(self):
print('=> Extracting text features')
text_features = []
for label in self.labels:
if isinstance(label, list):
texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label]
else:
texts = [tmpl.format(label) for tmpl in self.templates]
texts = self.tokenizer(texts)
if isinstance(texts, tuple):
# Bert-style tokenizer will output both ids and mask
texts, masks = texts
texts = texts.cuda(non_blocking=True)
masks = masks.cuda(non_blocking=True)
else:
texts = texts.cuda(non_blocking=True)
masks = None
texts = texts.view(-1, 77).contiguous()
masks = masks.view(-1, 77).contiguous() if masks is not None else None
if masks is not None:
class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks)
else:
class_embeddings, _ = self.model.encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
class_embeddings = class_embeddings.mean(dim=0)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
text_features.append(class_embeddings)
text_features = torch.stack(text_features, dim=0)
return text_features
@torch.no_grad()
def forward(self, idx=None):
print('=> Start forwarding')
val_loader = self.load_data(idx)
all_outputs = []
all_targets = []
for i, values in enumerate(val_loader):
images = values[0]
target = values[1]
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# encode images
image_features, _ = self.model.encode_image(images)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = image_features @ self.text_features.t()
logits_per_image = torch.softmax(logits_per_image, dim=1)
all_outputs.append(logits_per_image.cpu())
all_targets.append(target.cpu())
all_outputs = torch.cat(all_outputs)
all_targets = torch.cat(all_targets)
return all_outputs, all_targets
@torch.no_grad()
def predict(self, idx=0):
all_outputs, all_targets = self.forward(idx)
preds, targets = all_outputs.numpy(), all_targets.numpy()
sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.055)[0][0]
#sel = 5
df = pd.DataFrame(self.labels)
pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist()
gt_action = df.iloc[np.where(targets[0])[0]].values.tolist()
pred_action = sorted([x[0] for x in pred_action])
gt_action = sorted([x[0] for x in gt_action])
return pred_action, gt_action
@torch.no_grad()
def evaluate(self):
all_outputs, all_targets = self.forward()
preds, targets = all_outputs.numpy(), all_targets.numpy()
if self.dataset == 'charades_ego':
m_ap, _, m_aps = charades_map(preds, targets)
print('mAP = {:.3f}'.format(m_ap))
elif self.dataset == 'egtea':
cm = confusion_matrix(targets, preds.argmax(axis=1))
mean_class_acc, acc = get_mean_accuracy(cm)
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
else:
raise NotImplementedError
class VideoMIRModel(VideoModel):
""" Video model for video multi-instance retrieval tasks (EK100_MIR). """
def __init__(self, config):
super(VideoMIRModel, self).__init__(config)
self.narrations = pd.read_csv(self.cfg['data']['narrations']).values[:, 1]
self.text_features = self.get_text_features()
self.video_samples = pd.read_csv('meta/ek100_mir/sel_t2v.csv').values[:, 0]
def load_data(self, idx=None, t2v=False):
print(f"=> Creating dataset")
cfg, dataset = self.cfg, self.dataset
data_cfg = cfg['data']
crop_size = 224 if '336PX' not in self.arch else 336
val_transform = transforms.Compose([
Permute([3, 0, 1, 2]), # T H W C -> C T H W
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]),
])
if dataset == 'ek100_mir':
if t2v:
metadata_val = 'meta/ek100_mir/sel_t2v.csv'
self.relevancy_mat_v2t = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_v2t'))
self.relevancy_mat_t2v = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_t2v'))
val_dataset = datasets.VideoCaptionDatasetCLIP(
'ek100_mir_demo', data_cfg['root'], metadata_val, val_transform,
is_training=False, tokenizer=self.tokenizer,
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride']
)
elif idx is None:
metadata_val = data_cfg['metadata_val']
val_dataset = datasets.get_dataset(val_transform, self.tokenizer, cfg, is_training=False)
else:
metadata_val = data_cfg['metadata_val'].format(idx)
self.relevancy_mat_v2t = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_v2t'))
self.relevancy_mat_t2v = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_t2v'))
val_dataset = datasets.VideoCaptionDatasetCLIP(
'ek100_mir_demo', data_cfg['root'], metadata_val, val_transform,
is_training=False, tokenizer=self.tokenizer,
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride']
)
else:
raise NotImplementedError
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=8, shuffle=False,
num_workers=4, pin_memory=True, sampler=None, drop_last=False
)
return val_loader
@torch.no_grad()
def get_text_features(self):
print('=> Extracting text features')
text_features = []
for text in self.narrations:
text = self.tokenizer(text)
text = text.cuda(non_blocking=True)
text = text.view(-1, 77).contiguous()
text_embed, _ = self.model.encode_text(text)
text_embed = F.normalize(text_embed, dim=-1).squeeze()
text_features.append(text_embed)
text_features = torch.stack(text_features, dim=0)
return text_features
@torch.no_grad()
def forward_video(self, text_features=None, idx=None, t2v=False):
print('=> Start forwarding')
if t2v:
val_loader = self.load_data(t2v=t2v)
else:
val_loader = self.load_data(idx=idx)
all_outputs = []
for i, values in enumerate(val_loader):
images = values[0].cuda(non_blocking=True)
# encode images
image_features, _ = self.model.encode_image(images)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
if t2v:
all_outputs.append(image_features)
else:
# cosine similarity as logits
logits_per_image = image_features @ text_features.t()
logits_per_image = torch.softmax(logits_per_image, dim=1)
all_outputs.append(logits_per_image.cpu())
all_outputs = torch.cat(all_outputs)
if t2v:
all_outputs = torch.softmax(text_features @ all_outputs.t(), dim=1).cpu()
return all_outputs
@torch.no_grad()
def predict_v2t(self, idx=0, sid=0):
all_outputs = self.forward_video(self.text_features, sid)
preds = all_outputs.numpy()
relevancy = self.relevancy_mat_v2t[idx]
sel = 3
pred_action = self.narrations[(-preds[0]).argsort()[:sel]]
gt_action = self.narrations[np.where(relevancy == 1)[0]]
return pred_action, gt_action
@torch.no_grad()
def predict_t2v(self, idx=0, sid=0):
text_features = self.text_features[sid].unsqueeze(0)
all_outputs = self.forward_video(text_features, t2v=True)
preds = all_outputs.numpy()
relevancy = self.relevancy_mat_t2v[idx]
sel = 1
pred_video = self.video_samples[(-preds[0]).argsort()[:sel]]
gt_video = np.where(relevancy == 1)[0]
return pred_video, gt_video
@torch.no_grad()
def evaluate(self):
val_loader = self.load_data()
cfg, dataset = self.cfg, self.dataset
if self.dataset == 'ek100_mir':
all_video_embed = []
all_text_embed = []
for i, inputs in enumerate(val_loader):
inputs = [tensor.cuda(non_blocking=True) for tensor in inputs]
relevancies = inputs.pop()
# compute output
outputs = self.model(
*inputs,
use_checkpoint=True,
norm_embed=cfg['model']['norm_embed']
)
image_features = outputs['image_embed']
text_features = outputs['text_embed']
all_video_embed.append(image_features.cpu().numpy())
all_text_embed.append(text_features.cpu().numpy())
all_text_embed = np.vstack(all_text_embed)
all_video_embed = np.vstack(all_video_embed)
similarity_matrix = np.matmul(all_video_embed, all_text_embed.T)
similarity_matrix = (similarity_matrix + 1) / 2
video_id = pd.read_csv(cfg['data']['metadata'].replace('train', 'test')).values[:, 0]
text_id = pd.read_csv(cfg['data']['metadata'].replace('train', 'test_sentence')).values[:, 0]
indexes = [video_id.tolist().index(elem) for elem in text_id]
similarity_matrix = similarity_matrix[:, indexes]
print(similarity_matrix.shape)
rel_matrix = pd.read_pickle(
cfg['data']['relevancy_path']
)
vis_map = calculate_mAP(similarity_matrix, rel_matrix)
txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T)
avg_map = (vis_map + txt_map) / 2
print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, avg_map))
vis_k_counts = calculate_k_counts(rel_matrix)
txt_k_counts = calculate_k_counts(rel_matrix.T)
vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts)
txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts)
vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG)
txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG)
avg_nDCG = (vis_nDCG + txt_nDCG) / 2
print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, avg_nDCG))
else:
raise NotImplementedError
def main():
parser = argparse.ArgumentParser(description='Ego-VPA inference', add_help=False)
parser.add_argument('--dataset',
default='charades_ego',
type=str, help='charades_ego/ek100_mir')
args = parser.parse_args()
if args.dataset in ['charades_ego']:
lavila = VideoCLSModel(f"configs/{args.dataset}/zeroshot.yml")
egovpa = VideoCLSModel(f"configs/{args.dataset}/egovpa.yml")
elif args.dataset == 'ek100_mir':
lavila = VideoMIRModel(f"configs/{args.dataset}/zeroshot.yml")
egovpa = VideoMIRModel(f"configs/{args.dataset}/egovpa.yml")
else:
raise NotImplementedError
lavila.evaluate()
egovpa.evaluate()
#egovpa.predict_t2v(idx=0, sid=2119)
if __name__ == '__main__':
main()