binqiangliu commited on
Commit
a59a206
·
1 Parent(s): 05f0071

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -23
app.py CHANGED
@@ -32,16 +32,52 @@ documents=[]
32
  def generate_random_string(length):
33
  letters = string.ascii_lowercase
34
  return ''.join(random.choice(letters) for i in range(length))
35
- random_string = generate_random_string(20)
36
- directory_path=random_string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  with st.sidebar:
39
  st.subheader("Upload your Documents Here: ")
40
- pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
41
- if pdf_files:
42
- os.makedirs(directory_path)
43
- for pdf_file in pdf_files:
44
- file_path = os.path.join(directory_path, pdf_file.name)
 
 
45
  with open(file_path, 'wb') as f:
46
  f.write(pdf_file.read())
47
  st.success(f"File '{pdf_file.name}' saved successfully.")
@@ -49,7 +85,7 @@ with st.sidebar:
49
  try:
50
  start_1 = timeit.default_timer() # Start timer
51
  st.write(f"QA文档加载开始:{start_1}")
52
- documents = SimpleDirectoryReader(directory_path).load_data()
53
  end_1 = timeit.default_timer() # Start timer
54
  st.write(f"QA文档加载结束:{end_1}")
55
  st.write(f"QA文档加载耗时:{end_1 - start_1}")
@@ -61,45 +97,45 @@ except Exception as e:
61
 
62
  start_2 = timeit.default_timer() # Start timer
63
  st.write(f"向量模型加载开始:{start_2}")
64
- embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
65
  end_2 = timeit.default_timer() # Start timer
66
  st.write(f"向量模型加载加载结束:{end_2}")
67
  st.write(f"向量模型加载耗时:{end_2 - start_2}")
68
 
69
- llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))
70
 
71
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
72
 
73
  start_3 = timeit.default_timer() # Start timer
74
  st.write(f"向量库构建开始:{start_3}")
75
- new_index = VectorStoreIndex.from_documents(
76
- documents,
77
- service_context=service_context,
78
  )
79
  end_3 = timeit.default_timer() # Start timer
80
  st.write(f"向量库构建结束:{end_3}")
81
  st.write(f"向量库构建耗时:{end_3 - start_3}")
82
 
83
- new_index.storage_context.persist("directory_path")
84
 
85
- storage_context = StorageContext.from_defaults(persist_dir="directory_path")
86
 
87
  start_4 = timeit.default_timer() # Start timer
88
  st.write(f"向量库装载开始:{start_4}")
89
- loadedindex = load_index_from_storage(storage_context=storage_context, service_context=service_context)
90
  end_4 = timeit.default_timer() # Start timer
91
  st.write(f"向量库装载结束:{end_4}")
92
  st.write(f"向量库装载耗时:{end_4 - start_4}")
93
 
94
- query_engine = loadedindex.as_query_engine()
95
-
96
- user_question = st.text_input("Enter your query here:")
97
- if user_question !="" and not user_question.strip().isspace() and not user_question == "" and not user_question.strip() == "" and not user_question.isspace():
98
- print("user question: "+user_question)
99
  with st.spinner("AI Thinking...Please wait a while to Cheers!"):
100
  start_5 = timeit.default_timer() # Start timer
101
  st.write(f"Query Engine - AI QA开始:{start_5}")
102
- initial_response = query_engine.query(user_question)
103
  temp_ai_response=str(initial_response)
104
  final_ai_response=temp_ai_response.partition('<|end|>')[0]
105
  print("AI Response:\n"+final_ai_response)
 
32
  def generate_random_string(length):
33
  letters = string.ascii_lowercase
34
  return ''.join(random.choice(letters) for i in range(length))
