Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,144 +1,146 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import os
|
3 |
-
from dotenv import load_dotenv
|
4 |
-
import time
|
5 |
-
from langchain.vectorstores import Chroma
|
6 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
-
from langchain_core.prompts import ChatPromptTemplate
|
8 |
-
from langchain_groq import ChatGroq
|
9 |
-
from langchain.chains import RetrievalQA
|
10 |
-
from langchain.document_loaders import PyPDFLoader
|
11 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
start_time
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
if "
|
51 |
-
st.session_state["
|
52 |
-
if "
|
53 |
-
st.session_state["
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
for
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
chunk_size=
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
# Load
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
response =
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
with
|
121 |
-
st.
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
with
|
127 |
-
st.
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
#
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
#
|
141 |
-
#
|
142 |
-
#
|
143 |
-
#
|
|
|
|
|
144 |
# st.success("Data reloaded successfully!")
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import time
|
5 |
+
from langchain.vectorstores import Chroma
|
6 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
8 |
+
from langchain_groq import ChatGroq
|
9 |
+
from langchain.chains import RetrievalQA
|
10 |
+
from langchain.document_loaders import PyPDFLoader
|
11 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
+
|
13 |
+
# Access the secret
|
14 |
+
api_key1 = os.getenv("api_key")
|
15 |
+
start_time = time.time()
|
16 |
+
# Set page title and icon
|
17 |
+
st.set_page_config(page_title="Dr. Radha: The Agro-Homeopath", page_icon="π", layout="wide")
|
18 |
+
|
19 |
+
# Center the title
|
20 |
+
st.markdown("""
|
21 |
+
<style>
|
22 |
+
#the-title {
|
23 |
+
text-align: center;
|
24 |
+
}
|
25 |
+
</style>
|
26 |
+
""", unsafe_allow_html=True)
|
27 |
+
|
28 |
+
# Display the title
|
29 |
+
st.title("π Ask Dr. Radha - World's First AI based Agro-Homeopathy Doctor")
|
30 |
+
|
31 |
+
# Load images
|
32 |
+
human_image = "human.png"
|
33 |
+
robot_image = "bot.png"
|
34 |
+
|
35 |
+
# Load environment variables
|
36 |
+
load_dotenv()
|
37 |
+
end_time = time.time()
|
38 |
+
print(f"Loading environment variables took {end_time - start_time:.4f} seconds")
|
39 |
+
|
40 |
+
start_time = time.time()
|
41 |
+
# Set up Groq API
|
42 |
+
llm = ChatGroq(api_key=api_key1, max_tokens=None, timeout=None, max_retries=2, temperature=0.5, model="llama-3.1-70b-versatile")
|
43 |
+
|
44 |
+
# Set up embeddings
|
45 |
+
embeddings = HuggingFaceEmbeddings()
|
46 |
+
end_time = time.time()
|
47 |
+
print(f"Setting up Groq LLM & Embeddings took {end_time - start_time:.4f} seconds")
|
48 |
+
|
49 |
+
# Initialize session state
|
50 |
+
if "documents" not in st.session_state:
|
51 |
+
st.session_state["documents"] = None
|
52 |
+
if "vector_db" not in st.session_state:
|
53 |
+
st.session_state["vector_db"] = None
|
54 |
+
if "query" not in st.session_state:
|
55 |
+
st.session_state["query"] = ""
|
56 |
+
|
57 |
+
def load_data():
|
58 |
+
pdf_folder = "docs"
|
59 |
+
loaders = [PyPDFLoader(os.path.join(pdf_folder, fn)) for fn in os.listdir(pdf_folder)]
|
60 |
+
documents = []
|
61 |
+
for loader in loaders:
|
62 |
+
documents.extend(loader.load())
|
63 |
+
|
64 |
+
#text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=30)
|
65 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
66 |
+
chunk_size=1000,
|
67 |
+
chunk_overlap=200,
|
68 |
+
length_function=len,
|
69 |
+
separators=["\n\n", "\n", " ", ""]
|
70 |
+
)
|
71 |
+
texts = text_splitter.split_documents(documents)
|
72 |
+
# Set up vector database
|
73 |
+
persist_directory = "db"
|
74 |
+
vector_db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
|
75 |
+
|
76 |
+
return documents, vector_db
|
77 |
+
|
78 |
+
# Load and process PDFs
|
79 |
+
start_time = time.time()
|
80 |
+
# Load data if not already loaded
|
81 |
+
if st.session_state["documents"] is None or st.session_state["vector_db"] is None:
|
82 |
+
with st.spinner("Loading data..."):
|
83 |
+
documents, vector_db = load_data()
|
84 |
+
st.session_state["documents"] = documents
|
85 |
+
st.session_state["vector_db"] = vector_db
|
86 |
+
else:
|
87 |
+
documents = st.session_state["documents"]
|
88 |
+
vector_db = st.session_state["vector_db"]
|
89 |
+
|
90 |
+
end_time = time.time()
|
91 |
+
print(f"Loading and processing PDFs & vector database took {end_time - start_time:.4f} seconds")
|
92 |
+
|
93 |
+
# Set up retrieval chain
|
94 |
+
start_time = time.time()
|
95 |
+
retriever = vector_db.as_retriever()
|
96 |
+
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
|
97 |
+
|
98 |
+
# Chat interface
|
99 |
+
chat_container = st.container()
|
100 |
+
|
101 |
+
# Create a form for the query input and submit button
|
102 |
+
with st.form(key='query_form'):
|
103 |
+
query = st.text_input("Ask your question:", value="")#st.session_state["query"])
|
104 |
+
submit_button = st.form_submit_button(label='Submit')
|
105 |
+
|
106 |
+
end_time = time.time()
|
107 |
+
print(f"Setting up retrieval chain took {end_time - start_time:.4f} seconds")
|
108 |
+
start_time = time.time()
|
109 |
+
|
110 |
+
if submit_button and query:
|
111 |
+
with st.spinner("Generating response..."):
|
112 |
+
result = qa({"query": query})
|
113 |
+
if result['result'].strip() == "":
|
114 |
+
response = "I apologize, but I don't have enough information in the provided PDFs to answer your question."
|
115 |
+
else:
|
116 |
+
response = result['result']
|
117 |
+
|
118 |
+
# Display human image and question
|
119 |
+
col1, col2 = st.columns([1, 10])
|
120 |
+
with col1:
|
121 |
+
st.image(human_image, width=80)
|
122 |
+
with col2:
|
123 |
+
st.markdown(f"{query}")
|
124 |
+
# Display robot image and answer
|
125 |
+
col1, col2 = st.columns([1, 10])
|
126 |
+
with col1:
|
127 |
+
st.image(robot_image, width=80)
|
128 |
+
with col2:
|
129 |
+
st.markdown(f"{response}")
|
130 |
+
|
131 |
+
st.markdown("---")
|
132 |
+
|
133 |
+
# Clear the query input
|
134 |
+
st.session_state["query"] = ""
|
135 |
+
#st.rerun()
|
136 |
+
|
137 |
+
end_time = time.time()
|
138 |
+
print(f"Actual query took {end_time - start_time:.4f} seconds")
|
139 |
+
|
140 |
+
# Reload data button
|
141 |
+
# if st.button("Reload Data"):
|
142 |
+
# with st.spinner("Reloading data..."):
|
143 |
+
# documents, vector_db = load_data()
|
144 |
+
# st.session_state["documents"] = documents
|
145 |
+
# st.session_state["vector_db"] = vector_db
|
146 |
# st.success("Data reloaded successfully!")
|