import torch import torch.nn as nn from transformers import ErnieModel, ErnieTokenizer class ErnieBotDeepSearch(nn.Module): def __init__(self): super().__init__() self.name = "ErnieBot Deep Search" self.version = "Original 1.0" # Core Components self.ernie = ErnieModel.from_pretrained("ernie-3.0-base-zh") self.tokenizer = ErnieTokenizer.from_pretrained("ernie-3.0-base-zh") # Deep Search Components self.search_layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=768, nhead=12) for _ in range(6) ]) self.knowledge_encoder = nn.Linear(768, 1024) self.cross_attention = nn.MultiheadAttention(1024, 16) # Output layers self.classifier = nn.Linear(1024, 2) self.ranking_head = nn.Linear(1024, 1) def deep_search(self, query, documents): # Encode query query_tokens = self.tokenizer(query, return_tensors="pt") query_embed = self.ernie(**query_tokens)[0] # Process documents doc_embeddings = [] for doc in documents: doc_tokens = self.tokenizer(doc, return_tensors="pt") doc_embed = self.ernie(**doc_tokens)[0] doc_embeddings.append(doc_embed) # Deep search processing search_results = self._process_deep_search(query_embed, doc_embeddings) return self._rank_results(search_results) def _process_deep_search(self, query, documents): query_enhanced = self.knowledge_encoder(query) results = [] for doc in documents: # Apply search layers for layer in self.search_layers: doc = layer(doc) # Cross-attention between query and document doc_enhanced = self.knowledge_encoder(doc) attention_output, _ = self.cross_attention( query_enhanced, doc_enhanced, doc_enhanced ) results.append(attention_output) return results def _rank_results(self, search_results): rankings = [] for result in search_results: score = self.ranking_head(result) rankings.append(score) return torch.stack(rankings).squeeze() def train_step(self, batch): query, positive_docs, negative_docs = batch pos_scores = self.deep_search(query, positive_docs) neg_scores = self.deep_search(query, negative_docs) loss = nn.MarginRankingLoss(margin=1.0)(pos_scores, neg_scores, torch.ones_like(pos_scores)) return loss