|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
class SparseTextEmbedding(nn.Module): |
|
""" |
|
Camada de embedding para texto com atenção multi-cabeça. |
|
Realiza embeddings de tokens de texto e aplica atenção multi-cabeça. |
|
""" |
|
|
|
def __init__(self, num_tokens, emb_dim): |
|
super().__init__() |
|
self.embedding = nn.Embedding(num_tokens, emb_dim) |
|
self.attention = nn.MultiheadAttention(emb_dim, num_heads=8, batch_first=True) |
|
|
|
def forward(self, x): |
|
x = self.embedding(x) |
|
x, _ = self.attention(x, x, x) |
|
return x |
|
|
|
|
|
|
|
class GenericProcessor(nn.Module): |
|
""" |
|
Processador genérico que transforma entradas numéricas em embeddings. |
|
Utiliza uma camada linear seguida por uma ativação ReLU. |
|
""" |
|
|
|
def __init__(self, input_dim, emb_dim): |
|
super().__init__() |
|
self.fc = nn.Linear(input_dim, emb_dim) |
|
|
|
def forward(self, x): |
|
return F.relu(self.fc(x)) |
|
|
|
|
|
|
|
class TransformerExpert(nn.Module): |
|
""" |
|
Especialista em domínio específico usando um encoder Transformer. |
|
Projetado para processar embeddings e realizar tarefas específicas de domínio. |
|
""" |
|
|
|
def __init__(self, emb_dim, num_heads, num_layers, ff_dim): |
|
super().__init__() |
|
transformer_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim) |
|
self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers) |
|
|
|
def forward(self, x): |
|
return self.transformer_encoder(x) |
|
|
|
|
|
|
|
class TransformerDecoderWithCrossAttention(nn.Module): |
|
""" |
|
Decodificador Transformer com atenção cruzada. |
|
Combina informações de múltiplas fontes e projeta o resultado final. |
|
""" |
|
|
|
def __init__(self, emb_dim, num_heads, num_layers, ff_dim): |
|
super().__init__() |
|
transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim) |
|
self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers) |
|
self.projection = nn.Linear(emb_dim, emb_dim) |
|
|
|
def forward(self, x, memory): |
|
output = self.transformer_decoder(x, memory) |
|
return self.projection(output) |
|
|
|
|
|
|
|
class EnedinaModel(nn.Module): |
|
""" |
|
Modelo principal: Enedina. |
|
Integra diferentes componentes especializados para processar múltiplos tipos de entrada. |
|
""" |
|
|
|
def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim, emb_dim=1024, |
|
num_heads=16, num_layers=12, ff_dim=4096): |
|
super().__init__() |
|
self.text_embedding = SparseTextEmbedding(text_num_tokens, emb_dim) |
|
self.image_processor = GenericProcessor(image_input_dim, emb_dim) |
|
self.equation_processor = GenericProcessor(equation_input_dim, emb_dim) |
|
self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim) |
|
self.experts = nn.ModuleList([ |
|
TransformerExpert(emb_dim, num_heads, num_layers, ff_dim) for _ in range(4) |
|
]) |
|
self.gate = nn.Linear(emb_dim * 4, 4) |
|
self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim) |
|
|
|
def forward(self, text_input, image_input, equation_input, diagram_input): |
|
text_emb = self.text_embedding(text_input) |
|
image_emb = self.image_processor(image_input).unsqueeze(1) |
|
equation_emb = self.equation_processor(equation_input).unsqueeze(1) |
|
diagram_emb = self.diagram_processor(diagram_input).unsqueeze(1) |
|
|
|
|
|
expert_inputs = [equation_emb, image_emb, diagram_emb, text_emb] |
|
expert_outputs = [] |
|
for i, expert in enumerate(self.experts): |
|
expert_output = expert(expert_inputs[i].permute(1, 0, 2)) |
|
expert_outputs.append(expert_output.permute(1, 0, 2)[:, -1, :]) |
|
|
|
|
|
combined_expert_outputs = torch.cat(expert_outputs, dim=-1) |
|
|
|
|
|
gate_weights = F.softmax(self.gate(combined_expert_outputs), dim=-1) |
|
expert_outputs_stack = torch.stack(expert_outputs, dim=1) |
|
combined_output = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs_stack, dim=1) |
|
|
|
|
|
combined_output = combined_output.unsqueeze(0) |
|
text_emb = text_emb.permute(1, 0, 2) |
|
|
|
|
|
output = self.transformer_decoder(text_emb, combined_output) |
|
|
|
return output |
|
|
|
|
|
|
|
text_num_tokens = 200000 |
|
image_input_dim = 2048 |
|
equation_input_dim = 1024 |
|
diagram_input_dim = 1024 |
|
batch_size = 4 |
|
text_seq_len = 1000 |
|
image_seq_len = 10 |
|
equation_seq_len = 5 |
|
diagram_seq_len = 5 |
|
|
|
|
|
model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim) |
|
|
|
|
|
text_input = torch.randint(0, text_num_tokens, (batch_size, text_seq_len)) |
|
image_input = torch.randn(batch_size, image_input_dim) |
|
equation_input = torch.randn(batch_size, equation_input_dim) |
|
diagram_input = torch.randn(batch_size, diagram_input_dim) |
|
|
|
|
|
output = model(text_input, image_input, equation_input, diagram_input) |
|
|
|
|
|
print("A forma de saída do tensor é:", output.shape) |