|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, final |
|
|
|
import torch |
|
from fairseq2.nn.projection import Linear |
|
from fairseq2.typing import DataType, Device, finaloverride |
|
from torch import Tensor |
|
from torch.nn import AvgPool1d, Module, ModuleList, ReLU |
|
from torch.nn.parameter import Parameter |
|
|
|
|
|
class EnergyProjection(Module): |
|
def __init__( |
|
self, |
|
model_dim: int, |
|
num_layers: int, |
|
bias: bool = True, |
|
device: Optional[Device] = None, |
|
dtype: Optional[DataType] = None, |
|
) -> None: |
|
super().__init__() |
|
|
|
if num_layers < 1: |
|
raise ValueError( |
|
f"Invalid `num_layers`: {num_layers} for EnergyProjectionLayer." |
|
) |
|
|
|
self.layers = ModuleList() |
|
|
|
for _ in range(num_layers): |
|
self.layers.append( |
|
Linear(model_dim, model_dim, bias, device=device, dtype=dtype) |
|
) |
|
self.layers.append(ReLU()) |
|
|
|
def forward(self, seqs: Tensor) -> Tensor: |
|
for layer in self.layers: |
|
seqs = layer(seqs) |
|
return seqs |
|
|
|
|
|
@final |
|
class PChooseLayer(Module): |
|
"""Represents a PChoose layer.""" |
|
|
|
model_dim: int |
|
num_heads: int |
|
energy_bias: Parameter |
|
monotonic_temperature: float |
|
q_energy_proj: EnergyProjection |
|
k_energy_proj: EnergyProjection |
|
keys_pooling: AvgPool1d |
|
|
|
def __init__( |
|
self, |
|
model_dim: int, |
|
num_heads: int, |
|
energy_bias_value: float, |
|
monotonic_temperature: float, |
|
num_monotonic_energy_layers: int, |
|
pre_decision_ratio: int, |
|
*, |
|
bias: bool = True, |
|
device: Optional[Device] = None, |
|
dtype: Optional[DataType] = None, |
|
) -> None: |
|
""" |
|
:param model_dim: |
|
The dimensionality of the model. |
|
:param num_heads: |
|
The number of attention heads. |
|
:param bias: |
|
If ``True``, query, key energy projection layers learn an |
|
additive bias. |
|
""" |
|
super().__init__() |
|
|
|
self.model_dim = model_dim |
|
self.num_heads = num_heads |
|
|
|
if energy_bias_value != 0.0: |
|
self.energy_bias = Parameter( |
|
torch.full([1], energy_bias_value, device=device, dtype=dtype) |
|
) |
|
else: |
|
self.register_module("energy_bias", None) |
|
|
|
self.monotonic_temperature = monotonic_temperature |
|
|
|
if num_monotonic_energy_layers <= 0: |
|
raise ValueError("Number of monotonic energy layers must be > 0.") |
|
|
|
self.q_energy_proj = EnergyProjection( |
|
self.model_dim, |
|
num_monotonic_energy_layers, |
|
bias, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
self.k_energy_proj = EnergyProjection( |
|
self.model_dim, |
|
num_monotonic_energy_layers, |
|
bias, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
self.keys_pooling = AvgPool1d( |
|
kernel_size=pre_decision_ratio, |
|
stride=pre_decision_ratio, |
|
ceil_mode=True, |
|
) |
|
|
|
@finaloverride |
|
def forward(self, seqs: Tensor, keys: Tensor) -> Tensor: |
|
q = self.q_energy_proj(seqs) |
|
|
|
|
|
q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) |
|
|
|
|
|
pooled_keys = self.keys_pooling(keys.transpose(1, 2)) |
|
|
|
|
|
pooled_keys = pooled_keys.transpose(1, 2) |
|
|
|
k = self.k_energy_proj(pooled_keys) |
|
|
|
|
|
k = k.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) |
|
|
|
|
|
monotonic_energy = torch.matmul(q, k.transpose(-1, -2)) |
|
|
|
monotonic_energy = monotonic_energy * (q.size(-1) ** -0.5) |
|
|
|
if self.energy_bias is not None: |
|
monotonic_energy += self.energy_bias |
|
|
|
|
|
p_choose = torch.sigmoid(monotonic_energy / self.monotonic_temperature) |
|
|
|
return p_choose |
|
|