Spaces:
Sleeping
Sleeping
| #这个版本有个问题,如果在运行状况下,增删文件,不会重新装载文件并构建向量数据库! | |
| import streamlit as st | |
| from llama_index import VectorStoreIndex, SimpleDirectoryReader | |
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
| from llama_index import LangchainEmbedding, ServiceContext | |
| from llama_index import StorageContext, load_index_from_storage | |
| from llama_index import LLMPredictor | |
| #from transformers import HuggingFaceHub | |
| from langchain import HuggingFaceHub | |
| from streamlit.components.v1 import html | |
| from pathlib import Path | |
| from time import sleep | |
| import random | |
| import string | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import timeit | |
| st.set_page_config(page_title="Open AI Doc-Chat Assistant", layout="wide") | |
| st.subheader("Open AI Doc-Chat Assistant: Life Enhancing with AI!") | |
| css_file = "main.css" | |
| with open(css_file) as f: | |
| st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True) | |
| HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| #documents=[] | |
| def generate_random_string(length): | |
| letters = string.ascii_lowercase | |
| return ''.join(random.choice(letters) for i in range(length)) | |
| #random_string = generate_random_string(20) | |
| #directory_path=random_string | |
| #if "pdf_files" not in st.session_state: | |
| #st.session_state.pdf_files = None | |
| if "documents" not in st.session_state: | |
| st.session_state.documents = None | |
| if "query_engine" not in st.session_state: | |
| st.session_state.query_engine = None | |
| with st.sidebar: | |
| st.subheader("Upload your Documents Here: ") | |
| #if "pdf_files" not in st.session_state: | |
| #st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True) | |
| pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True) | |
| #if st.session_state.pdf_files: | |
| if not pdf_files: | |
| st.warning("请上传文档文件") | |
| st.stop() | |
| else: | |
| uploadedfile_path=generate_random_string(20) | |
| os.makedirs(uploadedfile_path) | |
| for pdf_file in pdf_files: | |
| file_path = os.path.join(uploadedfile_path, pdf_file.name) | |
| with open(file_path, 'wb') as f: | |
| f.write(pdf_file.read()) | |
| st.success(f"File '{pdf_file.name}' saved successfully.") | |
| try: | |
| start_1 = timeit.default_timer() # Start timer | |
| st.write(f"QA文档加载开始:{start_1}") | |
| st.session_state.documents = SimpleDirectoryReader(uploadedfile_path).load_data() | |
| #documents = SimpleDirectoryReader(uploadedfile_path).load_data() | |
| end_1 = timeit.default_timer() # Start timer | |
| st.write(f"QA文档加载结束:{end_1}") | |
| st.write(f"QA文档加载耗时:{end_1 - start_1}") | |
| except Exception as e: | |
| print("文档加载出现问题/Waiting for path creation.") | |
| st.warning("文档加载出现问题/Waiting for path creation.") | |
| st.stop() | |
| start_2 = timeit.default_timer() # Start timer | |
| st.write(f"向量模型加载开始:{start_2}") | |
| if "embed_model" not in st.session_state: | |
| st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')) | |
| end_2 = timeit.default_timer() # Start timer | |
| st.write(f"向量模型加载结束:{end_2}") | |
| st.write(f"向量模型加载耗时:{end_2 - start_2}") | |
| if "llm_predictor" not in st.session_state: | |
| st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155})) | |
| if "service_context" not in st.session_state: | |
| st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model) | |
| start_3 = timeit.default_timer() # Start timer | |
| st.write(f"向量库构建开始:{start_3}") | |
| if "new_index" not in st.session_state: | |
| st.session_state.new_index = VectorStoreIndex.from_documents(st.session_state.documents, service_context=st.session_state.service_context) | |
| end_3 = timeit.default_timer() # Start timer | |
| st.write(f"向量库构建结束:{end_3}") | |
| st.write(f"向量库构建耗时:{end_3 - start_3}") | |
| if "directory_path" not in st.session_state: | |
| st.session_state.directory_path = generate_random_string(20) | |
| os.makedirs(st.session_state.directory_path) | |
| st.session_state.new_index.storage_context.persist("st.session_state.directory_path") | |
| if "storage_context" not in st.session_state: | |
| st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path") | |
| start_4 = timeit.default_timer() # Start timer | |
| st.write(f"向量库装载开始:{start_4}") | |
| if "loadedindex" not in st.session_state: | |
| st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context) | |
| end_4 = timeit.default_timer() # Start timer | |
| st.write(f"向量库装载结束:{end_4}") | |
| st.write(f"向量库装载耗时:{end_4 - start_4}") | |
| st.session_state.query_engine = st.session_state.loadedindex.as_query_engine() | |
| user_question = st.text_input("Enter your query:") | |
| if user_question !="" and not user_question.strip().isspace() and not user_question == "" and not user_question.strip() == "" and not user_question.isspace(): | |
| print("user question: "+user_question) | |
| with st.spinner("AI Thinking...Please wait a while to Cheers!"): | |
| start_5 = timeit.default_timer() # Start timer | |
| st.write(f"Query Engine - AI QA开始:{start_5}") | |
| initial_response = st.session_state.query_engine.query(user_question) | |
| temp_ai_response=str(initial_response) | |
| final_ai_response=temp_ai_response.partition('<|end|>')[0] | |
| print("AI Response:\n"+final_ai_response) | |
| st.write("AI Response:\n\n"+final_ai_response) |