Spaces:
Running
Running
import os | |
import gradio as gr | |
import cohere | |
from typing import Generator | |
from langchain_chroma import Chroma | |
# from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain.schema.document import Document | |
from typing import List | |
class HFSpaceChatBot: | |
""" | |
A chatbot powered by Retrieval Augmented Generation (RAG) aimed | |
to be deployed on the Hugging Face Space platform. | |
""" | |
def __init__(self, | |
embedding_model_path: str, | |
vector_database_path: str, | |
top_k: int = 10, | |
embedding_model_name: str = os.getenv("EMBEDDING_MODEL"), | |
api_key: str = os.getenv("CO_API_KEY"), | |
device: str = os.getenv("DEVICE"), | |
system_prompt: str = "Answer the user's question", | |
**kwargs) -> None: | |
""" | |
Constructor for the HFSpaceChatBot class. | |
Args: | |
embedding_model_path (str): The path to the embedding model. | |
vector_database_path (str): The path to the vector database. | |
top_k (int): The number of top documents to retrieve. | |
embedding_model_name (str): The name of the embedding model. | |
api_key (str): The API key for the cohere API. | |
device (str): The device to run the model on. | |
system_prompt (str): The system prompt for the chatbot. | |
**kwargs: Additional keyword arguments (for the cohere API) | |
""" | |
self.chat_history = [] | |
self.cclient = cohere.Client(api_key=api_key) | |
self.embedding_model = HuggingFaceEmbeddings( | |
model_name=embedding_model_name, | |
model_kwargs={"device": device}, | |
encode_kwargs={"normalize_embeddings": True}, | |
cache_folder=embedding_model_path | |
) | |
self.vector_database = Chroma( | |
persist_directory=vector_database_path, | |
embedding_function=self.embedding_model | |
) | |
self.top_k = top_k | |
self.system_prompt = system_prompt | |
self.model_params = kwargs | |
def _get_relevant_information(self, | |
message: str) -> List[Document]: | |
""" | |
Get the relevant information from the chat history. | |
Args: | |
message (str): The message to search for. | |
Returns: | |
List[Document]: A list of relevant documents. | |
""" | |
return self.vector_database.similarity_search(message, self.top_k) | |
def _fetch_response(self, | |
message: str, | |
*args) -> Generator[str, None, None]: | |
""" | |
Fetch the reponse from the cohere API. | |
Args: | |
message (str): The message of the user. | |
Returns: | |
Generator[str, None, None]: A generator yielding the output tokens. | |
""" | |
docs = self._get_relevant_information(message) | |
relevant_information = "\n".join( | |
[doc.page_content | |
for doc in docs]) | |
final_message = f"{self.system_prompt}\nWith the help of the\ | |
following context:\n{relevant_information}\n\ | |
Answer the following question:\n{message}" | |
response = self.cclient.chat_stream( | |
message=final_message, | |
chat_history=self.chat_history, | |
**self.model_params | |
) | |
current_text = "" | |
for event in response: | |
if event.event_type == "text-generation": | |
current_text += event.text | |
yield current_text | |
self.chat_history.append({ | |
"role": "USER", | |
"text": message | |
}) | |
self.chat_history.append({ | |
"role": "CHATBOT", | |
"text": current_text | |
}) | |
def launch(self, | |
title: str, | |
description: str) -> None: | |
""" | |
Launch the chat interface. | |
Args: | |
title (str): The title of the chat interface. | |
description (str): The description of the chat interface. | |
""" | |
gr.ChatInterface( | |
fn=self._fetch_response, | |
title=title, | |
description=description | |
).launch() | |
# if __name__ == "__main__": | |
embedding_model_path = os.path.join(os.getcwd(), "model") | |
system_prompt = """You are now assuming the role of the personal assistant | |
of Ilan ALIOUCHOUCHE, a French Computer Science student. | |
Your task is to assist users by answering their | |
questions about Ilan. You have access to comprehensive | |
details about Ilan's education, skills, professional | |
experience, and interests. Don't be too chatty, and | |
make sure to provide accurate and relevant information. | |
""" | |
chatbot = HFSpaceChatBot( | |
embedding_model_path=embedding_model_path, | |
vector_database_path=os.path.join(os.getcwd(), "chromadb"), | |
system_prompt=system_prompt, | |
temperature=0.0001 | |
) | |
title = "🤖 Ilan's Personal Agent 🤖" | |
description = """ | |
You can ask my assistant (almost) anything about me! :D | |
You are currently using the Hugging Face Space version 🤗. A Docker image 🐳 for local use, utilizing a GGUF model is also available [here](https://github.com/ilanaliouchouche/my-ai-cv/pkgs/container/my-cv) | |
""" # noqa E501 | |
chatbot.launch( | |
title=title, | |
description=description | |
) | |