|
import os, streamlit as st |
|
import spaces |
|
import langchain |
|
from langchain.document_loaders import UnstructuredURLLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.chains import RetrievalQAWithSourcesChain |
|
from langchain.llms import OpenAI |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
|
|
|
|
|
|
|
os.environ['OPENAI_API_KEY'] = os.getenv('openaiapi') |
|
os.environ['GOOGLE_API_KEY'] = os.getenv('geminiapi') |
|
llm_openai = OpenAI(temperature=0.7, max_tokens=300) |
|
llm_gemini = ChatGoogleGenerativeAI(model="gemini-pro") |
|
|
|
|
|
st.title("URL Research Tool") |
|
|
|
model_selection = st.radio(label='Choose LLM👇', options=['Gemini','OpenAI']) |
|
|
|
st.write(f"Selected Model: :rainbow[{model_selection}]") |
|
|
|
st.sidebar.title("Enter URLs:") |
|
no_of_sidebars = 3 |
|
urls = [] |
|
file_name = 'all_url_data_vectors' |
|
|
|
for i in range(no_of_sidebars): |
|
url = st.sidebar.text_input(f"URL {i+1}") |
|
urls.append(url) |
|
|
|
query_placeholder = st.empty() |
|
user_query = query_placeholder.text_input("Question: ") |
|
query_button = st.button("Submit Query") |
|
progress_placeholder = st.empty() |
|
|
|
|
|
if query_button: |
|
progress_placeholder.text("Work in Progress...") |
|
|
|
|
|
url_loader = UnstructuredURLLoader(urls=urls) |
|
url_data = url_loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
separators=['\n\n', '\n', '.', ' '], |
|
chunk_size=1000, |
|
) |
|
progress_placeholder.text("Work in Progress: Text Splitting") |
|
chunked_url_data = text_splitter.split_documents(url_data) |
|
|
|
|
|
if model_selection=="OpenAI": |
|
selected_model = llm_openai |
|
embedding_creator = OpenAIEmbeddings() |
|
else: |
|
selected_model = llm_gemini |
|
embedding_creator = GoogleGenerativeAIEmbeddings(model="models/embedding-001") |
|
|
|
progress_placeholder.text("Work in Progress: Creating Embeddings") |
|
data_vectors = FAISS.from_documents(chunked_url_data, embedding_creator) |
|
|
|
data_vectors.save_local(file_name) |
|
|
|
if os.path.exists(file_name): |
|
progress_placeholder.text("Work in Progress: Loading Results") |
|
|
|
data_vectors_loaded = FAISS.load_local(file_name, embedding_creator, allow_dangerous_deserialization=True) |
|
|
|
main_chain = RetrievalQAWithSourcesChain.from_llm(llm=selected_model, retriever=data_vectors_loaded.as_retriever()) |
|
llm_result = main_chain({'question': user_query}) |
|
progress_placeholder.text("Task Completed: Displaying Results") |
|
st.header('Answer:') |
|
|
|
st.write(llm_result['answer']) |
|
|
|
answer_sources = llm_result.get('sources','') |
|
if answer_sources: |
|
answer_sources_list = answer_sources.split('\n') |
|
st.subheader('Sources:') |
|
for source in answer_sources_list: |
|
st.write(source) |
|
|