Liangrj5
init
ebf5d87
import torch
import torch.nn as nn
from baselines.mixture_embedding_experts.model_components import NetVLAD, MaxMarginRankingLoss, GatedEmbeddingUnit
from easydict import EasyDict as edict
mee_base_cfg = edict(
ctx_mode="video",
text_input_size=768,
vid_input_size=1024,
output_size=256,
margin=0.2
)
class MEE(nn.Module):
def __init__(self, config):
super(MEE, self).__init__()
self.config = config
self.use_video = "video" in config.ctx_mode
self.use_sub = "sub" in config.ctx_mode
self.query_pooling = NetVLAD(feature_size=config.text_input_size, cluster_size=2)
if self.use_sub:
self.sub_query_gu = GatedEmbeddingUnit(input_dimension=self.query_pooling.out_dim,
output_dimension=config.output_size)
self.sub_gu = GatedEmbeddingUnit(input_dimension=config.text_input_size,
output_dimension=config.output_size)
if self.use_video:
self.video_query_gu = GatedEmbeddingUnit(input_dimension=self.query_pooling.out_dim,
output_dimension=config.output_size)
self.video_gu = GatedEmbeddingUnit(input_dimension=config.vid_input_size,
output_dimension=config.output_size)
if self.use_video and self.use_sub:
self.moe_fc = nn.Linear(self.query_pooling.out_dim, 2) # weights
self.max_margin_loss = MaxMarginRankingLoss(margin=config.margin)
def forward(self, query_feat, query_mask, video_feat, sub_feat):
"""
Args:
query_feat: (N, L, D_q)
query_mask: (N, L)
video_feat: (N, Dv)
sub_feat: (N, Dt)
"""
pooled_query = self.query_pooling(query_feat) # (N, Dt)
encoded_video, encoded_sub = self.encode_context(video_feat, sub_feat)
confusion_matrix = self.get_score_from_pooled_query_with_encoded_ctx(pooled_query, encoded_video, encoded_sub)
return self.max_margin_loss(confusion_matrix)
def encode_context(self, video_feat, sub_feat):
"""(N, D)"""
encoded_video = self.video_gu(video_feat) if self.use_video else None
encoded_sub = self.sub_gu(sub_feat) if self.use_sub else None
return encoded_video, encoded_sub
def compute_single_stream_scores_with_encoded_ctx(self, pooled_query, encoded_ctx, module_name="video"):
encoded_query = getattr(self, module_name+"_query_gu")(pooled_query) # (N, D)
return torch.einsum("md,nd->mn", encoded_query, encoded_ctx) # (N, N)
def get_score_from_pooled_query_with_encoded_ctx(self, pooled_query, encoded_video, encoded_sub):
"""Nq may not equal to Nc
Args:
pooled_query: (Nq, Dt)
encoded_video: (Nc, Dc)
encoded_sub: (Nc, Dc)
"""
video_confusion_matrix = self.compute_single_stream_scores_with_encoded_ctx(
pooled_query, encoded_video, module_name="video") if self.use_video else 0
sub_confusion_matrix = self.compute_single_stream_scores_with_encoded_ctx(
pooled_query, encoded_sub, module_name="sub") if self.use_sub else 0
if self.use_video and self.use_sub:
stream_weights = self.moe_fc(pooled_query) # (N, 2)
confusion_matrix = \
stream_weights[:, 0:1] * video_confusion_matrix + stream_weights[:, 1:2] * sub_confusion_matrix
else:
confusion_matrix = video_confusion_matrix + sub_confusion_matrix
return confusion_matrix # (Nq, Nc)