File size: 6,308 Bytes
d8b389b
 
05f0071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a59a206
 
 
 
 
 
 
 
 
 
 
 
 
05f0071
 
93e9504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05f0071
 
 
 
607dadf
 
c97b3fb
 
607dadf
 
 
 
c97b3fb
 
607dadf
c97b3fb
 
607dadf
 
 
c97b3fb
 
607dadf
 
 
 
 
 
 
 
 
c97b3fb
 
607dadf
 
 
c97b3fb
 
607dadf
 
 
 
c97b3fb
 
a59a206
c97b3fb
 
607dadf
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#这个版本有个问题,如果在运行状况下,增删文件,不会重新装载文件并构建向量数据库!

import streamlit as st
from llama_index import VectorStoreIndex, SimpleDirectoryReader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LangchainEmbedding, ServiceContext
from llama_index import StorageContext, load_index_from_storage
from llama_index import LLMPredictor
#from transformers import HuggingFaceHub
from langchain import HuggingFaceHub
from streamlit.components.v1 import html
from pathlib import Path
from time import sleep
import random
import string

import os
from dotenv import load_dotenv
load_dotenv()

import timeit

st.set_page_config(page_title="Open AI Doc-Chat Assistant", layout="wide")
st.subheader("Open AI Doc-Chat Assistant: Life Enhancing with AI!")

css_file = "main.css"
with open(css_file) as f:
    st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
    
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")

documents=[]

def generate_random_string(length):
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(length))  

#random_string = generate_random_string(20)
#directory_path=random_string

if "directory_path" not in st.session_state:
    st.session_state.directory_path = generate_random_string(20)

if "pdf_files" not in st.session_state:
    st.session_state.pdf_files = None

if "documents" not in st.session_state:
    st.session_state.documents = None

with st.sidebar:
    st.subheader("Upload your Documents Here: ")
    #if "pdf_files" not in st.session_state:
    #st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
    pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)    
    #if st.session_state.pdf_files:
    if not pdf_files:
        st.warning("请上传文档文件")
        st.stop()      
    else:
        st.session_state.pdf_files=pdf_files
        if not os.path.exists(st.session_state.directory_path):
            os.makedirs(st.session_state.directory_path)
            for pdf_file in st.session_state.pdf_files:
            #for pdf_file in pdf_files:
                file_path = os.path.join(st.session_state.directory_path, pdf_file.name)
                with open(file_path, 'wb') as f:
                    f.write(pdf_file.read())
                st.success(f"File '{pdf_file.name}' saved successfully.")
            try:
                start_1 = timeit.default_timer() # Start timer
                st.write(f"QA文档加载开始:{start_1}")
                st.session_state.documents = SimpleDirectoryReader(st.session_state.directory_path).load_data()
                end_1 = timeit.default_timer() # Start timer
                st.write(f"QA文档加载结束:{end_1}")
                st.write(f"QA文档加载耗时:{end_1 - start_1}")
            except Exception as e:
                print("文档加载出现问题/Waiting for path creation.")  

# Load documents from a directory
#documents = SimpleDirectoryReader('data').load_data()
    
start_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载开始:{start_2}")
if "embed_model" not in st.session_state:
    st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
end_2 = timeit.default_timer() # Start timer
st.write(f"向量模型加载加载结束:{end_2}")
st.write(f"向量模型加载耗时:{end_2 - start_2}")

if "llm_predictor" not in st.session_state:
    st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))

if "service_context" not in st.session_state:
    st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model)

start_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建开始:{start_3}")
if "new_index" not in st.session_state:
    st.session_state.new_index = VectorStoreIndex.from_documents(
    st.session_state.documents,
    service_context=st.session_state.service_context,
)
end_3 = timeit.default_timer() # Start timer
st.write(f"向量库构建结束:{end_3}")
st.write(f"向量库构建耗时:{end_3 - start_3}")

st.session_state.new_index.storage_context.persist("st.session_state.directory_path")

if "storage_context" not in st.session_state:
    st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path")

start_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载开始:{start_4}")
if "loadedindex" not in st.session_state:
    st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context)
end_4 = timeit.default_timer() # Start timer
st.write(f"向量库装载结束:{end_4}")
st.write(f"向量库装载耗时:{end_4 - start_4}")

if "query_engine" not in st.session_state:
    st.session_state.query_engine = st.session_state.loadedindex.as_query_engine()
    
if "user_question " not in st.session_state:
    st.session_state.user_question = st.text_input("Enter your query:")
if st.session_state.user_question !="" and not st.session_state.user_question.strip().isspace() and not st.session_state.user_question == "" and not st.session_state.user_question.strip() == "" and not st.session_state.user_question.isspace():
    print("user question: "+st.session_state.user_question)
    with st.spinner("AI Thinking...Please wait a while to Cheers!"):
        start_5 = timeit.default_timer() # Start timer
        st.write(f"Query Engine - AI QA开始:{start_5}") 
        initial_response = st.session_state.query_engine.query(st.session_state.user_question)
        temp_ai_response=str(initial_response)
        final_ai_response=temp_ai_response.partition('<|end|>')[0] 
        print("AI Response:\n"+final_ai_response)
        st.write("AI Response:\n\n"+final_ai_response)