File size: 4,265 Bytes
abacfaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
from typing import Optional, final
import torch
from fairseq2.nn.projection import Linear
from fairseq2.typing import DataType, Device, finaloverride
from torch import Tensor
from torch.nn import AvgPool1d, Module, ModuleList, ReLU
from torch.nn.parameter import Parameter
class EnergyProjection(Module):
def __init__(
self,
model_dim: int,
num_layers: int,
bias: bool = True,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
super().__init__()
if num_layers < 1:
raise ValueError(
f"Invalid `num_layers`: {num_layers} for EnergyProjectionLayer."
)
self.layers = ModuleList()
for _ in range(num_layers):
self.layers.append(
Linear(model_dim, model_dim, bias, device=device, dtype=dtype)
)
self.layers.append(ReLU())
def forward(self, seqs: Tensor) -> Tensor:
for layer in self.layers:
seqs = layer(seqs)
return seqs
@final
class PChooseLayer(Module):
"""Represents a PChoose layer."""
model_dim: int
num_heads: int
energy_bias: Parameter
monotonic_temperature: float
q_energy_proj: EnergyProjection
k_energy_proj: EnergyProjection
keys_pooling: AvgPool1d
def __init__(
self,
model_dim: int,
num_heads: int,
energy_bias_value: float,
monotonic_temperature: float,
num_monotonic_energy_layers: int,
pre_decision_ratio: int,
*,
bias: bool = True,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
"""
:param model_dim:
The dimensionality of the model.
:param num_heads:
The number of attention heads.
:param bias:
If ``True``, query, key energy projection layers learn an
additive bias.
"""
super().__init__()
self.model_dim = model_dim
self.num_heads = num_heads
if energy_bias_value != 0.0:
self.energy_bias = Parameter(
torch.full([1], energy_bias_value, device=device, dtype=dtype)
)
else:
self.register_module("energy_bias", None)
self.monotonic_temperature = monotonic_temperature
if num_monotonic_energy_layers <= 0:
raise ValueError("Number of monotonic energy layers must be > 0.")
self.q_energy_proj = EnergyProjection(
self.model_dim,
num_monotonic_energy_layers,
bias,
device=device,
dtype=dtype,
)
self.k_energy_proj = EnergyProjection(
self.model_dim,
num_monotonic_energy_layers,
bias,
device=device,
dtype=dtype,
)
self.keys_pooling = AvgPool1d(
kernel_size=pre_decision_ratio,
stride=pre_decision_ratio,
ceil_mode=True,
)
@finaloverride
def forward(self, seqs: Tensor, keys: Tensor) -> Tensor:
q = self.q_energy_proj(seqs)
# (N, S, M) -> (N, H, S, K)
q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
# (N, S_kv, M) -> (N, M, S_kv) -> (N, M, S_p)
pooled_keys = self.keys_pooling(keys.transpose(1, 2))
# (N, M, S_p) -> (N, S_p, M)
pooled_keys = pooled_keys.transpose(1, 2)
k = self.k_energy_proj(pooled_keys)
# (N, S_p, M) -> (N, H, S_p, K)
k = k.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
# (N, H, S, K) @ (N, H, K, S_p) = (N, H, S, S_p)
monotonic_energy = torch.matmul(q, k.transpose(-1, -2))
monotonic_energy = monotonic_energy * (q.size(-1) ** -0.5)
if self.energy_bias is not None:
monotonic_energy += self.energy_bias
# p_choose: (N, H, S, S_p)
p_choose = torch.sigmoid(monotonic_energy / self.monotonic_temperature)
return p_choose
|