XML_RVMR / baselines /excl /model_components.py
Liangrj5
init
ebf5d87
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class DepthwiseSeparableConv(nn.Module):
"""
Depth-wise separable convolution uses less parameters to generate output by convolution.
:Examples:
>>> m = DepthwiseSeparableConv(300, 200, 5, dim=1)
>>> input_tensor = torch.randn(32, 300, 20)
>>> output = m(input_tensor)
"""
def __init__(self, in_ch, out_ch, k, dim=1, relu=True):
"""
:param in_ch: input hidden dimension size
:param out_ch: output hidden dimension size
:param k: kernel size
:param dim: default 1. 1D conv or 2D conv
"""
super(DepthwiseSeparableConv, self).__init__()
self.relu = relu
if dim == 1:
self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch,
kernel_size=k, groups=in_ch, padding=k//2)
self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch,
kernel_size=1, padding=0)
elif dim == 2:
self.depthwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch,
kernel_size=k, groups=in_ch, padding=k//2)
self.pointwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch,
kernel_size=1, padding=0)
else:
raise Exception("Incorrect dimension!")
def forward(self, x):
"""
:Input: (N, L_in, D)
:Output: (N, L_out, D)
"""
x = x.transpose(1, 2)
if self.relu:
out = F.relu(self.pointwise_conv(self.depthwise_conv(x)), inplace=True)
else:
out = self.pointwise_conv(self.depthwise_conv(x))
return out.transpose(1, 2) # (N, L, D)
class ConvEncoder(nn.Module):
def __init__(self, kernel_size=7, n_filters=128, dropout=0.1):
super(ConvEncoder, self).__init__()
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(n_filters)
self.conv = DepthwiseSeparableConv(in_ch=n_filters, out_ch=n_filters, k=kernel_size, relu=True)
def forward(self, x, mask):
"""
:param x: (N, L, D)
:param mask: (N, L), is not used.
:return: (N, L, D)
"""
return self.layer_norm(self.dropout(self.conv(x)) + x) # (N, L, D)
class TrainablePositionalEncoding(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
super(TrainablePositionalEncoding, self).__init__()
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
self.LayerNorm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, input_feat):
"""
Args:
input_feat: (N, L, D)
"""
bsz, seq_length = input_feat.shape[:2]
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
position_embeddings = self.position_embeddings(position_ids)
embeddings = self.LayerNorm(input_feat + position_embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class PositionEncoding(nn.Module):
"""
Add positional information to input tensor.
:Examples:
>>> model = PositionEncoding(n_filters=6, max_len=10)
>>> test_input1 = torch.zeros(3, 10, 6)
>>> output1 = model(test_input1)
>>> output1.size()
>>> test_input2 = torch.zeros(5, 3, 9, 6)
>>> output2 = model(test_input2)
>>> output2.size()
"""
def __init__(self, n_filters=128, max_len=500, pe_type="cosine"):
"""
:param n_filters: same with input hidden size
:param max_len: maximum sequence length
:param pe_type: cosine or linear or None
"""
super(PositionEncoding, self).__init__()
self.pe_type = pe_type
if pe_type != "none":
position = torch.arange(0, max_len).float().unsqueeze(1)
if pe_type == "cosine":
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, n_filters) # (L, D)
div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
elif pe_type == "linear":
pe = position / max_len
else:
raise ValueError
self.register_buffer("pe", pe) # buffer is a tensor, not a variable, (L, D)
def forward(self, x):
"""
:Input: (*, L, D)
:Output: (*, L, D) the same size as input
"""
if self.pe_type != "none":
pe = self.pe.data[:x.size(-2), :] # (#x.size(-2), n_filters)
extra_dim = len(x.size()) - 2
for _ in range(extra_dim):
pe = pe.unsqueeze(0)
x = x + pe
return x
class LinearLayer(nn.Module):
"""linear layer configurable with layer normalization, dropout, ReLU."""
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
super(LinearLayer, self).__init__()
self.relu = relu
self.layer_norm = layer_norm
if layer_norm:
self.LayerNorm = nn.LayerNorm(in_hsz)
layers = [
nn.Dropout(dropout),
nn.Linear(in_hsz, out_hsz)
]
self.net = nn.Sequential(*layers)
def forward(self, x):
"""(N, L, D)"""
if self.layer_norm:
x = self.LayerNorm(x)
x = self.net(x)
if self.relu:
x = F.relu(x, inplace=True)
return x # (N, L, D)
bert_config = dict(
hidden_size=768,
intermediate_size=768,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
num_attention_heads=4,
)
class BertLayer(nn.Module):
def __init__(self, config, use_self_attention=True):
super(BertLayer, self).__init__()
self.use_self_attention = use_self_attention
if use_self_attention:
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
"""
Args:
hidden_states: (N, L, D)
attention_mask: (N, L) with 1 indicate valid, 0 indicates invalid
Returns:
"""
if self.use_self_attention:
attention_output = self.attention(hidden_states, attention_mask)
else:
attention_output = hidden_states
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
"""
Args:
input_tensor: (N, L, D)
attention_mask: (N, L)
Returns:
"""
self_output = self.self(input_tensor, input_tensor, input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.ReLU(True))
def forward(self, hidden_states):
return self.dense(hidden_states)
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (N, L, nh, dh)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) # (N, nh, L, dh)
def forward(self, query_states, key_states, value_states, attention_mask):
"""
Args:
query_states: (N, Lq, D)
key_states: (N, L, D)
value_states: (N, L, D)
attention_mask: (N, Lq, L)
Returns:
"""
# only need to mask the dimension where the softmax (last dim) is applied, as another dim (second last)
# will be ignored in future computation anyway
attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000. # (N, 1, Lq, L)
mixed_query_layer = self.query(query_states)
mixed_key_layer = self.key(key_states)
mixed_value_layer = self.value(value_states)
query_layer = self.transpose_for_scores(mixed_query_layer) # (N, nh, Lq, dh)
key_layer = self.transpose_for_scores(mixed_key_layer) # (N, nh, L, dh)
value_layer = self.transpose_for_scores(mixed_value_layer) # (N, nh, L, dh)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (N, nh, Lq, L)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states