Huy0502 commited on
Commit
ea8b3bf
·
verified ·
1 Parent(s): c8846e1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_community.chat_models import ChatOpenAI
3
+
4
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain_community.embeddings import OpenAIEmbeddings
7
+ from langchain_community.vectorstores import Chroma
8
+ from langchain.chains import ConversationalRetrievalChain
9
+
10
+ import streamlit as st
11
+ from streamlit_chat import message
12
+
13
+
14
+ @st.cache_data()
15
+ def load_docs():
16
+ documents = []
17
+ for file in os.listdir('docs'):
18
+ if file.endswith('.pdf'):
19
+ pdf_path = "./docs/"+file
20
+ loader = PyPDFLoader(pdf_path)
21
+ documents.extend(loader.load())
22
+ elif file.endswith('.docx') or file.endswith('.doc'):
23
+ doc_path = './docs/'+file
24
+ loader = Docx2txtLoader(doc_path)
25
+ documents.extend(loader.load())
26
+ elif file.endswith('.txt'):
27
+ text_path = '.docs/'+file
28
+ loader = TextLoader(text_path)
29
+ documents.extend(loader.load())
30
+
31
+ return documents
32
+
33
+ os.environ["OPENAI_API_KEY"] = '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'
34
+
35
+ llm_model = "gpt-3.5-turbo"
36
+ llm = ChatOpenAI(temperature=.7, model=llm_model)
37
+ #======================================================================================================================
38
+ # Load documents
39
+ documents = load_docs()
40
+ chat_history = []
41
+
42
+ # 1. Text splitter
43
+ text_splitter = CharacterTextSplitter(
44
+ chunk_size = 100,
45
+ chunk_overlap = 20,
46
+ length_function = len
47
+ )
48
+
49
+ # 2. Embedding
50
+ embeddings = OpenAIEmbeddings()
51
+
52
+ docs = text_splitter.split_documents(documents)
53
+
54
+ #=====================================================================================================================
55
+ # 3. Storage
56
+ vector_store = Chroma.from_documents(
57
+ documents=docs,
58
+ embedding=embeddings,
59
+ persist_directory='./data'
60
+ )
61
+ vector_store.persist()
62
+ # ====================================================================================================================
63
+ # 4. Retrieve
64
+ retriever = vector_store.as_retriever(search_kwargs={"k":6})
65
+ # docs = retriever.get_relevant_documents("Tell me more about Data Science")
66
+
67
+ # Make a chain to answer questions
68
+ qa_chain = ConversationalRetrievalChain.from_llm(
69
+ llm,
70
+ vector_store.as_retriever(search_kwargs={'k':6}),
71
+ return_source_documents=True,
72
+ verbose=False
73
+ )
74
+
75
+
76
+ # cite sources - helper function to prettyfy responses
77
+ def process_llm_response(llm_response):
78
+ print(llm_response['result'])
79
+ print('\n\nSources:')
80
+ for source in llm_response['source_documents']:
81
+ print(source.metadata['source'])
82
+
83
+ #==============================FRONTEND=======================================
84
+ st.title("ViTo chatbot👠")
85
+ st.header("Ask anything about ViTo company...")
86
+
87
+ if 'generated' not in st.session_state:
88
+ st.session_state['generated'] = []
89
+
90
+ if 'past' not in st.session_state:
91
+ st.session_state['past'] = []
92
+
93
+ def get_query():
94
+ input_text = st.chat_input("Ask a question about your documents...")
95
+ return input_text
96
+
97
+ # retrieve the user input
98
+ user_input = get_query()
99
+ if user_input:
100
+ result = qa_chain({'question': user_input, 'chat_history': chat_history})
101
+ st.session_state.past.append(user_input)
102
+ st.session_state.generated.append(result['answer'])
103
+
104
+ if st.session_state['generated']:
105
+ for i in range(len(st.session_state['generated'])):
106
+ message(st.session_state['past'][i], is_user=True, key=str(i)+'_user')
107
+ message(st.session_state['generated'][i], key=str(i))
108
+