|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import math |
|
|
|
import numpy as np |
|
import torch |
|
from configuration_bertabs import BertAbsConfig |
|
from torch import nn |
|
from torch.nn.init import xavier_uniform_ |
|
|
|
from transformers import BertConfig, BertModel, PreTrainedModel |
|
|
|
|
|
MAX_SIZE = 5000 |
|
|
|
BERTABS_FINETUNED_MODEL_ARCHIVE_LIST = [ |
|
"remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization", |
|
] |
|
|
|
|
|
class BertAbsPreTrainedModel(PreTrainedModel): |
|
config_class = BertAbsConfig |
|
load_tf_weights = False |
|
base_model_prefix = "bert" |
|
|
|
|
|
class BertAbs(BertAbsPreTrainedModel): |
|
def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None): |
|
super().__init__(args) |
|
self.args = args |
|
self.bert = Bert() |
|
|
|
|
|
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False |
|
if load_bert_pretrained_extractive: |
|
self.bert.model.load_state_dict( |
|
{n[11:]: p for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")}, |
|
strict=True, |
|
) |
|
|
|
self.vocab_size = self.bert.model.config.vocab_size |
|
|
|
if args.max_pos > 512: |
|
my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) |
|
my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data |
|
my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][ |
|
None, : |
|
].repeat(args.max_pos - 512, 1) |
|
self.bert.model.embeddings.position_embeddings = my_pos_embeddings |
|
tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) |
|
|
|
tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) |
|
|
|
self.decoder = TransformerDecoder( |
|
self.args.dec_layers, |
|
self.args.dec_hidden_size, |
|
heads=self.args.dec_heads, |
|
d_ff=self.args.dec_ff_size, |
|
dropout=self.args.dec_dropout, |
|
embeddings=tgt_embeddings, |
|
vocab_size=self.vocab_size, |
|
) |
|
|
|
gen_func = nn.LogSoftmax(dim=-1) |
|
self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func) |
|
self.generator[0].weight = self.decoder.embeddings.weight |
|
|
|
load_from_checkpoints = False if checkpoint is None else True |
|
if load_from_checkpoints: |
|
self.load_state_dict(checkpoint) |
|
|
|
def init_weights(self): |
|
for module in self.decoder.modules(): |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
for p in self.generator.parameters(): |
|
if p.dim() > 1: |
|
xavier_uniform_(p) |
|
else: |
|
p.data.zero_() |
|
|
|
def forward( |
|
self, |
|
encoder_input_ids, |
|
decoder_input_ids, |
|
token_type_ids, |
|
encoder_attention_mask, |
|
decoder_attention_mask, |
|
): |
|
encoder_output = self.bert( |
|
input_ids=encoder_input_ids, |
|
token_type_ids=token_type_ids, |
|
attention_mask=encoder_attention_mask, |
|
) |
|
encoder_hidden_states = encoder_output[0] |
|
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states) |
|
decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state) |
|
return decoder_outputs |
|
|
|
|
|
class Bert(nn.Module): |
|
"""This class is not really necessary and should probably disappear.""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
config = BertConfig.from_pretrained("bert-base-uncased") |
|
self.model = BertModel(config) |
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs): |
|
self.eval() |
|
with torch.no_grad(): |
|
encoder_outputs, _ = self.model( |
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs |
|
) |
|
return encoder_outputs |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
""" |
|
The Transformer decoder from "Attention is All You Need". |
|
|
|
Args: |
|
num_layers (int): number of encoder layers. |
|
d_model (int): size of the model |
|
heads (int): number of heads |
|
d_ff (int): size of the inner FF layer |
|
dropout (float): dropout parameters |
|
embeddings (:obj:`onmt.modules.Embeddings`): |
|
embeddings to use, should have positional encodings |
|
attn_type (str): if using a separate copy attention |
|
""" |
|
|
|
def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size): |
|
super().__init__() |
|
|
|
|
|
self.decoder_type = "transformer" |
|
self.num_layers = num_layers |
|
self.embeddings = embeddings |
|
self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim) |
|
|
|
|
|
self.transformer_layers = nn.ModuleList( |
|
[TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)] |
|
) |
|
|
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) |
|
|
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
encoder_hidden_states=None, |
|
state=None, |
|
attention_mask=None, |
|
memory_lengths=None, |
|
step=None, |
|
cache=None, |
|
encoder_attention_mask=None, |
|
): |
|
""" |
|
See :obj:`onmt.modules.RNNDecoderBase.forward()` |
|
memory_bank = encoder_hidden_states |
|
""" |
|
|
|
tgt = input_ids |
|
memory_bank = encoder_hidden_states |
|
memory_mask = encoder_attention_mask |
|
|
|
|
|
src_words = state.src |
|
src_batch, src_len = src_words.size() |
|
|
|
padding_idx = self.embeddings.padding_idx |
|
|
|
|
|
tgt_words = tgt |
|
tgt_batch, tgt_len = tgt_words.size() |
|
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len) |
|
|
|
|
|
if memory_mask is not None: |
|
src_len = memory_mask.size(-1) |
|
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len) |
|
else: |
|
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len) |
|
|
|
|
|
emb = self.embeddings(input_ids) |
|
output = self.pos_emb(emb, step) |
|
assert emb.dim() == 3 |
|
|
|
if state.cache is None: |
|
saved_inputs = [] |
|
|
|
for i in range(self.num_layers): |
|
prev_layer_input = None |
|
if state.cache is None: |
|
if state.previous_input is not None: |
|
prev_layer_input = state.previous_layer_inputs[i] |
|
|
|
output, all_input = self.transformer_layers[i]( |
|
output, |
|
memory_bank, |
|
src_pad_mask, |
|
tgt_pad_mask, |
|
previous_input=prev_layer_input, |
|
layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None, |
|
step=step, |
|
) |
|
if state.cache is None: |
|
saved_inputs.append(all_input) |
|
|
|
if state.cache is None: |
|
saved_inputs = torch.stack(saved_inputs) |
|
|
|
output = self.layer_norm(output) |
|
|
|
if state.cache is None: |
|
state = state.update_state(tgt, saved_inputs) |
|
|
|
|
|
|
|
return output, state |
|
|
|
def init_decoder_state(self, src, memory_bank, with_cache=False): |
|
"""Init decoder state""" |
|
state = TransformerDecoderState(src) |
|
if with_cache: |
|
state._init_cache(memory_bank, self.num_layers) |
|
return state |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, dropout, dim, max_len=5000): |
|
pe = torch.zeros(max_len, dim) |
|
position = torch.arange(0, max_len).unsqueeze(1) |
|
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))) |
|
pe[:, 0::2] = torch.sin(position.float() * div_term) |
|
pe[:, 1::2] = torch.cos(position.float() * div_term) |
|
pe = pe.unsqueeze(0) |
|
super().__init__() |
|
self.register_buffer("pe", pe) |
|
self.dropout = nn.Dropout(p=dropout) |
|
self.dim = dim |
|
|
|
def forward(self, emb, step=None): |
|
emb = emb * math.sqrt(self.dim) |
|
if step: |
|
emb = emb + self.pe[:, step][:, None, :] |
|
|
|
else: |
|
emb = emb + self.pe[:, : emb.size(1)] |
|
emb = self.dropout(emb) |
|
return emb |
|
|
|
def get_emb(self, emb): |
|
return self.pe[:, : emb.size(1)] |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
""" |
|
Args: |
|
d_model (int): the dimension of keys/values/queries in |
|
MultiHeadedAttention, also the input size of |
|
the first-layer of the PositionwiseFeedForward. |
|
heads (int): the number of heads for MultiHeadedAttention. |
|
d_ff (int): the second-layer of the PositionwiseFeedForward. |
|
dropout (float): dropout probability(0-1.0). |
|
self_attn_type (string): type of self-attention scaled-dot, average |
|
""" |
|
|
|
def __init__(self, d_model, heads, d_ff, dropout): |
|
super().__init__() |
|
|
|
self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) |
|
|
|
self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) |
|
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) |
|
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) |
|
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) |
|
self.drop = nn.Dropout(dropout) |
|
mask = self._get_attn_subsequent_mask(MAX_SIZE) |
|
|
|
|
|
self.register_buffer("mask", mask) |
|
|
|
def forward( |
|
self, |
|
inputs, |
|
memory_bank, |
|
src_pad_mask, |
|
tgt_pad_mask, |
|
previous_input=None, |
|
layer_cache=None, |
|
step=None, |
|
): |
|
""" |
|
Args: |
|
inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` |
|
memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` |
|
src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` |
|
tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` |
|
|
|
Returns: |
|
(`FloatTensor`, `FloatTensor`, `FloatTensor`): |
|
|
|
* output `[batch_size x 1 x model_dim]` |
|
* attn `[batch_size x 1 x src_len]` |
|
* all_input `[batch_size x current_step x model_dim]` |
|
|
|
""" |
|
dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0) |
|
input_norm = self.layer_norm_1(inputs) |
|
all_input = input_norm |
|
if previous_input is not None: |
|
all_input = torch.cat((previous_input, input_norm), dim=1) |
|
dec_mask = None |
|
|
|
query = self.self_attn( |
|
all_input, |
|
all_input, |
|
input_norm, |
|
mask=dec_mask, |
|
layer_cache=layer_cache, |
|
type="self", |
|
) |
|
|
|
query = self.drop(query) + inputs |
|
|
|
query_norm = self.layer_norm_2(query) |
|
mid = self.context_attn( |
|
memory_bank, |
|
memory_bank, |
|
query_norm, |
|
mask=src_pad_mask, |
|
layer_cache=layer_cache, |
|
type="context", |
|
) |
|
output = self.feed_forward(self.drop(mid) + query) |
|
|
|
return output, all_input |
|
|
|
|
|
def _get_attn_subsequent_mask(self, size): |
|
""" |
|
Get an attention mask to avoid using the subsequent info. |
|
|
|
Args: |
|
size: int |
|
|
|
Returns: |
|
(`LongTensor`): |
|
|
|
* subsequent_mask `[1 x size x size]` |
|
""" |
|
attn_shape = (1, size, size) |
|
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") |
|
subsequent_mask = torch.from_numpy(subsequent_mask) |
|
return subsequent_mask |
|
|
|
|
|
class MultiHeadedAttention(nn.Module): |
|
""" |
|
Multi-Head Attention module from |
|
"Attention is All You Need" |
|
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. |
|
|
|
Similar to standard `dot` attention but uses |
|
multiple attention distributions simulataneously |
|
to select relevant items. |
|
|
|
.. mermaid:: |
|
|
|
graph BT |
|
A[key] |
|
B[value] |
|
C[query] |
|
O[output] |
|
subgraph Attn |
|
D[Attn 1] |
|
E[Attn 2] |
|
F[Attn N] |
|
end |
|
A --> D |
|
C --> D |
|
A --> E |
|
C --> E |
|
A --> F |
|
C --> F |
|
D --> O |
|
E --> O |
|
F --> O |
|
B --> O |
|
|
|
Also includes several additional tricks. |
|
|
|
Args: |
|
head_count (int): number of parallel heads |
|
model_dim (int): the dimension of keys/values/queries, |
|
must be divisible by head_count |
|
dropout (float): dropout parameter |
|
""" |
|
|
|
def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): |
|
assert model_dim % head_count == 0 |
|
self.dim_per_head = model_dim // head_count |
|
self.model_dim = model_dim |
|
|
|
super().__init__() |
|
self.head_count = head_count |
|
|
|
self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) |
|
self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head) |
|
self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head) |
|
self.softmax = nn.Softmax(dim=-1) |
|
self.dropout = nn.Dropout(dropout) |
|
self.use_final_linear = use_final_linear |
|
if self.use_final_linear: |
|
self.final_linear = nn.Linear(model_dim, model_dim) |
|
|
|
def forward( |
|
self, |
|
key, |
|
value, |
|
query, |
|
mask=None, |
|
layer_cache=None, |
|
type=None, |
|
predefined_graph_1=None, |
|
): |
|
""" |
|
Compute the context vector and the attention vectors. |
|
|
|
Args: |
|
key (`FloatTensor`): set of `key_len` |
|
key vectors `[batch, key_len, dim]` |
|
value (`FloatTensor`): set of `key_len` |
|
value vectors `[batch, key_len, dim]` |
|
query (`FloatTensor`): set of `query_len` |
|
query vectors `[batch, query_len, dim]` |
|
mask: binary mask indicating which keys have |
|
non-zero attention `[batch, query_len, key_len]` |
|
Returns: |
|
(`FloatTensor`, `FloatTensor`) : |
|
|
|
* output context vectors `[batch, query_len, dim]` |
|
* one of the attention vectors `[batch, query_len, key_len]` |
|
""" |
|
batch_size = key.size(0) |
|
dim_per_head = self.dim_per_head |
|
head_count = self.head_count |
|
|
|
def shape(x): |
|
"""projection""" |
|
return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2) |
|
|
|
def unshape(x): |
|
"""compute context""" |
|
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head) |
|
|
|
|
|
if layer_cache is not None: |
|
if type == "self": |
|
query, key, value = ( |
|
self.linear_query(query), |
|
self.linear_keys(query), |
|
self.linear_values(query), |
|
) |
|
|
|
key = shape(key) |
|
value = shape(value) |
|
|
|
if layer_cache is not None: |
|
device = key.device |
|
if layer_cache["self_keys"] is not None: |
|
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2) |
|
if layer_cache["self_values"] is not None: |
|
value = torch.cat((layer_cache["self_values"].to(device), value), dim=2) |
|
layer_cache["self_keys"] = key |
|
layer_cache["self_values"] = value |
|
elif type == "context": |
|
query = self.linear_query(query) |
|
if layer_cache is not None: |
|
if layer_cache["memory_keys"] is None: |
|
key, value = self.linear_keys(key), self.linear_values(value) |
|
key = shape(key) |
|
value = shape(value) |
|
else: |
|
key, value = ( |
|
layer_cache["memory_keys"], |
|
layer_cache["memory_values"], |
|
) |
|
layer_cache["memory_keys"] = key |
|
layer_cache["memory_values"] = value |
|
else: |
|
key, value = self.linear_keys(key), self.linear_values(value) |
|
key = shape(key) |
|
value = shape(value) |
|
else: |
|
key = self.linear_keys(key) |
|
value = self.linear_values(value) |
|
query = self.linear_query(query) |
|
key = shape(key) |
|
value = shape(value) |
|
|
|
query = shape(query) |
|
|
|
|
|
query = query / math.sqrt(dim_per_head) |
|
scores = torch.matmul(query, key.transpose(2, 3)) |
|
|
|
if mask is not None: |
|
mask = mask.unsqueeze(1).expand_as(scores) |
|
scores = scores.masked_fill(mask, -1e18) |
|
|
|
|
|
|
|
attn = self.softmax(scores) |
|
|
|
if predefined_graph_1 is not None: |
|
attn_masked = attn[:, -1] * predefined_graph_1 |
|
attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) |
|
|
|
attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) |
|
|
|
drop_attn = self.dropout(attn) |
|
if self.use_final_linear: |
|
context = unshape(torch.matmul(drop_attn, value)) |
|
output = self.final_linear(context) |
|
return output |
|
else: |
|
context = torch.matmul(drop_attn, value) |
|
return context |
|
|
|
|
|
class DecoderState(object): |
|
"""Interface for grouping together the current state of a recurrent |
|
decoder. In the simplest case just represents the hidden state of |
|
the model. But can also be used for implementing various forms of |
|
input_feeding and non-recurrent models. |
|
|
|
Modules need to implement this to utilize beam search decoding. |
|
""" |
|
|
|
def detach(self): |
|
"""Need to document this""" |
|
self.hidden = tuple([_.detach() for _ in self.hidden]) |
|
self.input_feed = self.input_feed.detach() |
|
|
|
def beam_update(self, idx, positions, beam_size): |
|
"""Need to document this""" |
|
for e in self._all: |
|
sizes = e.size() |
|
br = sizes[1] |
|
if len(sizes) == 3: |
|
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx] |
|
else: |
|
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx] |
|
|
|
sent_states.data.copy_(sent_states.data.index_select(1, positions)) |
|
|
|
def map_batch_fn(self, fn): |
|
raise NotImplementedError() |
|
|
|
|
|
class TransformerDecoderState(DecoderState): |
|
"""Transformer Decoder state base class""" |
|
|
|
def __init__(self, src): |
|
""" |
|
Args: |
|
src (FloatTensor): a sequence of source words tensors |
|
with optional feature tensors, of size (len x batch). |
|
""" |
|
self.src = src |
|
self.previous_input = None |
|
self.previous_layer_inputs = None |
|
self.cache = None |
|
|
|
@property |
|
def _all(self): |
|
""" |
|
Contains attributes that need to be updated in self.beam_update(). |
|
""" |
|
if self.previous_input is not None and self.previous_layer_inputs is not None: |
|
return (self.previous_input, self.previous_layer_inputs, self.src) |
|
else: |
|
return (self.src,) |
|
|
|
def detach(self): |
|
if self.previous_input is not None: |
|
self.previous_input = self.previous_input.detach() |
|
if self.previous_layer_inputs is not None: |
|
self.previous_layer_inputs = self.previous_layer_inputs.detach() |
|
self.src = self.src.detach() |
|
|
|
def update_state(self, new_input, previous_layer_inputs): |
|
state = TransformerDecoderState(self.src) |
|
state.previous_input = new_input |
|
state.previous_layer_inputs = previous_layer_inputs |
|
return state |
|
|
|
def _init_cache(self, memory_bank, num_layers): |
|
self.cache = {} |
|
|
|
for l in range(num_layers): |
|
layer_cache = {"memory_keys": None, "memory_values": None} |
|
layer_cache["self_keys"] = None |
|
layer_cache["self_values"] = None |
|
self.cache["layer_{}".format(l)] = layer_cache |
|
|
|
def repeat_beam_size_times(self, beam_size): |
|
"""Repeat beam_size times along batch dimension.""" |
|
self.src = self.src.data.repeat(1, beam_size, 1) |
|
|
|
def map_batch_fn(self, fn): |
|
def _recursive_map(struct, batch_dim=0): |
|
for k, v in struct.items(): |
|
if v is not None: |
|
if isinstance(v, dict): |
|
_recursive_map(v) |
|
else: |
|
struct[k] = fn(v, batch_dim) |
|
|
|
self.src = fn(self.src, 0) |
|
if self.cache is not None: |
|
_recursive_map(self.cache) |
|
|
|
|
|
def gelu(x): |
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
"""A two-layer Feed-Forward-Network with residual layer norm. |
|
|
|
Args: |
|
d_model (int): the size of input for the first-layer of the FFN. |
|
d_ff (int): the hidden layer size of the second-layer |
|
of the FNN. |
|
dropout (float): dropout probability in :math:`[0, 1)`. |
|
""" |
|
|
|
def __init__(self, d_model, d_ff, dropout=0.1): |
|
super().__init__() |
|
self.w_1 = nn.Linear(d_model, d_ff) |
|
self.w_2 = nn.Linear(d_ff, d_model) |
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) |
|
self.actv = gelu |
|
self.dropout_1 = nn.Dropout(dropout) |
|
self.dropout_2 = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) |
|
output = self.dropout_2(self.w_2(inter)) |
|
return output + x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_predictor(args, tokenizer, symbols, model, logger=None): |
|
|
|
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu") |
|
translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger) |
|
return translator |
|
|
|
|
|
class GNMTGlobalScorer(object): |
|
""" |
|
NMT re-ranking score from |
|
"Google's Neural Machine Translation System" :cite:`wu2016google` |
|
|
|
Args: |
|
alpha (float): length parameter |
|
beta (float): coverage parameter |
|
""" |
|
|
|
def __init__(self, alpha, length_penalty): |
|
self.alpha = alpha |
|
penalty_builder = PenaltyBuilder(length_penalty) |
|
self.length_penalty = penalty_builder.length_penalty() |
|
|
|
def score(self, beam, logprobs): |
|
""" |
|
Rescores a prediction based on penalty functions |
|
""" |
|
normalized_probs = self.length_penalty(beam, logprobs, self.alpha) |
|
return normalized_probs |
|
|
|
|
|
class PenaltyBuilder(object): |
|
""" |
|
Returns the Length and Coverage Penalty function for Beam Search. |
|
|
|
Args: |
|
length_pen (str): option name of length pen |
|
cov_pen (str): option name of cov pen |
|
""" |
|
|
|
def __init__(self, length_pen): |
|
self.length_pen = length_pen |
|
|
|
def length_penalty(self): |
|
if self.length_pen == "wu": |
|
return self.length_wu |
|
elif self.length_pen == "avg": |
|
return self.length_average |
|
else: |
|
return self.length_none |
|
|
|
""" |
|
Below are all the different penalty terms implemented so far |
|
""" |
|
|
|
def length_wu(self, beam, logprobs, alpha=0.0): |
|
""" |
|
NMT length re-ranking score from |
|
"Google's Neural Machine Translation System" :cite:`wu2016google`. |
|
""" |
|
|
|
modifier = ((5 + len(beam.next_ys)) ** alpha) / ((5 + 1) ** alpha) |
|
return logprobs / modifier |
|
|
|
def length_average(self, beam, logprobs, alpha=0.0): |
|
""" |
|
Returns the average probability of tokens in a sequence. |
|
""" |
|
return logprobs / len(beam.next_ys) |
|
|
|
def length_none(self, beam, logprobs, alpha=0.0, beta=0.0): |
|
""" |
|
Returns unmodified scores. |
|
""" |
|
return logprobs |
|
|
|
|
|
class Translator(object): |
|
""" |
|
Uses a model to translate a batch of sentences. |
|
|
|
Args: |
|
model (:obj:`onmt.modules.NMTModel`): |
|
NMT model to use for translation |
|
fields (dict of Fields): data fields |
|
beam_size (int): size of beam to use |
|
n_best (int): number of translations produced |
|
max_length (int): maximum length output to produce |
|
global_scores (:obj:`GlobalScorer`): |
|
object to rescore final translations |
|
copy_attn (bool): use copy attention during translation |
|
beam_trace (bool): trace beam search for debugging |
|
logger(logging.Logger): logger. |
|
""" |
|
|
|
def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None): |
|
self.logger = logger |
|
|
|
self.args = args |
|
self.model = model |
|
self.generator = self.model.generator |
|
self.vocab = vocab |
|
self.symbols = symbols |
|
self.start_token = symbols["BOS"] |
|
self.end_token = symbols["EOS"] |
|
|
|
self.global_scorer = global_scorer |
|
self.beam_size = args.beam_size |
|
self.min_length = args.min_length |
|
self.max_length = args.max_length |
|
|
|
def translate(self, batch, step, attn_debug=False): |
|
"""Generates summaries from one batch of data.""" |
|
self.model.eval() |
|
with torch.no_grad(): |
|
batch_data = self.translate_batch(batch) |
|
translations = self.from_batch(batch_data) |
|
return translations |
|
|
|
def translate_batch(self, batch, fast=False): |
|
""" |
|
Translate a batch of sentences. |
|
|
|
Mostly a wrapper around :obj:`Beam`. |
|
|
|
Args: |
|
batch (:obj:`Batch`): a batch from a dataset object |
|
fast (bool): enables fast beam search (may not support all features) |
|
""" |
|
with torch.no_grad(): |
|
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length) |
|
|
|
|
|
|
|
def _fast_translate_batch(self, batch, max_length, min_length=0): |
|
"""Beam Search using the encoder inputs contained in `batch`.""" |
|
|
|
|
|
|
|
|
|
|
|
beam_size = self.beam_size |
|
batch_size = batch.batch_size |
|
src = batch.src |
|
segs = batch.segs |
|
mask_src = batch.mask_src |
|
|
|
src_features = self.model.bert(src, segs, mask_src) |
|
dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True) |
|
device = src_features.device |
|
|
|
|
|
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim)) |
|
src_features = tile(src_features, beam_size, dim=0) |
|
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) |
|
beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device) |
|
alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device) |
|
|
|
|
|
topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size) |
|
|
|
|
|
hypotheses = [[] for _ in range(batch_size)] |
|
|
|
results = {} |
|
results["predictions"] = [[] for _ in range(batch_size)] |
|
results["scores"] = [[] for _ in range(batch_size)] |
|
results["gold_score"] = [0] * batch_size |
|
results["batch"] = batch |
|
|
|
for step in range(max_length): |
|
decoder_input = alive_seq[:, -1].view(1, -1) |
|
|
|
|
|
decoder_input = decoder_input.transpose(0, 1) |
|
|
|
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step) |
|
|
|
|
|
log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0)) |
|
vocab_size = log_probs.size(-1) |
|
|
|
if step < min_length: |
|
log_probs[:, self.end_token] = -1e20 |
|
|
|
|
|
log_probs += topk_log_probs.view(-1).unsqueeze(1) |
|
|
|
alpha = self.global_scorer.alpha |
|
length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha |
|
|
|
|
|
curr_scores = log_probs / length_penalty |
|
|
|
if self.args.block_trigram: |
|
cur_len = alive_seq.size(1) |
|
if cur_len > 3: |
|
for i in range(alive_seq.size(0)): |
|
fail = False |
|
words = [int(w) for w in alive_seq[i]] |
|
words = [self.vocab.ids_to_tokens[w] for w in words] |
|
words = " ".join(words).replace(" ##", "").split() |
|
if len(words) <= 3: |
|
continue |
|
trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)] |
|
trigram = tuple(trigrams[-1]) |
|
if trigram in trigrams[:-1]: |
|
fail = True |
|
if fail: |
|
curr_scores[i] = -10e20 |
|
|
|
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) |
|
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) |
|
|
|
|
|
topk_log_probs = topk_scores * length_penalty |
|
|
|
|
|
topk_beam_index = topk_ids.div(vocab_size) |
|
topk_ids = topk_ids.fmod(vocab_size) |
|
|
|
|
|
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1) |
|
select_indices = batch_index.view(-1) |
|
|
|
|
|
alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1) |
|
|
|
is_finished = topk_ids.eq(self.end_token) |
|
if step + 1 == max_length: |
|
is_finished.fill_(1) |
|
|
|
end_condition = is_finished[:, 0].eq(1) |
|
|
|
if is_finished.any(): |
|
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) |
|
for i in range(is_finished.size(0)): |
|
b = batch_offset[i] |
|
if end_condition[i]: |
|
is_finished[i].fill_(1) |
|
finished_hyp = is_finished[i].nonzero().view(-1) |
|
|
|
for j in finished_hyp: |
|
hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:])) |
|
|
|
if end_condition[i]: |
|
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) |
|
score, pred = best_hyp[0] |
|
|
|
results["scores"][b].append(score) |
|
results["predictions"][b].append(pred) |
|
non_finished = end_condition.eq(0).nonzero().view(-1) |
|
|
|
if len(non_finished) == 0: |
|
break |
|
|
|
topk_log_probs = topk_log_probs.index_select(0, non_finished) |
|
batch_index = batch_index.index_select(0, non_finished) |
|
batch_offset = batch_offset.index_select(0, non_finished) |
|
alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1)) |
|
|
|
select_indices = batch_index.view(-1) |
|
src_features = src_features.index_select(0, select_indices) |
|
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices)) |
|
|
|
return results |
|
|
|
def from_batch(self, translation_batch): |
|
batch = translation_batch["batch"] |
|
assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"]) |
|
batch_size = batch.batch_size |
|
|
|
preds, _, _, tgt_str, src = ( |
|
translation_batch["predictions"], |
|
translation_batch["scores"], |
|
translation_batch["gold_score"], |
|
batch.tgt_str, |
|
batch.src, |
|
) |
|
|
|
translations = [] |
|
for b in range(batch_size): |
|
pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]]) |
|
pred_sents = " ".join(pred_sents).replace(" ##", "") |
|
gold_sent = " ".join(tgt_str[b].split()) |
|
raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500] |
|
raw_src = " ".join(raw_src) |
|
translation = (pred_sents, gold_sent, raw_src) |
|
translations.append(translation) |
|
|
|
return translations |
|
|
|
|
|
def tile(x, count, dim=0): |
|
""" |
|
Tiles x on dimension dim count times. |
|
""" |
|
perm = list(range(len(x.size()))) |
|
if dim != 0: |
|
perm[0], perm[dim] = perm[dim], perm[0] |
|
x = x.permute(perm).contiguous() |
|
out_size = list(x.size()) |
|
out_size[0] *= count |
|
batch = x.size(0) |
|
x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size) |
|
if dim != 0: |
|
x = x.permute(perm).contiguous() |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertSumOptimizer(object): |
|
"""Specific optimizer for BertSum. |
|
|
|
As described in [1], the authors fine-tune BertSum for abstractive |
|
summarization using two Adam Optimizers with different warm-up steps and |
|
learning rate. They also use a custom learning rate scheduler. |
|
|
|
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." |
|
arXiv preprint arXiv:1908.08345 (2019). |
|
""" |
|
|
|
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8): |
|
self.encoder = model.encoder |
|
self.decoder = model.decoder |
|
self.lr = lr |
|
self.warmup_steps = warmup_steps |
|
|
|
self.optimizers = { |
|
"encoder": torch.optim.Adam( |
|
model.encoder.parameters(), |
|
lr=lr["encoder"], |
|
betas=(beta_1, beta_2), |
|
eps=eps, |
|
), |
|
"decoder": torch.optim.Adam( |
|
model.decoder.parameters(), |
|
lr=lr["decoder"], |
|
betas=(beta_1, beta_2), |
|
eps=eps, |
|
), |
|
} |
|
|
|
self._step = 0 |
|
self.current_learning_rates = {} |
|
|
|
def _update_rate(self, stack): |
|
return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5)) |
|
|
|
def zero_grad(self): |
|
self.optimizer_decoder.zero_grad() |
|
self.optimizer_encoder.zero_grad() |
|
|
|
def step(self): |
|
self._step += 1 |
|
for stack, optimizer in self.optimizers.items(): |
|
new_rate = self._update_rate(stack) |
|
for param_group in optimizer.param_groups: |
|
param_group["lr"] = new_rate |
|
optimizer.step() |
|
self.current_learning_rates[stack] = new_rate |
|
|