File size: 7,088 Bytes
656fdc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd7ff1
 
656fdc6
 
c52b691
 
656fdc6
c52b691
 
351c7d2
 
 
 
c52b691
ddd7ff1
 
 
351c7d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656fdc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#删除了documents=[]
#将st.session_state的变量全部移动到相应的变量第一次出现位置,而不是在最开始全部声明为None
#将pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
#修改为if "pdf_files" not in st.session_state:
#    st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)   
#if not st.session_state.pdf_files:
#意思就是如果st.session_state.pdf_files为空,就停止执行程序

import streamlit as st
from llama_index import VectorStoreIndex, SimpleDirectoryReader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LangchainEmbedding, ServiceContext
from llama_index import StorageContext, load_index_from_storage
from llama_index import LLMPredictor
#from transformers import HuggingFaceHub
from langchain import HuggingFaceHub
from streamlit.components.v1 import html
from pathlib import Path
from time import sleep
import random
import string

import os
from dotenv import load_dotenv
load_dotenv()

import timeit

st.set_page_config(page_title="Open AI Doc-Chat Assistant", layout="wide")
st.subheader("Open AI Doc-Chat Assistant: Life Enhancing with AI!")

css_file = "main.css"
with open(css_file) as f:
    st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
    
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")

def generate_random_string(length):
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(length))  

#random_string = generate_random_string(20)
#directory_path=random_string

#if "directory_path" not in st.session_state:
#    st.session_state.directory_path = generate_random_string(20)

with st.sidebar:
    st.subheader("Upload your Documents Here: ") 
    pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
    if "pdf_files" not in st.session_state:
        #st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)   
        st.session_state.pdf_files = pdf_files
        if not st.session_state.pdf_files:   #如果没有上传文件,则程序停止执行,就不会出现documents为空的错误情况、
            st.warning("请上传文档文件")
            st.stop()      
        else:   #如果已经上传文件,则装载文件SimpleDirectoryReader.load_data()
            #st.session_state.pdf_files=pdf_files
            #if not os.path.exists(st.session_state.directory_path):
            if "directory_path" not in st.session_state:
                st.session_state.directory_path = generate_random_string(20)
                os.makedirs(st.session_state.directory_path)
                for pdf_file in st.session_state.pdf_files:
                    #for pdf_file in pdf_files:
                    file_path = os.path.join(st.session_state.directory_path, pdf_file.name)
                    with open(file_path, 'wb') as f:
                        f.write(pdf_file.read())
                    st.success(f"File '{pdf_file.name}' saved successfully.")
                try:
                    start_1 = timeit.default_timer() # Start timer
                    st.write(f"QA文档加载开始:{start_1}")
                    if "documents" not in st.session_state:
                        st.session_state.documents = SimpleDirectoryReader(st.session_state.directory_path).load_data()
                        end_1 = timeit.default_timer() # Start timer
                        st.write(f"QA文档加载结束:{end_1}")
                        st.write(f"QA文档加载耗时:{end_1 - start_1}")
                except Exception as e:
                    print("文档加载出现问题/Waiting for path creation.")  

# Load documents from a directory
#documents = SimpleDirectoryReader('data').load_data()
    
start_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载开始:{start_2}")
if "embed_model" not in st.session_state:
    st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
end_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载加载结束:{end_2}")
st.write(f"向量模型加载耗时:{end_2 - start_2}")

if "llm_predictor" not in st.session_state:
    st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))

if "service_context" not in st.session_state:
    st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model)

start_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建开始:{start_3}")
if "new_index" not in st.session_state:
    st.session_state.new_index = VectorStoreIndex.from_documents(
    st.session_state.documents,
    service_context=st.session_state.service_context,
)
end_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建结束:{end_3}")
st.write(f"向量库构建耗时:{end_3 - start_3}")

st.session_state.new_index.storage_context.persist("st.session_state.directory_path")

if "storage_context" not in st.session_state:
    st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path")

start_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载开始:{start_4}")
if "loadedindex" not in st.session_state:
    st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context)
end_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载结束:{end_4}")
st.write(f"向量库装载耗时:{end_4 - start_4}")

if "query_engine" not in st.session_state:
    st.session_state.query_engine = st.session_state.loadedindex.as_query_engine()
    
if "user_question " not in st.session_state:
    st.session_state.user_question = st.text_input("Enter your query:")
if st.session_state.user_question !="" and not st.session_state.user_question.strip().isspace() and not st.session_state.user_question == "" and not st.session_state.user_question.strip() == "" and not st.session_state.user_question.isspace():
    print("user question: "+st.session_state.user_question)
    with st.spinner("AI Thinking...Please wait a while to Cheers!"):
        start_5 = timeit.default_timer() # Start timer
        st.write(f"Query Engine - AI QA开始:{start_5}") 
        initial_response = st.session_state.query_engine.query(st.session_state.user_question)
        temp_ai_response=str(initial_response)
        final_ai_response=temp_ai_response.partition('<|end|>')[0]
        print("AI Response:\n"+final_ai_response)
        st.write("AI Response:\n\n"+final_ai_response)