Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn.functional as F | |
from sentence_transformers import SentenceTransformer | |
import pickle | |
import gradio as gr | |
import gdown | |
import requests | |
from llama_cpp import Llama | |
from tqdm import tqdm | |
import time | |
from functools import lru_cache | |
# Beam deployment configuration | |
beam.init( | |
name="sports-chatbot", | |
python_packages=[ | |
"torch", | |
"sentence-transformers", | |
"llama-cpp-python", | |
"gradio", | |
"gdown", | |
"tqdm" | |
] | |
) | |
def initialize_llm(): | |
from llama_cpp import Llama | |
model_path = "models/mistral-7b-v0.1.Q4_K_M.gguf" | |
if not os.path.exists(model_path): | |
os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
download_file_with_progress(direct_url, model_path) | |
return Llama( | |
model_path=model_path, | |
n_ctx=2048, | |
n_threads=8, | |
n_batch=8, | |
n_gpu_layers=0, # CPU only | |
verbose=False, | |
rope_freq_scale=0.5, | |
seed=42 | |
) | |
def download_file_with_progress(url: str, filename: str): | |
"""Download a file with progress bar using requests""" | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
with open(filename, 'wb') as file, tqdm( | |
desc=filename, | |
total=total_size, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as progress_bar: | |
for data in response.iter_content(chunk_size=1024): | |
size = file.write(data) | |
progress_bar.update(size) | |
class SentenceTransformerRetriever: | |
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"): | |
self.device = torch.device("cpu") | |
self.model = SentenceTransformer(model_name, device=str(self.device)) | |
self.doc_embeddings = None | |
self.cache_dir = cache_dir | |
os.makedirs(cache_dir, exist_ok=True) | |
def load_specific_cache(self, cache_filename: str, drive_link: str) -> dict: | |
cache_path = os.path.join(self.cache_dir, cache_filename) | |
if not os.path.exists(cache_path): | |
print(f"Cache file not found. Downloading from Google Drive...") | |
try: | |
gdown.download(drive_link, cache_path, quiet=False) | |
except Exception as e: | |
raise Exception(f"Failed to download cache file: {str(e)}") | |
if not os.path.exists(cache_path): | |
raise FileNotFoundError(f"Failed to download cache file to {cache_path}") | |
print(f"Loading cache from: {cache_path}") | |
with open(cache_path, 'rb') as f: | |
return pickle.load(f) | |
def encode(self, texts: list) -> torch.Tensor: | |
embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=True) | |
return F.normalize(embeddings, p=2, dim=1) | |
def store_embeddings(self, embeddings: torch.Tensor): | |
self.doc_embeddings = embeddings | |
def search(self, query_embedding: torch.Tensor, k: int): | |
if self.doc_embeddings is None: | |
raise ValueError("No document embeddings stored!") | |
similarities = F.cosine_similarity(query_embedding, self.doc_embeddings) | |
scores, indices = torch.topk(similarities, k=min(k, similarities.shape[0])) | |
return indices.cpu(), scores.cpu() | |
def process_query(query: str, cache_data: dict, llm) -> str: | |
MAX_ATTEMPTS = 5 | |
SIMILARITY_THRESHOLD = 0.3 | |
retriever = SentenceTransformerRetriever() | |
retriever.store_embeddings(cache_data['embeddings']) | |
documents = cache_data['documents'] | |
for attempt in range(MAX_ATTEMPTS): | |
try: | |
print(f"\nAttempt {attempt + 1}/{MAX_ATTEMPTS}") | |
query_embedding = retriever.encode([query]) | |
indices, _ = retriever.search(query_embedding, k=10) | |
relevant_docs = [documents[idx] for idx in indices.tolist()] | |
context = "\n".join(relevant_docs) | |
prompt = f"""Context information is below in backticks: | |
``` | |
{context} | |
``` | |
Given the context above, please answer the following question: | |
{query} | |
If you cannot answer based on the context, mention politely that you don't know. | |
Answer in a paragraph format using only the context information. | |
Please don't repeat any part of this prompt in the answer. Feel free to use this information to improve the answer. | |
Please avoid repetition. | |
Answer:""" | |
start_time = time.time() | |
response = llm( | |
prompt, | |
max_tokens=512, | |
temperature=0.1, | |
top_p=0.1, | |
echo=False, | |
stop=["Question:", "\n\n"], | |
top_k=10, | |
repeat_penalty=1.1, | |
stream=False | |
) | |
print('Inference Time:', time.time() - start_time) | |
answer = response['choices'][0]['text'].strip() | |
if not answer or len(answer) < 2: | |
print(f"Got empty or too short response: '{answer}'. Retrying...") | |
continue | |
response_embedding = retriever.encode([answer]) | |
response_similarity = F.cosine_similarity(query_embedding, response_embedding) | |
response_score = response_similarity.item() | |
print(f"Response relevance score: {response_score:.3f}") | |
if response_score < SIMILARITY_THRESHOLD: | |
print(f"Response relevance {response_score:.3f} below threshold {SIMILARITY_THRESHOLD}. Retrying...") | |
continue | |
print(f"Successful response generated on attempt {attempt + 1}") | |
return answer | |
except Exception as e: | |
print(f"Error on attempt {attempt + 1}: {str(e)}") | |
continue | |
return "I apologize, but after multiple attempts, I was unable to generate a satisfactory response. Please try rephrasing your question." | |
class SportsChatbot: | |
def __init__(self): | |
self.cache_filename = "embeddings_2296.pkl" | |
self.cache_drive_link = "https://drive.google.com/uc?id=1LuJdnwe99C0EgvJpyfHYCKzUvj94FWlC" | |
self.cache_data = None | |
self.llm = None | |
self.initialize_pipeline() | |
def initialize_pipeline(self): | |
try: | |
# Initialize retriever and load cache | |
retriever = SentenceTransformerRetriever() | |
self.cache_data = retriever.load_specific_cache(self.cache_filename, self.cache_drive_link) | |
# Initialize LLM on Beam | |
self.llm = initialize_llm() | |
except Exception as e: | |
raise Exception(f"Error initializing the application: {str(e)}") | |
def process_question(self, question: str, progress=gr.Progress()): | |
if not question.strip(): | |
return "Please enter a question!" | |
progress(0, desc="Processing query...") | |
response = process_query(question, self.cache_data, self.llm) | |
return response | |
def create_demo(): | |
chatbot = SportsChatbot() | |
with gr.Blocks(title="The Sport Chatbot") as demo: | |
gr.Markdown("# The Sport Chatbot") | |
gr.Markdown("### Using ESPN API") | |
gr.Markdown("Hey there! 👋 I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.") | |
gr.Markdown("Got any general questions? Feel free to ask—I'll do my best to provide answers based on the information I've been trained on!") | |
with gr.Row(): | |
question_input = gr.Textbox( | |
label="Enter your question:", | |
placeholder="Type your sports-related question here...", | |
lines=2 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Get Answer", variant="primary") | |
with gr.Row(): | |
answer_output = gr.Markdown(label="Answer") | |
submit_btn.click( | |
fn=chatbot.process_question, | |
inputs=question_input, | |
outputs=answer_output, | |
api_name="answer_question" | |
) | |
gr.Examples( | |
examples=[ | |
"Who won the NBA championship in 2023?", | |
"What are the basic rules of ice hockey?", | |
"Tell me about the NFL playoffs format.", | |
], | |
inputs=question_input, | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch() |