AIDocChat / app.py
binqiangliu's picture
Update app.py
6e3ac46
raw
history blame
6.31 kB
import streamlit as st
from llama_index import VectorStoreIndex, SimpleDirectoryReader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LangchainEmbedding, ServiceContext
from llama_index import StorageContext, load_index_from_storage
from llama_index import LLMPredictor
#from transformers import HuggingFaceHub
from langchain import HuggingFaceHub
from streamlit.components.v1 import html
from pathlib import Path
from time import sleep
import random
import string
import time
import os
from dotenv import load_dotenv
load_dotenv()
import timeit
st.set_page_config(page_title="Open AI Doc-Chat Assistant", layout="wide")
st.subheader("Open AI Doc-Chat Assistant: Life Enhancing with AI!")
css_file = "main.css"
with open(css_file) as f:
st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
documents=[]
def generate_random_string(length):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for i in range(length))
#random_string = generate_random_string(20)
#directory_path=random_string
if "directory_path" not in st.session_state:
st.session_state.directory_path = generate_random_string(20)
if "pdf_files" not in st.session_state:
st.session_state.pdf_files = None
if "documents" not in st.session_state:
st.session_state.documents = None
if "embed_model" not in st.session_state:
st.session_state.embed_model = None
if "llm_predictor" not in st.session_state:
st.session_state.llm_predictor = None
if "service_context" not in st.session_state:
st.session_state.service_context = None
if "new_index" not in st.session_state:
st.session_state.new_index = None
if "storage_context" not in st.session_state:
st.session_state.storage_context = None
if "loadedindex" not in st.session_state:
st.session_state.loadedindex = None
if "query_engine" not in st.session_state:
st.session_state.query_engine = None
if "user_question " not in st.session_state:
st.session_state.user_question = ""
with st.sidebar:
st.subheader("Upload your Documents Here: ")
while not st.session_state.pdf_files:
st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
if not st.session_state.pdf_files:
st.warning("请上传文档文件")
time.sleep(1) # 等待1秒后再次检查上传文件
# 上传文件后的处理代码
if not os.path.exists(st.session_state.directory_path):
os.makedirs(st.session_state.directory_path)
for pdf_file in st.session_state.pdf_files:
file_path = os.path.join(st.session_state.directory_path, pdf_file.name)
with open(file_path, 'wb') as f:
f.write(pdf_file.read())
st.success(f"File '{pdf_file.name}' saved successfully.")
try:
start_1 = timeit.default_timer() # Start timer
st.write(f"QA文档加载开始:{start_1}")
st.session_state.documents = SimpleDirectoryReader(st.session_state.directory_path).load_data()
end_1 = timeit.default_timer() # Start timer
st.write(f"QA文档加载结束:{end_1}")
st.write(f"QA文档加载耗时:{end_1 - start_1}")
except Exception as e:
print("文档加载出现问题/Waiting for path creation.")
# Load documents from a directory
#documents = SimpleDirectoryReader('data').load_data()
start_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载开始:{start_2}")
st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
end_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载加载结束:{end_2}")
st.write(f"向量模型加载耗时:{end_2 - start_2}")
st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))
st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model)
start_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建开始:{start_3}")
st.session_state.new_index = VectorStoreIndex.from_documents(
st.session_state.documents,
service_context=st.session_state.service_context,
)
end_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建结束:{end_3}")
st.write(f"向量库构建耗时:{end_3 - start_3}")
st.session_state.new_index.storage_context.persist("st.session_state.directory_path")
st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path")
start_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载开始:{start_4}")
st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context)
end_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载结束:{end_4}")
st.write(f"向量库装载耗时:{end_4 - start_4}")
st.session_state.query_engine = st.session_state.loadedindex.as_query_engine()
st.session_state.user_question=st.text_input("Enter your query:")
if st.session_state.user_question !="" and not st.session_state.user_question.strip().isspace() and not st.session_state.user_question == "" and not st.session_state.user_question.strip() == "" and not st.session_state.user_question.isspace():
print("user question: "+st.session_state.user_question)
with st.spinner("AI Thinking...Please wait a while to Cheers!"):
start_5 = timeit.default_timer() # Start timer
st.write(f"Query Engine - AI QA开始:{start_5}")
initial_response = st.session_state.query_engine.query(st.session_state.user_question)
temp_ai_response=str(initial_response)
final_ai_response=temp_ai_response.partition('<|end|>')[0]
print("AI Response:\n"+final_ai_response)
st.write("AI Response:\n\n"+final_ai_response)