Spaces:
Sleeping
Sleeping
File size: 6,308 Bytes
d8b389b 05f0071 a59a206 05f0071 93e9504 05f0071 607dadf c97b3fb 607dadf c97b3fb 607dadf c97b3fb 607dadf c97b3fb 607dadf c97b3fb 607dadf c97b3fb 607dadf c97b3fb a59a206 c97b3fb 607dadf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
#这个版本有个问题,如果在运行状况下,增删文件,不会重新装载文件并构建向量数据库!
import streamlit as st
from llama_index import VectorStoreIndex, SimpleDirectoryReader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LangchainEmbedding, ServiceContext
from llama_index import StorageContext, load_index_from_storage
from llama_index import LLMPredictor
#from transformers import HuggingFaceHub
from langchain import HuggingFaceHub
from streamlit.components.v1 import html
from pathlib import Path
from time import sleep
import random
import string
import os
from dotenv import load_dotenv
load_dotenv()
import timeit
st.set_page_config(page_title="Open AI Doc-Chat Assistant", layout="wide")
st.subheader("Open AI Doc-Chat Assistant: Life Enhancing with AI!")
css_file = "main.css"
with open(css_file) as f:
st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
documents=[]
def generate_random_string(length):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for i in range(length))
#random_string = generate_random_string(20)
#directory_path=random_string
if "directory_path" not in st.session_state:
st.session_state.directory_path = generate_random_string(20)
if "pdf_files" not in st.session_state:
st.session_state.pdf_files = None
if "documents" not in st.session_state:
st.session_state.documents = None
with st.sidebar:
st.subheader("Upload your Documents Here: ")
#if "pdf_files" not in st.session_state:
#st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
#if st.session_state.pdf_files:
if not pdf_files:
st.warning("请上传文档文件")
st.stop()
else:
st.session_state.pdf_files=pdf_files
if not os.path.exists(st.session_state.directory_path):
os.makedirs(st.session_state.directory_path)
for pdf_file in st.session_state.pdf_files:
#for pdf_file in pdf_files:
file_path = os.path.join(st.session_state.directory_path, pdf_file.name)
with open(file_path, 'wb') as f:
f.write(pdf_file.read())
st.success(f"File '{pdf_file.name}' saved successfully.")
try:
start_1 = timeit.default_timer() # Start timer
st.write(f"QA文档加载开始:{start_1}")
st.session_state.documents = SimpleDirectoryReader(st.session_state.directory_path).load_data()
end_1 = timeit.default_timer() # Start timer
st.write(f"QA文档加载结束:{end_1}")
st.write(f"QA文档加载耗时:{end_1 - start_1}")
except Exception as e:
print("文档加载出现问题/Waiting for path creation.")
# Load documents from a directory
#documents = SimpleDirectoryReader('data').load_data()
start_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载开始:{start_2}")
if "embed_model" not in st.session_state:
st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
end_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载加载结束:{end_2}")
st.write(f"向量模型加载耗时:{end_2 - start_2}")
if "llm_predictor" not in st.session_state:
st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))
if "service_context" not in st.session_state:
st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model)
start_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建开始:{start_3}")
if "new_index" not in st.session_state:
st.session_state.new_index = VectorStoreIndex.from_documents(
st.session_state.documents,
service_context=st.session_state.service_context,
)
end_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建结束:{end_3}")
st.write(f"向量库构建耗时:{end_3 - start_3}")
st.session_state.new_index.storage_context.persist("st.session_state.directory_path")
if "storage_context" not in st.session_state:
st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path")
start_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载开始:{start_4}")
if "loadedindex" not in st.session_state:
st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context)
end_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载结束:{end_4}")
st.write(f"向量库装载耗时:{end_4 - start_4}")
if "query_engine" not in st.session_state:
st.session_state.query_engine = st.session_state.loadedindex.as_query_engine()
if "user_question " not in st.session_state:
st.session_state.user_question = st.text_input("Enter your query:")
if st.session_state.user_question !="" and not st.session_state.user_question.strip().isspace() and not st.session_state.user_question == "" and not st.session_state.user_question.strip() == "" and not st.session_state.user_question.isspace():
print("user question: "+st.session_state.user_question)
with st.spinner("AI Thinking...Please wait a while to Cheers!"):
start_5 = timeit.default_timer() # Start timer
st.write(f"Query Engine - AI QA开始:{start_5}")
initial_response = st.session_state.query_engine.query(st.session_state.user_question)
temp_ai_response=str(initial_response)
final_ai_response=temp_ai_response.partition('<|end|>')[0]
print("AI Response:\n"+final_ai_response)
st.write("AI Response:\n\n"+final_ai_response) |