xl2533 commited on
Commit
988418a
·
1 Parent(s): 48121bc
app.py CHANGED
@@ -3,30 +3,33 @@ import os
3
  import json
4
  import requests
5
  from langchain import FAISS
6
- from langchain.prompts import PromptTemplate
7
- from langchain.embeddings import CohereEmbeddings, HuggingFaceInstructEmbeddings
 
 
 
 
 
 
 
8
 
9
  # Streaming endpoint
10
  API_URL = "https://api.openai.com/v1/chat/completions" # os.getenv("API_URL") + "/generate_stream"
11
- embedding_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
12
- faiss_store = './output/数字经济'
13
-
14
-
15
- # load prompt template
16
- with open("prompts/combine_prompt.txt", "r") as f:
17
- template = f.read()
18
-
19
- with open("prompts/combine_prompt_hist.txt", "r") as f:
20
- template_hist = f.read()
21
-
22
- with open("prompts/question_prompt.txt", "r") as f:
23
- template_quest = f.read()
24
-
25
- with open("prompts/chat_combine_prompt.txt", "r") as f:
26
- chat_combine_template = f.read()
27
-
28
- with open("prompts/chat_reduce_prompt.txt", "r") as f:
29
- chat_reduce_template = f.read()
30
 
31
 
32
  def predict(inputs, top_p, temperature, openai_api_key, enable_index,
@@ -37,71 +40,67 @@ def predict(inputs, top_p, temperature, openai_api_key, enable_index,
37
  }
38
 
39
  print(f"chat_counter - {chat_counter}")
40
- # 如果有历史对话,把对话拼接进入上下文
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if chat_counter == 0:
42
- payload = {
43
- "model": "gpt-3.5-turbo",
44
- "messages": [{"role": "user", "content": f"{inputs}"}],
45
- "temperature": 1.0,
46
- "top_p": 1.0,
47
- "n": 1,
48
- "stream": True,
49
- "presence_penalty": 0,
50
- "frequency_penalty": 0,
51
- }
52
- else:
53
- messages = []
54
- if enable_index:
55
- pass
56
- # history = json.loads(history)
57
- # template_temp = template_hist.replace("{historyquestion}", history[0]).replace("{historyanswer}",
58
- # history[1])
59
- # c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template_temp,
60
- # template_format="jinja2")
61
- else:
62
- for data in chatbot:
63
- temp1 = {}
64
- temp1["role"] = "user"
65
- temp1["content"] = data[0]
66
- temp2 = {}
67
- temp2["role"] = "assistant"
68
- temp2["content"] = data[1]
69
- messages.append(temp1)
70
- messages.append(temp2)
71
 
72
- # Faiss 检索最近的embedding
73
- if enable_index:
74
- docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=embedding_key))
75
- else:
76
- temp3 = {}
77
- temp3["role"] = "user"
78
- temp3["content"] = inputs
79
- messages.append(temp3)
80
-
81
- # messages
82
- payload = {
83
- "model": "gpt-3.5-turbo",
84
- "messages": messages, # [{"role": "user", "content": f"{inputs}"}],
85
- "temperature": temperature, # 1.0,
86
- "top_p": top_p, # 1.0,
87
- "n": 1,
88
- "stream": True,
89
- "presence_penalty": 0,
90
- "frequency_penalty": 0,
91
- }
92
 
93
  chat_counter += 1
94
 
95
- # list of user input
96
  history.append(inputs)
97
  print(f"payload is - {payload}")
 
98
  print(f'chatbot - {chatbot}')
99
- print(f'chatbot - {chatbot}')
100
-
101
  response = requests.post(API_URL, headers=headers, json=payload, stream=True)
102
  token_counter = 0
103
  partial_words = ""
104
 
 
105
  counter = 0
106
  for chunk in response.iter_lines():
107
  if counter == 0:
