sport-chatbot / app.py
bhanumitt's picture
Added Beam
f318673
import os
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import pickle
import gradio as gr
import gdown
import requests
from llama_cpp import Llama
from tqdm import tqdm
import time
from functools import lru_cache
# Beam deployment configuration
beam.init(
name="sports-chatbot",
python_packages=[
"torch",
"sentence-transformers",
"llama-cpp-python",
"gradio",
"gdown",
"tqdm"
]
)
@beam.run_on_beam(
memory="16Gi",
cpu="8",
python_version="3.10",
mount={
"embeddings": "./embeddings_cache",
"models": "./models"
}
)
def initialize_llm():
from llama_cpp import Llama
model_path = "models/mistral-7b-v0.1.Q4_K_M.gguf"
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
download_file_with_progress(direct_url, model_path)
return Llama(
model_path=model_path,
n_ctx=2048,
n_threads=8,
n_batch=8,
n_gpu_layers=0, # CPU only
verbose=False,
rope_freq_scale=0.5,
seed=42
)
def download_file_with_progress(url: str, filename: str):
"""Download a file with progress bar using requests"""
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(filename, 'wb') as file, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
class SentenceTransformerRetriever:
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
self.device = torch.device("cpu")
self.model = SentenceTransformer(model_name, device=str(self.device))
self.doc_embeddings = None
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
def load_specific_cache(self, cache_filename: str, drive_link: str) -> dict:
cache_path = os.path.join(self.cache_dir, cache_filename)
if not os.path.exists(cache_path):
print(f"Cache file not found. Downloading from Google Drive...")
try:
gdown.download(drive_link, cache_path, quiet=False)
except Exception as e:
raise Exception(f"Failed to download cache file: {str(e)}")
if not os.path.exists(cache_path):
raise FileNotFoundError(f"Failed to download cache file to {cache_path}")
print(f"Loading cache from: {cache_path}")
with open(cache_path, 'rb') as f:
return pickle.load(f)
def encode(self, texts: list) -> torch.Tensor:
embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
return F.normalize(embeddings, p=2, dim=1)
def store_embeddings(self, embeddings: torch.Tensor):
self.doc_embeddings = embeddings
def search(self, query_embedding: torch.Tensor, k: int):
if self.doc_embeddings is None:
raise ValueError("No document embeddings stored!")
similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
scores, indices = torch.topk(similarities, k=min(k, similarities.shape[0]))
return indices.cpu(), scores.cpu()
def process_query(query: str, cache_data: dict, llm) -> str:
MAX_ATTEMPTS = 5
SIMILARITY_THRESHOLD = 0.3
retriever = SentenceTransformerRetriever()
retriever.store_embeddings(cache_data['embeddings'])
documents = cache_data['documents']
for attempt in range(MAX_ATTEMPTS):
try:
print(f"\nAttempt {attempt + 1}/{MAX_ATTEMPTS}")
query_embedding = retriever.encode([query])
indices, _ = retriever.search(query_embedding, k=10)
relevant_docs = [documents[idx] for idx in indices.tolist()]
context = "\n".join(relevant_docs)
prompt = f"""Context information is below in backticks:
```
{context}
```
Given the context above, please answer the following question:
{query}
If you cannot answer based on the context, mention politely that you don't know.
Answer in a paragraph format using only the context information.
Please don't repeat any part of this prompt in the answer. Feel free to use this information to improve the answer.
Please avoid repetition.
Answer:"""
start_time = time.time()
response = llm(
prompt,
max_tokens=512,
temperature=0.1,
top_p=0.1,
echo=False,
stop=["Question:", "\n\n"],
top_k=10,
repeat_penalty=1.1,
stream=False
)
print('Inference Time:', time.time() - start_time)
answer = response['choices'][0]['text'].strip()
if not answer or len(answer) < 2:
print(f"Got empty or too short response: '{answer}'. Retrying...")
continue
response_embedding = retriever.encode([answer])
response_similarity = F.cosine_similarity(query_embedding, response_embedding)
response_score = response_similarity.item()
print(f"Response relevance score: {response_score:.3f}")
if response_score < SIMILARITY_THRESHOLD:
print(f"Response relevance {response_score:.3f} below threshold {SIMILARITY_THRESHOLD}. Retrying...")
continue
print(f"Successful response generated on attempt {attempt + 1}")
return answer
except Exception as e:
print(f"Error on attempt {attempt + 1}: {str(e)}")
continue
return "I apologize, but after multiple attempts, I was unable to generate a satisfactory response. Please try rephrasing your question."
class SportsChatbot:
def __init__(self):
self.cache_filename = "embeddings_2296.pkl"
self.cache_drive_link = "https://drive.google.com/uc?id=1LuJdnwe99C0EgvJpyfHYCKzUvj94FWlC"
self.cache_data = None
self.llm = None
self.initialize_pipeline()
def initialize_pipeline(self):
try:
# Initialize retriever and load cache
retriever = SentenceTransformerRetriever()
self.cache_data = retriever.load_specific_cache(self.cache_filename, self.cache_drive_link)
# Initialize LLM on Beam
self.llm = initialize_llm()
except Exception as e:
raise Exception(f"Error initializing the application: {str(e)}")
def process_question(self, question: str, progress=gr.Progress()):
if not question.strip():
return "Please enter a question!"
progress(0, desc="Processing query...")
response = process_query(question, self.cache_data, self.llm)
return response
def create_demo():
chatbot = SportsChatbot()
with gr.Blocks(title="The Sport Chatbot") as demo:
gr.Markdown("# The Sport Chatbot")
gr.Markdown("### Using ESPN API")
gr.Markdown("Hey there! 👋 I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.")
gr.Markdown("Got any general questions? Feel free to ask—I'll do my best to provide answers based on the information I've been trained on!")
with gr.Row():
question_input = gr.Textbox(
label="Enter your question:",
placeholder="Type your sports-related question here...",
lines=2
)
with gr.Row():
submit_btn = gr.Button("Get Answer", variant="primary")
with gr.Row():
answer_output = gr.Markdown(label="Answer")
submit_btn.click(
fn=chatbot.process_question,
inputs=question_input,
outputs=answer_output,
api_name="answer_question"
)
gr.Examples(
examples=[
"Who won the NBA championship in 2023?",
"What are the basic rules of ice hockey?",
"Tell me about the NFL playoffs format.",
],
inputs=question_input,
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()