|
import math |
|
import copy |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from utils.model_utils import RNNEncoder |
|
from easydict import EasyDict as edict |
|
|
|
|
|
excl_base_cfg = edict( |
|
visual_input_size=2048, |
|
query_input_size=768, |
|
sub_input_size=768, |
|
hidden_size=256, |
|
drop=0.5, |
|
ctx_mode="video_sub", |
|
initializer_range=0.02, |
|
) |
|
|
|
|
|
class EXCL(nn.Module): |
|
def __init__(self, config): |
|
super(EXCL, self).__init__() |
|
self.config = config |
|
self.use_video = "video" in config.ctx_mode |
|
self.use_sub = "sub" in config.ctx_mode |
|
|
|
self.query_encoder = RNNEncoder( |
|
word_embedding_size=config.query_input_size, |
|
hidden_size=config.hidden_size // 2, |
|
bidirectional=True, |
|
n_layers=1, |
|
rnn_type="lstm", |
|
return_outputs=False, |
|
return_hidden=True |
|
) |
|
|
|
if self.use_video: |
|
self.video_encoder = RNNEncoder( |
|
word_embedding_size=config.visual_input_size, |
|
hidden_size=config.hidden_size // 2, |
|
bidirectional=True, |
|
n_layers=1, |
|
rnn_type="lstm", |
|
return_outputs=True, |
|
return_hidden=False) |
|
|
|
self.video_encoder2 = RNNEncoder( |
|
word_embedding_size=2*config.hidden_size, |
|
hidden_size=config.hidden_size // 2, |
|
bidirectional=True, |
|
n_layers=1, |
|
rnn_type="lstm", |
|
return_outputs=True, |
|
return_hidden=False) |
|
|
|
self.video_st_predictor = nn.Sequential( |
|
nn.Linear(3*config.hidden_size, config.hidden_size), |
|
nn.Tanh(), |
|
nn.Linear(config.hidden_size, 1)) |
|
self.video_ed_predictor = copy.deepcopy(self.video_st_predictor) |
|
|
|
if self.use_sub: |
|
self.sub_encoder = RNNEncoder( |
|
word_embedding_size=config.sub_input_size, |
|
hidden_size=config.hidden_size // 2, |
|
bidirectional=True, |
|
n_layers=1, |
|
rnn_type="lstm", |
|
return_outputs=True, |
|
return_hidden=False) |
|
|
|
self.sub_encoder2 = RNNEncoder( |
|
word_embedding_size=2*config.hidden_size, |
|
hidden_size=config.hidden_size // 2, |
|
bidirectional=True, |
|
n_layers=1, |
|
rnn_type="lstm", |
|
return_outputs=True, |
|
return_hidden=False) |
|
|
|
self.sub_st_predictor = nn.Sequential( |
|
nn.Linear(3*config.hidden_size, config.hidden_size), |
|
nn.Tanh(), |
|
nn.Linear(config.hidden_size, 1)) |
|
self.sub_ed_predictor = copy.deepcopy(self.video_st_predictor) |
|
|
|
self.temporal_criterion = nn.CrossEntropyLoss(reduction="mean") |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
""" Initialize the weights.""" |
|
|
|
def re_init(module): |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Conv1d): |
|
module.reset_parameters() |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
self.apply(re_init) |
|
|
|
def get_prob_single_stream(self, encoded_query, ctx_feat, ctx_mask, module_name=None): |
|
ctx_mask_rnn = ctx_mask.sum(1).long() |
|
ctx_feat1 = getattr(self, module_name+"_encoder")( |
|
F.dropout(ctx_feat, p=self.config.drop, training=self.training), |
|
ctx_mask_rnn)[0] |
|
ctx_feat2 = getattr(self, module_name+"_encoder2")( |
|
F.dropout(torch.cat([ctx_feat1, encoded_query], dim=-1), p=self.config.drop, training=self.training), |
|
ctx_mask_rnn)[0] |
|
ctx_feat3 = torch.cat([ctx_feat2, ctx_feat1, encoded_query], dim=2) |
|
st_probs = getattr(self, module_name+"_st_predictor")(ctx_feat3).squeeze() |
|
ed_probs = getattr(self, module_name+"_ed_predictor")(ctx_feat3).squeeze() |
|
st_probs = mask_logits(st_probs, ctx_mask) |
|
ed_probs = mask_logits(ed_probs, ctx_mask) |
|
return st_probs, ed_probs |
|
|
|
def forward(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask, |
|
tef_feat, tef_mask, st_ed_indices, is_training=True): |
|
""" |
|
Args: |
|
query_feat: (N, Lq, Dq) |
|
query_mask: (N, Lq) |
|
video_feat: (N, Lv, Dv) or None |
|
video_mask: (N, Lv) or None |
|
sub_feat: (N, Lv, Ds) or None |
|
sub_mask: (N, Lv) or None |
|
tef_feat: (N, Lv, 2) or None, |
|
tef_mask: (N, Lv) or None, |
|
st_ed_indices: (N, 2), torch.LongTensor, 1st, 2nd columns are st, ed labels respectively. |
|
is_training: |
|
""" |
|
query_mask = query_mask.sum(1).long() |
|
encoded_query = self.query_encoder(query_feat, query_mask)[1] |
|
encoded_query = encoded_query.unsqueeze(1).repeat(1, video_feat.shape[1], 1) |
|
|
|
video_st_prob, video_ed_prob = self.get_prob_single_stream( |
|
encoded_query, video_feat, video_mask, module_name="video") if self.use_video else (0, 0) |
|
|
|
sub_st_prob, sub_ed_prob = self.get_prob_single_stream( |
|
encoded_query, sub_feat, sub_mask, module_name="sub") if self.use_sub else (0, 0) |
|
|
|
st_prob = (video_st_prob + sub_st_prob) / (self.use_video + self.use_sub) |
|
ed_prob = (video_ed_prob + sub_ed_prob) / (self.use_video + self.use_sub) |
|
|
|
if is_training: |
|
loss_st = self.temporal_criterion(st_prob, st_ed_indices[:, 0]) |
|
loss_ed = self.temporal_criterion(ed_prob, st_ed_indices[:, 1]) |
|
loss_st_ed = loss_st + loss_ed |
|
|
|
return loss_st_ed, {"loss_st_ed": float(loss_st_ed)}, st_prob, ed_prob |
|
else: |
|
|
|
prob_product = torch.einsum("bm,bn->bmn", st_prob, ed_prob) |
|
prob_product = torch.triu(prob_product) |
|
prob_product = prob_product.view(prob_product.shape[0], -1) |
|
prob_product = torch.topk(prob_product, k=100, dim=1, largest=True) |
|
return None |
|
|
|
|
|
def mask_logits(target, mask): |
|
return target * mask + (1 - mask) * (-1e10) |
|
|