|
import math |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class AttentionScore(nn.Module): |
|
r""" |
|
A helper class for attention operations. |
|
There are no parameters in this module. |
|
This module computes the alignment score with mask |
|
and return only the attention score. |
|
|
|
The default operation is |
|
|
|
.. math:: |
|
\pmb{u} = \mathrm{Attention}(q,\pmb{k}, \mathrm{mask}) |
|
|
|
where for each key :math:`k_j`, we have |
|
|
|
.. math:: |
|
u_j = |
|
\begin{cases} |
|
&\frac{q^Tk_j}{\sqrt{\smash{d_q}}} & \text{ if } j \notin \mathrm{mask}\\ |
|
&-\infty & \text{ otherwise. } |
|
\end{cases} |
|
|
|
If ``use_tanh`` is ``True``, apply clipping on the logits :math:`u_j` before masking: |
|
|
|
.. math:: |
|
u_j = |
|
\begin{cases} |
|
&C\mathrm{tanh}\left(\frac{q^Tk_j}{\sqrt{\smash{d_q}}}\right) & \text{ if } j \notin \mathrm{mask}\\ |
|
&-\infty & \text{ otherwise. } |
|
\end{cases} |
|
|
|
Args: |
|
use_tanh: if True, use clipping on the logits |
|
C: the range of the clipping [-C,C] |
|
Inputs: query, keys, mask |
|
* **query** : [..., 1, h_dim] |
|
* **keys**: [..., graph_size, h_dim] |
|
* **mask**: [..., graph_size] ``logits[...,j]==-inf`` if ``mask[...,j]==True``. |
|
Outputs: logits |
|
* **logits**: [..., 1, graph_size] The attention score for each key. |
|
""" |
|
|
|
def __init__(self, use_tanh=False, C=10): |
|
super(AttentionScore, self).__init__() |
|
self.use_tanh = use_tanh |
|
self.C = C |
|
|
|
def forward(self, query, key, mask=torch.zeros([], dtype=torch.bool)): |
|
u = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) |
|
if self.use_tanh: |
|
logits = torch.tanh(u) * self.C |
|
else: |
|
logits = u |
|
|
|
logits[mask.expand_as(logits)] = float("-inf") |
|
return logits |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
r""" |
|
Compute the multi-head attention. |
|
|
|
.. math:: |
|
q^\prime = \mathrm{MultiHeadAttention}(q,\pmb{k},\pmb{v},\mathrm{mask}) |
|
|
|
The following is computed: |
|
|
|
.. math:: |
|
\begin{aligned} |
|
\pmb{a}^{(j)} &= \mathrm{Softmax}(\mathrm{AttentionScore}(q^{(j)},\pmb{k}^{(j)}, \mathrm{mask}))\\ |
|
h^{(j)} &= \sum\nolimits_i \pmb{a}^{(j)}_i\pmb{v}_i \\ |
|
q^\prime &= W^O \left[h^{(1)},...,h^{(J)}\right] |
|
\end{aligned} |
|
|
|
Args: |
|
embedding_dim: dimension of the query, keys, values |
|
n_head: number of heads |
|
Inputs: query, keys, value, mask |
|
* **query** : [batch, n_querys, embedding_dim] |
|
* **keys**: [batch, n_keys, embedding_dim] |
|
* **value**: [batch, n_keys, embedding_dim] |
|
* **mask**: [batch, 1, n_keys] ``logits[batch,j]==-inf`` if ``mask[batch, 0, j]==True`` |
|
Outputs: logits, out |
|
* **out**: [batch, 1, embedding_dim] The output of the multi-head attention |
|
""" |
|
|
|
def __init__(self, embedding_dim, n_heads=8): |
|
super(MultiHeadAttention, self).__init__() |
|
self.n_heads = n_heads |
|
self.attentionScore = AttentionScore() |
|
self.project_out = nn.Linear(embedding_dim, embedding_dim, bias=False) |
|
|
|
def forward(self, query, key, value, mask): |
|
query_heads = self._make_heads(query) |
|
key_heads = self._make_heads(key) |
|
value_heads = self._make_heads(value) |
|
|
|
|
|
compatibility = self.attentionScore(query_heads, key_heads, mask) |
|
|
|
|
|
out_heads = torch.matmul(torch.softmax(compatibility, dim=-1), value_heads) |
|
|
|
|
|
out = self.project_out(self._unmake_heads(out_heads)) |
|
return out |
|
|
|
def _make_heads(self, v): |
|
batch_size, nkeys, h_dim = v.shape |
|
|
|
out = v.reshape(batch_size, nkeys, self.n_heads, h_dim // self.n_heads).movedim(-2, 0) |
|
return out |
|
|
|
def _unmake_heads(self, v): |
|
|
|
out = v.movedim(0, -2).flatten(-2) |
|
return out |
|
|
|
|
|
class MultiHeadAttentionProj(nn.Module): |
|
r""" |
|
Compute the multi-head attention with projection. |
|
Different from :class:`.MultiHeadAttention` which accepts precomputed query, keys, and values, |
|
this module computes linear projections from the inputs to query, keys, and values. |
|
|
|
.. math:: |
|
q^\prime = \mathrm{MultiHeadAttentionProj}(q_0,\pmb{h},\mathrm{mask}) |
|
|
|
The following is computed: |
|
|
|
.. math:: |
|
\begin{aligned} |
|
q, \pmb{k}, \pmb{v} &= W^Qq_0, W^K\pmb{h}, W^V\pmb{h}\\ |
|
\pmb{a}^{(j)} &= \mathrm{Softmax}(\mathrm{AttentionScore}(q^{(j)},\pmb{k}^{(j)}, \mathrm{mask}))\\ |
|
h^{(j)} &= \sum\nolimits_i \pmb{a}^{(j)}_i\pmb{v}_i \\ |
|
q^\prime &= W^O \left[h^{(1)},...,h^{(J)}\right] |
|
\end{aligned} |
|
|
|
if :math:`\pmb{h}` is not given. This module will compute the self attention of :math:`q_0`. |
|
|
|
.. warning:: |
|
The results of the in-projection of query, key, value are |
|
slightly different (order of ``1e-6``) with the original implementation. |
|
This is due to the numerical accuracy. |
|
The two implementations differ by the way of multiplying matrix. |
|
Thus, different internal implementation libraries of pytorch are called |
|
and the results are slightly different. |
|
See the pytorch docs on `numerical accruacy <https://pytorch.org/docs/stable/notes/numerical_accuracy.html>`_ for detail. |
|
|
|
Args: |
|
embedding_dim: dimension of the query, keys, values |
|
n_head: number of heads |
|
Inputs: q, h, mask |
|
* **q** : [batch, n_querys, embedding_dim] |
|
* **h**: [batch, n_keys, embedding_dim] |
|
* **mask**: [batch, n_keys] ``logits[batch,j]==-inf`` if ``mask[batch,j]==True`` |
|
Outputs: out |
|
* **out**: [batch, n_querys, embedding_dim] The output of the multi-head attention |
|
|
|
|
|
""" |
|
|
|
def __init__(self, embedding_dim, n_heads=8): |
|
super(MultiHeadAttentionProj, self).__init__() |
|
|
|
self.queryEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False) |
|
self.keyEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False) |
|
self.valueEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False) |
|
|
|
self.MHA = MultiHeadAttention(embedding_dim, n_heads) |
|
|
|
def forward(self, q, h=None, mask=torch.zeros([], dtype=torch.bool)): |
|
|
|
if h is None: |
|
h = q |
|
|
|
query = self.queryEncoder(q) |
|
key = self.keyEncoder(h) |
|
value = self.valueEncoder(h) |
|
|
|
out = self.MHA(query, key, value, mask) |
|
|
|
return out |
|
|