Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. | |
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" PyTorch Transformer XL model. | |
Adapted from https://github.com/kimiyoung/transformer-xl. | |
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py | |
""" | |
from __future__ import absolute_import, division, print_function, unicode_literals | |
import os | |
import json | |
import math | |
import logging | |
import collections | |
import sys | |
from io import open | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss | |
from torch.nn.parameter import Parameter | |
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary | |
from .configuration_transfo_xl import TransfoXLConfig | |
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits | |
from .file_utils import add_start_docstrings | |
logger = logging.getLogger(__name__) | |
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { | |
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin", | |
} | |
def build_tf_to_pytorch_map(model, config): | |
""" A map of modules from TF to PyTorch. | |
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible. | |
""" | |
tf_to_pt_map = {} | |
if hasattr(model, 'transformer'): | |
# We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax | |
tf_to_pt_map.update({ | |
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight, | |
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias}) | |
for i, (out_l, proj_l, tie_proj) in enumerate(zip( | |
model.crit.out_layers, | |
model.crit.out_projs, | |
config.tie_projs)): | |
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i | |
if config.tie_weight: | |
tf_to_pt_map.update({ | |
layer_str + 'b': out_l.bias}) | |
else: | |
raise NotImplementedError | |
# I don't think this is implemented in the TF code | |
tf_to_pt_map.update({ | |
layer_str + 'lookup_table': out_l.weight, | |
layer_str + 'b': out_l.bias}) | |
if not tie_proj: | |
tf_to_pt_map.update({ | |
layer_str + 'proj': proj_l | |
}) | |
# Now load the rest of the transformer | |
model = model.transformer | |
# Embeddings | |
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)): | |
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i | |
tf_to_pt_map.update({ | |
layer_str + 'lookup_table': embed_l.weight, | |
layer_str + 'proj_W': proj_l | |
}) | |
# Transformer blocks | |
for i, b in enumerate(model.layers): | |
layer_str = "transformer/layer_%d/" % i | |
tf_to_pt_map.update({ | |
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight, | |
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias, | |
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight, | |
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight, | |
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight, | |
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight, | |
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias, | |
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight, | |
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias, | |
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight, | |
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias, | |
}) | |
# Relative positioning biases | |
if config.untie_r: | |
r_r_list = [] | |
r_w_list = [] | |
for b in model.layers: | |
r_r_list.append(b.dec_attn.r_r_bias) | |
r_w_list.append(b.dec_attn.r_w_bias) | |
else: | |
r_r_list = [model.r_r_bias] | |
r_w_list = [model.r_w_bias] | |
tf_to_pt_map.update({ | |
'transformer/r_r_bias': r_r_list, | |
'transformer/r_w_bias': r_w_list}) | |
return tf_to_pt_map | |
def load_tf_weights_in_transfo_xl(model, config, tf_path): | |
""" Load tf checkpoints in a pytorch model | |
""" | |
try: | |
import numpy as np | |
import tensorflow as tf | |
except ImportError: | |
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " | |
"https://www.tensorflow.org/install/ for installation instructions.") | |
raise | |
# Build TF to PyTorch weights loading map | |
tf_to_pt_map = build_tf_to_pytorch_map(model, config) | |
# Load weights from TF model | |
init_vars = tf.train.list_variables(tf_path) | |
tf_weights = {} | |
for name, shape in init_vars: | |
logger.info("Loading TF weight {} with shape {}".format(name, shape)) | |
array = tf.train.load_variable(tf_path, name) | |
tf_weights[name] = array | |
for name, pointer in tf_to_pt_map.items(): | |
assert name in tf_weights | |
array = tf_weights[name] | |
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v | |
# which are not required for using pretrained model | |
if 'kernel' in name or 'proj' in name: | |
array = np.transpose(array) | |
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1: | |
# Here we will split the TF weigths | |
assert len(pointer) == array.shape[0] | |
for i, p_i in enumerate(pointer): | |
arr_i = array[i, ...] | |
try: | |
assert p_i.shape == arr_i.shape | |
except AssertionError as e: | |
e.args += (p_i.shape, arr_i.shape) | |
raise | |
logger.info("Initialize PyTorch weight {} for layer {}".format(name, i)) | |
p_i.data = torch.from_numpy(arr_i) | |
else: | |
try: | |
assert pointer.shape == array.shape | |
except AssertionError as e: | |
e.args += (pointer.shape, array.shape) | |
raise | |
logger.info("Initialize PyTorch weight {}".format(name)) | |
pointer.data = torch.from_numpy(array) | |
tf_weights.pop(name, None) | |
tf_weights.pop(name + '/Adam', None) | |
tf_weights.pop(name + '/Adam_1', None) | |
logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys()))) | |
return model | |
class PositionalEmbedding(nn.Module): | |
def __init__(self, demb): | |
super(PositionalEmbedding, self).__init__() | |
self.demb = demb | |
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) | |
self.register_buffer('inv_freq', inv_freq) | |
def forward(self, pos_seq, bsz=None): | |
sinusoid_inp = torch.ger(pos_seq, self.inv_freq) | |
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) | |
if bsz is not None: | |
return pos_emb[:,None,:].expand(-1, bsz, -1) | |
else: | |
return pos_emb[:,None,:] | |
class PositionwiseFF(nn.Module): | |
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): | |
super(PositionwiseFF, self).__init__() | |
self.d_model = d_model | |
self.d_inner = d_inner | |
self.dropout = dropout | |
self.CoreNet = nn.Sequential( | |
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), | |
nn.Dropout(dropout), | |
nn.Linear(d_inner, d_model), | |
nn.Dropout(dropout), | |
) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.pre_lnorm = pre_lnorm | |
def forward(self, inp): | |
if self.pre_lnorm: | |
##### layer normalization + positionwise feed-forward | |
core_out = self.CoreNet(self.layer_norm(inp)) | |
##### residual connection | |
output = core_out + inp | |
else: | |
##### positionwise feed-forward | |
core_out = self.CoreNet(inp) | |
##### residual connection + layer normalization | |
output = self.layer_norm(inp + core_out) | |
return output | |
class MultiHeadAttn(nn.Module): | |
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, | |
pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False): | |
super(MultiHeadAttn, self).__init__() | |
self.output_attentions = output_attentions | |
self.n_head = n_head | |
self.d_model = d_model | |
self.d_head = d_head | |
self.dropout = dropout | |
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) | |
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) | |
self.drop = nn.Dropout(dropout) | |
self.dropatt = nn.Dropout(dropatt) | |
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.scale = 1 / (d_head ** 0.5) | |
self.pre_lnorm = pre_lnorm | |
if r_r_bias is None or r_w_bias is None: # Biases are not shared | |
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) | |
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) | |
else: | |
self.r_r_bias = r_r_bias | |
self.r_w_bias = r_w_bias | |
def forward(self, h, attn_mask=None, mems=None, head_mask=None): | |
##### multihead attention | |
# [hlen x bsz x n_head x d_head] | |
if mems is not None: | |
c = torch.cat([mems, h], 0) | |
else: | |
c = h | |
if self.pre_lnorm: | |
##### layer normalization | |
c = self.layer_norm(c) | |
head_q = self.q_net(h) | |
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) | |
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head) | |
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) | |
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) | |
# [qlen x klen x bsz x n_head] | |
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) | |
attn_score.mul_(self.scale) | |
if attn_mask is not None and torch.sum(attn_mask).item(): | |
attn_mask = (attn_mask == 1) # Switch to bool | |
if attn_mask.dim() == 2: | |
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) | |
elif attn_mask.dim() == 3: | |
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) | |
# [qlen x klen x bsz x n_head] | |
attn_prob = F.softmax(attn_score, dim=1) | |
attn_prob = self.dropatt(attn_prob) | |
# Mask heads if we want to | |
if head_mask is not None: | |
attn_prob = attn_prob * head_mask | |
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] | |
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) | |
attn_vec = attn_vec.contiguous().view( | |
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) | |
##### linear projection | |
attn_out = self.o_net(attn_vec) | |
attn_out = self.drop(attn_out) | |
if self.pre_lnorm: | |
##### residual connection | |
outputs = [h + attn_out] | |
else: | |
##### residual connection + layer normalization | |
outputs = [self.layer_norm(h + attn_out)] | |
if self.output_attentions: | |
outputs.append(attn_prob) | |
return outputs | |
class RelMultiHeadAttn(nn.Module): | |
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, | |
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, | |
r_r_bias=None, r_w_bias=None, output_attentions=False): | |
super(RelMultiHeadAttn, self).__init__() | |
self.output_attentions = output_attentions | |
self.n_head = n_head | |
self.d_model = d_model | |
self.d_head = d_head | |
self.dropout = dropout | |
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) | |
self.drop = nn.Dropout(dropout) | |
self.dropatt = nn.Dropout(dropatt) | |
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.scale = 1 / (d_head ** 0.5) | |
self.pre_lnorm = pre_lnorm | |
if r_r_bias is None or r_w_bias is None: # Biases are not shared | |
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) | |
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) | |
else: | |
self.r_r_bias = r_r_bias | |
self.r_w_bias = r_w_bias | |
def _parallelogram_mask(self, h, w, left=False): | |
mask = torch.ones((h, w)).byte() | |
m = min(h, w) | |
mask[:m,:m] = torch.triu(mask[:m,:m]) | |
mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) | |
if left: | |
return mask | |
else: | |
return mask.flip(0) | |
def _shift(self, x, qlen, klen, mask, left=False): | |
if qlen > 1: | |
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), | |
device=x.device, dtype=x.dtype) | |
else: | |
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) | |
if left: | |
mask = mask.flip(1) | |
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) | |
else: | |
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) | |
x = x_padded.masked_select(mask[:,:,None,None]) \ | |
.view(qlen, klen, x.size(2), x.size(3)) | |
return x | |
def _rel_shift(self, x, zero_triu=False): | |
zero_pad_shape = (x.size(0), 1) + x.size()[2:] | |
zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype) | |
x_padded = torch.cat([zero_pad, x], dim=1) | |
x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:] | |
x_padded = x_padded.view(*x_padded_shape) | |
x = x_padded[1:].view_as(x) | |
if zero_triu: | |
ones = torch.ones((x.size(0), x.size(1))) | |
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] | |
return x | |
def forward(self, w, r, attn_mask=None, mems=None): | |
raise NotImplementedError | |
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): | |
def __init__(self, *args, **kwargs): | |
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) | |
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) | |
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None): | |
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) | |
if mems is not None: | |
cat = torch.cat([mems, w], 0) | |
if self.pre_lnorm: | |
w_heads = self.qkv_net(self.layer_norm(cat)) | |
else: | |
w_heads = self.qkv_net(cat) | |
r_head_k = self.r_net(r) | |
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) | |
w_head_q = w_head_q[-qlen:] | |
else: | |
if self.pre_lnorm: | |
w_heads = self.qkv_net(self.layer_norm(w)) | |
else: | |
w_heads = self.qkv_net(w) | |
r_head_k = self.r_net(r) | |
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) | |
klen = w_head_k.size(0) | |
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head | |
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head | |
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head | |
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head | |
#### compute attention score | |
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head | |
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head | |
rr_head_q = w_head_q + self.r_r_bias | |
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head | |
BD = self._rel_shift(BD) | |
# [qlen x klen x bsz x n_head] | |
attn_score = AC + BD | |
attn_score.mul_(self.scale) | |
#### compute attention probability | |
if attn_mask is not None and torch.sum(attn_mask).item(): | |
attn_mask = (attn_mask == 1) # Switch to bool | |
if attn_mask.dim() == 2: | |
if next(self.parameters()).dtype == torch.float16: | |
attn_score = attn_score.float().masked_fill( | |
attn_mask[None,:,:,None], -65000).type_as(attn_score) | |
else: | |
attn_score = attn_score.float().masked_fill( | |
attn_mask[None,:,:,None], -1e30).type_as(attn_score) | |
elif attn_mask.dim() == 3: | |
if next(self.parameters()).dtype == torch.float16: | |
attn_score = attn_score.float().masked_fill( | |
attn_mask[:,:,:,None], -65000).type_as(attn_score) | |
else: | |
attn_score = attn_score.float().masked_fill( | |
attn_mask[:,:,:,None], -1e30).type_as(attn_score) | |
# [qlen x klen x bsz x n_head] | |
attn_prob = F.softmax(attn_score, dim=1) | |
attn_prob = self.dropatt(attn_prob) | |
# Mask heads if we want to | |
if head_mask is not None: | |
attn_prob = attn_prob * head_mask | |
#### compute attention vector | |
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) | |
# [qlen x bsz x n_head x d_head] | |
attn_vec = attn_vec.contiguous().view( | |
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) | |
##### linear projection | |
attn_out = self.o_net(attn_vec) | |
attn_out = self.drop(attn_out) | |
if self.pre_lnorm: | |
##### residual connection | |
outputs = [w + attn_out] | |
else: | |
##### residual connection + layer normalization | |
outputs = [self.layer_norm(w + attn_out)] | |
if self.output_attentions: | |
outputs.append(attn_prob) | |
return outputs | |
class RelLearnableMultiHeadAttn(RelMultiHeadAttn): | |
def __init__(self, *args, **kwargs): | |
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) | |
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None): | |
# r_emb: [klen, n_head, d_head], used for term B | |
# r_w_bias: [n_head, d_head], used for term C | |
# r_bias: [klen, n_head], used for term D | |
qlen, bsz = w.size(0), w.size(1) | |
if mems is not None: | |
cat = torch.cat([mems, w], 0) | |
if self.pre_lnorm: | |
w_heads = self.qkv_net(self.layer_norm(cat)) | |
else: | |
w_heads = self.qkv_net(cat) | |
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) | |
w_head_q = w_head_q[-qlen:] | |
else: | |
if self.pre_lnorm: | |
w_heads = self.qkv_net(self.layer_norm(w)) | |
else: | |
w_heads = self.qkv_net(w) | |
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) | |
klen = w_head_k.size(0) | |
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) | |
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) | |
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) | |
if klen > r_emb.size(0): | |
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1) | |
r_emb = torch.cat([r_emb_pad, r_emb], 0) | |
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1) | |
r_bias = torch.cat([r_bias_pad, r_bias], 0) | |
else: | |
r_emb = r_emb[-klen:] | |
r_bias = r_bias[-klen:] | |
#### compute attention score | |
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head | |
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head | |
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head | |
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head | |
BD = self._rel_shift(B_ + D_) | |
# [qlen x klen x bsz x n_head] | |
attn_score = AC + BD | |
attn_score.mul_(self.scale) | |
#### compute attention probability | |
if attn_mask is not None and torch.sum(attn_mask).item(): | |
attn_mask = (attn_mask == 1) # Switch to bool | |
if attn_mask.dim() == 2: | |
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) | |
elif attn_mask.dim() == 3: | |
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) | |
# [qlen x klen x bsz x n_head] | |
attn_prob = F.softmax(attn_score, dim=1) | |
attn_prob = self.dropatt(attn_prob) | |
if head_mask is not None: | |
attn_prob = attn_prob * head_mask | |
#### compute attention vector | |
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) | |
# [qlen x bsz x n_head x d_head] | |
attn_vec = attn_vec.contiguous().view( | |
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) | |
##### linear projection | |
attn_out = self.o_net(attn_vec) | |
attn_out = self.drop(attn_out) | |
if self.pre_lnorm: | |
##### residual connection | |
outputs = [w + attn_out] | |
else: | |
##### residual connection + layer normalization | |
outputs = [self.layer_norm(w + attn_out)] | |
if self.output_attentions: | |
outputs.append(attn_prob) | |
return outputs | |
class DecoderLayer(nn.Module): | |
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): | |
super(DecoderLayer, self).__init__() | |
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) | |
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, | |
pre_lnorm=kwargs.get('pre_lnorm')) | |
def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None): | |
attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, | |
mems=mems, head_mask=head_mask) | |
ff_output = self.pos_ff(attn_outputs[0]) | |
outputs = [ff_output] + attn_outputs[1:] | |
return outputs | |
class RelLearnableDecoderLayer(nn.Module): | |
def __init__(self, n_head, d_model, d_head, d_inner, dropout, | |
**kwargs): | |
super(RelLearnableDecoderLayer, self).__init__() | |
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, | |
**kwargs) | |
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, | |
pre_lnorm=kwargs.get('pre_lnorm')) | |
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None): | |
attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, | |
attn_mask=dec_attn_mask, | |
mems=mems, head_mask=head_mask) | |
ff_output = self.pos_ff(attn_outputs[0]) | |
outputs = [ff_output] + attn_outputs[1:] | |
return outputs | |
class RelPartialLearnableDecoderLayer(nn.Module): | |
def __init__(self, n_head, d_model, d_head, d_inner, dropout, | |
**kwargs): | |
super(RelPartialLearnableDecoderLayer, self).__init__() | |
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, | |
d_head, dropout, **kwargs) | |
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, | |
pre_lnorm=kwargs.get('pre_lnorm')) | |
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None): | |
attn_outputs = self.dec_attn(dec_inp, r, | |
attn_mask=dec_attn_mask, | |
mems=mems, head_mask=head_mask) | |
ff_output = self.pos_ff(attn_outputs[0]) | |
outputs = [ff_output] + attn_outputs[1:] | |
return outputs | |
class AdaptiveEmbedding(nn.Module): | |
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, | |
sample_softmax=False): | |
super(AdaptiveEmbedding, self).__init__() | |
self.n_token = n_token | |
self.d_embed = d_embed | |
self.cutoffs = cutoffs + [n_token] | |
self.div_val = div_val | |
self.d_proj = d_proj | |
self.emb_scale = d_proj ** 0.5 | |
self.cutoff_ends = [0] + self.cutoffs | |
self.emb_layers = nn.ModuleList() | |
self.emb_projs = nn.ParameterList() | |
if div_val == 1: | |
self.emb_layers.append( | |
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) | |
) | |
if d_proj != d_embed: | |
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) | |
else: | |
for i in range(len(self.cutoffs)): | |
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] | |
d_emb_i = d_embed // (div_val ** i) | |
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) | |
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) | |
def forward(self, inp): | |
if self.div_val == 1: | |
embed = self.emb_layers[0](inp) | |
if self.d_proj != self.d_embed: | |
embed = F.linear(embed, self.emb_projs[0]) | |
else: | |
param = next(self.parameters()) | |
inp_flat = inp.view(-1) | |
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], | |
dtype=param.dtype, device=param.device) | |
for i in range(len(self.cutoffs)): | |
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] | |
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) | |
indices_i = mask_i.nonzero().squeeze() | |
if indices_i.numel() == 0: | |
continue | |
inp_i = inp_flat.index_select(0, indices_i) - l_idx | |
emb_i = self.emb_layers[i](inp_i) | |
emb_i = F.linear(emb_i, self.emb_projs[i]) | |
emb_flat.index_copy_(0, indices_i, emb_i) | |
embed_shape = inp.size() + (self.d_proj,) | |
embed = emb_flat.view(embed_shape) | |
embed.mul_(self.emb_scale) | |
return embed | |
class TransfoXLPreTrainedModel(PreTrainedModel): | |
""" An abstract class to handle weights initialization and | |
a simple interface for dowloading and loading pretrained models. | |
""" | |
config_class = TransfoXLConfig | |
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP | |
load_tf_weights = load_tf_weights_in_transfo_xl | |
base_model_prefix = "transformer" | |
def _init_weight(self, weight): | |
if self.config.init == 'uniform': | |
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) | |
elif self.config.init == 'normal': | |
nn.init.normal_(weight, 0.0, self.config.init_std) | |
def _init_bias(self, bias): | |
nn.init.constant_(bias, 0.0) | |
def _init_weights(self, m): | |
""" Initialize the weights. | |
""" | |
classname = m.__class__.__name__ | |
if classname.find('Linear') != -1: | |
if hasattr(m, 'weight') and m.weight is not None: | |
self._init_weight(m.weight) | |
if hasattr(m, 'bias') and m.bias is not None: | |
self._init_bias(m.bias) | |
elif classname.find('AdaptiveEmbedding') != -1: | |
if hasattr(m, 'emb_projs'): | |
for i in range(len(m.emb_projs)): | |
if m.emb_projs[i] is not None: | |
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) | |
elif classname.find('Embedding') != -1: | |
if hasattr(m, 'weight'): | |
self._init_weight(m.weight) | |
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: | |
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: | |
self._init_weight(m.cluster_weight) | |
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: | |
self._init_bias(m.cluster_bias) | |
if hasattr(m, 'out_projs'): | |
for i in range(len(m.out_projs)): | |
if m.out_projs[i] is not None: | |
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std) | |
elif classname.find('LayerNorm') != -1: | |
if hasattr(m, 'weight'): | |
nn.init.normal_(m.weight, 1.0, self.config.init_std) | |
if hasattr(m, 'bias') and m.bias is not None: | |
self._init_bias(m.bias) | |
else: | |
if hasattr(m, 'r_emb'): | |
self._init_weight(m.r_emb) | |
if hasattr(m, 'r_w_bias'): | |
self._init_weight(m.r_w_bias) | |
if hasattr(m, 'r_r_bias'): | |
self._init_weight(m.r_r_bias) | |
if hasattr(m, 'r_bias'): | |
self._init_bias(m.r_bias) | |
def set_num_special_tokens(self, num_special_tokens): | |
pass | |
TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in | |
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_ | |
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. | |
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse | |
previously computed hidden-states to attend to longer context (memory). | |
This model also uses adaptive softmax inputs and outputs (tied). | |
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and | |
refer to the PyTorch documentation for all matter related to general usage and behavior. | |
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`: | |
https://arxiv.org/abs/1901.02860 | |
.. _`torch.nn.Module`: | |
https://pytorch.org/docs/stable/nn.html#module | |
Parameters: | |
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model. | |
Initializing with a config file does not load the weights associated with the model, only the configuration. | |
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights. | |
""" | |
TRANSFO_XL_INPUTS_DOCSTRING = r""" | |
Inputs: | |
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: | |
Indices of input sequence tokens in the vocabulary. | |
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on | |
the right or on the left. | |
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`. | |
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and | |
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. | |
**mems**: (`optional`) | |
list of ``torch.FloatTensor`` (one for each layer): | |
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model | |
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context. | |
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: | |
Mask to nullify selected heads of the self-attention modules. | |
Mask values selected in ``[0, 1]``: | |
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. | |
""" | |
class TransfoXLModel(TransfoXLPreTrainedModel): | |
r""" | |
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: | |
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` | |
Sequence of hidden-states at the last layer of the model. | |
**mems**: | |
list of ``torch.FloatTensor`` (one for each layer): | |
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model | |
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. | |
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) | |
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) | |
of shape ``(batch_size, sequence_length, hidden_size)``: | |
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |
**attentions**: (`optional`, returned when ``config.output_attentions=True``) | |
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. | |
Examples:: | |
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') | |
model = TransfoXLModel.from_pretrained('transfo-xl-wt103') | |
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 | |
outputs = model(input_ids) | |
last_hidden_states, mems = outputs[:2] | |
""" | |
def __init__(self, config): | |
super(TransfoXLModel, self).__init__(config) | |
self.output_attentions = config.output_attentions | |
self.output_hidden_states = config.output_hidden_states | |
self.n_token = config.n_token | |
self.d_embed = config.d_embed | |
self.d_model = config.d_model | |
self.n_head = config.n_head | |
self.d_head = config.d_head | |
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, | |
div_val=config.div_val) | |
self.drop = nn.Dropout(config.dropout) | |
self.n_layer = config.n_layer | |
self.tgt_len = config.tgt_len | |
self.mem_len = config.mem_len | |
self.ext_len = config.ext_len | |
self.max_klen = config.tgt_len + config.ext_len + config.mem_len | |
self.attn_type = config.attn_type | |
if not config.untie_r: | |
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) | |
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) | |
self.layers = nn.ModuleList() | |
if config.attn_type == 0: # the default attention | |
for i in range(config.n_layer): | |
self.layers.append( | |
RelPartialLearnableDecoderLayer( | |
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, | |
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, | |
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, | |
r_w_bias=None if config.untie_r else self.r_w_bias, | |
r_r_bias=None if config.untie_r else self.r_r_bias, | |
output_attentions=self.output_attentions) | |
) | |
elif config.attn_type == 1: # learnable embeddings | |
for i in range(config.n_layer): | |
self.layers.append( | |
RelLearnableDecoderLayer( | |
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, | |
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, | |
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, | |
r_w_bias=None if config.untie_r else self.r_w_bias, | |
r_r_bias=None if config.untie_r else self.r_r_bias, | |
output_attentions=self.output_attentions) | |
) | |
elif config.attn_type in [2, 3]: # absolute embeddings | |
for i in range(config.n_layer): | |
self.layers.append( | |
DecoderLayer( | |
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, | |
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, | |
r_w_bias=None if config.untie_r else self.r_w_bias, | |
r_r_bias=None if config.untie_r else self.r_r_bias, | |
output_attentions=self.output_attentions) | |
) | |
self.same_length = config.same_length | |
self.clamp_len = config.clamp_len | |
if self.attn_type == 0: # default attention | |
self.pos_emb = PositionalEmbedding(self.d_model) | |
elif self.attn_type == 1: # learnable | |
self.r_emb = nn.Parameter(torch.FloatTensor( | |
self.n_layer, self.max_klen, self.n_head, self.d_head)) | |
self.r_bias = nn.Parameter(torch.FloatTensor( | |
self.n_layer, self.max_klen, self.n_head)) | |
elif self.attn_type == 2: # absolute standard | |
self.pos_emb = PositionalEmbedding(self.d_model) | |
elif self.attn_type == 3: # absolute deeper SA | |
self.r_emb = nn.Parameter(torch.FloatTensor( | |
self.n_layer, self.max_klen, self.n_head, self.d_head)) | |
self.init_weights() | |
def _resize_token_embeddings(self, new_num_tokens): | |
return self.word_emb | |
def backward_compatible(self): | |
self.sample_softmax = -1 | |
def reset_length(self, tgt_len, ext_len, mem_len): | |
self.tgt_len = tgt_len | |
self.mem_len = mem_len | |
self.ext_len = ext_len | |
def _prune_heads(self, heads): | |
logger.info("Head pruning is not implemented for Transformer-XL model") | |
pass | |
def init_mems(self, data): | |
if self.mem_len > 0: | |
mems = [] | |
param = next(self.parameters()) | |
for i in range(self.n_layer): | |
empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model, | |
dtype=param.dtype, device=param.device) | |
mems.append(empty) | |
return mems | |
else: | |
return None | |
def _update_mems(self, hids, mems, qlen, mlen): | |
# does not deal with None | |
if mems is None: return None | |
# mems is not None | |
assert len(hids) == len(mems), 'len(hids) != len(mems)' | |
# There are `mlen + qlen` steps that can be cached into mems | |
# For the next step, the last `ext_len` of the `qlen` tokens | |
# will be used as the extended context. Hence, we only cache | |
# the tokens from `mlen + qlen - self.ext_len - self.mem_len` | |
# to `mlen + qlen - self.ext_len`. | |
with torch.no_grad(): | |
new_mems = [] | |
end_idx = mlen + max(0, qlen - 0 - self.ext_len) | |
beg_idx = max(0, end_idx - self.mem_len) | |
for i in range(len(hids)): | |
cat = torch.cat([mems[i], hids[i]], dim=0) | |
new_mems.append(cat[beg_idx:end_idx].detach()) | |
return new_mems | |
def _forward(self, dec_inp, mems=None, head_mask=None): | |
qlen, bsz = dec_inp.size() | |
# Prepare head mask if needed | |
# 1.0 in head_mask indicate we keep the head | |
# attention_probs has shape bsz x n_heads x N x N | |
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) | |
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] | |
if head_mask is not None: | |
if head_mask.dim() == 1: | |
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1) | |
elif head_mask.dim() == 2: | |
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) | |
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility | |
else: | |
head_mask = [None] * self.n_layer | |
word_emb = self.word_emb(dec_inp) | |
mlen = mems[0].size(0) if mems is not None else 0 | |
klen = mlen + qlen | |
if self.same_length: | |
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) | |
mask_len = klen - self.mem_len | |
if mask_len > 0: | |
mask_shift_len = qlen - mask_len | |
else: | |
mask_shift_len = qlen | |
dec_attn_mask = (torch.triu(all_ones, 1+mlen) | |
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 | |
else: | |
dec_attn_mask = torch.triu( | |
word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None] | |
hids = [] | |
attentions = [] | |
if self.attn_type == 0: # default | |
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, | |
dtype=word_emb.dtype) | |
if self.clamp_len > 0: | |
pos_seq.clamp_(max=self.clamp_len) | |
pos_emb = self.pos_emb(pos_seq) | |
core_out = self.drop(word_emb) | |
pos_emb = self.drop(pos_emb) | |
for i, layer in enumerate(self.layers): | |
hids.append(core_out) | |
mems_i = None if mems is None else mems[i] | |
layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, | |
mems=mems_i, head_mask=head_mask[i]) | |
core_out = layer_outputs[0] | |
if self.output_attentions: | |
attentions.append(layer_outputs[1]) | |
elif self.attn_type == 1: # learnable | |
core_out = self.drop(word_emb) | |
for i, layer in enumerate(self.layers): | |
hids.append(core_out) | |
if self.clamp_len > 0: | |
r_emb = self.r_emb[i][-self.clamp_len :] | |
r_bias = self.r_bias[i][-self.clamp_len :] | |
else: | |
r_emb, r_bias = self.r_emb[i], self.r_bias[i] | |
mems_i = None if mems is None else mems[i] | |
layer_outputs = layer(core_out, r_emb, self.r_w_bias[i], | |
r_bias, dec_attn_mask=dec_attn_mask, | |
mems=mems_i, head_mask=head_mask[i]) | |
core_out = layer_outputs[0] | |
if self.output_attentions: | |
attentions.append(layer_outputs[1]) | |
elif self.attn_type == 2: # absolute | |
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, | |
dtype=word_emb.dtype) | |
if self.clamp_len > 0: | |
pos_seq.clamp_(max=self.clamp_len) | |
pos_emb = self.pos_emb(pos_seq) | |
core_out = self.drop(word_emb + pos_emb[-qlen:]) | |
for i, layer in enumerate(self.layers): | |
hids.append(core_out) | |
mems_i = None if mems is None else mems[i] | |
if mems_i is not None and i == 0: | |
mems_i += pos_emb[:mlen] | |
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask, | |
mems=mems_i, head_mask=head_mask[i]) | |
core_out = layer_outputs[0] | |
if self.output_attentions: | |
attentions.append(layer_outputs[1]) | |
elif self.attn_type == 3: | |
core_out = self.drop(word_emb) | |
for i, layer in enumerate(self.layers): | |
hids.append(core_out) | |
mems_i = None if mems is None else mems[i] | |
if mems_i is not None and mlen > 0: | |
cur_emb = self.r_emb[i][:-qlen] | |
cur_size = cur_emb.size(0) | |
if cur_size < mlen: | |
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) | |
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) | |
else: | |
cur_emb = cur_emb[-mlen:] | |
mems_i += cur_emb.view(mlen, 1, -1) | |
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) | |
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask, | |
mems=mems_i, head_mask=head_mask[i]) | |
core_out = layer_outputs[0] | |
if self.output_attentions: | |
attentions.append(layer_outputs[1]) | |
core_out = self.drop(core_out) | |
new_mems = self._update_mems(hids, mems, mlen, qlen) | |
# We transpose back here to shape [bsz, len, hidden_dim] | |
outputs = [core_out.transpose(0, 1).contiguous(), new_mems] | |
if self.output_hidden_states: | |
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim] | |
hids.append(core_out) | |
hids = list(t.transpose(0, 1).contiguous() for t in hids) | |
outputs.append(hids) | |
if self.output_attentions: | |
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] | |
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions) | |
outputs.append(attentions) | |
return outputs # last hidden state, new_mems, (all hidden states), (all attentions) | |
def forward(self, input_ids, mems=None, head_mask=None): | |
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library | |
# so we transpose here from shape [bsz, len] to shape [len, bsz] | |
input_ids = input_ids.transpose(0, 1).contiguous() | |
if mems is None: | |
mems = self.init_mems(input_ids) | |
outputs = self._forward(input_ids, mems=mems, head_mask=head_mask) | |
return outputs # last hidden state, new_mems, (all hidden states), (all attentions) | |
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): | |
r""" | |
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: | |
Labels for language modeling. | |
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` | |
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` | |
All labels set to ``-1`` are ignored (masked), the loss is only | |
computed for labels in ``[0, ..., config.vocab_size]`` | |
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: | |
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: | |
Language modeling loss. | |
**prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` | |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
We don't output them when the loss is computed to speedup adaptive softmax decoding. | |
**mems**: | |
list of ``torch.FloatTensor`` (one for each layer): | |
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model | |
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. | |
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) | |
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) | |
of shape ``(batch_size, sequence_length, hidden_size)``: | |
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |
**attentions**: (`optional`, returned when ``config.output_attentions=True``) | |
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. | |
Examples:: | |
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') | |
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103') | |
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 | |
outputs = model(input_ids) | |
prediction_scores, mems = outputs[:2] | |
""" | |
def __init__(self, config): | |
super(TransfoXLLMHeadModel, self).__init__(config) | |
self.transformer = TransfoXLModel(config) | |
self.sample_softmax = config.sample_softmax | |
# use sampled softmax | |
if config.sample_softmax > 0: | |
self.out_layer = nn.Linear(config.d_model, config.n_token) | |
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) | |
# use adaptive softmax (including standard softmax) | |
else: | |
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, | |
config.cutoffs, div_val=config.div_val) | |
self.init_weights() | |
self.tie_weights() | |
def tie_weights(self): | |
""" | |
Run this to be sure output and input (adaptive) softmax weights are tied | |
""" | |
# sampled softmax | |
if self.sample_softmax > 0: | |
if self.config.tie_weight: | |
self.out_layer.weight = self.transformer.word_emb.weight | |
# adaptive softmax (including standard softmax) | |
else: | |
if self.config.tie_weight: | |
for i in range(len(self.crit.out_layers)): | |
self._tie_or_clone_weights(self.crit.out_layers[i], | |
self.transformer.word_emb.emb_layers[i]) | |
if self.config.tie_projs: | |
for i, tie_proj in enumerate(self.config.tie_projs): | |
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: | |
if self.config.torchscript: | |
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone()) | |
else: | |
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] | |
elif tie_proj and self.config.div_val != 1: | |
if self.config.torchscript: | |
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone()) | |
else: | |
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] | |
def reset_length(self, tgt_len, ext_len, mem_len): | |
self.transformer.reset_length(tgt_len, ext_len, mem_len) | |
def init_mems(self, data): | |
return self.transformer.init_mems(data) | |
def forward(self, input_ids, mems=None, head_mask=None, labels=None): | |
bsz = input_ids.size(0) | |
tgt_len = input_ids.size(1) | |
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask) | |
last_hidden = transformer_outputs[0] | |
pred_hid = last_hidden[:, -tgt_len:] | |
outputs = transformer_outputs[1:] | |
if self.sample_softmax > 0 and self.training: | |
assert self.config.tie_weight | |
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler) | |
softmax_output = -F.log_softmax(logit, -1)[:, :, 0] | |
outputs = [softmax_output] + outputs | |
if labels is not None: | |
# TODO: This is not implemented | |
raise NotImplementedError | |
else: | |
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels) | |
if labels is None: | |
softmax_output = softmax_output.view(bsz, tgt_len, -1) | |
outputs = [softmax_output] + outputs | |
else: | |
softmax_output = softmax_output.view(bsz, tgt_len) | |
outputs = [softmax_output, None] + outputs | |
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions) | |