Spaces:
Running
Running
File size: 9,750 Bytes
1999a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import copy
from typing import Optional
import torch
from torch import Tensor, nn
from .blocks import RoPEAttention
class MemoryAttentionLayer(nn.Module):
"""
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
generate memory-based attention outputs.
Attributes:
d_model (int): Dimensionality of the model.
dim_feedforward (int): Dimensionality of the feedforward network.
dropout_value (float): Dropout rate for regularization.
self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
linear1 (nn.Linear): First linear layer of the feedforward network.
linear2 (nn.Linear): Second linear layer of the feedforward network.
norm1 (nn.LayerNorm): Layer normalization for self-attention output.
norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
dropout1 (nn.Dropout): Dropout layer after self-attention.
dropout2 (nn.Dropout): Dropout layer after cross-attention.
dropout3 (nn.Dropout): Dropout layer after feedforward network.
activation (nn.ReLU): Activation function for the feedforward network.
pos_enc_at_attn (bool): Flag to add positional encoding at attention.
pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
Methods:
forward: Performs the full memory attention operation on input tensors.
_forward_sa: Performs self-attention on input tensor.
_forward_ca: Performs cross-attention between target and memory tensors.
Examples:
>>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
>>> tgt = torch.randn(1, 100, 256)
>>> memory = torch.randn(1, 100, 64)
>>> pos = torch.randn(1, 100, 256)
>>> query_pos = torch.randn(1, 100, 256)
>>> output = layer(tgt, memory, pos, query_pos)
>>> print(output.shape)
torch.Size([1, 100, 256])
"""
def __init__(
self,
d_model: int = 256,
dim_feedforward: int = 2048,
dropout: float = 0.1,
pos_enc_at_attn: bool = False,
pos_enc_at_cross_attn_keys: bool = True,
pos_enc_at_cross_attn_queries: bool = False,
):
"""Initializes a memory attention layer with self-attention, cross-attention, and feedforward components."""
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
self.cross_attn_image = RoPEAttention(
rope_k_repeat=True,
embedding_dim=256,
num_heads=1,
downsample_rate=1,
kv_in_dim=64,
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = nn.ReLU()
# Where to add pos enc
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
def _forward_sa(self, tgt, query_pos):
"""Performs self-attention on input tensor using positional encoding and RoPE attention mechanism."""
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(q, k, v=tgt2)
tgt = tgt + self.dropout1(tgt2)
return tgt
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
"""Performs cross-attention between target and memory tensors using RoPEAttention mechanism."""
kwds = {}
if num_k_exclude_rope > 0:
assert isinstance(self.cross_attn_image, RoPEAttention)
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
# Cross-Attention
tgt2 = self.norm2(tgt)
tgt2 = self.cross_attn_image(
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
v=memory,
**kwds,
)
tgt = tgt + self.dropout2(tgt2)
return tgt
def forward(
self,
tgt,
memory,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
num_k_exclude_rope: int = 0,
) -> torch.Tensor:
"""Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention."""
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
# MLP
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
class MemoryAttention(nn.Module):
"""
Memory attention module for processing sequential data with self and cross-attention mechanisms.
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
for processing sequential data, particularly useful in transformer-like architectures.
Attributes:
d_model (int): The dimension of the model's hidden state.
layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
num_layers (int): The number of attention layers.
norm (nn.LayerNorm): Layer normalization applied to the output.
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
batch_first (bool): Whether the input tensors are in batch-first format.
Methods:
forward: Processes input tensors through the attention layers.
Examples:
>>> d_model = 256
>>> layer = MemoryAttentionLayer(d_model)
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
>>> curr_pos = torch.randn(10, 32, d_model)
>>> memory_pos = torch.randn(20, 32, d_model)
>>> output = attention(curr, memory, curr_pos, memory_pos)
>>> print(output.shape)
torch.Size([10, 32, 256])
"""
def __init__(
self,
d_model: int,
pos_enc_at_input: bool,
layer: nn.Module,
num_layers: int,
batch_first: bool = True, # Do layers expect batch first input?
):
"""Initializes MemoryAttention module with layers and normalization for attention processing."""
super().__init__()
self.d_model = d_model
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.pos_enc_at_input = pos_enc_at_input
self.batch_first = batch_first
def forward(
self,
curr: torch.Tensor, # self-attention inputs
memory: torch.Tensor, # cross-attention inputs
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
):
"""Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms."""
if isinstance(curr, list):
assert isinstance(curr_pos, list)
assert len(curr) == len(curr_pos) == 1
curr, curr_pos = (
curr[0],
curr_pos[0],
)
assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
output = curr
if self.pos_enc_at_input and curr_pos is not None:
output = output + 0.1 * curr_pos
if self.batch_first:
# Convert to batch first
output = output.transpose(0, 1)
curr_pos = curr_pos.transpose(0, 1)
memory = memory.transpose(0, 1)
memory_pos = memory_pos.transpose(0, 1)
for layer in self.layers:
kwds = {}
if isinstance(layer.cross_attn_image, RoPEAttention):
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
output = layer(
tgt=output,
memory=memory,
pos=memory_pos,
query_pos=curr_pos,
**kwds,
)
normed_output = self.norm(output)
if self.batch_first:
# Convert back to seq first
normed_output = normed_output.transpose(0, 1)
curr_pos = curr_pos.transpose(0, 1)
return normed_output
|