#删除了documents=[] #将st.session_state的变量全部移动到相应的变量第一次出现位置,而不是在最开始全部声明为None #将pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True) #修改为if "pdf_files" not in st.session_state: # st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True) #if not st.session_state.pdf_files: #意思就是如果st.session_state.pdf_files为空,就停止执行程序 import streamlit as st from llama_index import VectorStoreIndex, SimpleDirectoryReader from langchain.embeddings.huggingface import HuggingFaceEmbeddings from llama_index import LangchainEmbedding, ServiceContext from llama_index import StorageContext, load_index_from_storage from llama_index import LLMPredictor #from transformers import HuggingFaceHub from langchain import HuggingFaceHub from streamlit.components.v1 import html from pathlib import Path from time import sleep import random import string import os from dotenv import load_dotenv load_dotenv() import timeit st.set_page_config(page_title="Open AI Doc-Chat Assistant", layout="wide") st.subheader("Open AI Doc-Chat Assistant: Life Enhancing with AI!") css_file = "main.css" with open(css_file) as f: st.markdown("".format(f.read()), unsafe_allow_html=True) HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") def generate_random_string(length): letters = string.ascii_lowercase return ''.join(random.choice(letters) for i in range(length)) #random_string = generate_random_string(20) #directory_path=random_string #if "directory_path" not in st.session_state: # st.session_state.directory_path = generate_random_string(20) with st.sidebar: st.subheader("Upload your Documents Here: ") pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True) if "pdf_files" not in st.session_state: #st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True) st.session_state.pdf_files = pdf_files if not st.session_state.pdf_files: #如果没有上传文件,则程序停止执行,就不会出现documents为空的错误情况、 st.warning("请上传文档文件") st.stop() else: #如果已经上传文件,则装载文件SimpleDirectoryReader.load_data() #st.session_state.pdf_files=pdf_files #if not os.path.exists(st.session_state.directory_path): if "directory_path" not in st.session_state: st.session_state.directory_path = generate_random_string(20) os.makedirs(st.session_state.directory_path) for pdf_file in st.session_state.pdf_files: #for pdf_file in pdf_files: file_path = os.path.join(st.session_state.directory_path, pdf_file.name) with open(file_path, 'wb') as f: f.write(pdf_file.read()) st.success(f"File '{pdf_file.name}' saved successfully.") try: start_1 = timeit.default_timer() # Start timer st.write(f"QA文档加载开始:{start_1}") if "documents" not in st.session_state: st.session_state.documents = SimpleDirectoryReader(st.session_state.directory_path).load_data() end_1 = timeit.default_timer() # Start timer st.write(f"QA文档加载结束:{end_1}") st.write(f"QA文档加载耗时:{end_1 - start_1}") except Exception as e: print("文档加载出现问题/Waiting for path creation.") # Load documents from a directory #documents = SimpleDirectoryReader('data').load_data() start_2 = timeit.default_timer() # Start timer st.write(f"向量模型加载开始:{start_2}") if "embed_model" not in st.session_state: st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')) end_2 = timeit.default_timer() # Start timer st.write(f"向量模型加载加载结束:{end_2}") st.write(f"向量模型加载耗时:{end_2 - start_2}") if "llm_predictor" not in st.session_state: st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155})) if "service_context" not in st.session_state: st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model) start_3 = timeit.default_timer() # Start timer st.write(f"向量库构建开始:{start_3}") if "new_index" not in st.session_state: st.session_state.new_index = VectorStoreIndex.from_documents( st.session_state.documents, service_context=st.session_state.service_context, ) end_3 = timeit.default_timer() # Start timer st.write(f"向量库构建结束:{end_3}") st.write(f"向量库构建耗时:{end_3 - start_3}") st.session_state.new_index.storage_context.persist("st.session_state.directory_path") if "storage_context" not in st.session_state: st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path") start_4 = timeit.default_timer() # Start timer st.write(f"向量库装载开始:{start_4}") if "loadedindex" not in st.session_state: st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context) end_4 = timeit.default_timer() # Start timer st.write(f"向量库装载结束:{end_4}") st.write(f"向量库装载耗时:{end_4 - start_4}") if "query_engine" not in st.session_state: st.session_state.query_engine = st.session_state.loadedindex.as_query_engine() if "user_question " not in st.session_state: st.session_state.user_question = st.text_input("Enter your query:") if st.session_state.user_question !="" and not st.session_state.user_question.strip().isspace() and not st.session_state.user_question == "" and not st.session_state.user_question.strip() == "" and not st.session_state.user_question.isspace(): print("user question: "+st.session_state.user_question) with st.spinner("AI Thinking...Please wait a while to Cheers!"): start_5 = timeit.default_timer() # Start timer st.write(f"Query Engine - AI QA开始:{start_5}") initial_response = st.session_state.query_engine.query(st.session_state.user_question) temp_ai_response=str(initial_response) final_ai_response=temp_ai_response.partition('<|end|>')[0] print("AI Response:\n"+final_ai_response) st.write("AI Response:\n\n"+final_ai_response)