@@ -111,9 +110,11 @@ def predict(inputs, top_p, temperature, openai_api_key, enable_index,
111
  # check whether each line is non-empty
112
  if chunk:
113
  # decode each line as response data is in bytes
114
- if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0:
 
115
  break
116
- partial_words = partial_words + json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"]
 
117
  if token_counter == 0:
118
  history.append(" " + partial_words)
119
  else:
@@ -136,7 +137,7 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
136
  inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
137
  state = gr.State([])
138
 
139
- clear = gr.Button("Clear")
140
  run = gr.Button("Run")
141
 
142
  # inputs, top_p, temperature, top_k, repetition_penalty
@@ -147,15 +148,17 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
147
  label="Temperature", )
148
  # top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
149
  # repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
150
- chat_counter = gr.Number(value=0, visible=False, precision=0)
151
  enable_index = gr.Checkbox(label='是', info='是否使用研报等金融数据')
 
 
152
 
153
  inputs.submit(predict, [inputs, top_p, temperature, openai_api_key, enable_index, chat_counter, chatbot, state],
154
  [chatbot, state, chat_counter], )
155
  run.click(predict, [inputs, top_p, temperature, openai_api_key, enable_index, chat_counter, chatbot, state],
156
  [chatbot, state, chat_counter], )
157
 
158
- # 每次对话结束都重置对话框
159
  clear.click(reset_textbox, [], [inputs], queue=False)
160
  inputs.submit(reset_textbox, [], [inputs])
161
 
 
3
  import json
4
  import requests
5
  from langchain import FAISS
6
+ from langchain.embeddings import CohereEmbeddings
7
+ from langchain import VectorDBQA
8
+ from langchain.chat_models import ChatOpenAI
9
+ from prompts import MyTemplate
10
+ from langchain.prompts.chat import (
11
+ ChatPromptTemplate,
12
+ SystemMessagePromptTemplate,
13
+ HumanMessagePromptTemplate,
14
+ )
15
 
16
  # Streaming endpoint
17
  API_URL = "https://api.openai.com/v1/chat/completions" # os.getenv("API_URL") + "/generate_stream"
18
+ embeddings_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
19
+ faiss_store = './indexer'
20
+
21
+ def gen_conversation(conversations):
22
+ messages = []
23
+ for data in conversations:
24
+ temp1 = {}
25
+ temp1["role"] = "user"
26
+ temp1["content"] = data[0]
27
+ temp2 = {}
28
+ temp2["role"] = "assistant"
29
+ temp2["content"] = data[1]
30
+ messages.append(temp1)
31
+ messages.append(temp2)
32
+ return messages
 
 
 
 
33
 
34
 
35
  def predict(inputs, top_p, temperature, openai_api_key, enable_index,
 
40
  }
41
 
42
  print(f"chat_counter - {chat_counter}")
43
+ #Debugging
44
+ if enable_index:
45
+ # Faiss 检索最近的embedding
46
+ docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=embeddings_key))
47
+ llm = ChatOpenAI(openai_api_key=openai_api_key)
48
+ messages_combine = [
49
+ SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
50
+ HumanMessagePromptTemplate.from_template("{question}")
51
+ ]
52
+ p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
53
+ messages_reduce = [
54
+ SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
55
+ HumanMessagePromptTemplate.from_template("{question}")
56
+ ]
57
+ p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
58
+ chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
59
+ k=4,
60
+ chain_type_kwargs={"question_prompt": p_chat_reduce,
61
+ "combine_prompt": p_chat_combine}
62
+ )
63
+ result = chain({"query": inputs})
64
+ print(result)
65
+
66
  if chat_counter == 0:
