ruohguo's picture
Upload 117 files
b80ae90 verified
from math import ceil
import fvcore.nn.weight_init as weight_init
from typing import Optional
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import copy
from detectron2.config import configurable
from detectron2.layers import Conv2d
class SelfAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
return self.forward_post(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
class CrossAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
class FFNLayer(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm = nn.LayerNorm(d_model)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt):
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt):
tgt2 = self.norm(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt):
if self.normalize_before:
return self.forward_pre(tgt)
return self.forward_post(tgt)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Avism_COCO(nn.Module):
@configurable
def __init__(
self,
in_channels,
aux_loss,
*,
hidden_dim: int,
num_frame_queries: int,
num_queries: int,
nheads: int,
dim_feedforward: int,
enc_layers: int,
dec_layers: int,
enc_window_size: int,
pre_norm: bool,
enforce_input_project: bool,
num_frames: int,
num_classes: int,
clip_last_layer_num: bool,
conv_dim: int,
mask_dim: int,
sim_use_clip: list,
use_sim: bool,
):
"""
NOTE: this interface is experimental.
Args:
in_channels: channels of the input features
hidden_dim: Transformer feature dimension
num_queries: number of queries
nheads: number of heads
dim_feedforward: feature dimension in feedforward network
enc_layers: number of Transformer encoder layers
dec_layers: number of Transformer decoder layers
pre_norm: whether to use pre-LayerNorm or not
enforce_input_project: add input project 1x1 conv even if input
channels and hidden dim is identical
"""
super().__init__()
# define Transformer decoder here
self.num_heads = nheads
self.num_layers = dec_layers
self.transformer_self_attention_layers = nn.ModuleList()
self.transformer_cross_attention_layers = nn.ModuleList()
self.transformer_ffn_layers = nn.ModuleList()
self.num_frames = num_frames
self.num_classes = num_classes
self.clip_last_layer_num = clip_last_layer_num
self.enc_layers = enc_layers
self.window_size = enc_window_size
self.sim_use_clip = sim_use_clip
self.use_sim = use_sim
self.aux_loss = aux_loss
self.av_proj = nn.Linear(128, hidden_dim)
self.enc_layers = enc_layers
if enc_layers > 0:
self.enc_self_attn = nn.ModuleList()
self.enc_ffn = nn.ModuleList()
for _ in range(self.enc_layers):
self.enc_self_attn.append(
SelfAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
),
)
self.enc_ffn.append(
FFNLayer(
d_model=hidden_dim,
dim_feedforward=dim_feedforward,
dropout=0.0,
normalize_before=pre_norm,
)
)
if enc_layers > 0:
self.enc_av_cross_attn = nn.ModuleList()
self.enc_av_ffn = nn.ModuleList()
for _ in range(self.enc_layers):
self.enc_av_cross_attn.append(
CrossAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
),
)
self.enc_av_ffn.append(
FFNLayer(
d_model=hidden_dim,
dim_feedforward=dim_feedforward,
dropout=0.0,
normalize_before=pre_norm,
)
)
for _ in range(self.num_layers):
self.transformer_self_attention_layers.append(
SelfAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.transformer_cross_attention_layers.append(
CrossAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.transformer_ffn_layers.append(
FFNLayer(
d_model=hidden_dim,
dim_feedforward=dim_feedforward,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.vita_mask_features = Conv2d(
conv_dim,
mask_dim,
kernel_size=1,
stride=1,
padding=0,
)
weight_init.c2_xavier_fill(self.vita_mask_features)
self.decoder_norm = nn.LayerNorm(hidden_dim)
self.num_queries = num_queries
# learnable query features
self.query_feat = nn.Embedding(num_queries, hidden_dim)
# learnable query p.e.
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.fq_pos = nn.Embedding(num_frame_queries, hidden_dim)
if in_channels != hidden_dim or enforce_input_project:
self.input_proj_dec = nn.Linear(hidden_dim, hidden_dim)
else:
self.input_proj_dec = nn.Sequential()
self.src_embed = nn.Identity()
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
if self.use_sim:
self.sim_embed_frame = nn.Linear(hidden_dim, hidden_dim)
if self.sim_use_clip:
self.sim_embed_clip = nn.Linear(hidden_dim, hidden_dim)
@classmethod
def from_config(cls, cfg, in_channels):
ret = {}
ret["in_channels"] = in_channels
ret["hidden_dim"] = cfg.MODEL.AVISM.HIDDEN_DIM
ret["num_frame_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
ret["num_queries"] = cfg.MODEL.AVISM.NUM_OBJECT_QUERIES
# Transformer parameters:
ret["nheads"] = cfg.MODEL.AVISM.NHEADS
ret["dim_feedforward"] = cfg.MODEL.AVISM.DIM_FEEDFORWARD
assert cfg.MODEL.AVISM.DEC_LAYERS >= 1
ret["enc_layers"] = cfg.MODEL.AVISM.ENC_LAYERS
ret["dec_layers"] = cfg.MODEL.AVISM.DEC_LAYERS
ret["enc_window_size"] = cfg.MODEL.AVISM.ENC_WINDOW_SIZE
ret["pre_norm"] = cfg.MODEL.AVISM.PRE_NORM
ret["enforce_input_project"] = cfg.MODEL.AVISM.ENFORCE_INPUT_PROJ
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
ret["num_frames"] = cfg.INPUT.SAMPLING_FRAME_NUM
ret["clip_last_layer_num"] = cfg.MODEL.AVISM.LAST_LAYER_NUM
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
ret["sim_use_clip"] = cfg.MODEL.AVISM.SIM_USE_CLIP
ret["use_sim"] = cfg.MODEL.AVISM.SIM_WEIGHT > 0.0
return ret
def forward(self, frame_query, audio_features):
"""
L: Number of Layers.
B: Batch size.
T: Temporal window size. Number of frames per video.
C: Channel size.
fQ: Number of frame-wise queries from IFC.
cQ: Number of clip-wise queries to decode Q.
"""
if not self.training:
frame_query = frame_query[[-1]]
L, BT, fQ, C = frame_query.shape
B = BT // self.num_frames if self.training else 1
T = self.num_frames if self.training else BT // B
frame_query = frame_query.reshape(L * B, T, fQ, C)
frame_query = frame_query.permute(1, 2, 0, 3).contiguous()
frame_query = self.input_proj_dec(frame_query) # T, fQ, LB, C
audio_feat = self.av_proj(audio_features) # T, C
audio_feat = audio_feat[:, None, None, :].repeat(1, fQ, L * B, 1)
if self.window_size > 0:
pad = int(ceil(T / self.window_size)) * self.window_size - T
_T = pad + T
frame_query = F.pad(frame_query, (0, 0, 0, 0, 0, 0, 0, pad)) # _T, fQ, LB, C
audio_feat = F.pad(audio_feat, (0, 0, 0, 0, 0, 0, 0, pad))
enc_mask = frame_query.new_ones(L * B, _T).bool() # LB, _T
enc_mask[:, :T] = False
else:
enc_mask = None
frame_query = self.encode_frame_query(frame_query, enc_mask)
# audio
av_feat = self.encode_av_fusion(frame_query, enc_mask, audio_feat)
frame_query = frame_query[:T].flatten(0, 1) # TfQ, LB, C
av_feat = av_feat[:T].flatten(0, 1)
frame_query = frame_query + av_feat
if self.use_sim:
pred_fq_embed = self.sim_embed_frame(frame_query) # TfQ, LB, C
pred_fq_embed = pred_fq_embed.transpose(0, 1).reshape(L, B, T, fQ, C)
else:
pred_fq_embed = None
src = self.src_embed(frame_query) # TfQ, LB, C
dec_pos = self.fq_pos.weight[None, :, None, :].repeat(T, 1, L * B, 1).flatten(0, 1) # TfQ, LB, C
# QxNxC
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
output = self.query_feat.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
decoder_outputs = []
for i in range(self.num_layers):
# attention: cross-attention first
output = self.transformer_cross_attention_layers[i](
output, src,
memory_mask=None,
memory_key_padding_mask=None,
pos=dec_pos, query_pos=query_embed
)
output = self.transformer_self_attention_layers[i](
output, tgt_mask=None,
tgt_key_padding_mask=None,
query_pos=query_embed
)
# FFN
output = self.transformer_ffn_layers[i](
output
)
if (self.training and self.aux_loss) or (i == self.num_layers - 1):
dec_out = self.decoder_norm(output) # cQ, LB, C
dec_out = dec_out.transpose(0, 1) # LB, cQ, C
decoder_outputs.append(dec_out.view(L, B, self.num_queries, C))
decoder_outputs = torch.stack(decoder_outputs, dim=0) # D, L, B, cQ, C
pred_cls = self.class_embed(decoder_outputs)
pred_mask_embed = self.mask_embed(decoder_outputs)
if self.use_sim and self.sim_use_clip:
pred_cq_embed = self.sim_embed_clip(decoder_outputs)
else:
pred_cq_embed = [None] * self.num_layers
out = {
'pred_logits': pred_cls[-1],
'pred_mask_embed': pred_mask_embed[-1],
'pred_fq_embed': pred_fq_embed,
'pred_cq_embed': pred_cq_embed[-1],
'aux_outputs': self._set_aux_loss(
pred_cls, pred_mask_embed, pred_cq_embed, pred_fq_embed
)
}
return out
@torch.jit.unused
def _set_aux_loss(
self, outputs_cls, outputs_mask_embed, outputs_cq_embed, outputs_fq_embed
):
return [{"pred_logits": a, "pred_mask_embed": b, "pred_cq_embed": c, "pred_fq_embed": outputs_fq_embed}
for a, b, c in zip(outputs_cls[:-1], outputs_mask_embed[:-1], outputs_cq_embed[:-1])]
def encode_frame_query(self, frame_query, attn_mask):
"""
input shape (frame_query) : T, fQ, LB, C
output shape (frame_query) : T, fQ, LB, C
"""
# Not using window-based attention if self.window_size == 0.
if self.window_size == 0:
return_shape = frame_query.shape # T, fQ, LB, C
frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
for i in range(self.enc_layers):
frame_query = self.enc_self_attn[i](frame_query)
frame_query = self.enc_ffn[i](frame_query)
frame_query = frame_query.view(return_shape)
return frame_query
# Using window-based attention if self.window_size > 0.
else:
T, fQ, LB, C = frame_query.shape
W = self.window_size
Nw = T // W
half_W = int(ceil(W / 2))
window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
_attn_mask = torch.roll(attn_mask, half_W, 1)
_attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
_attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
_attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
_attn_mask[:, 0, :half_W, half_W:] = True
_attn_mask[:, 0, half_W:, :half_W] = True
_attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
LB * Nw * self.num_heads, W * fQ, W * fQ)
shift_window_mask = _attn_mask.float() * -1000
for layer_idx in range(self.enc_layers):
if self.training or layer_idx % 2 == 0:
frame_query = self._window_attn(frame_query, window_mask, layer_idx)
else:
frame_query = self._shift_window_attn(frame_query, shift_window_mask, layer_idx)
return frame_query
def _window_attn(self, frame_query, attn_mask, layer_idx):
T, fQ, LB, C = frame_query.shape
# LBN, WTfQ = attn_mask.shape
W = self.window_size
Nw = T // W
frame_query = frame_query.view(Nw, W, fQ, LB, C)
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_key_padding_mask=attn_mask)
frame_query = self.enc_ffn[layer_idx](frame_query)
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
return frame_query
def _shift_window_attn(self, frame_query, attn_mask, layer_idx):
T, fQ, LB, C = frame_query.shape
# LBNH, WfQ, WfQ = attn_mask.shape
W = self.window_size
Nw = T // W
half_W = int(ceil(W / 2))
frame_query = torch.roll(frame_query, half_W, 0)
frame_query = frame_query.view(Nw, W, fQ, LB, C)
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_mask=attn_mask)
frame_query = self.enc_ffn[layer_idx](frame_query)
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
frame_query = torch.roll(frame_query, -half_W, 0)
return frame_query
def encode_av_fusion(self, frame_query, attn_mask, audio_feats):
"""
input shape (frame_query) : T, fQ, LB, C
output shape (frame_query) : T, fQ, LB, C
"""
# Not using window-based attention if self.window_size == 0.
if self.window_size == 0:
return_shape = frame_query.shape # T, fQ, LB, C
frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
audio_feats = audio_feats.flatten(0, 1)
for i in range(self.enc_layers):
audio_feats = self.enc_av_cross_attn[i](audio_feats, frame_query)
audio_feats = self.enc_av_ffn[i](audio_feats)
audio_feats = audio_feats.view(return_shape)
return audio_feats
# Using window-based attention if self.window_size > 0.
else:
T, fQ, LB, C = frame_query.shape
W = self.window_size
Nw = T // W
half_W = int(ceil(W / 2))
window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
_attn_mask = torch.roll(attn_mask, half_W, 1)
_attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
_attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
_attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
_attn_mask[:, 0, :half_W, half_W:] = True
_attn_mask[:, 0, half_W:, :half_W] = True
_attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(
LB * Nw * self.num_heads, W * fQ, W * fQ)
shift_window_mask = _attn_mask.float() * -1000
for layer_idx in range(self.enc_layers):
if layer_idx % 2 == 0:
frame_query, audio_feats = self._window_av_attn(frame_query, window_mask, layer_idx, audio_feats)
else:
frame_query, audio_feats = self._shift_window_av_attn(frame_query, shift_window_mask, layer_idx, audio_feats)
return audio_feats
def _window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
T, fQ, LB, C = frame_query.shape
W = self.window_size
Nw = T // W
frame_query = frame_query.view(Nw, W, fQ, LB, C)
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_key_padding_mask=attn_mask)
audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
return frame_query, audio_feats
def _shift_window_av_attn(self, frame_query, attn_mask, layer_idx, audio_feats):
T, fQ, LB, C = frame_query.shape
W = self.window_size
Nw = T // W
half_W = int(ceil(W / 2))
frame_query = torch.roll(frame_query, half_W, 0)
frame_query = frame_query.view(Nw, W, fQ, LB, C)
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
audio_feats = torch.roll(audio_feats, half_W, 0)
audio_feats = audio_feats.view(Nw, W, fQ, LB, C)
audio_feats = audio_feats.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
audio_feats = self.enc_av_cross_attn[layer_idx](audio_feats, frame_query, memory_mask=attn_mask)
audio_feats = self.enc_av_ffn[layer_idx](audio_feats)
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
frame_query = torch.roll(frame_query, -half_W, 0)
audio_feats = audio_feats.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
audio_feats = torch.roll(audio_feats, -half_W, 0)
return frame_query, audio_feats