Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -73,9 +73,12 @@ def chunk_text(text: str, max_chunk_size: int = 1000) :
|
|
73 |
|
74 |
# Streamlit UI
|
75 |
st.title("CUDA Documentation QA System")
|
|
|
76 |
# Initialize global variables
|
77 |
-
vector_store
|
78 |
-
|
|
|
|
|
79 |
|
80 |
# Crawling and processing the data
|
81 |
if st.button('Crawl CUDA Documentation'):
|
@@ -94,12 +97,13 @@ if st.button('Crawl CUDA Documentation'):
|
|
94 |
model_kwargs={'device': 'cpu'})
|
95 |
|
96 |
# Store embeddings in FAISS
|
97 |
-
vector_store = FAISS.from_texts(texts, embeddings)
|
|
|
98 |
st.write("Embeddings stored in FAISS.")
|
99 |
|
100 |
# Asking questions
|
101 |
query = st.text_input("Enter your question about CUDA:")
|
102 |
-
if query:
|
103 |
with st.spinner('Searching for an answer...'):
|
104 |
# Initialize Google Generative AI
|
105 |
llm = GoogleGenerativeAI(model='gemini-1.0-pro', google_api_key="AIzaSyC1AvHnvobbycU8XSCXh-gRq3DUfG0EP98")
|
@@ -112,7 +116,7 @@ if query:
|
|
112 |
# Create the retrieval QA chain
|
113 |
qa_chain = RetrievalQA.from_chain_type(
|
114 |
chain_type="map_rerank",
|
115 |
-
retriever=vector_store.as_retriever(),
|
116 |
combine_documents_chain=qa_prompt,
|
117 |
llm=llm
|
118 |
)
|
@@ -122,3 +126,6 @@ if query:
|
|
122 |
st.write(response['answer'])
|
123 |
st.write("**Source:**")
|
124 |
st.write(response['source'])
|
|
|
|
|
|
|
|
73 |
|
74 |
# Streamlit UI
|
75 |
st.title("CUDA Documentation QA System")
|
76 |
+
|
77 |
# Initialize global variables
|
78 |
+
if 'vector_store' not in st.session_state:
|
79 |
+
st.session_state.vector_store = None
|
80 |
+
if 'documents_loaded' not in st.session_state:
|
81 |
+
st.session_state.documents_loaded = False
|
82 |
|
83 |
# Crawling and processing the data
|
84 |
if st.button('Crawl CUDA Documentation'):
|
|
|
97 |
model_kwargs={'device': 'cpu'})
|
98 |
|
99 |
# Store embeddings in FAISS
|
100 |
+
st.session_state.vector_store = FAISS.from_texts(texts, embeddings)
|
101 |
+
st.session_state.documents_loaded = True
|
102 |
st.write("Embeddings stored in FAISS.")
|
103 |
|
104 |
# Asking questions
|
105 |
query = st.text_input("Enter your question about CUDA:")
|
106 |
+
if query and st.session_state.documents_loaded:
|
107 |
with st.spinner('Searching for an answer...'):
|
108 |
# Initialize Google Generative AI
|
109 |
llm = GoogleGenerativeAI(model='gemini-1.0-pro', google_api_key="AIzaSyC1AvHnvobbycU8XSCXh-gRq3DUfG0EP98")
|
|
|
116 |
# Create the retrieval QA chain
|
117 |
qa_chain = RetrievalQA.from_chain_type(
|
118 |
chain_type="map_rerank",
|
119 |
+
retriever=st.session_state.vector_store.as_retriever(),
|
120 |
combine_documents_chain=qa_prompt,
|
121 |
llm=llm
|
122 |
)
|
|
|
126 |
st.write(response['answer'])
|
127 |
st.write("**Source:**")
|
128 |
st.write(response['source'])
|
129 |
+
elif query:
|
130 |
+
st.warning("Please crawl the CUDA documentation first.")
|
131 |
+
|