67
+ messages = [{"role": "user", "content": f"{inputs}"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ else:
70
+ # 如果有历史对话,把对话拼接进入上下文
71
+ messages = gen_conversation(chatbot)
72
+
73
+ temp3 = {}
74
+ temp3["role"] = "user"
75
+ temp3["content"] = inputs
76
+ messages.append(temp3)
77
+
78
+ # messages
79
+ payload = {
80
+ "model": "gpt-3.5-turbo",
81
+ "messages": messages, # [{"role": "user", "content": f"{inputs}"}],
82
+ "temperature": temperature, # 1.0,
83
+ "top_p": top_p, # 1.0,
84
+ "n": 1,
85
+ "stream": True,
86
+ "presence_penalty": 0,
87
+ "frequency_penalty": 0,
88
+ }
89
 
90
  chat_counter += 1
91
 
92
+ # History: Original Input and Output
93
  history.append(inputs)
94
  print(f"payload is - {payload}")
95
+ #上一轮回复的[[user, AI]]
96
  print(f'chatbot - {chatbot}')
97
+ print(f'Histroy - {history}')
98
+ # 请求OpenAI
99
  response = requests.post(API_URL, headers=headers, json=payload, stream=True)
100
  token_counter = 0
101
  partial_words = ""
102
 
103
+ # 逐字返回
104
  counter = 0
105
  for chunk in response.iter_lines():
106
  if counter == 0:
 
110
  # check whether each line is non-empty
111
  if chunk:
112
  # decode each line as response data is in bytes
113
+ delta = json.loads(chunk.decode()[6:])['choices'][0]["delta"]
114
+ if len(delta) == 0:
115
  break
116
+ partial_words += delta["content"]
117
+ # Keep Updating history
118
  if token_counter == 0:
119
  history.append(" " + partial_words)
120
  else:
 
137
  inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
138
  state = gr.State([])
139
 
140
+ clear = gr.Button("Clear Conversation")
141
  run = gr.Button("Run")
142
 
143
  # inputs, top_p, temperature, top_k, repetition_penalty
 
148
  label="Temperature", )
149
  # top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
150
  # repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
151
+ chat_counter = gr.Number(value=0, precision=0)
152
  enable_index = gr.Checkbox(label='是', info='是否使用研报等金融数据')
153
+ # 后续考虑加入搜索结果
154
+ enable_search = gr.Checkbox(label='是', info='是否使用搜索结果')
155
 
156
  inputs.submit(predict, [inputs, top_p, temperature, openai_api_key, enable_index, chat_counter, chatbot, state],
157
  [chatbot, state, chat_counter], )
158
  run.click(predict, [inputs, top_p, temperature, openai_api_key, enable_index, chat_counter, chatbot, state],
159
  [chatbot, state, chat_counter], )
160
 
161
+ # 每次对话结束都重置对话
162
  clear.click(reset_textbox, [], [inputs], queue=False)
163
  inputs.submit(reset_textbox, [], [inputs])
164
 
indexer/index.faiss ADDED
Binary file (32.8 kB). View file
 
indexer/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e6006bd06e30b017ce77e5859f1c7b0abcad6c69ac81c9067e1adeb448ac273
3
+ size 17619
prompts/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*-coding:utf-8 -*-
2
+
3
+ # load prompt template
4
+ with open("prompts/combine_prompt.txt", "r") as f:
5
+ template = f.read()
6
+
7
+ with open("prompts/combine_prompt_hist.txt", "r") as f:
8
+ template_hist = f.read()
9
+
10
+ with open("prompts/chat_combine_prompt.txt", "r") as f:
11
+ chat_combine_template = f.read()
12
+
13
+ with open("prompts/chat_reduce_prompt.txt", "r") as f:
14
+ chat_reduce_template = f.read()
15
+
16
+
17
+ MyTemplate ={
18
+ 'chat_reduce_template': chat_reduce_template,
19
+ 'chat_combine_template': chat_combine_template,
20
+ 'template_hist': template_hist,
21
+ 'template':template
22
+ }
prompts/chat_reduce_prompt.txt CHANGED
@@ -1,3 +1,3 @@
1
  Use the following portion of a long document to see if any of the text is relevant to answer the question.
2
  {context}
3
- Provide all relevant text to the question verbatim. Summarize if needed. If nothing relevant return "-".
 
1
  Use the following portion of a long document to see if any of the text is relevant to answer the question.
2
  {context}
3
+ Provide all relevant text to the question verbatim. Summarize if needed, Answer in Chinese. If nothing relevant return "-".