Spaces:
Paused
Paused
| import os | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.chains import ConversationalRetrievalChain | |
| import streamlit as st | |
| from streamlit_chat import message | |
| def load_docs(): | |
| documents = [] | |
| for file in os.listdir('docs'): | |
| if file.endswith('.pdf'): | |
| pdf_path = "./docs/"+file | |
| loader = PyPDFLoader(pdf_path) | |
| documents.extend(loader.load()) | |
| elif file.endswith('.docx') or file.endswith('.doc'): | |
| doc_path = './docs/'+file | |
| loader = Docx2txtLoader(doc_path) | |
| documents.extend(loader.load()) | |
| elif file.endswith('.txt'): | |
| text_path = '.docs/'+file | |
| loader = TextLoader(text_path) | |
| documents.extend(loader.load()) | |
| return documents | |
| os.environ["OPENAI_API_KEY"] = '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' | |
| llm_model = "gpt-3.5-turbo" | |
| llm = ChatOpenAI(temperature=.7, model=llm_model) | |
| #====================================================================================================================== | |
| # Load documents | |
| documents = load_docs() | |
| chat_history = [] | |
| # 1. Text splitter | |
| text_splitter = CharacterTextSplitter( | |
| chunk_size = 100, | |
| chunk_overlap = 20, | |
| length_function = len | |
| ) | |
| # 2. Embedding | |
| embeddings = OpenAIEmbeddings() | |
| docs = text_splitter.split_documents(documents) | |
| #===================================================================================================================== | |
| # 3. Storage | |
| vector_store = Chroma.from_documents( | |
| documents=docs, | |
| embedding=embeddings, | |
| persist_directory='./data' | |
| ) | |
| vector_store.persist() | |
| # ==================================================================================================================== | |
| # 4. Retrieve | |
| retriever = vector_store.as_retriever(search_kwargs={"k":6}) | |
| # docs = retriever.get_relevant_documents("Tell me more about Data Science") | |
| # Make a chain to answer questions | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm, | |
| vector_store.as_retriever(search_kwargs={'k':6}), | |
| return_source_documents=True, | |
| verbose=False | |
| ) | |
| # cite sources - helper function to prettyfy responses | |
| def process_llm_response(llm_response): | |
| print(llm_response['result']) | |
| print('\n\nSources:') | |
| for source in llm_response['source_documents']: | |
| print(source.metadata['source']) | |
| #==============================FRONTEND======================================= | |
| st.title("ViTo chatbot👠") | |
| st.header("Ask anything about ViTo company...") | |
| if 'generated' not in st.session_state: | |
| st.session_state['generated'] = [] | |
| if 'past' not in st.session_state: | |
| st.session_state['past'] = [] | |
| def get_query(): | |
| input_text = st.chat_input("Ask a question about your documents...") | |
| return input_text | |
| # retrieve the user input | |
| user_input = get_query() | |
| if user_input: | |
| result = qa_chain({'question': user_input, 'chat_history': chat_history}) | |
| st.session_state.past.append(user_input) | |
| st.session_state.generated.append(result['answer']) | |
| if st.session_state['generated']: | |
| for i in range(len(st.session_state['generated'])): | |
| message(st.session_state['past'][i], is_user=True, key=str(i)+'_user') | |
| message(st.session_state['generated'][i], key=str(i)) | |