xl2533 commited on
Commit
7e69c1f
·
1 Parent(s): 8720494

add topic, qa answer

Browse files
app.py CHANGED
@@ -14,9 +14,10 @@ from langchain.prompts.chat import (
14
  )
15
 
16
  # Streaming endpoint
17
- API_URL = "https://api.openai.com/v1/chat/completions" # os.getenv("API_URL") + "/generate_stream"
18
  cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
19
- faiss_store = './indexer'
 
20
 
21
  def gen_conversation(conversations):
22
  messages = []
@@ -32,22 +33,28 @@ def gen_conversation(conversations):
32
  return messages
33
 
34
 
35
- def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model,
36
  chat_counter, chatbot=[], history=[]):
37
  model = model[0]
 
38
  headers = {
39
  "Content-Type": "application/json",
40
  "Authorization": f"Bearer {openai_api_key}"
41
  }
42
 
43
  print(f"chat_counter - {chat_counter}")
44
- #Debugging
 
 
 
 
45
  if enable_index:
46
  # Faiss 检索最近的embedding
47
- if model =='openai':
48
- docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key))
 
49
  else:
50
- docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=cohere_key ))
51
  # 构建模板
52
  llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
53
  messages_combine = [
@@ -67,66 +74,62 @@ def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens
67
  )
68
  result = chain({"query": inputs})
69
  print(result)
70
-
71
- if chat_counter == 0:
72
- messages = [{"role": "user", "content": f"{inputs}"}]
73
-
 
 
74
  else:
75
- # 如果有历史对话,把对话拼接进入上下文
76
- messages = gen_conversation(chatbot)
77
-
78
- temp3 = {}
79
- temp3["role"] = "user"
80
- temp3["content"] = inputs
81
- messages.append(temp3)
82
-
83
- # messages
84
- payload = {
85
- "model": "gpt-3.5-turbo",
86
- "messages": messages, # [{"role": "user", "content": f"{inputs}"}],
87
- "temperature": temperature, # 1.0,
88
- "top_p": top_p, # 1.0,
89
- "n": 1,
90
- "stream": True,
91
- "presence_penalty": 0,
92
- "frequency_penalty": 0,
93
- }
94
-
95
- chat_counter += 1
96
-
97
- # History: Original Input and Output
98
- history.append(inputs)
99
- print(f"payload is - {payload}")
100
- #上一轮回复的[[user, AI]]
101
- print(f'chatbot - {chatbot}')
102
- print(f'Histroy - {history}')
103
- # 请求OpenAI
104
- response = requests.post(API_URL, headers=headers, json=payload, stream=True)
105
- token_counter = 0
106
- partial_words = ""
107
-
108
- # 逐字返回
109
- counter = 0
110
- for chunk in response.iter_lines():
111
- if counter == 0:
112
  counter += 1
113
- continue
114
- counter += 1
115
- # check whether each line is non-empty
116
- if chunk:
117
- # decode each line as response data is in bytes
118
- delta = json.loads(chunk.decode()[6:])['choices'][0]["delta"]
119
- if len(delta) == 0:
120
- break
121
- partial_words += delta["content"]
122
- # Keep Updating history
123
- if token_counter == 0:
124
- history.append(" " + partial_words)
125
- else:
126
- history[-1] = partial_words
127
- chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] # convert to tuples of list
128
- token_counter += 1
129
- yield chat, history, chat_counter
130
 
131
 
132
  def reset_textbox():
@@ -138,30 +141,40 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
138
  gr.HTML("""<h1 align="center">🚀Finance ChatBot🚀</h1>""")
139
  with gr.Column(elem_id="col_container"):
140
  openai_api_key = gr.Textbox(type='password', label="输入OPEN API Key")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  chatbot = gr.Chatbot(elem_id='chatbot')
142
  inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
143
  state = gr.State([])
144
 
145
- clear = gr.Button("Clear Conversation")
146
- run = gr.Button("Run")
 
147
 
148
- # inputs, top_p, temperature, top_k, repetition_penalty
149
- with gr.Accordion("Parameters", open=True):
150
- top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True,
151
- label="Top-p (nucleus sampling)", )
152
- max_tokens = gr.Slider(minimum=512, maximum=3000, value=3000, step=100, interactive=True,
153
- label="Max Tokens", )
154
- temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True,
155
- label="Temperature", )
156
- model = gr.CheckboxGroup(["cohere", "openai", "mpnet"])
157
- chat_counter = gr.Number(value=0, precision=0)
158
- enable_index = gr.Checkbox(label='是', info='是否使用研报等金融数据')
159
- # 后续考虑加入搜索结果
160
- enable_search = gr.Checkbox(label='是', info='是否使用搜索结果')
161
-
162
- inputs.submit(predict, [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, chat_counter, chatbot, state],
163
  [chatbot, state, chat_counter], )
164
- run.click(predict, [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, chat_counter, chatbot, state],
 
 
165
  [chatbot, state, chat_counter], )
166
 
167
  # 每次对话结束都重置对话
 
14
  )
15
 
16
  # Streaming endpoint
17
+ API_URL = "https://api.openai.com/v1/chat/completions"
18
  cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
19
+ faiss_store = './indexer/{}'
20
+
21
 
22
  def gen_conversation(conversations):
23
  messages = []
 
33
  return messages
34
 
35
 
36
+ def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, topic,
37
  chat_counter, chatbot=[], history=[]):
38
  model = model[0]
39
+ topic = topic[0]
40
  headers = {
41
  "Content-Type": "application/json",
42
  "Authorization": f"Bearer {openai_api_key}"
43
  }
