Ernie_Bot.py / ernie_bot.py
iedavidcastilloX's picture
Create ernie_bot.py
f7c34c4 verified
import torch
import torch.nn as nn
from transformers import ErnieModel, ErnieTokenizer
class ErnieBotDeepSearch(nn.Module):
def __init__(self):
super().__init__()
self.name = "ErnieBot Deep Search"
self.version = "Original 1.0"
# Core Components
self.ernie = ErnieModel.from_pretrained("ernie-3.0-base-zh")
self.tokenizer = ErnieTokenizer.from_pretrained("ernie-3.0-base-zh")
# Deep Search Components
self.search_layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=768, nhead=12)
for _ in range(6)
])
self.knowledge_encoder = nn.Linear(768, 1024)
self.cross_attention = nn.MultiheadAttention(1024, 16)
# Output layers
self.classifier = nn.Linear(1024, 2)
self.ranking_head = nn.Linear(1024, 1)
def deep_search(self, query, documents):
# Encode query
query_tokens = self.tokenizer(query, return_tensors="pt")
query_embed = self.ernie(**query_tokens)[0]
# Process documents
doc_embeddings = []
for doc in documents:
doc_tokens = self.tokenizer(doc, return_tensors="pt")
doc_embed = self.ernie(**doc_tokens)[0]
doc_embeddings.append(doc_embed)
# Deep search processing
search_results = self._process_deep_search(query_embed, doc_embeddings)
return self._rank_results(search_results)
def _process_deep_search(self, query, documents):
query_enhanced = self.knowledge_encoder(query)
results = []
for doc in documents:
# Apply search layers
for layer in self.search_layers:
doc = layer(doc)
# Cross-attention between query and document
doc_enhanced = self.knowledge_encoder(doc)
attention_output, _ = self.cross_attention(
query_enhanced, doc_enhanced, doc_enhanced
)
results.append(attention_output)
return results
def _rank_results(self, search_results):
rankings = []
for result in search_results:
score = self.ranking_head(result)
rankings.append(score)
return torch.stack(rankings).squeeze()
def train_step(self, batch):
query, positive_docs, negative_docs = batch
pos_scores = self.deep_search(query, positive_docs)
neg_scores = self.deep_search(query, negative_docs)
loss = nn.MarginRankingLoss(margin=1.0)(pos_scores, neg_scores, torch.ones_like(pos_scores))
return loss