Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import TransformerEncoder, TransformerEncoderLayer | |
| class Transformer(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| dropout = self.cfg.dropout | |
| nhead = self.cfg.n_heads | |
| nlayers = self.cfg.n_layers | |
| input_dim = self.cfg.input_dim | |
| output_dim = self.cfg.output_dim | |
| d_model = input_dim | |
| self.pos_encoder = PositionalEncoding(d_model, dropout) | |
| encoder_layers = TransformerEncoderLayer( | |
| d_model, nhead, dropout=dropout, batch_first=True | |
| ) | |
| self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) | |
| self.output_mlp = nn.Linear(d_model, output_dim) | |
| def forward(self, x, mask=None): | |
| """ | |
| Args: | |
| x: (N, seq_len, input_dim) | |
| Returns: | |
| output: (N, seq_len, output_dim) | |
| """ | |
| # (N, seq_len, d_model) | |
| src = self.pos_encoder(x) | |
| # model_stats["pos_embedding"] = x | |
| # (N, seq_len, d_model) | |
| output = self.transformer_encoder(src) | |
| # (N, seq_len, output_dim) | |
| output = self.output_mlp(output) | |
| return output | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| position = torch.arange(max_len).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) | |
| ) | |
| # Assume that x is (seq_len, N, d) | |
| # pe = torch.zeros(max_len, 1, d_model) | |
| # pe[:, 0, 0::2] = torch.sin(position * div_term) | |
| # pe[:, 0, 1::2] = torch.cos(position * div_term) | |
| # Assume that x in (N, seq_len, d) | |
| pe = torch.zeros(1, max_len, d_model) | |
| pe[0, :, 0::2] = torch.sin(position * div_term) | |
| pe[0, :, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: Tensor, shape [N, seq_len, d] | |
| """ | |
| # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model) | |
| # x = x + self.pe[: x.size(0)] | |
| # Now: self.pe is (1, max_len, d) | |
| x = x + self.pe[:, : x.size(1), :] | |
| return self.dropout(x) | |
