import streamlit as st | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain_google_genai import GoogleGenerativeAI | |
from langchain.prompts import PromptTemplate | |
#from langchain.chains import load_qa_chain, RetrievalQA | |
import requests | |
from bs4 import BeautifulSoup | |
from urllib.parse import urljoin | |
import re | |
from collections import deque | |
import time | |
import numpy as np | |
# Crawling function | |
def crawl(start_url: str, max_depth: int = 1, delay: float = 0.1) : | |
visited = set() | |
results = [] | |
queue = deque([(start_url, 0)]) | |
crawled_urls = [] | |
while queue: | |
url, depth = queue.popleft() | |
if depth > max_depth or url in visited: | |
continue | |
visited.add(url) | |
crawled_urls.append(url) | |
try: | |
time.sleep(delay) | |
response = requests.get(url) | |
soup = BeautifulSoup(response.text, 'html.parser') | |
text = soup.get_text() | |
text = re.sub(r'\s+', ' ', text).strip() | |
results.append((url, text)) | |
if depth < max_depth: | |
for link in soup.find_all('a', href=True): | |
next_url = urljoin(url, link['href']) | |
if next_url.startswith('') and next_url not in visited: | |
queue.append((next_url, depth + 1)) | |
if len(queue) > 10: | |
break | |
except Exception as e: | |
print(f"Error crawling {url}: {e}") | |
return results, crawled_urls | |
# Text chunking function | |
def chunk_text(text: str, max_chunk_size: int = 1000) : | |
chunks = [] | |
current_chunk = "" | |
for sentence in re.split(r'(?<=[.!?])\s+', text): | |
if len(current_chunk) + len(sentence) <= max_chunk_size: | |
current_chunk += sentence + " " | |
else: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence + " " | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
# Streamlit UI | |
st.title("CUDA Documentation QA System") | |
# Initialize global variables | |
if 'vector_store' not in st.session_state: | |
st.session_state.vector_store = None | |
if 'documents_loaded' not in st.session_state: | |
st.session_state.documents_loaded = False | |
# Crawling and processing the data | |
if st.button('Crawl CUDA Documentation'): | |
with st.spinner('Crawling CUDA documentation...'): | |
crawled_data, crawled_urls = crawl("", max_depth=1, delay=0.1) | |
st.write(f"Processed {len(crawled_data)} pages.") | |
texts = [] | |
for url, text in crawled_data: | |
chunks = chunk_text(text, max_chunk_size=1024) | |
texts.extend(chunks) | |
st.success("Crawling and processing completed.") | |
# Create embeddings | |
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2', model_kwargs={'device': 'cpu'}) | |
# Store embeddings in FAISS | |
st.session_state.vector_store = FAISS.from_texts(texts, embeddings) | |
st.session_state.documents_loaded = True | |
st.write("Embeddings stored in FAISS.") | |
# Asking questions | |
query = st.text_input("Enter your question about CUDA:") | |
if query and st.session_state.documents_loaded: | |
with st.spinner('Searching for an answer...'): | |
# Initialize Google Generative AI | |
llm = GoogleGenerativeAI(model='gemini-1.0-pro', google_api_key="AIzaSyC1AvHnvobbycU8XSCXh-gRq3DUfG0EP98") | |
#Create a PromptTemplate for the QA chain | |
qa_prompt = PromptTemplate(template="Answer the following question based on the context provided:\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:", input_variables=["context", "question"]) | |
# Create the retrieval QA chain | |
qa_chain = RetrievalQA.from_chain_type( | |
retriever=st.session_state.vector_store.as_retriever(), | |
#chain_type="stuff", | |
llm=llm, | |
#chain_type_kwargs={"prompt": qa_prompt} | |
) | |
response = qa_chain({"question": query}) | |
st.write("**Answer:**") | |
st.write(response['answer']) | |
st.write("**Source:**") | |
st.write(response['source']) | |
elif query: | |
st.warning("Please crawl the CUDA documentation first.") | |