Enedina / EnedinaModel.py
leitaofilho's picture
Update EnedinaModel.py
c238e32 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# Definição de uma camada de embedding com atenção esparsa para texto
class SparseTextEmbedding(nn.Module):
"""
Camada de embedding para texto com atenção multi-cabeça.
Realiza embeddings de tokens de texto e aplica atenção multi-cabeça.
"""
def __init__(self, num_tokens, emb_dim):
super().__init__()
self.embedding = nn.Embedding(num_tokens, emb_dim)
self.attention = nn.MultiheadAttention(emb_dim, num_heads=8, batch_first=True)
def forward(self, x):
x = self.embedding(x)
x, _ = self.attention(x, x, x)
return x
# Processador genérico para transformar entradas numéricas em embeddings
class GenericProcessor(nn.Module):
"""
Processador genérico que transforma entradas numéricas em embeddings.
Utiliza uma camada linear seguida por uma ativação ReLU.
"""
def __init__(self, input_dim, emb_dim):
super().__init__()
self.fc = nn.Linear(input_dim, emb_dim)
def forward(self, x):
return F.relu(self.fc(x))
# Especialista em transformador para domínios específicos
class TransformerExpert(nn.Module):
"""
Especialista em domínio específico usando um encoder Transformer.
Projetado para processar embeddings e realizar tarefas específicas de domínio.
"""
def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
super().__init__()
transformer_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)
def forward(self, x):
return self.transformer_encoder(x)
# Decodificador Transformer com atenção cruzada
class TransformerDecoderWithCrossAttention(nn.Module):
"""
Decodificador Transformer com atenção cruzada.
Combina informações de múltiplas fontes e projeta o resultado final.
"""
def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
super().__init__()
transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers)
self.projection = nn.Linear(emb_dim, emb_dim)
def forward(self, x, memory):
output = self.transformer_decoder(x, memory)
return self.projection(output)
# Modelo principal que incorpora os componentes acima
class EnedinaModel(nn.Module):
"""
Modelo principal: Enedina.
Integra diferentes componentes especializados para processar múltiplos tipos de entrada.
"""
def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim, emb_dim=1024,
num_heads=16, num_layers=12, ff_dim=4096):
super().__init__()
self.text_embedding = SparseTextEmbedding(text_num_tokens, emb_dim)
self.image_processor = GenericProcessor(image_input_dim, emb_dim)
self.equation_processor = GenericProcessor(equation_input_dim, emb_dim)
self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim)
self.experts = nn.ModuleList([
TransformerExpert(emb_dim, num_heads, num_layers, ff_dim) for _ in range(4)
])
self.gate = nn.Linear(emb_dim * 4, 4)
self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim)
def forward(self, text_input, image_input, equation_input, diagram_input):
text_emb = self.text_embedding(text_input)
image_emb = self.image_processor(image_input).unsqueeze(1)
equation_emb = self.equation_processor(equation_input).unsqueeze(1)
diagram_emb = self.diagram_processor(diagram_input).unsqueeze(1)
# Estrutura dos especialistas
expert_inputs = [equation_emb, image_emb, diagram_emb, text_emb]
expert_outputs = []
for i, expert in enumerate(self.experts):
expert_output = expert(expert_inputs[i].permute(1, 0, 2))
expert_outputs.append(expert_output.permute(1, 0, 2)[:, -1, :])
# Combina as saídas dos especialistas
combined_expert_outputs = torch.cat(expert_outputs, dim=-1)
# Calcula os pesos do gate e aplica a combinação ponderada das saídas dos especialistas
gate_weights = F.softmax(self.gate(combined_expert_outputs), dim=-1)
expert_outputs_stack = torch.stack(expert_outputs, dim=1)
combined_output = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs_stack, dim=1)
# Ajustes de dimensão antes do TransformerDecoder
combined_output = combined_output.unsqueeze(0)
text_emb = text_emb.permute(1, 0, 2)
# Aplica o decodificador Transformer com atenção cruzada
output = self.transformer_decoder(text_emb, combined_output)
return output
# Configuração dos parâmetros do modelo e simulação de entrada para testes
text_num_tokens = 200000
image_input_dim = 2048
equation_input_dim = 1024
diagram_input_dim = 1024
batch_size = 4
text_seq_len = 1000
image_seq_len = 10
equation_seq_len = 5
diagram_seq_len = 5
# Inicializa o modelo
model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim)
# Gera entradas simuladas
text_input = torch.randint(0, text_num_tokens, (batch_size, text_seq_len))
image_input = torch.randn(batch_size, image_input_dim)
equation_input = torch.randn(batch_size, equation_input_dim)
diagram_input = torch.randn(batch_size, diagram_input_dim)
# Executa o modelo com as entradas simuladas
output = model(text_input, image_input, equation_input, diagram_input)
# Verifica a forma da saída
print("A forma de saída do tensor é:", output.shape)