44
 
45
  print(f"chat_counter - {chat_counter}")
46
+ print(f'Histroy - {history}') # History: Original Input and Output in flatten list
47
+ print(f'chatbot - {chatbot}') # Chat Bot: 上一轮回复的[[user, AI]]
48
+
49
+ history.append(inputs)
50
+ # Debugging
51
  if enable_index:
52
  # Faiss 检索最近的embedding
53
+ store = faiss_store.format(topic)
54
+ if model == 'openai':
55
+ docsearch = FAISS.load_local(store, OpenAIEmbeddings(openai_api_key=openai_api_key))
56
  else:
57
+ docsearch = FAISS.load_local(store, CohereEmbeddings(cohere_api_key=cohere_key))
58
  # 构建模板
59
  llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
60
  messages_combine = [
 
74
  )
75
  result = chain({"query": inputs})
76
  print(result)
77
+ result = result['result']
78
+ # 生成返回值
79
+ history.append(result)
80
+ chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
81
+ chat_counter += 1
82
+ yield chat, history, chat_counter
83
  else:
84
+ if chat_counter == 0:
85
+ messages = [{"role": "user", "content": f"{inputs}"}]
86
+ else:
87
+ # 如果有历史对话,把对话拼接进入上下文
88
+ messages = gen_conversation(chatbot)
89
+ messages.append({'role': 'user', 'content': inputs})
90
+ # messages
91
+ payload = {
92
+ "model": "gpt-3.5-turbo",
93
+ "messages": messages, # [{"role": "user", "content": f"{inputs}"}],
94
+ "temperature": temperature, # 1.0,
95
+ "top_p": top_p, # 1.0,
96
+ "n": 1,
97
+ "stream": True,
98
+ "presence_penalty": 0,
99
+ "frequency_penalty": 0,
100
+ }
101
+ print(f"payload is - {payload}")
102
+
103
+ chat_counter += 1
104
+
105
+ # 请求OpenAI
106
+ response = requests.post(API_URL, headers=headers, json=payload, stream=True)
107
+ token_counter = 0
108
+ partial_words = ""
109
+
110
+ # 逐字返回
111
+ counter = 0
112
+ for chunk in response.iter_lines():
113
+ if counter == 0:
114
+ counter += 1
115
+ continue
 
 
 
 
 
116
  counter += 1
117
+ # check whether each line is non-empty
118
+ if chunk:
119
+ # decode each line as response data is in bytes
120
+ delta = json.loads(chunk.decode()[6:])['choices'][0]["delta"]
121
+ if len(delta) == 0:
122
+ break
123
+ partial_words += delta["content"]
124
+ # Keep Updating history
125
+ if token_counter == 0:
126
+ history.append(" " + partial_words)
127
+ else:
128
+ history[-1] = partial_words
129
+ chat = [(history[i], history[i + 1]) for i in
130
+ range(0, len(history) - 1, 2)] # convert to tuples of list
131
+ token_counter += 1
132
+ yield chat, history, chat_counter
 
133
 
134
 
135
  def reset_textbox():
 
141
  gr.HTML("""<h1 align="center">🚀Finance ChatBot🚀</h1>""")
142
  with gr.Column(elem_id="col_container"):
143
  openai_api_key = gr.Textbox(type='password', label="输入OPEN API Key")
144
+
145
+ # inputs, top_p, temperature, top_k, repetition_penalty
146
+ with gr.Accordion("Parameters", open=True):
147
+ with gr.Row():
148
+ top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.9, step=0.05, interactive=True,
149
+ label="Top-p (nucleus sampling)", )
150
+
151
+ temperature = gr.Slider(minimum=-0, maximum=5.0, value=0.8, step=0.1, interactive=True,
152
+ label="Temperature", )
153
+ with gr.Row():
154
+ model = gr.CheckboxGroup(["cohere", "openai", "mpnet"])
155
+ max_tokens = gr.Slider(minimum=100, maximum=2000, value=200, step=100, interactive=True,
156
+ label="Max Tokens", )
157
+ chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
158
+ with gr.Row():
159
+ enable_index = gr.Checkbox(label='是', info='开启基于文档问答模式/关闭为聊天模式')
160
+ enable_search = gr.Checkbox(label='是', info='是否使用搜索结果')
161
+ topic = gr.CheckboxGroup(["两会", "数字经济", "硅谷银行"])
162
+
163
  chatbot = gr.Chatbot(elem_id='chatbot')
164
  inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
165
  state = gr.State([])
166
 
167
+ with gr.Row():
168
+ clear = gr.Button("Clear Conversation")
169
+ run = gr.Button("Run")
170
 
171
+ inputs.submit(predict,
172
+ [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, topic, chat_counter, chatbot,
173
+ state],
 
 
 
 
 
 
 
 
 
 
 
 
174
  [chatbot, state, chat_counter], )
175
+ run.click(predict,
176
+ [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, topic, chat_counter, chatbot,
177
+ state],
178
  [chatbot, state, chat_counter], )
179
 
180
  # 每次对话结束都重置对话
prompts/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (545 Bytes). View file
 
prompts/chat_combine_prompt.txt CHANGED
@@ -1,4 +1,4 @@
1
- You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible.
2
- Use the following pieces of context to help answer the users question.
3
  ----------------
4
  {summaries}
 
1
+ You are a DocsGPT, friendly and helpful AI assistant by TianHong Asset Managementthat provides help with documents and financial news. You give thorough answers with detail number and illustrated examples if possible.
2
+ Use the following pieces of context to help answer the users question, always answer in chinese.
3
  ----------------
4
  {summaries}