35
+
36
+ #random_string = generate_random_string(20)
37
+ #directory_path=random_string
38
+
39
+ if "directory_path" not in st.session_state:
40
+ st.session_state.directory_path = generate_random_string(20)
41
+
42
+ if "pdf_files" not in st.session_state:
43
+ st.session_state.pdf_files = None
44
+
45
+ if "documents" not in st.session_state:
46
+ st.session_state.documents = None
47
+
48
+ if "embed_model" not in st.session_state:
49
+ st.session_state.embed_model = None
50
+
51
+ if "llm_predictor" not in st.session_state:
52
+ st.session_state.llm_predictor = None
53
+
54
+ if "service_context" not in st.session_state:
55
+ st.session_state.service_context = None
56
+
57
+ if "new_index" not in st.session_state:
58
+ st.session_state.new_index = None
59
+
60
+ if "storage_context" not in st.session_state:
61
+ st.session_state.storage_context = None
62
+
63
+ if "loadedindex" not in st.session_state:
64
+ st.session_state.loadedindex = None
65
+
66
+ if "query_engine" not in st.session_state:
67
+ st.session_state.query_engine = None
68
+
69
+ if "user_question " not in st.session_state:
70
+ st.session_state.user_question = ""
71
 
72
  with st.sidebar:
73
  st.subheader("Upload your Documents Here: ")
74
+ #if "pdf_files" not in st.session_state:
75
+ st.session_state.pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
76
+ #pdf_files = st.file_uploader("Choose your PDF Files and Press OK", type=['pdf'], accept_multiple_files=True)
77
+ if st.session_state.pdf_files:
78
+ os.makedirs(st.session_state.directory_path)
79
+ for pdf_file in st.session_state.pdf_files:
80
+ file_path = os.path.join(st.session_state.directory_path, pdf_file.name)
81
  with open(file_path, 'wb') as f:
82
  f.write(pdf_file.read())
83
  st.success(f"File '{pdf_file.name}' saved successfully.")
 
85
  try:
86
  start_1 = timeit.default_timer() # Start timer
87
  st.write(f"QA文档加载开始:{start_1}")
88
+ st.session_state.documents = SimpleDirectoryReader(st.session_state.directory_path).load_data()
89
  end_1 = timeit.default_timer() # Start timer
90
  st.write(f"QA文档加载结束:{end_1}")
91
  st.write(f"QA文档加载耗时:{end_1 - start_1}")
 
97
 
98
  start_2 = timeit.default_timer() # Start timer
99
  st.write(f"向量模型加载开始:{start_2}")
100
+ st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
101
  end_2 = timeit.default_timer() # Start timer
102
  st.write(f"向量模型加载加载结束:{end_2}")
103
  st.write(f"向量模型加载耗时:{end_2 - start_2}")
104
 
105
+ st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))
106
 
107
+ st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model)
108
 
109
  start_3 = timeit.default_timer() # Start timer
110
  st.write(f"向量库构建开始:{start_3}")
111
+ st.session_state.new_index = VectorStoreIndex.from_documents(
112
+ st.session_state.documents,
113
+ service_context=st.session_state.service_context,
114
  )
115
  end_3 = timeit.default_timer() # Start timer
116
  st.write(f"向量库构建结束:{end_3}")
117
  st.write(f"向量库构建耗时:{end_3 - start_3}")
118
 
119
+ st.session_state.new_index.storage_context.persist("st.session_state.directory_path")
120
 
121
+ st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path")
122
 
123
  start_4 = timeit.default_timer() # Start timer
124
  st.write(f"向量库装载开始:{start_4}")
125
+ st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context)
126
  end_4 = timeit.default_timer() # Start timer
127
  st.write(f"向量库装载结束:{end_4}")
128
  st.write(f"向量库装载耗时:{end_4 - start_4}")
129
 
130
+ st.session_state.query_engine = st.session_state.loadedindex.as_query_engine()
131
+
132
+ st.session_state.user_question=st.text_input("Enter your query:")
133
+ if st.session_state.user_question !="" and not st.session_state.user_question.strip().isspace() and not st.session_state.user_question == "" and not st.session_state.user_question.strip() == "" and not st.session_state.user_question.isspace():
134
+ print("user question: "+st.session_state.user_question)
135
  with st.spinner("AI Thinking...Please wait a while to Cheers!"):
136
  start_5 = timeit.default_timer() # Start timer
137
  st.write(f"Query Engine - AI QA开始:{start_5}")
138
+ initial_response = st.session_state.query_engine.query(st.session_state.user_question)
139
  temp_ai_response=str(initial_response)
140
  final_ai_response=temp_ai_response.partition('<|end|>')[0]
141
  print("AI Response:\n"+final_ai_response)