Spaces:
Runtime error
Runtime error
# __import__('pysqlite3') | |
# import sys | |
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | |
# modify comment | |
import streamlit as st | |
from langchain.llms import Ollama | |
import os | |
import chromadb | |
from constants import CHROMA_SETTINGS | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from chromadb.config import Settings | |
from langchain.vectorstores import Chroma | |
from langchain.llms import Ollama | |
from langchain.callbacks.base import BaseCallbackHandler | |
# Custom streamlit handler to display LLM outputs in stream mode | |
class StreamHandler(BaseCallbackHandler): | |
def __init__(self, container, initial_text=""): | |
self.container = container | |
self.text=initial_text | |
def on_llm_new_token(self, token: str, **kwargs) -> None: | |
self.text+=token+"" | |
self.container.markdown(self.text) | |
# streamlit UI configuration | |
def setup_page(): | |
st.set_page_config(layout="wide") | |
st.markdown("<h2 style='text-align: center; color: white;'>Your Personal MBA </h2>" , unsafe_allow_html=True) | |
url = 'https://personalmba.com/' | |
col1, col2, col3= st.columns(3) | |
with col2: | |
st.markdown(""" | |
<div style="text-align: center;"> | |
<h5 style='color: white;'>Inspired by </h5> | |
<a href="%s">The Personal MBA by Josh Kaufman</a> | |
</div> | |
""" % url, unsafe_allow_html=True) | |
st.divider() | |
# get necessary environment variables for later use | |
def get_environment_variables(): | |
model = os.environ.get("MODEL", "mistral") | |
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME", "all-MiniLM-L6-v2") | |
persist_directory = os.environ.get("PERSIST_DIRECTORY", "db") | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS', 4)) | |
return model, embeddings_model_name, persist_directory, target_source_chunks | |
# create knowledge base retriever | |
def create_knowledge_base(embeddings_model_name, persist_directory, target_source_chunks): | |
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings) | |
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) | |
return retriever | |
# handle query when user hit 'enter' on a question | |
def handle_query(query, model, retriever): | |
with st.chat_message('assistant'): | |
with st.spinner("Generating answer..."): | |
message_placeholder = st.empty() | |
stream_handler = StreamHandler(message_placeholder) | |
llm = Ollama(model=model, callbacks=[stream_handler]) | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False) | |
res = qa(query) | |
answer = res['result'] | |
message_placeholder.markdown(answer) | |
return answer | |
# dictionary to store the previous messages, create a 'memory' for the LLM | |
def initialize_session(): | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# display the messages | |
def display_messages(): | |
for message in st.session_state.messages: | |
with st.chat_message(message['role']): | |
st.markdown(message['content']) | |
# example questions when user first load up the app. Will disappear after user send the first query | |
def show_examples(): | |
examples = st.empty() | |
with examples.container(): | |
with st.chat_message('assistant'): | |
st.markdown('Example questions:') | |
st.markdown(' - How do I know that I am making the right decisions?') | |
st.markdown(' - What are the key ideas in Chapter 6 "The Human Mind"?') | |
st.markdown(' - What are common traits shared by the most sucessful individuals in the world?') | |
st.markdown(' - I want to be a millionaire, build me a 5 year roadmap based on the top 0.01 percent of the human population.') | |
st.markdown('So, how may I help you today?') | |
return examples | |
def main(): | |
setup_page() | |
initialize_session() | |
display_messages() | |
examples = show_examples() | |
model, embeddings_model_name, persist_directory, target_source_chunks = get_environment_variables() | |
retriever = create_knowledge_base(embeddings_model_name, persist_directory, target_source_chunks) | |
query = st.chat_input(placeholder='Ask a question...') # starting with empty query | |
if query: # if user input a query and hit 'Enter' | |
examples.empty() | |
st.session_state.messages.append({ # add the query into session state/ dictionary | |
'role': 'user', | |
'content': query | |
}) | |
with st.chat_message('user'): | |
st.markdown(query) | |
answer = handle_query(query, model, retriever) | |
st.session_state.messages.append({ # add the answer into session state/ dictionary | |
'role': 'assistant', | |
'content': answer | |
}) | |
# display_messages() | |
if __name__ == "__main__": | |
main() |