iedavidcastilloX commited on
Commit
f7c34c4
·
verified ·
1 Parent(s): 513e31a

Create ernie_bot.py

Browse files

Chinese/to train at your expense and experience, help Americans understand him.

Files changed (1) hide show
  1. ernie_bot.py +76 -0
ernie_bot.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import ErnieModel, ErnieTokenizer
4
+
5
+ class ErnieBotDeepSearch(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.name = "ErnieBot Deep Search"
9
+ self.version = "Original 1.0"
10
+
11
+ # Core Components
12
+ self.ernie = ErnieModel.from_pretrained("ernie-3.0-base-zh")
13
+ self.tokenizer = ErnieTokenizer.from_pretrained("ernie-3.0-base-zh")
14
+
15
+ # Deep Search Components
16
+ self.search_layers = nn.ModuleList([
17
+ nn.TransformerEncoderLayer(d_model=768, nhead=12)
18
+ for _ in range(6)
19
+ ])
20
+
21
+ self.knowledge_encoder = nn.Linear(768, 1024)
22
+ self.cross_attention = nn.MultiheadAttention(1024, 16)
23
+
24
+ # Output layers
25
+ self.classifier = nn.Linear(1024, 2)
26
+ self.ranking_head = nn.Linear(1024, 1)
27
+
28
+ def deep_search(self, query, documents):
29
+ # Encode query
30
+ query_tokens = self.tokenizer(query, return_tensors="pt")
31
+ query_embed = self.ernie(**query_tokens)[0]
32
+
33
+ # Process documents
34
+ doc_embeddings = []
35
+ for doc in documents:
36
+ doc_tokens = self.tokenizer(doc, return_tensors="pt")
37
+ doc_embed = self.ernie(**doc_tokens)[0]
38
+ doc_embeddings.append(doc_embed)
39
+
40
+ # Deep search processing
41
+ search_results = self._process_deep_search(query_embed, doc_embeddings)
42
+ return self._rank_results(search_results)
43
+
44
+ def _process_deep_search(self, query, documents):
45
+ query_enhanced = self.knowledge_encoder(query)
46
+
47
+ results = []
48
+ for doc in documents:
49
+ # Apply search layers
50
+ for layer in self.search_layers:
51
+ doc = layer(doc)
52
+
53
+ # Cross-attention between query and document
54
+ doc_enhanced = self.knowledge_encoder(doc)
55
+ attention_output, _ = self.cross_attention(
56
+ query_enhanced, doc_enhanced, doc_enhanced
57
+ )
58
+
59
+ results.append(attention_output)
60
+ return results
61
+
62
+ def _rank_results(self, search_results):
63
+ rankings = []
64
+ for result in search_results:
65
+ score = self.ranking_head(result)
66
+ rankings.append(score)
67
+ return torch.stack(rankings).squeeze()
68
+
69
+ def train_step(self, batch):
70
+ query, positive_docs, negative_docs = batch
71
+ pos_scores = self.deep_search(query, positive_docs)
72
+ neg_scores = self.deep_search(query, negative_docs)
73
+
74
+ loss = nn.MarginRankingLoss(margin=1.0)(pos_scores, neg_scores, torch.ones_like(pos_scores))
75
+ return loss
76
+