binqiangliu commited on
Commit
794580c
·
1 Parent(s): 932638a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -48
app.py CHANGED
@@ -41,8 +41,11 @@ def generate_random_string(length):
41
  #if "pdf_files" not in st.session_state:
42
  #st.session_state.pdf_files = None
43
 
44
- if "documents" not in st.session_state:
45
- st.session_state.documents = None
 
 
 
46
 
47
  with st.sidebar:
48
  st.subheader("Upload your Documents Here: ")
@@ -64,57 +67,45 @@ with st.sidebar:
64
  try:
65
  start_1 = timeit.default_timer() # Start timer
66
  st.write(f"QA文档加载开始:{start_1}")
67
- st.session_state.documents = SimpleDirectoryReader(uploadedfile_path).load_data()
 
68
  end_1 = timeit.default_timer() # Start timer
69
  st.write(f"QA文档加载结束:{end_1}")
70
  st.write(f"QA文档加载耗时:{end_1 - start_1}")
71
  except Exception as e:
72
  print("文档加载出现问题/Waiting for path creation.")
73
-
74
- start_2 = timeit.default_timer() # Start timer
75
- st.write(f"向量模型加载开始:{start_2}")
76
- if "embed_model" not in st.session_state:
77
- st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
78
- end_2 = timeit.default_timer() # Start timer
79
- st.write(f"向量模型加载加载结束:{end_2}")
80
- st.write(f"向量模型加载耗时:{end_2 - start_2}")
81
-
82
- if "llm_predictor" not in st.session_state:
83
- st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))
84
-
85
- if "service_context" not in st.session_state:
86
- st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model)
87
-
88
- start_3 = timeit.default_timer() # Start timer
89
- st.write(f"向量库构建开始:{start_3}")
90
- if "new_index" not in st.session_state:
91
- st.session_state.new_index = VectorStoreIndex.from_documents(
92
- st.session_state.documents,
93
- service_context=st.session_state.service_context,
94
- )
95
- end_3 = timeit.default_timer() # Start timer
96
- st.write(f"向量库构建结束:{end_3}")
97
- st.write(f"向量库构建耗时:{end_3 - start_3}")
98
-
99
- if "directory_path" not in st.session_state:
100
- st.session_state.directory_path = generate_random_string(20)
101
- os.makedirs(st.session_state.directory_path)
102
-
103
- st.session_state.new_index.storage_context.persist("st.session_state.directory_path")
104
-
105
- if "storage_context" not in st.session_state:
106
- st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path")
107
-
108
- start_4 = timeit.default_timer() # Start timer
109
- st.write(f"向量库装载开始:{start_4}")
110
- if "loadedindex" not in st.session_state:
111
- st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context)
112
- end_4 = timeit.default_timer() # Start timer
113
- st.write(f"向量库装载结束:{end_4}")
114
- st.write(f"向量库装载耗时:{end_4 - start_4}")
115
-
116
- if "query_engine" not in st.session_state:
117
- st.session_state.query_engine = st.session_state.loadedindex.as_query_engine()
118
 
119
  user_question = st.text_input("Enter your query:")
120
  if user_question !="" and not user_question.strip().isspace() and not user_question == "" and not user_question.strip() == "" and not user_question.isspace():
 
41
  #if "pdf_files" not in st.session_state:
42
  #st.session_state.pdf_files = None
43
 
44
+ #if "documents" not in st.session_state:
45
+ # st.session_state.documents = None
46
+
47
+ if "query_engine" not in st.session_state:
48
+ st.session_state.query_engine = None
49
 
50
  with st.sidebar:
51
  st.subheader("Upload your Documents Here: ")
 
67
  try:
68
  start_1 = timeit.default_timer() # Start timer
69
  st.write(f"QA文档加载开始:{start_1}")
70
+ #st.session_state.documents = SimpleDirectoryReader(uploadedfile_path).load_data()
71
+ documents = SimpleDirectoryReader(uploadedfile_path).load_data()
72
  end_1 = timeit.default_timer() # Start timer
73
  st.write(f"QA文档加载结束:{end_1}")
74
  st.write(f"QA文档加载耗时:{end_1 - start_1}")
75
  except Exception as e:
76
  print("文档加载出现问题/Waiting for path creation.")
77
+ start_2 = timeit.default_timer() # Start timer
78
+ st.write(f"向量模型加载开始:{start_2}")
79
+ if "embed_model" not in st.session_state:
80
+ st.session_state.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
81
+ end_2 = timeit.default_timer() # Start timer
82
+ st.write(f"向量模型加载结束:{end_2}")
83
+ st.write(f"向量模型加载耗时:{end_2 - start_2}")
84
+ if "llm_predictor" not in st.session_state:
85
+ st.session_state.llm_predictor = LLMPredictor(HuggingFaceHub(repo_id="HuggingFaceH4/starchat-beta", model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.2,"top_k":50, "top_p":0.95, "eos_token_id":49155}))
86
+ if "service_context" not in st.session_state:
87
+ st.session_state.service_context = ServiceContext.from_defaults(llm_predictor=st.session_state.llm_predictor, embed_model=st.session_state.embed_model)
88
+ start_3 = timeit.default_timer() # Start timer
89
+ st.write(f"向量库构建开始:{start_3}")
90
+ if "new_index" not in st.session_state:
91
+ st.session_state.new_index = VectorStoreIndex.from_documents(st.session_state.documents, service_context=st.session_state.service_context)
92
+ end_3 = timeit.default_timer() # Start timer
93
+ st.write(f"向量库构建结束:{end_3}")
94
+ st.write(f"向量库构建耗时:{end_3 - start_3}")
95
+ if "directory_path" not in st.session_state:
96
+ st.session_state.directory_path = generate_random_string(20)
97
+ os.makedirs(st.session_state.directory_path)
98
+ st.session_state.new_index.storage_context.persist("st.session_state.directory_path")
99
+ if "storage_context" not in st.session_state:
100
+ st.session_state.storage_context = StorageContext.from_defaults(persist_dir="st.session_state.directory_path")
101
+ start_4 = timeit.default_timer() # Start timer
102
+ st.write(f"向量库装载开始:{start_4}")
103
+ if "loadedindex" not in st.session_state:
104
+ st.session_state.loadedindex = load_index_from_storage(storage_context=st.session_state.storage_context, service_context=st.session_state.service_context)
105
+ end_4 = timeit.default_timer() # Start timer
106
+ st.write(f"向量库装载结束:{end_4}")
107
+ st.write(f"向量库装载耗时:{end_4 - start_4}")
108
+ st.session_state.query_engine = st.session_state.loadedindex.as_query_engine()
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  user_question = st.text_input("Enter your query:")
111
  if user_question !="" and not user_question.strip().isspace() and not user_question == "" and not user_question.strip() == "" and not user_question.isspace():