Text Generation
Transformers
English
Russian
legal
SkillForge45 commited on
Commit
e618d88
·
verified ·
1 Parent(s): 4decd29

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +167 -0
model.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer
7
+ from tqdm import tqdm
8
+ import math
9
+ import speech_recognition as sr
10
+ import pyttsx3
11
+
12
+ class FullChatDataset(Dataset):
13
+ def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=128):
14
+ self.datasets = []
15
+
16
+ for name in dataset_names:
17
+ try:
18
+ dataset = load_dataset(name, split="train")
19
+ self.datasets.append(dataset)
20
+ except Exception as e:
21
+ print(f"Failed to load dataset {name}: {e}")
22
+
23
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
24
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
25
+ self.max_length = max_length
26
+
27
+ def __len__(self):
28
+ return sum(len(d) for d in self.datasets)
29
+
30
+ def __getitem__(self, idx):
31
+ for dataset in self.datasets:
32
+ if idx < len(dataset):
33
+ item = dataset[idx]
34
+ break
35
+ idx -= len(dataset)
36
+
37
+ if 'dialog' in item:
38
+ dialog = item['dialog']
39
+ elif 'messages' in item:
40
+ dialog = [msg['text'] for msg in item['messages']]
41
+ else:
42
+ dialog = [v for k, v in item.items() if isinstance(v, str)]
43
+
44
+ context = " [SEP] ".join(dialog[:-1])
45
+ response = dialog[-1]
46
+
47
+ inputs = self.tokenizer(
48
+ context,
49
+ text_pair=response,
50
+ max_length=self.max_length,
51
+ padding='max_length',
52
+ truncation=True,
53
+ return_tensors="pt"
54
+ )
55
+
56
+ return {
57
+ 'input_ids': inputs['input_ids'].flatten(),
58
+ 'attention_mask': inputs['attention_mask'].flatten(),
59
+ 'labels': inputs['input_ids'].flatten()
60
+ }
61
+
62
+ class SimpleTransformerModel(nn.Module):
63
+ def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3):
64
+ super().__init__()
65
+ self.embedding = nn.Embedding(vocab_size, d_model)
66
+ self.pos_encoder = PositionalEncoding(d_model)
67
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
68
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
69
+ self.fc = nn.Linear(d_model, vocab_size)
70
+
71
+ def forward(self, x, mask=None):
72
+ x = self.embedding(x)
73
+ x = self.pos_encoder(x)
74
+ x = self.transformer(x, mask)
75
+ return self.fc(x)
76
+
77
+ class PositionalEncoding(nn.Module):
78
+ def __init__(self, d_model, max_len=500):
79
+ super().__init__()
80
+ position = torch.arange(max_len).unsqueeze(1)
81
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
82
+ pe = torch.zeros(max_len, d_model)
83
+ pe[:, 0::2] = torch.sin(position * div_term)
84
+ pe[:, 1::2] = torch.cos(position * div_term)
85
+ self.register_buffer('pe', pe)
86
+
87
+ def forward(self, x):
88
+ return x + self.pe[:x.size(1)]
89
+
90
+ class VoiceInterface:
91
+ def __init__(self):
92
+ self.recognizer = sr.Recognizer()
93
+ self.engine = pyttsx3.init()
94
+
95
+ def listen(self):
96
+ with sr.Microphone() as source:
97
+ print("Listening...")
98
+ audio = self.recognizer.listen(source)
99
+ try:
100
+ text = self.recognizer.recognize_google(audio)
101
+ print(f"You said: {text}")
102
+ return text
103
+ except Exception as e:
104
+ print(f"Error recognizing speech: {e}")
105
+ return None
106
+
107
+ def speak(self, text):
108
+ print(f"Bot: {text}")
109
+ self.engine.say(text)
110
+ self.engine.runAndWait()
111
+
112
+ def train_model(model, dataloader, epochs=3, lr=3e-4):
113
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
114
+ model = model.to(device)
115
+ criterion = nn.CrossEntropyLoss(ignore_index=0)
116
+ optimizer = optim.Adam(model.parameters(), lr=lr)
117
+
118
+ for epoch in range(epochs):
119
+ model.train()
120
+ total_loss = 0
121
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
122
+
123
+ for batch in pbar:
124
+ inputs = batch['input_ids'].to(device)
125
+ masks = batch['attention_mask'].to(device)
126
+ labels = batch['labels'].to(device)
127
+
128
+ optimizer.zero_grad()
129
+ outputs = model(inputs, masks)
130
+ loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
131
+ loss.backward()
132
+ optimizer.step()
133
+
134
+ total_loss += loss.item()
135
+ pbar.set_postfix({'loss': loss.item()})
136
+
137
+ print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}")
138
+
139
+ def generate_response(model, tokenizer, prompt, max_length=50, voice_interface=None):
140
+ device = next(model.parameters()).device
141
+ model.eval()
142
+
143
+ inputs = tokenizer(
144
+ prompt,
145
+ return_tensors="pt",
146
+ max_length=128,
147
+ truncation=True,
148
+ padding='max_length'
149
+ ).to(device)
150
+
151
+ with torch.no_grad():
152
+ outputs = model.generate(
153
+ input_ids=inputs['input_ids'],
154
+ attention_mask=inputs['attention_mask'],
155
+ max_length=max_length,
156
+ do_sample=True,
157
+ top_k=50,
158
+ top_p=0.95,
159
+ temperature=0.7
160
+ )
161
+
162
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
163
+
164
+ if voice_interface:
165
+ voice_interface.speak(response)
166
+
167
+ return response