xl2533 commited on
Commit
9143731
·
1 Parent(s): b49db7d

global faiss

Browse files
Files changed (2) hide show
  1. app.py +14 -13
  2. requirement.txt +0 -3
app.py CHANGED
@@ -17,6 +17,7 @@ from langchain.prompts.chat import (
17
  API_URL = "https://api.openai.com/v1/chat/completions"
18
  cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
19
  faiss_store = './indexer/{}'
 
20
 
21
 
22
  def gen_conversation(conversations):
@@ -33,9 +34,10 @@ def gen_conversation(conversations):
33
  return messages
34
 
35
 
36
- def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, topic,
37
  chat_counter, chatbot=[], history=[]):
38
- model = model[0]
 
39
  topic = topic[0]
40
  headers = {
41
  "Content-Type": "application/json",
@@ -51,10 +53,11 @@ def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens
51
  if enable_index:
52
  # Faiss 检索最近的embedding
53
  store = faiss_store.format(topic)
54
- if model == 'openai':
 
55
  docsearch = FAISS.load_local(store, OpenAIEmbeddings(openai_api_key=openai_api_key))
56
  else:
57
- docsearch = FAISS.load_local(store, CohereEmbeddings(cohere_api_key=cohere_key))
58
  # 构建模板
59
  llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
60
  messages_combine = [
@@ -147,18 +150,16 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
147
  with gr.Row():
148
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.9, step=0.05, interactive=True,
149
  label="Top-p (nucleus sampling)", )
150
-
151
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=0.8, step=0.1, interactive=True,
152
  label="Temperature", )
153
- with gr.Row():
154
- model = gr.CheckboxGroup(["cohere", "openai", "mpnet"])
155
- max_tokens = gr.Slider(minimum=100, maximum=2000, value=200, step=100, interactive=True,
156
  label="Max Tokens", )
157
  chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
 
158
  with gr.Row():
159
- enable_index = gr.Checkbox(label='是', info='开启基于文档问答模式/关闭为聊天模式')
160
- enable_search = gr.Checkbox(label='是', info='是否使用搜索结果')
161
- topic = gr.CheckboxGroup(["两会", "数字经济", "硅谷银行"])
162
 
163
  chatbot = gr.Chatbot(elem_id='chatbot')
164
  inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
@@ -169,11 +170,11 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
169
  run = gr.Button("Run")
170
 
171
  inputs.submit(predict,
172
- [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, topic, chat_counter, chatbot,
173
  state],
174
  [chatbot, state, chat_counter], )
175
  run.click(predict,
176
- [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, topic, chat_counter, chatbot,
177
  state],
178
  [chatbot, state, chat_counter], )
179
 
 
17
  API_URL = "https://api.openai.com/v1/chat/completions"
18
  cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
19
  faiss_store = './indexer/{}'
20
+ docsearch = None
21
 
22
 
23
  def gen_conversation(conversations):
 
34
  return messages
35
 
36
 
37
+ def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic,
38
  chat_counter, chatbot=[], history=[]):
39
+ global docsearch
40
+
41
  topic = topic[0]
42
  headers = {
43
  "Content-Type": "application/json",
 
53
  if enable_index:
54
  # Faiss 检索最近的embedding
55
  store = faiss_store.format(topic)
56
+ if docsearch is None:
57
+ print('Loading FAISS')
58
  docsearch = FAISS.load_local(store, OpenAIEmbeddings(openai_api_key=openai_api_key))
59
  else:
60
+ print('Faiss already loaded')
61
  # 构建模板
62
  llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
63
  messages_combine = [
 
150
  with gr.Row():
151
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.9, step=0.05, interactive=True,
152
  label="Top-p (nucleus sampling)", )
 
153
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=0.8, step=0.1, interactive=True,
154
  label="Temperature", )
155
+ max_tokens = gr.Slider(minimum=100, maximum=1000, value=200, step=100, interactive=True,
 
 
156
  label="Max Tokens", )
157
  chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
158
+
159
  with gr.Row():
160
+ enable_index = gr.Checkbox(label='是', info='开启文档问答模式/聊天模式')
161
+ enable_search = gr.Checkbox(label='是', info='是否使用搜索')
162
+ topic = gr.CheckboxGroup(["两会", "数字经济", "硅谷银行"], label='使用文档索引')
163
 
164
  chatbot = gr.Chatbot(elem_id='chatbot')
165
  inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
 
170
  run = gr.Button("Run")
171
 
172
  inputs.submit(predict,
173
+ [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic, chat_counter, chatbot,
174
  state],
175
  [chatbot, state, chat_counter], )
176
  run.click(predict,
177
+ [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic, chat_counter, chatbot,
178
  state],
179
  [chatbot, state, chat_counter], )
180
 
requirement.txt DELETED
@@ -1,3 +0,0 @@
1
- openai==0.27.2
2
- gradio==3.21.0
3
- langchain==0.0.113