import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from datasets import load_dataset from transformers import AutoTokenizer from tqdm import tqdm import math import speech_recognition as sr import pyttsx3 from googlesearch import search import warnings from typing import List, Dict, Union # Ignore warnings warnings.filterwarnings("ignore") class WebSearchWrapper: """Wrapper for web search with caching""" def __init__(self, cache_size: int = 100): self.cache: Dict[str, List[str]] = {} self.cache_size = cache_size def search(self, query: str, num_results: int = 3) -> List[str]: """Perform web search with caching""" if query.lower() in self.cache: return self.cache[query.lower()] try: search_results = list(search(query, num_results=num_results, stop=num_results, pause=2)) self._add_to_cache(query, search_results) return search_results except Exception as e: print(f"Web search error: {e}") return [] def _add_to_cache(self, query: str, results: List[str]): """Add results to cache with LRU eviction policy""" if len(self.cache) >= self.cache_size: self.cache.pop(next(iter(self.cache))) self.cache[query.lower()] = results class FullChatDataset(Dataset): def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=256): self.datasets = [] for name in dataset_names: try: dataset = load_dataset(name, split="train") self.datasets.append(dataset) except Exception as e: print(f"Failed to load dataset {name}: {e}") self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.max_length = max_length def __len__(self): return sum(len(d) for d in self.datasets) def __getitem__(self, idx): for dataset in self.datasets: if idx < len(dataset): item = dataset[idx] break idx -= len(dataset) if 'dialog' in item: dialog = item['dialog'] elif 'messages' in item: dialog = [msg['text'] for msg in item['messages']] else: dialog = [v for k, v in item.items() if isinstance(v, str)] context = " [SEP] ".join(dialog[:-1]) response = dialog[-1] inputs = self.tokenizer( context, text_pair=response, max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt" ) return { 'input_ids': inputs['input_ids'].flatten(), 'attention_mask': inputs['attention_mask'].flatten(), 'labels': inputs['input_ids'].flatten() } class SimpleTransformerModel(nn.Module): def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoder = PositionalEncoding(d_model) encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) self.fc = nn.Linear(d_model, vocab_size) def forward(self, x, mask=None): x = self.embedding(x) x = self.pos_encoder(x) x = self.transformer(x, mask) return self.fc(x) class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=500): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:x.size(1)] class VoiceInterface: def __init__(self): self.recognizer = sr.Recognizer() self.engine = pyttsx3.init() def listen(self) -> Union[str, None]: with sr.Microphone() as source: print("Listening...") audio = self.recognizer.listen(source) try: text = self.recognizer.recognize_google(audio) print(f"You said: {text}") return text except Exception as e: print(f"Error recognizing speech: {e}") return None def speak(self, text: str): print(f"Bot: {text}") self.engine.say(text) self.engine.runAndWait() class ChatBot: def __init__(self): self.dataset = FullChatDataset() self.model = SimpleTransformerModel(len(self.dataset.tokenizer)) self.voice_interface = VoiceInterface() self.web_searcher = WebSearchWrapper() def train(self, epochs=3, lr=3e-4): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = self.model.to(device) criterion = nn.CrossEntropyLoss(ignore_index=0) optimizer = optim.Adam(self.model.parameters(), lr=lr) dataloader = DataLoader(self.dataset, batch_size=8, shuffle=True) for epoch in range(epochs): self.model.train() total_loss = 0 pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch in pbar: inputs = batch['input_ids'].to(device) masks = batch['attention_mask'].to(device) labels = batch['labels'].to(device) optimizer.zero_grad() outputs = self.model(inputs, masks) loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1)) loss.backward() optimizer.step() total_loss += loss.item() pbar.set_postfix({'loss': loss.item()}) print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}") def generate_response(self, prompt: str, max_length: int = 100, use_web: bool = True) -> str: device = next(self.model.parameters()).device self.model.eval() # Add web context if needed if use_web and self._needs_web_search(prompt): web_results = self.web_searcher.search(prompt) if web_results: prompt = f"Web context: {', '.join(web_results[:3])}. User question: {prompt}" inputs = self.dataset.tokenizer( prompt, return_tensors="pt", max_length=256, truncation=True, padding='max_length' ).to(device) with torch.no_grad(): outputs = self.model.generate( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=max_length, do_sample=True, top_k=50, top_p=0.95, temperature=0.7 ) response = self.dataset.tokenizer.decode(outputs[0], skip_special_tokens=True) return response def _needs_web_search(self, text: str) -> bool: """Determine if a query needs web search""" question_words = ['what', 'when', 'where', 'who', 'why', 'how', 'which', '?'] return any(word in text.lower() for word in question_words)