# coding=utf-8 # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py """ from __future__ import absolute_import, division, print_function, unicode_literals import os import json import math import logging import collections import sys from io import open import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary from .configuration_transfo_xl import TransfoXLConfig from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin", } def build_tf_to_pytorch_map(model, config): """ A map of modules from TF to PyTorch. This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible. """ tf_to_pt_map = {} if hasattr(model, 'transformer'): # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax tf_to_pt_map.update({ "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight, "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias}) for i, (out_l, proj_l, tie_proj) in enumerate(zip( model.crit.out_layers, model.crit.out_projs, config.tie_projs)): layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i if config.tie_weight: tf_to_pt_map.update({ layer_str + 'b': out_l.bias}) else: raise NotImplementedError # I don't think this is implemented in the TF code tf_to_pt_map.update({ layer_str + 'lookup_table': out_l.weight, layer_str + 'b': out_l.bias}) if not tie_proj: tf_to_pt_map.update({ layer_str + 'proj': proj_l }) # Now load the rest of the transformer model = model.transformer # Embeddings for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)): layer_str = "transformer/adaptive_embed/cutoff_%d/" % i tf_to_pt_map.update({ layer_str + 'lookup_table': embed_l.weight, layer_str + 'proj_W': proj_l }) # Transformer blocks for i, b in enumerate(model.layers): layer_str = "transformer/layer_%d/" % i tf_to_pt_map.update({ layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight, layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias, layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight, layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight, layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight, layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight, layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias, layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight, layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias, layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight, layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias, }) # Relative positioning biases if config.untie_r: r_r_list = [] r_w_list = [] for b in model.layers: r_r_list.append(b.dec_attn.r_r_bias) r_w_list.append(b.dec_attn.r_w_bias) else: r_r_list = [model.r_r_bias] r_w_list = [model.r_w_bias] tf_to_pt_map.update({ 'transformer/r_r_bias': r_r_list, 'transformer/r_w_bias': r_w_list}) return tf_to_pt_map def load_tf_weights_in_transfo_xl(model, config, tf_path): """ Load tf checkpoints in a pytorch model """ try: import numpy as np import tensorflow as tf except ImportError: logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " "https://www.tensorflow.org/install/ for installation instructions.") raise # Build TF to PyTorch weights loading map tf_to_pt_map = build_tf_to_pytorch_map(model, config) # Load weights from TF model init_vars = tf.train.list_variables(tf_path) tf_weights = {} for name, shape in init_vars: logger.info("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(tf_path, name) tf_weights[name] = array for name, pointer in tf_to_pt_map.items(): assert name in tf_weights array = tf_weights[name] # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if 'kernel' in name or 'proj' in name: array = np.transpose(array) if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1: # Here we will split the TF weigths assert len(pointer) == array.shape[0] for i, p_i in enumerate(pointer): arr_i = array[i, ...] try: assert p_i.shape == arr_i.shape except AssertionError as e: e.args += (p_i.shape, arr_i.shape) raise logger.info("Initialize PyTorch weight {} for layer {}".format(name, i)) p_i.data = torch.from_numpy(arr_i) else: try: assert pointer.shape == array.shape except AssertionError as e: e.args += (pointer.shape, array.shape) raise logger.info("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) tf_weights.pop(name, None) tf_weights.pop(name + '/Adam', None) tf_weights.pop(name + '/Adam_1', None) logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys()))) return model class PositionalEmbedding(nn.Module): def __init__(self, demb): super(PositionalEmbedding, self).__init__() self.demb = demb inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) self.register_buffer('inv_freq', inv_freq) def forward(self, pos_seq, bsz=None): sinusoid_inp = torch.ger(pos_seq, self.inv_freq) pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) if bsz is not None: return pos_emb[:,None,:].expand(-1, bsz, -1) else: return pos_emb[:,None,:] class PositionwiseFF(nn.Module): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): super(PositionwiseFF, self).__init__() self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.CoreNet = nn.Sequential( nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(d_inner, d_model), nn.Dropout(dropout), ) self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm def forward(self, inp): if self.pre_lnorm: ##### layer normalization + positionwise feed-forward core_out = self.CoreNet(self.layer_norm(inp)) ##### residual connection output = core_out + inp else: ##### positionwise feed-forward core_out = self.CoreNet(inp) ##### residual connection + layer normalization output = self.layer_norm(inp + core_out) return output class MultiHeadAttn(nn.Module): def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False): super(MultiHeadAttn, self).__init__() self.output_attentions = output_attentions self.n_head = n_head self.d_model = d_model self.d_head = d_head self.dropout = dropout self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) self.drop = nn.Dropout(dropout) self.dropatt = nn.Dropout(dropatt) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) self.layer_norm = nn.LayerNorm(d_model) self.scale = 1 / (d_head ** 0.5) self.pre_lnorm = pre_lnorm if r_r_bias is None or r_w_bias is None: # Biases are not shared self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) else: self.r_r_bias = r_r_bias self.r_w_bias = r_w_bias def forward(self, h, attn_mask=None, mems=None, head_mask=None): ##### multihead attention # [hlen x bsz x n_head x d_head] if mems is not None: c = torch.cat([mems, h], 0) else: c = h if self.pre_lnorm: ##### layer normalization c = self.layer_norm(c) head_q = self.q_net(h) head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head) head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) # [qlen x klen x bsz x n_head] attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) attn_score.mul_(self.scale) if attn_mask is not None and torch.sum(attn_mask).item(): attn_mask = (attn_mask == 1) # Switch to bool if attn_mask.dim() == 2: attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) elif attn_mask.dim() == 3: attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) # Mask heads if we want to if head_mask is not None: attn_prob = attn_prob * head_mask # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) attn_vec = attn_vec.contiguous().view( attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection outputs = [h + attn_out] else: ##### residual connection + layer normalization outputs = [self.layer_norm(h + attn_out)] if self.output_attentions: outputs.append(attn_prob) return outputs class RelMultiHeadAttn(nn.Module): def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False): super(RelMultiHeadAttn, self).__init__() self.output_attentions = output_attentions self.n_head = n_head self.d_model = d_model self.d_head = d_head self.dropout = dropout self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) self.drop = nn.Dropout(dropout) self.dropatt = nn.Dropout(dropatt) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) self.layer_norm = nn.LayerNorm(d_model) self.scale = 1 / (d_head ** 0.5) self.pre_lnorm = pre_lnorm if r_r_bias is None or r_w_bias is None: # Biases are not shared self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) else: self.r_r_bias = r_r_bias self.r_w_bias = r_w_bias def _parallelogram_mask(self, h, w, left=False): mask = torch.ones((h, w)).byte() m = min(h, w) mask[:m,:m] = torch.triu(mask[:m,:m]) mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) if left: return mask else: return mask.flip(0) def _shift(self, x, qlen, klen, mask, left=False): if qlen > 1: zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), device=x.device, dtype=x.dtype) else: zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) if left: mask = mask.flip(1) x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) else: x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) x = x_padded.masked_select(mask[:,:,None,None]) \ .view(qlen, klen, x.size(2), x.size(3)) return x def _rel_shift(self, x, zero_triu=False): zero_pad_shape = (x.size(0), 1) + x.size()[2:] zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype) x_padded = torch.cat([zero_pad, x], dim=1) x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:] x_padded = x_padded.view(*x_padded_shape) x = x_padded[1:].view_as(x) if zero_triu: ones = torch.ones((x.size(0), x.size(1))) x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] return x def forward(self, w, r, attn_mask=None, mems=None): raise NotImplementedError class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): def __init__(self, *args, **kwargs): super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) def forward(self, w, r, attn_mask=None, mems=None, head_mask=None): qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) if mems is not None: cat = torch.cat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) else: w_heads = self.qkv_net(cat) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q = w_head_q[-qlen:] else: if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(w)) else: w_heads = self.qkv_net(w) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) klen = w_head_k.size(0) w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head #### compute attention score rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head rr_head_q = w_head_q + self.r_r_bias BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head BD = self._rel_shift(BD) # [qlen x klen x bsz x n_head] attn_score = AC + BD attn_score.mul_(self.scale) #### compute attention probability if attn_mask is not None and torch.sum(attn_mask).item(): attn_mask = (attn_mask == 1) # Switch to bool if attn_mask.dim() == 2: if next(self.parameters()).dtype == torch.float16: attn_score = attn_score.float().masked_fill( attn_mask[None,:,:,None], -65000).type_as(attn_score) else: attn_score = attn_score.float().masked_fill( attn_mask[None,:,:,None], -1e30).type_as(attn_score) elif attn_mask.dim() == 3: if next(self.parameters()).dtype == torch.float16: attn_score = attn_score.float().masked_fill( attn_mask[:,:,:,None], -65000).type_as(attn_score) else: attn_score = attn_score.float().masked_fill( attn_mask[:,:,:,None], -1e30).type_as(attn_score) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) # Mask heads if we want to if head_mask is not None: attn_prob = attn_prob * head_mask #### compute attention vector attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) # [qlen x bsz x n_head x d_head] attn_vec = attn_vec.contiguous().view( attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection outputs = [w + attn_out] else: ##### residual connection + layer normalization outputs = [self.layer_norm(w + attn_out)] if self.output_attentions: outputs.append(attn_prob) return outputs class RelLearnableMultiHeadAttn(RelMultiHeadAttn): def __init__(self, *args, **kwargs): super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None): # r_emb: [klen, n_head, d_head], used for term B # r_w_bias: [n_head, d_head], used for term C # r_bias: [klen, n_head], used for term D qlen, bsz = w.size(0), w.size(1) if mems is not None: cat = torch.cat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) else: w_heads = self.qkv_net(cat) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q = w_head_q[-qlen:] else: if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(w)) else: w_heads = self.qkv_net(w) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) klen = w_head_k.size(0) w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) if klen > r_emb.size(0): r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1) r_emb = torch.cat([r_emb_pad, r_emb], 0) r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1) r_bias = torch.cat([r_bias_pad, r_bias], 0) else: r_emb = r_emb[-klen:] r_bias = r_bias[-klen:] #### compute attention score rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head BD = self._rel_shift(B_ + D_) # [qlen x klen x bsz x n_head] attn_score = AC + BD attn_score.mul_(self.scale) #### compute attention probability if attn_mask is not None and torch.sum(attn_mask).item(): attn_mask = (attn_mask == 1) # Switch to bool if attn_mask.dim() == 2: attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) elif attn_mask.dim() == 3: attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) if head_mask is not None: attn_prob = attn_prob * head_mask #### compute attention vector attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) # [qlen x bsz x n_head x d_head] attn_vec = attn_vec.contiguous().view( attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection outputs = [w + attn_out] else: ##### residual connection + layer normalization outputs = [self.layer_norm(w + attn_out)] if self.output_attentions: outputs.append(attn_prob) return outputs class DecoderLayer(nn.Module): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): super(DecoderLayer, self).__init__() self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None): attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask) ff_output = self.pos_ff(attn_outputs[0]) outputs = [ff_output] + attn_outputs[1:] return outputs class RelLearnableDecoderLayer(nn.Module): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): super(RelLearnableDecoderLayer, self).__init__() self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None): attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask) ff_output = self.pos_ff(attn_outputs[0]) outputs = [ff_output] + attn_outputs[1:] return outputs class RelPartialLearnableDecoderLayer(nn.Module): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): super(RelPartialLearnableDecoderLayer, self).__init__() self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None): attn_outputs = self.dec_attn(dec_inp, r, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask) ff_output = self.pos_ff(attn_outputs[0]) outputs = [ff_output] + attn_outputs[1:] return outputs class AdaptiveEmbedding(nn.Module): def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False): super(AdaptiveEmbedding, self).__init__() self.n_token = n_token self.d_embed = d_embed self.cutoffs = cutoffs + [n_token] self.div_val = div_val self.d_proj = d_proj self.emb_scale = d_proj ** 0.5 self.cutoff_ends = [0] + self.cutoffs self.emb_layers = nn.ModuleList() self.emb_projs = nn.ParameterList() if div_val == 1: self.emb_layers.append( nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) ) if d_proj != d_embed: self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) else: for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] d_emb_i = d_embed // (div_val ** i) self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) def forward(self, inp): if self.div_val == 1: embed = self.emb_layers[0](inp) if self.d_proj != self.d_embed: embed = F.linear(embed, self.emb_projs[0]) else: param = next(self.parameters()) inp_flat = inp.view(-1) emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) indices_i = mask_i.nonzero().squeeze() if indices_i.numel() == 0: continue inp_i = inp_flat.index_select(0, indices_i) - l_idx emb_i = self.emb_layers[i](inp_i) emb_i = F.linear(emb_i, self.emb_projs[i]) emb_flat.index_copy_(0, indices_i, emb_i) embed_shape = inp.size() + (self.d_proj,) embed = emb_flat.view(embed_shape) embed.mul_(self.emb_scale) return embed class TransfoXLPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ config_class = TransfoXLConfig pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_transfo_xl base_model_prefix = "transformer" def _init_weight(self, weight): if self.config.init == 'uniform': nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) elif self.config.init == 'normal': nn.init.normal_(weight, 0.0, self.config.init_std) def _init_bias(self, bias): nn.init.constant_(bias, 0.0) def _init_weights(self, m): """ Initialize the weights. """ classname = m.__class__.__name__ if classname.find('Linear') != -1: if hasattr(m, 'weight') and m.weight is not None: self._init_weight(m.weight) if hasattr(m, 'bias') and m.bias is not None: self._init_bias(m.bias) elif classname.find('AdaptiveEmbedding') != -1: if hasattr(m, 'emb_projs'): for i in range(len(m.emb_projs)): if m.emb_projs[i] is not None: nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) elif classname.find('Embedding') != -1: if hasattr(m, 'weight'): self._init_weight(m.weight) elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: self._init_weight(m.cluster_weight) if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: self._init_bias(m.cluster_bias) if hasattr(m, 'out_projs'): for i in range(len(m.out_projs)): if m.out_projs[i] is not None: nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std) elif classname.find('LayerNorm') != -1: if hasattr(m, 'weight'): nn.init.normal_(m.weight, 1.0, self.config.init_std) if hasattr(m, 'bias') and m.bias is not None: self._init_bias(m.bias) else: if hasattr(m, 'r_emb'): self._init_weight(m.r_emb) if hasattr(m, 'r_w_bias'): self._init_weight(m.r_w_bias) if hasattr(m, 'r_r_bias'): self._init_weight(m.r_r_bias) if hasattr(m, 'r_bias'): self._init_bias(m.r_bias) def set_num_special_tokens(self, num_special_tokens): pass TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse previously computed hidden-states to attend to longer context (memory). This model also uses adaptive softmax inputs and outputs (tied). This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. .. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`: https://arxiv.org/abs/1901.02860 .. _`torch.nn.Module`: https://pytorch.org/docs/stable/nn.html#module Parameters: config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ TRANSFO_XL_INPUTS_DOCSTRING = r""" Inputs: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Indices of input sequence tokens in the vocabulary. Transformer-XL is a model with relative position embeddings so you can either pad the inputs on the right or on the left. Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. **mems**: (`optional`) list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context. **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. """ @add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) class TransfoXLModel(TransfoXLPreTrainedModel): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` Sequence of hidden-states at the last layer of the model. **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') model = TransfoXLModel.from_pretrained('transfo-xl-wt103') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 outputs = model(input_ids) last_hidden_states, mems = outputs[:2] """ def __init__(self, config): super(TransfoXLModel, self).__init__(config) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.n_token = config.n_token self.d_embed = config.d_embed self.d_model = config.d_model self.n_head = config.n_head self.d_head = config.d_head self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) self.drop = nn.Dropout(config.dropout) self.n_layer = config.n_layer self.tgt_len = config.tgt_len self.mem_len = config.mem_len self.ext_len = config.ext_len self.max_klen = config.tgt_len + config.ext_len + config.mem_len self.attn_type = config.attn_type if not config.untie_r: self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) self.layers = nn.ModuleList() if config.attn_type == 0: # the default attention for i in range(config.n_layer): self.layers.append( RelPartialLearnableDecoderLayer( config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, r_w_bias=None if config.untie_r else self.r_w_bias, r_r_bias=None if config.untie_r else self.r_r_bias, output_attentions=self.output_attentions) ) elif config.attn_type == 1: # learnable embeddings for i in range(config.n_layer): self.layers.append( RelLearnableDecoderLayer( config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, r_w_bias=None if config.untie_r else self.r_w_bias, r_r_bias=None if config.untie_r else self.r_r_bias, output_attentions=self.output_attentions) ) elif config.attn_type in [2, 3]: # absolute embeddings for i in range(config.n_layer): self.layers.append( DecoderLayer( config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, r_w_bias=None if config.untie_r else self.r_w_bias, r_r_bias=None if config.untie_r else self.r_r_bias, output_attentions=self.output_attentions) ) self.same_length = config.same_length self.clamp_len = config.clamp_len if self.attn_type == 0: # default attention self.pos_emb = PositionalEmbedding(self.d_model) elif self.attn_type == 1: # learnable self.r_emb = nn.Parameter(torch.FloatTensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) self.r_bias = nn.Parameter(torch.FloatTensor( self.n_layer, self.max_klen, self.n_head)) elif self.attn_type == 2: # absolute standard self.pos_emb = PositionalEmbedding(self.d_model) elif self.attn_type == 3: # absolute deeper SA self.r_emb = nn.Parameter(torch.FloatTensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) self.init_weights() def _resize_token_embeddings(self, new_num_tokens): return self.word_emb def backward_compatible(self): self.sample_softmax = -1 def reset_length(self, tgt_len, ext_len, mem_len): self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len def _prune_heads(self, heads): logger.info("Head pruning is not implemented for Transformer-XL model") pass def init_mems(self, data): if self.mem_len > 0: mems = [] param = next(self.parameters()) for i in range(self.n_layer): empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model, dtype=param.dtype, device=param.device) mems.append(empty) return mems else: return None def _update_mems(self, hids, mems, qlen, mlen): # does not deal with None if mems is None: return None # mems is not None assert len(hids) == len(mems), 'len(hids) != len(mems)' # There are `mlen + qlen` steps that can be cached into mems # For the next step, the last `ext_len` of the `qlen` tokens # will be used as the extended context. Hence, we only cache # the tokens from `mlen + qlen - self.ext_len - self.mem_len` # to `mlen + qlen - self.ext_len`. with torch.no_grad(): new_mems = [] end_idx = mlen + max(0, qlen - 0 - self.ext_len) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): cat = torch.cat([mems[i], hids[i]], dim=0) new_mems.append(cat[beg_idx:end_idx].detach()) return new_mems def _forward(self, dec_inp, mems=None, head_mask=None): qlen, bsz = dec_inp.size() # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] if head_mask is not None: if head_mask.dim() == 1: head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1) elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility else: head_mask = [None] * self.n_layer word_emb = self.word_emb(dec_inp) mlen = mems[0].size(0) if mems is not None else 0 klen = mlen + qlen if self.same_length: all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len else: mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1+mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 else: dec_attn_mask = torch.triu( word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None] hids = [] attentions = [] if self.attn_type == 0: # default pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb) pos_emb = self.drop(pos_emb) for i, layer in enumerate(self.layers): hids.append(core_out) mems_i = None if mems is None else mems[i] layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i, head_mask=head_mask[i]) core_out = layer_outputs[0] if self.output_attentions: attentions.append(layer_outputs[1]) elif self.attn_type == 1: # learnable core_out = self.drop(word_emb) for i, layer in enumerate(self.layers): hids.append(core_out) if self.clamp_len > 0: r_emb = self.r_emb[i][-self.clamp_len :] r_bias = self.r_bias[i][-self.clamp_len :] else: r_emb, r_bias = self.r_emb[i], self.r_bias[i] mems_i = None if mems is None else mems[i] layer_outputs = layer(core_out, r_emb, self.r_w_bias[i], r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i, head_mask=head_mask[i]) core_out = layer_outputs[0] if self.output_attentions: attentions.append(layer_outputs[1]) elif self.attn_type == 2: # absolute pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb + pos_emb[-qlen:]) for i, layer in enumerate(self.layers): hids.append(core_out) mems_i = None if mems is None else mems[i] if mems_i is not None and i == 0: mems_i += pos_emb[:mlen] layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i, head_mask=head_mask[i]) core_out = layer_outputs[0] if self.output_attentions: attentions.append(layer_outputs[1]) elif self.attn_type == 3: core_out = self.drop(word_emb) for i, layer in enumerate(self.layers): hids.append(core_out) mems_i = None if mems is None else mems[i] if mems_i is not None and mlen > 0: cur_emb = self.r_emb[i][:-qlen] cur_size = cur_emb.size(0) if cur_size < mlen: cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) else: cur_emb = cur_emb[-mlen:] mems_i += cur_emb.view(mlen, 1, -1) core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i, head_mask=head_mask[i]) core_out = layer_outputs[0] if self.output_attentions: attentions.append(layer_outputs[1]) core_out = self.drop(core_out) new_mems = self._update_mems(hids, mems, mlen, qlen) # We transpose back here to shape [bsz, len, hidden_dim] outputs = [core_out.transpose(0, 1).contiguous(), new_mems] if self.output_hidden_states: # Add last layer and transpose to library standard shape [bsz, len, hidden_dim] hids.append(core_out) hids = list(t.transpose(0, 1).contiguous() for t in hids) outputs.append(hids) if self.output_attentions: # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions) outputs.append(attentions) return outputs # last hidden state, new_mems, (all hidden states), (all attentions) def forward(self, input_ids, mems=None, head_mask=None): # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # so we transpose here from shape [bsz, len] to shape [len, bsz] input_ids = input_ids.transpose(0, 1).contiguous() if mems is None: mems = self.init_mems(input_ids) outputs = self._forward(input_ids, mems=mems, head_mask=head_mask) return outputs # last hidden state, new_mems, (all hidden states), (all attentions) @add_start_docstrings("""The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive input embeddings)""", TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): r""" **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to ``-1`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Language modeling loss. **prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). We don't output them when the loss is computed to speedup adaptive softmax decoding. **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 outputs = model(input_ids) prediction_scores, mems = outputs[:2] """ def __init__(self, config): super(TransfoXLLMHeadModel, self).__init__(config) self.transformer = TransfoXLModel(config) self.sample_softmax = config.sample_softmax # use sampled softmax if config.sample_softmax > 0: self.out_layer = nn.Linear(config.d_model, config.n_token) self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) self.init_weights() self.tie_weights() def tie_weights(self): """ Run this to be sure output and input (adaptive) softmax weights are tied """ # sampled softmax if self.sample_softmax > 0: if self.config.tie_weight: self.out_layer.weight = self.transformer.word_emb.weight # adaptive softmax (including standard softmax) else: if self.config.tie_weight: for i in range(len(self.crit.out_layers)): self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i]) if self.config.tie_projs: for i, tie_proj in enumerate(self.config.tie_projs): if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: if self.config.torchscript: self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone()) else: self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] elif tie_proj and self.config.div_val != 1: if self.config.torchscript: self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone()) else: self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] def reset_length(self, tgt_len, ext_len, mem_len): self.transformer.reset_length(tgt_len, ext_len, mem_len) def init_mems(self, data): return self.transformer.init_mems(data) def forward(self, input_ids, mems=None, head_mask=None, labels=None): bsz = input_ids.size(0) tgt_len = input_ids.size(1) transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask) last_hidden = transformer_outputs[0] pred_hid = last_hidden[:, -tgt_len:] outputs = transformer_outputs[1:] if self.sample_softmax > 0 and self.training: assert self.config.tie_weight logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler) softmax_output = -F.log_softmax(logit, -1)[:, :, 0] outputs = [softmax_output] + outputs if labels is not None: # TODO: This is not implemented raise NotImplementedError else: softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels) if labels is None: softmax_output = softmax_output.view(bsz, tgt_len, -1) outputs = [softmax_output] + outputs else: softmax_output = softmax_output.view(bsz, tgt_len) outputs = [softmax_output, None] + outputs return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)