Text Generation
Transformers
English
Russian
legal
SkillForge45 commited on
Commit
088ce82
·
verified ·
1 Parent(s): 3c5cfb7

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +214 -0
model.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer
7
+ from tqdm import tqdm
8
+ import math
9
+ import speech_recognition as sr
10
+ import pyttsx3
11
+ from googlesearch import search
12
+ import warnings
13
+ from typing import List, Dict, Union
14
+
15
+ # Ignore warnings
16
+ warnings.filterwarnings("ignore")
17
+
18
+ class WebSearchWrapper:
19
+ """Wrapper for web search with caching"""
20
+ def __init__(self, cache_size: int = 100):
21
+ self.cache: Dict[str, List[str]] = {}
22
+ self.cache_size = cache_size
23
+
24
+ def search(self, query: str, num_results: int = 3) -> List[str]:
25
+ """Perform web search with caching"""
26
+ if query.lower() in self.cache:
27
+ return self.cache[query.lower()]
28
+
29
+ try:
30
+ search_results = list(search(query, num_results=num_results, stop=num_results, pause=2))
31
+ self._add_to_cache(query, search_results)
32
+ return search_results
33
+ except Exception as e:
34
+ print(f"Web search error: {e}")
35
+ return []
36
+
37
+ def _add_to_cache(self, query: str, results: List[str]):
38
+ """Add results to cache with LRU eviction policy"""
39
+ if len(self.cache) >= self.cache_size:
40
+ self.cache.pop(next(iter(self.cache)))
41
+ self.cache[query.lower()] = results
42
+
43
+ class FullChatDataset(Dataset):
44
+ def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=256):
45
+ self.datasets = []
46
+
47
+ for name in dataset_names:
48
+ try:
49
+ dataset = load_dataset(name, split="train")
50
+ self.datasets.append(dataset)
51
+ except Exception as e:
52
+ print(f"Failed to load dataset {name}: {e}")
53
+
54
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
55
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
56
+ self.max_length = max_length
57
+
58
+ def __len__(self):
59
+ return sum(len(d) for d in self.datasets)
60
+
61
+ def __getitem__(self, idx):
62
+ for dataset in self.datasets:
63
+ if idx < len(dataset):
64
+ item = dataset[idx]
65
+ break
66
+ idx -= len(dataset)
67
+
68
+ if 'dialog' in item:
69
+ dialog = item['dialog']
70
+ elif 'messages' in item:
71
+ dialog = [msg['text'] for msg in item['messages']]
72
+ else:
73
+ dialog = [v for k, v in item.items() if isinstance(v, str)]
74
+
75
+ context = " [SEP] ".join(dialog[:-1])
76
+ response = dialog[-1]
77
+
78
+ inputs = self.tokenizer(
79
+ context,
80
+ text_pair=response,
81
+ max_length=self.max_length,
82
+ padding='max_length',
83
+ truncation=True,
84
+ return_tensors="pt"
85
+ )
86
+
87
+ return {
88
+ 'input_ids': inputs['input_ids'].flatten(),
89
+ 'attention_mask': inputs['attention_mask'].flatten(),
90
+ 'labels': inputs['input_ids'].flatten()
91
+ }
92
+
93
+ class SimpleTransformerModel(nn.Module):
94
+ def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3):
95
+ super().__init__()
96
+ self.embedding = nn.Embedding(vocab_size, d_model)
97
+ self.pos_encoder = PositionalEncoding(d_model)
98
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
99
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
100
+ self.fc = nn.Linear(d_model, vocab_size)
101
+
102
+ def forward(self, x, mask=None):
103
+ x = self.embedding(x)
104
+ x = self.pos_encoder(x)
105
+ x = self.transformer(x, mask)
106
+ return self.fc(x)
107
+
108
+ class PositionalEncoding(nn.Module):
109
+ def __init__(self, d_model, max_len=500):
110
+ super().__init__()
111
+ position = torch.arange(max_len).unsqueeze(1)
112
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
113
+ pe = torch.zeros(max_len, d_model)
114
+ pe[:, 0::2] = torch.sin(position * div_term)
115
+ pe[:, 1::2] = torch.cos(position * div_term)
116
+ self.register_buffer('pe', pe)
117
+
118
+ def forward(self, x):
119
+ return x + self.pe[:x.size(1)]
120
+
121
+ class VoiceInterface:
122
+ def __init__(self):
123
+ self.recognizer = sr.Recognizer()
124
+ self.engine = pyttsx3.init()
125
+
126
+ def listen(self) -> Union[str, None]:
127
+ with sr.Microphone() as source:
128
+ print("Listening...")
129
+ audio = self.recognizer.listen(source)
130
+ try:
131
+ text = self.recognizer.recognize_google(audio)
132
+ print(f"You said: {text}")
133
+ return text
134
+ except Exception as e:
135
+ print(f"Error recognizing speech: {e}")
136
+ return None
137
+
138
+ def speak(self, text: str):
139
+ print(f"Bot: {text}")
140
+ self.engine.say(text)
141
+ self.engine.runAndWait()
142
+
143
+ class ChatBot:
144
+ def __init__(self):
145
+ self.dataset = FullChatDataset()
146
+ self.model = SimpleTransformerModel(len(self.dataset.tokenizer))
147
+ self.voice_interface = VoiceInterface()
148
+ self.web_searcher = WebSearchWrapper()
149
+
150
+ def train(self, epochs=3, lr=3e-4):
151
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
152
+ self.model = self.model.to(device)
153
+ criterion = nn.CrossEntropyLoss(ignore_index=0)
154
+ optimizer = optim.Adam(self.model.parameters(), lr=lr)
155
+
156
+ dataloader = DataLoader(self.dataset, batch_size=8, shuffle=True)
157
+
158
+ for epoch in range(epochs):
159
+ self.model.train()
160
+ total_loss = 0
161
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
162
+
163
+ for batch in pbar:
164
+ inputs = batch['input_ids'].to(device)
165
+ masks = batch['attention_mask'].to(device)
166
+ labels = batch['labels'].to(device)
167
+
168
+ optimizer.zero_grad()
169
+ outputs = self.model(inputs, masks)
170
+ loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
171
+ loss.backward()
172
+ optimizer.step()
173
+
174
+ total_loss += loss.item()
175
+ pbar.set_postfix({'loss': loss.item()})
176
+
177
+ print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}")
178
+
179
+ def generate_response(self, prompt: str, max_length: int = 100, use_web: bool = True) -> str:
180
+ device = next(self.model.parameters()).device
181
+ self.model.eval()
182
+
183
+ # Add web context if needed
184
+ if use_web and self._needs_web_search(prompt):
185
+ web_results = self.web_searcher.search(prompt)
186
+ if web_results:
187
+ prompt = f"Web context: {', '.join(web_results[:3])}. User question: {prompt}"
188
+
189
+ inputs = self.dataset.tokenizer(
190
+ prompt,
191
+ return_tensors="pt",
192
+ max_length=256,
193
+ truncation=True,
194
+ padding='max_length'
195
+ ).to(device)
196
+
197
+ with torch.no_grad():
198
+ outputs = self.model.generate(
199
+ input_ids=inputs['input_ids'],
200
+ attention_mask=inputs['attention_mask'],
201
+ max_length=max_length,
202
+ do_sample=True,
203
+ top_k=50,
204
+ top_p=0.95,
205
+ temperature=0.7
206
+ )
207
+
208
+ response = self.dataset.tokenizer.decode(outputs[0], skip_special_tokens=True)
209
+ return response
210
+
211
+ def _needs_web_search(self, text: str) -> bool:
212
+ """Determine if a query needs web search"""
213
+ question_words = ['what', 'when', 'where', 'who', 'why', 'how', 'which', '?']
214
+ return any(word in text.lower() for word in question_words)