pankajsingh3012 commited on
Commit
325521f
·
verified ·
1 Parent(s): 6ef000b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -73,9 +73,12 @@ def chunk_text(text: str, max_chunk_size: int = 1000) :
73
 
74
  # Streamlit UI
75
  st.title("CUDA Documentation QA System")
 
76
  # Initialize global variables
77
- vector_store = None
78
- documents_loaded = False
 
 
79
 
80
  # Crawling and processing the data
81
  if st.button('Crawl CUDA Documentation'):
@@ -94,12 +97,13 @@ if st.button('Crawl CUDA Documentation'):
94
  model_kwargs={'device': 'cpu'})
95
 
96
  # Store embeddings in FAISS
97
- vector_store = FAISS.from_texts(texts, embeddings)
 
98
  st.write("Embeddings stored in FAISS.")
99
 
100
  # Asking questions
101
  query = st.text_input("Enter your question about CUDA:")
102
- if query:
103
  with st.spinner('Searching for an answer...'):
104
  # Initialize Google Generative AI
105
  llm = GoogleGenerativeAI(model='gemini-1.0-pro', google_api_key="AIzaSyC1AvHnvobbycU8XSCXh-gRq3DUfG0EP98")
@@ -112,7 +116,7 @@ if query:
112
  # Create the retrieval QA chain
113
  qa_chain = RetrievalQA.from_chain_type(
114
  chain_type="map_rerank",
115
- retriever=vector_store.as_retriever(),
116
  combine_documents_chain=qa_prompt,
117
  llm=llm
118
  )
@@ -122,3 +126,6 @@ if query:
122
  st.write(response['answer'])
123
  st.write("**Source:**")
124
  st.write(response['source'])
 
 
 
 
73
 
74
  # Streamlit UI
75
  st.title("CUDA Documentation QA System")
76
+
77
  # Initialize global variables
78
+ if 'vector_store' not in st.session_state:
79
+ st.session_state.vector_store = None
80
+ if 'documents_loaded' not in st.session_state:
81
+ st.session_state.documents_loaded = False
82
 
83
  # Crawling and processing the data
84
  if st.button('Crawl CUDA Documentation'):
 
97
  model_kwargs={'device': 'cpu'})
98
 
99
  # Store embeddings in FAISS
100
+ st.session_state.vector_store = FAISS.from_texts(texts, embeddings)
101
+ st.session_state.documents_loaded = True
102
  st.write("Embeddings stored in FAISS.")
103
 
104
  # Asking questions
105
  query = st.text_input("Enter your question about CUDA:")
106
+ if query and st.session_state.documents_loaded:
107
  with st.spinner('Searching for an answer...'):
108
  # Initialize Google Generative AI
109
  llm = GoogleGenerativeAI(model='gemini-1.0-pro', google_api_key="AIzaSyC1AvHnvobbycU8XSCXh-gRq3DUfG0EP98")
 
116
  # Create the retrieval QA chain
117
  qa_chain = RetrievalQA.from_chain_type(
118
  chain_type="map_rerank",
119
+ retriever=st.session_state.vector_store.as_retriever(),
120
  combine_documents_chain=qa_prompt,
121
  llm=llm
122
  )
 
126
  st.write(response['answer'])
127
  st.write("**Source:**")
128
  st.write(response['source'])
129
+ elif query:
130
+ st.warning("Please crawl the CUDA documentation first.")
131
+