Spaces:
Sleeping
Sleeping
File size: 7,088 Bytes
656fdc6 ddd7ff1 656fdc6 c52b691 656fdc6 c52b691 351c7d2 c52b691 ddd7ff1 351c7d2 656fdc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
#删除了documents=[]
#将st.session_state的变量全部移动到相应的变量第一次出现位置,而不是在最开始全部声明为None
#将pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
#修改为if "pdf_files" not in st.session_state:
# st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
#if not st.session_state.pdf_files:
#意思就是如果st.session_state.pdf_files为空,就停止执行程序
import streamlit as st
from llama_index import VectorStoreIndex, SimpleDirectoryReader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LangchainEmbedding, ServiceContext
from llama_index import StorageContext, load_index_from_storage
from llama_index import LLMPredictor
#from transformers import HuggingFaceHub
from langchain import HuggingFaceHub
from streamlit.components.v1 import html
from pathlib import Path
from time import sleep
import random
import string
import os
from dotenv import load_dotenv
load_dotenv()
import timeit
st.set_page_config(page_title="Open AI Doc-Chat Assistant", layout="wide")
st.subheader("Open AI Doc-Chat Assistant: Life Enhancing with AI!")
css_file = "main.css"
with open(css_file) as f:
st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
def generate_random_string(length):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for i in range(length))
#random_string = generate_random_string(20)
#directory_path=random_string
#if "directory_path" not in st.session_state:
# st.session_state.directory_path = generate_random_string(20)
with st.sidebar:
st.subheader("Upload your Documents Here: ")
pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
if "pdf_files" not in st.session_state:
#st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
st.session_state.pdf_files = pdf_files
if not st.session_state.pdf_files: #如果没有上传文件,则程序停止执行,就不会出现documents为空的错误情况、
st.warning("请上传文档文件")
st.stop()
else: #如果已经上传文件,则装载文件SimpleDirectoryReader.load_data()
#st.session_state.pdf_files=pdf_files
#if not os.path.exists(st.session_state.directory_path):
if "directory_path" not in st.session_state:
st.session_state.directory_path = generate_random_string(20)
os.makedirs(st.session_state.directory_path)
for pdf_file in st.session_state.pdf_files:
#for pdf_file in pdf_files:
file_path = os.path.join(st.session_state.directory_path, pdf_file.name)
with open(file_path, 'wb') as f:
f.write(pdf_file.read())
st.success(f"File '{pdf_file.name}' saved successfully.")
try:
start_1 = timeit.default_timer() # Start timer
st.write(f"QA文档加载开始:{start_1}")
if "documents" not in st.session_state:
st.session_state.documents = SimpleDirectoryReader(st.session_state.directory_path).load_data()
end_1 = timeit.default_timer() # Start timer
st.write(f"QA文档加载结束:{end_1}")
st.write(f"QA文档加载耗时:{end_1 - start_1}")
except Exception as e:
print("文档加载出现问题/Waiting for path creation.")
# Load documents from a directory
#documents = SimpleDirectoryReader('data').load_data()
start_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载开始:{start_2}")
if "embed_model" not in st.session_state:
st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
end_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载加载结束:{end_2}")
st.write(f"向量模型加载耗时:{end_2 - start_2}")
if "llm_predictor" not in st.session_state:
st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))
if "service_context" not in st.session_state:
st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model)
start_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建开始:{start_3}")
if "new_index" not in st.session_state:
st.session_state.new_index = VectorStoreIndex.from_documents(
st.session_state.documents,
service_context=st.session_state.service_context,
)
end_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建结束:{end_3}")
st.write(f"向量库构建耗时:{end_3 - start_3}")
st.session_state.new_index.storage_context.persist("st.session_state.directory_path")
if "storage_context" not in st.session_state:
st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path")
start_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载开始:{start_4}")
if "loadedindex" not in st.session_state:
st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context)
end_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载结束:{end_4}")
st.write(f"向量库装载耗时:{end_4 - start_4}")
if "query_engine" not in st.session_state:
st.session_state.query_engine = st.session_state.loadedindex.as_query_engine()
if "user_question " not in st.session_state:
st.session_state.user_question = st.text_input("Enter your query:")
if st.session_state.user_question !="" and not st.session_state.user_question.strip().isspace() and not st.session_state.user_question == "" and not st.session_state.user_question.strip() == "" and not st.session_state.user_question.isspace():
print("user question: "+st.session_state.user_question)
with st.spinner("AI Thinking...Please wait a while to Cheers!"):
start_5 = timeit.default_timer() # Start timer
st.write(f"Query Engine - AI QA开始:{start_5}")
initial_response = st.session_state.query_engine.query(st.session_state.user_question)
temp_ai_response=str(initial_response)
final_ai_response=temp_ai_response.partition('<|end|>')[0]
print("AI Response:\n"+final_ai_response)
st.write("AI Response:\n\n"+final_ai_response) |