bhanumitt commited on
Commit
b55e0b8
·
1 Parent(s): 0a6b400

Edited requirements.txt

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app_faster.py +0 -216
  3. app_local.py +0 -179
  4. app_streamlit.py +0 -243
  5. requirements.txt +7 -7
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: "🏀"
4
  colorFrom: "blue"
5
  colorTo: "yellow"
6
  sdk: "gradio" # Change to "streamlit" if you're using Streamlit instead of Gradio
7
- sdk_version: "4.19.2" # Replace with the Gradio version you are using
8
  app_file: app.py # Update this to the main file of your app if it’s named differently
9
  pinned: false
10
  ---
 
4
  colorFrom: "blue"
5
  colorTo: "yellow"
6
  sdk: "gradio" # Change to "streamlit" if you're using Streamlit instead of Gradio
7
+ sdk_version: "5.4.0" # Replace with the Gradio version you are using
8
  app_file: app.py # Update this to the main file of your app if it’s named differently
9
  pinned: false
10
  ---
app_faster.py DELETED
@@ -1,216 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn.functional as F
4
- from sentence_transformers import SentenceTransformer
5
- import pickle
6
- import streamlit as st
7
- import gdown
8
- import requests
9
- from llama_cpp import Llama
10
- from tqdm import tqdm
11
- import time
12
- from functools import lru_cache
13
-
14
- def download_file_with_progress(url: str, filename: str):
15
- """Download a file with progress bar using requests"""
16
- response = requests.get(url, stream=True)
17
- total_size = int(response.headers.get('content-length', 0))
18
-
19
- with open(filename, 'wb') as file, tqdm(
20
- desc=filename,
21
- total=total_size,
22
- unit='iB',
23
- unit_scale=True,
24
- unit_divisor=1024,
25
- ) as progress_bar:
26
- for data in response.iter_content(chunk_size=1024):
27
- size = file.write(data)
28
- progress_bar.update(size)
29
-
30
- @st.cache_resource
31
- def initialize_llama_model(model_path: str):
32
- """Initialize model with CPU-optimized settings"""
33
- try:
34
- if not os.path.exists(model_path):
35
- direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
36
- download_file_with_progress(direct_url, model_path)
37
-
38
- if os.path.getsize(model_path) < 1000000:
39
- raise ValueError("Model file appears corrupted")
40
-
41
- # CPU-optimized settings
42
- llm = Llama(
43
- model_path=model_path,
44
- n_ctx=512, # Reduced context window
45
- n_threads=8, # Optimal for most CPUs
46
- n_batch=8, # Small batch size for CPU
47
- n_gpu_layers=0, # CPU only
48
- verbose=False,
49
- rope_freq_scale=0.5,
50
- seed=42
51
- )
52
- return llm
53
-
54
- except Exception as e:
55
- st.error(f"Error initializing model: {str(e)}")
56
- raise
57
-
58
- class SentenceTransformerRetriever:
59
- def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
60
- self.device = torch.device("cpu")
61
- self.model = SentenceTransformer(model_name, device=str(self.device))
62
- self.doc_embeddings = None
63
- self.cache_dir = cache_dir
64
- os.makedirs(cache_dir, exist_ok=True)
65
-
66
- def load_specific_cache(self, cache_filename: str, drive_link: str) -> dict:
67
- cache_path = os.path.join(self.cache_dir, cache_filename)
68
-
69
- if not os.path.exists(cache_path):
70
- print(f"Cache file not found. Downloading from Google Drive...")
71
- try:
72
- gdown.download(drive_link, cache_path, quiet=False)
73
- except Exception as e:
74
- raise Exception(f"Failed to download cache file: {str(e)}")
75
-
76
- print(f"Loading cache from: {cache_path}")
77
- with open(cache_path, 'rb') as f:
78
- return pickle.load(f)
79
-
80
- @lru_cache(maxsize=128)
81
- def encode(self, text: str) -> torch.Tensor:
82
- embeddings = self.model.encode([text], convert_to_tensor=True, show_progress_bar=False)
83
- return F.normalize(embeddings, p=2, dim=1)
84
-
85
- def store_embeddings(self, embeddings: torch.Tensor):
86
- self.doc_embeddings = embeddings
87
-
88
- def search(self, query_embedding: torch.Tensor, k: int):
89
- if self.doc_embeddings is None:
90
- raise ValueError("No document embeddings stored!")
91
-
92
- similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
93
- scores, indices = torch.topk(similarities, k=min(k, similarities.shape[0]))
94
-
95
- return indices.cpu(), scores.cpu()
96
-
97
- class RAGPipeline:
98
- def __init__(self, cache_filename: str, cache_drive_link: str, k: int = 3):
99
- self.cache_filename = cache_filename
100
- self.cache_drive_link = cache_drive_link
101
- self.k = k
102
- self.retriever = SentenceTransformerRetriever()
103
- self.documents = []
104
- self.model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
105
- # Initialize model using cached function
106
- self.llm = initialize_llama_model(self.model_path)
107
-
108
- # Performance metrics
109
- self.retrieval_time = 0
110
- self.inference_time = 0
111
-
112
- def load_cached_embeddings(self):
113
- try:
114
- cache_data = self.retriever.load_specific_cache(self.cache_filename, self.cache_drive_link)
115
- self.documents = cache_data['documents']
116
- self.retriever.store_embeddings(cache_data['embeddings'])
117
- return True
118
- except Exception as e:
119
- st.error(f"Error loading cache: {str(e)}")
120
- return False
121
-
122
- def _get_response(self, prompt_str: str) -> tuple[str, float]:
123
- """Generate response with timing"""
124
- start_time = time.time()
125
- response = self.llm(
126
- prompt_str,
127
- max_tokens=256, # Reduced max tokens
128
- temperature=0.1, # Lower temperature for faster inference
129
- top_p=0.1, # More focused sampling
130
- echo=False,
131
- stop=["Question:", "\n\n"],
132
- top_k=10,
133
- repeat_penalty=1.1,
134
- stream=False
135
- )
136
- inference_time = time.time() - start_time
137
- return response['choices'][0]['text'].strip(), inference_time
138
-
139
- def process_query(self, query: str) -> tuple[str, dict]:
140
- """Process query and return response with timing metrics"""
141
- try:
142
- # Measure retrieval time
143
- retrieval_start = time.time()
144
- query_embedding = self.retriever.encode(query)
145
- indices, _ = self.retriever.search(query_embedding, self.k)
146
-
147
- # Limit context size
148
- relevant_docs = [self.documents[idx] for idx in indices.tolist()][:2]
149
- context = " ".join(relevant_docs)[:800] # Further reduced context length
150
-
151
- # Simplified prompt
152
- prompt = f"Context: {context}\nQ: {query}\nA:"
153
-
154
- self.retrieval_time = time.time() - retrieval_start
155
-
156
- # Get response with timing
157
- response, self.inference_time = self._get_response(prompt)
158
-
159
- # Prepare timing metrics
160
- metrics = {
161
- "retrieval_time": f"{self.retrieval_time:.2f}s",
162
- "inference_time": f"{self.inference_time:.2f}s",
163
- "total_time": f"{(self.retrieval_time + self.inference_time):.2f}s"
164
- }
165
-
166
- return response, metrics
167
-
168
- except Exception as e:
169
- return f"An error occurred: {str(e)}", {"error": str(e)}
170
-
171
- def main():
172
- st.set_page_config(page_title="Sport Chatbot", layout="wide")
173
-
174
- if "rag" not in st.session_state:
175
- with st.spinner("Loading model... (this may take a minute)"):
176
- cache_filename = "embeddings_2296.pkl"
177
- cache_drive_link = "https://drive.google.com/uc?id=1LuJdnwe99C0EgvJpyfHYCKzUvj94FWlC"
178
- st.session_state.rag = RAGPipeline(cache_filename, cache_drive_link)
179
- st.session_state.rag.load_cached_embeddings()
180
-
181
- st.title("Sport Chatbot")
182
-
183
- with st.container():
184
- query = st.text_input("Enter your question:", key="query_input")
185
-
186
- if query:
187
- if "history" not in st.session_state:
188
- st.session_state.history = []
189
-
190
- with st.spinner("Processing..."):
191
- response, metrics = st.session_state.rag.process_query(query)
192
- st.session_state.history.append((query, response, metrics))
193
-
194
- # Display metrics
195
- col1, col2, col3 = st.columns(3)
196
- with col1:
197
- st.metric("Retrieval Time", metrics["retrieval_time"])
198
- with col2:
199
- st.metric("Inference Time", metrics["inference_time"])
200
- with col3:
201
- st.metric("Total Time", metrics["total_time"])
202
-
203
- # Display response
204
- st.write("### Answer:")
205
- st.write(response)
206
-
207
- # Display history (last 3 interactions)
208
- if len(st.session_state.history) > 1:
209
- st.write("### Previous Interactions:")
210
- for q, r, m in list(st.session_state.history[:-1])[-3:]:
211
- with st.expander(f"Q: {q[:50]}..."):
212
- st.write(f"**A:** {r}")
213
- st.write(f"*Timing: Total {m['total_time']}*")
214
-
215
- if __name__ == "__main__":
216
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_local.py DELETED
@@ -1,179 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn.functional as F
4
- from sentence_transformers import SentenceTransformer
5
- import pickle
6
- from llama_cpp import Llama
7
- import streamlit as st
8
-
9
- class SentenceTransformerRetriever:
10
- def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
11
- self.device = torch.device("cpu")
12
- self.model = SentenceTransformer(model_name, device=str(self.device))
13
- self.doc_embeddings = None
14
- self.cache_dir = cache_dir
15
-
16
- def load_specific_cache(self, cache_filename: str) -> dict:
17
- cache_path = os.path.join(self.cache_dir, cache_filename)
18
- if not os.path.exists(cache_path):
19
- raise FileNotFoundError(f"Cache file not found at {cache_path}")
20
-
21
- print(f"Loading cache from: {cache_path}")
22
- with open(cache_path, 'rb') as f:
23
- return pickle.load(f)
24
-
25
- def encode(self, texts: list) -> torch.Tensor:
26
- embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
27
- return F.normalize(embeddings, p=2, dim=1)
28
-
29
- def store_embeddings(self, embeddings: torch.Tensor):
30
- self.doc_embeddings = embeddings
31
-
32
- def search(self, query_embedding: torch.Tensor, k: int):
33
- if self.doc_embeddings is None:
34
- raise ValueError("No document embeddings stored!")
35
-
36
- similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
37
- scores, indices = torch.topk(similarities, k=min(k, similarities.shape[0]))
38
-
39
- return indices.cpu(), scores.cpu()
40
-
41
- class RAGPipeline:
42
- def __init__(self, cache_filename: str, k: int = 10):
43
- self.cache_filename = cache_filename
44
- self.k = k
45
- self.retriever = SentenceTransformerRetriever()
46
- self.documents = []
47
-
48
- # Load the model
49
- model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
50
- if not os.path.exists(model_path):
51
- raise FileNotFoundError(f"Model file {model_path} not found!")
52
-
53
- self.llm = Llama(
54
- model_path=model_path,
55
- n_ctx=4096,
56
- n_gpu_layers=0, # CPU only
57
- verbose=False,
58
- )
59
-
60
- def load_cached_embeddings(self):
61
- """Load embeddings from specific cache file"""
62
- try:
63
- cache_data = self.retriever.load_specific_cache(self.cache_filename)
64
- self.documents = cache_data['documents']
65
- self.retriever.store_embeddings(cache_data['embeddings'])
66
- return True
67
- except Exception as e:
68
- st.error(f"Error loading cache: {str(e)}")
69
- return False
70
-
71
- def process_query(self, query: str) -> str:
72
- MAX_ATTEMPTS = 5
73
- SIMILARITY_THRESHOLD = 0.4
74
-
75
- for attempt in range(MAX_ATTEMPTS):
76
- try:
77
- print(f"\nAttempt {attempt + 1}/{MAX_ATTEMPTS}")
78
-
79
- # Get query embedding and search for relevant docs
80
- query_embedding = self.retriever.encode([query])
81
- indices, _ = self.retriever.search(query_embedding, self.k)
82
-
83
- relevant_docs = [self.documents[idx] for idx in indices.tolist()]
84
- context = "\n".join(relevant_docs)
85
-
86
- prompt = f"""Context information is below in backticks:
87
-
88
- ```
89
- {context}
90
- ```
91
-
92
- Given the context above, please answer the following question:
93
- {query}
94
-
95
- If you cannot answer it based on the context, please mention politely that you don't know the answer.
96
- Prefer to answer whatever information you can give to the user based on the context.
97
- Answer in a paragraph format.
98
- Answer using the information available in the context.
99
- Please don't repeat any part of this prompt in the answer. Feel free to use this information to improve the answer.
100
- Please avoid repetition.
101
-
102
- Answer:"""
103
-
104
- response = self.llm(
105
- prompt,
106
- max_tokens=1024,
107
- temperature=0.4,
108
- top_p=0.95,
109
- echo=False,
110
- stop=["Question:", "\n\n"]
111
- )
112
-
113
- answer = response['choices'][0]['text'].strip()
114
-
115
- # Check if response is empty or too short
116
- if not answer or len(answer) < 2:
117
- print(f"Got empty or too short response: '{answer}'. Retrying...")
118
- continue
119
-
120
- # Validate response relevance by comparing embeddings
121
- response_embedding = self.retriever.encode([answer])
122
- response_similarity = F.cosine_similarity(query_embedding, response_embedding)
123
- response_score = response_similarity.item()
124
- print(f"Response relevance score: {response_score:.3f}")
125
-
126
- if response_score < SIMILARITY_THRESHOLD:
127
- print(f"Response: {answer}. Response relevance {response_score:.3f} below threshold {SIMILARITY_THRESHOLD}. Retrying...")
128
- continue
129
-
130
- print(f"Successful response generated on attempt {attempt + 1}")
131
- return answer
132
-
133
- except Exception as e:
134
- print(f"Error on attempt {attempt + 1}: {str(e)}")
135
- continue
136
-
137
- return "I apologize, but after multiple attempts, I was unable to generate a satisfactory response. Please try rephrasing your question."
138
-
139
-
140
- @st.cache_resource
141
- def initialize_rag_pipeline(cache_filename: str):
142
- """Initialize and load the RAG pipeline with cached embeddings"""
143
- rag = RAGPipeline(cache_filename)
144
- success = rag.load_cached_embeddings()
145
- if not success:
146
- st.error("Failed to load cached embeddings. Please check the cache file path.")
147
- st.stop()
148
- return rag
149
-
150
- def main():
151
- st.title("The Sport Chatbot")
152
- st.subheader("Using ESPN API")
153
-
154
- st.write("Hey there! 👋 I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.")
155
- st.write("Got any general questions? Feel free to ask—I'll do my best to provide answers based on the information I've been trained on!")
156
-
157
- # Use the specific cache file we know exists
158
- cache_filename = "embeddings_2296.pkl"
159
-
160
- try:
161
- rag = initialize_rag_pipeline(cache_filename)
162
- except Exception as e:
163
- st.error(f"Error initializing the application: {str(e)}")
164
- st.stop()
165
-
166
- # Query input
167
- query = st.text_input("Enter your question:")
168
-
169
- if st.button("Get Answer"):
170
- if query:
171
- with st.spinner("Searching for information..."):
172
- response = rag.process_query(query)
173
- st.write("### Answer:")
174
- st.write(response)
175
- else:
176
- st.warning("Please enter a question!")
177
-
178
- if __name__ == "__main__":
179
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_streamlit.py DELETED
@@ -1,243 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn.functional as F
4
- from sentence_transformers import SentenceTransformer
5
- import pickle
6
- import streamlit as st
7
- import gdown
8
- import requests
9
- from llama_cpp import Llama
10
- from tqdm import tqdm
11
- import time
12
- from functools import lru_cache
13
-
14
- def download_file_with_progress(url: str, filename: str):
15
- """Download a file with progress bar using requests"""
16
- response = requests.get(url, stream=True)
17
- total_size = int(response.headers.get('content-length', 0))
18
-
19
- with open(filename, 'wb') as file, tqdm(
20
- desc=filename,
21
- total=total_size,
22
- unit='iB',
23
- unit_scale=True,
24
- unit_divisor=1024,
25
- ) as progress_bar:
26
- for data in response.iter_content(chunk_size=1024):
27
- size = file.write(data)
28
- progress_bar.update(size)
29
-
30
- class SentenceTransformerRetriever:
31
- def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
32
- self.device = torch.device("cpu")
33
- self.model = SentenceTransformer(model_name, device=str(self.device))
34
- self.doc_embeddings = None
35
- self.cache_dir = cache_dir
36
- os.makedirs(cache_dir, exist_ok=True)
37
-
38
- def load_specific_cache(self, cache_filename: str, drive_link: str) -> dict:
39
- cache_path = os.path.join(self.cache_dir, cache_filename)
40
-
41
- # Download cache file if it doesn't exist
42
- if not os.path.exists(cache_path):
43
- print(f"Cache file not found. Downloading from Google Drive...")
44
- try:
45
- gdown.download(drive_link, cache_path, quiet=False)
46
- except Exception as e:
47
- raise Exception(f"Failed to download cache file: {str(e)}")
48
-
49
- if not os.path.exists(cache_path):
50
- raise FileNotFoundError(f"Failed to download cache file to {cache_path}")
51
-
52
- print(f"Loading cache from: {cache_path}")
53
- with open(cache_path, 'rb') as f:
54
- return pickle.load(f)
55
-
56
- def encode(self, texts: list) -> torch.Tensor:
57
- embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
58
- return F.normalize(embeddings, p=2, dim=1)
59
-
60
- def store_embeddings(self, embeddings: torch.Tensor):
61
- self.doc_embeddings = embeddings
62
-
63
- def search(self, query_embedding: torch.Tensor, k: int):
64
- if self.doc_embeddings is None:
65
- raise ValueError("No document embeddings stored!")
66
-
67
- similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
68
- scores, indices = torch.topk(similarities, k=min(k, similarities.shape[0]))
69
-
70
- return indices.cpu(), scores.cpu()
71
-
72
- class RAGPipeline:
73
- def __init__(self, cache_filename: str, cache_drive_link: str, k: int = 10):
74
- self.cache_filename = cache_filename
75
- self.cache_drive_link = cache_drive_link
76
- self.k = k
77
- self.retriever = SentenceTransformerRetriever()
78
- self.documents = []
79
-
80
- # Model configuration
81
- self.model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
82
-
83
- # Initialize model
84
- self._initialize_model()
85
-
86
- def _initialize_model(self):
87
- """Initialize the model with proper error handling and verification"""
88
- try:
89
- if not os.path.exists(self.model_path):
90
- direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
91
- download_file_with_progress(direct_url, self.model_path)
92
-
93
- # Verify file exists and has content
94
- if not os.path.exists(self.model_path):
95
- raise FileNotFoundError(f"Model file {self.model_path} not found after download attempts")
96
-
97
- if os.path.getsize(self.model_path) < 1000000: # Less than 1MB
98
- os.remove(self.model_path)
99
- raise ValueError("Downloaded model file is too small, likely corrupted")
100
-
101
- # Initialize the model
102
- self.llm = Llama(
103
- model_path=self.model_path,
104
- n_ctx=2048, # Reduced context window
105
- n_threads=8, # Optimal for most CPUs
106
- n_batch=8, # Small batch size for CPU
107
- n_gpu_layers=0, # CPU only
108
- verbose=False,
109
- rope_freq_scale=0.5,
110
- seed=42
111
- )
112
- st.success("Model loaded successfully!")
113
-
114
- except Exception as e:
115
- st.error(f"Error initializing model: {str(e)}")
116
- raise
117
-
118
- def load_cached_embeddings(self):
119
- """Load embeddings from Google Drive cache file"""
120
- try:
121
- cache_data = self.retriever.load_specific_cache(self.cache_filename, self.cache_drive_link)
122
- self.documents = cache_data['documents']
123
- self.retriever.store_embeddings(cache_data['embeddings'])
124
- return True
125
- except Exception as e:
126
- st.error(f"Error loading cache: {str(e)}")
127
- return False
128
-
129
- def process_query(self, query: str) -> str:
130
- MAX_ATTEMPTS = 5
131
- SIMILARITY_THRESHOLD = 0.3
132
-
133
- for attempt in range(MAX_ATTEMPTS):
134
- try:
135
- print(f"\nAttempt {attempt + 1}/{MAX_ATTEMPTS}")
136
-
137
- # Get query embedding and search for relevant docs
138
- query_embedding = self.retriever.encode([query])
139
- indices, _ = self.retriever.search(query_embedding, self.k)
140
-
141
- relevant_docs = [self.documents[idx] for idx in indices.tolist()]
142
- context = "\n".join(relevant_docs)
143
-
144
- prompt = f"""Context information is below in backticks:
145
-
146
- ```
147
- {context}
148
- ```
149
-
150
- Given the context above, please answer the following question:
151
- {query}
152
-
153
- If you cannot answer based on the context, mention politely that you don't know.
154
- Answer in a paragraph format using only the context information.
155
-
156
- Please don't repeat any part of this prompt in the answer. Feel free to use this information to improve the answer.
157
- Please avoid repetition.
158
-
159
- Answer:"""
160
-
161
- start_time = time.time()
162
-
163
- response = self.llm(
164
- prompt,
165
- max_tokens=512, # Reduced max tokens
166
- temperature=0.1, # Lower temperature for faster inference
167
- top_p=0.1, # More focused sampling
168
- echo=False,
169
- stop=["Question:", "\n\n"],
170
- top_k=10,
171
- repeat_penalty=1.1,
172
- stream=False
173
- )
174
-
175
- print('Inference Time:', time.time() - start_time)
176
-
177
- answer = response['choices'][0]['text'].strip()
178
-
179
- # Check if response is empty or too short
180
- if not answer or len(answer) < 2:
181
- print(f"Got empty or too short response: '{answer}'. Retrying...")
182
- continue
183
-
184
- # Validate response relevance by comparing embeddings
185
- response_embedding = self.retriever.encode([answer])
186
- response_similarity = F.cosine_similarity(query_embedding, response_embedding)
187
- response_score = response_similarity.item()
188
- print(f"Response relevance score: {response_score:.3f}")
189
-
190
- if response_score < SIMILARITY_THRESHOLD:
191
- print(f"Response: {answer}; Response relevance {response_score:.3f} below threshold {SIMILARITY_THRESHOLD}. Retrying...")
192
- continue
193
-
194
- print(f"Successful response generated on attempt {attempt + 1}")
195
- return answer
196
-
197
- except Exception as e:
198
- print(f"Error on attempt {attempt + 1}: {str(e)}")
199
- continue
200
-
201
- return "I apologize, but after multiple attempts, I was unable to generate a satisfactory response. Please try rephrasing your question."
202
-
203
- @st.cache_resource
204
- def initialize_rag_pipeline(cache_filename: str, cache_drive_link: str):
205
- """Initialize and load the RAG pipeline with cached embeddings from Google Drive"""
206
- rag = RAGPipeline(cache_filename, cache_drive_link)
207
- success = rag.load_cached_embeddings()
208
- if not success:
209
- st.error("Failed to load cached embeddings. Please check the cache file and drive link.")
210
- st.stop()
211
- return rag
212
-
213
- def main():
214
- st.title("The Sport Chatbot")
215
- st.subheader("Using ESPN API")
216
-
217
- st.write("Hey there! 👋 I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.")
218
- st.write("Got any general questions? Feel free to ask—I'll do my best to provide answers based on the information I've been trained on!")
219
-
220
- # Cache file details
221
- cache_filename = "embeddings_2296.pkl"
222
- cache_drive_link = "https://drive.google.com/uc?id=1LuJdnwe99C0EgvJpyfHYCKzUvj94FWlC" # Replace with your cache file's Google Drive link
223
-
224
- try:
225
- rag = initialize_rag_pipeline(cache_filename, cache_drive_link)
226
- except Exception as e:
227
- st.error(f"Error initializing the application: {str(e)}")
228
- st.stop()
229
-
230
- # Query input
231
- query = st.text_input("Enter your question:")
232
-
233
- if st.button("Get Answer"):
234
- if query:
235
- with st.spinner("Searching for information..."):
236
- response = rag.process_query(query)
237
- st.write("### Answer:")
238
- st.write(response)
239
- else:
240
- st.warning("Please enter a question!")
241
-
242
- if __name__ == "__main__":
243
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- gradio==4.19.2
2
- torch>=2.0.0
3
- sentence-transformers==2.5.1
4
- llama-cpp-python==0.2.56
5
- gdown==5.1.0
6
- tqdm==4.66.2
7
- requests==2.31.0
 
1
+ gdown==5.2.0
2
+ gradio==5.4.0
3
+ Requests==2.32.3
4
+ sentence_transformers==3.2.1
5
+ llama_cpp_python==0.3.1
6
+ torch==2.5.0
7
+ tqdm==4.66.5