|
import torch |
|
import torch.nn as nn |
|
from transformers import ErnieModel, ErnieTokenizer |
|
|
|
class ErnieBotDeepSearch(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.name = "ErnieBot Deep Search" |
|
self.version = "Original 1.0" |
|
|
|
|
|
self.ernie = ErnieModel.from_pretrained("ernie-3.0-base-zh") |
|
self.tokenizer = ErnieTokenizer.from_pretrained("ernie-3.0-base-zh") |
|
|
|
|
|
self.search_layers = nn.ModuleList([ |
|
nn.TransformerEncoderLayer(d_model=768, nhead=12) |
|
for _ in range(6) |
|
]) |
|
|
|
self.knowledge_encoder = nn.Linear(768, 1024) |
|
self.cross_attention = nn.MultiheadAttention(1024, 16) |
|
|
|
|
|
self.classifier = nn.Linear(1024, 2) |
|
self.ranking_head = nn.Linear(1024, 1) |
|
|
|
def deep_search(self, query, documents): |
|
|
|
query_tokens = self.tokenizer(query, return_tensors="pt") |
|
query_embed = self.ernie(**query_tokens)[0] |
|
|
|
|
|
doc_embeddings = [] |
|
for doc in documents: |
|
doc_tokens = self.tokenizer(doc, return_tensors="pt") |
|
doc_embed = self.ernie(**doc_tokens)[0] |
|
doc_embeddings.append(doc_embed) |
|
|
|
|
|
search_results = self._process_deep_search(query_embed, doc_embeddings) |
|
return self._rank_results(search_results) |
|
|
|
def _process_deep_search(self, query, documents): |
|
query_enhanced = self.knowledge_encoder(query) |
|
|
|
results = [] |
|
for doc in documents: |
|
|
|
for layer in self.search_layers: |
|
doc = layer(doc) |
|
|
|
|
|
doc_enhanced = self.knowledge_encoder(doc) |
|
attention_output, _ = self.cross_attention( |
|
query_enhanced, doc_enhanced, doc_enhanced |
|
) |
|
|
|
results.append(attention_output) |
|
return results |
|
|
|
def _rank_results(self, search_results): |
|
rankings = [] |
|
for result in search_results: |
|
score = self.ranking_head(result) |
|
rankings.append(score) |
|
return torch.stack(rankings).squeeze() |
|
|
|
def train_step(self, batch): |
|
query, positive_docs, negative_docs = batch |
|
pos_scores = self.deep_search(query, positive_docs) |
|
neg_scores = self.deep_search(query, negative_docs) |
|
|
|
loss = nn.MarginRankingLoss(margin=1.0)(pos_scores, neg_scores, torch.ones_like(pos_scores)) |
|
return loss |
|
|
|
|