Update appStore/rag.py
Browse files- appStore/rag.py +16 -15
appStore/rag.py
CHANGED
@@ -10,11 +10,12 @@ from huggingface_hub import InferenceClient
|
|
10 |
|
11 |
# Get openai API key
|
12 |
hf_token = os.environ["HF_API_KEY"]
|
|
|
13 |
# define a special function for putting the prompt together (as we can't use haystack)
|
14 |
def get_prompt(context, label):
|
15 |
base_prompt="Summarize the following context efficiently in bullet points, the less the better - but keep concrete goals. \
|
16 |
Summarize only elements of the context that address vulnerability of "+label+" to climate change. \
|
17 |
-
If there is no mention of "+label+" in the context, return
|
18 |
Do not include an introduction sentence, just the bullet points as per below. \
|
19 |
Formatting example: \
|
20 |
- Bullet point 1 \
|
@@ -32,7 +33,7 @@ def get_prompt(context, label):
|
|
32 |
# return openai.ChatCompletion.create(**kwargs)
|
33 |
|
34 |
# construct query, send to HF API and process response
|
35 |
-
def run_query(context, label):
|
36 |
'''
|
37 |
For non-streamed completion, enable the following 2 lines and comment out the code below
|
38 |
'''
|
@@ -40,29 +41,29 @@ def run_query(context, label):
|
|
40 |
messages = [{"role": "system", "content": chatbot_role},{"role": "user", "content": get_prompt(context, label)}]
|
41 |
|
42 |
# Initialize the client, pointing it to one of the available models
|
43 |
-
client = InferenceClient(
|
44 |
-
|
|
|
45 |
chat_completion = client.chat.completions.create(
|
46 |
messages=messages,
|
47 |
stream=True
|
48 |
)
|
|
|
49 |
|
50 |
# iterate through the streamed output
|
51 |
report = []
|
52 |
res_box = st.empty()
|
53 |
for chunk in chat_completion:
|
54 |
# extract the object containing the text (totally different structure when streaming)
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
if chunk.choices[0].finish_reason != None:
|
65 |
-
break
|
66 |
|
67 |
|
68 |
|
|
|
10 |
|
11 |
# Get openai API key
|
12 |
hf_token = os.environ["HF_API_KEY"]
|
13 |
+
|
14 |
# define a special function for putting the prompt together (as we can't use haystack)
|
15 |
def get_prompt(context, label):
|
16 |
base_prompt="Summarize the following context efficiently in bullet points, the less the better - but keep concrete goals. \
|
17 |
Summarize only elements of the context that address vulnerability of "+label+" to climate change. \
|
18 |
+
If there is no mention of "+label+" in the context, return: 'No clear references to vulnerability of "+label+" found'. \
|
19 |
Do not include an introduction sentence, just the bullet points as per below. \
|
20 |
Formatting example: \
|
21 |
- Bullet point 1 \
|
|
|
33 |
# return openai.ChatCompletion.create(**kwargs)
|
34 |
|
35 |
# construct query, send to HF API and process response
|
36 |
+
def run_query(context, label, model_sel_name):
|
37 |
'''
|
38 |
For non-streamed completion, enable the following 2 lines and comment out the code below
|
39 |
'''
|
|
|
41 |
messages = [{"role": "system", "content": chatbot_role},{"role": "user", "content": get_prompt(context, label)}]
|
42 |
|
43 |
# Initialize the client, pointing it to one of the available models
|
44 |
+
client = InferenceClient(model_sel_name, token = hf_token)
|
45 |
+
|
46 |
+
# instantiate ChatCompletion as a generator object (stream is set to True)
|
47 |
chat_completion = client.chat.completions.create(
|
48 |
messages=messages,
|
49 |
stream=True
|
50 |
)
|
51 |
+
# chat_completion = completion_with_backoff(messages=messages, stream=True)
|
52 |
|
53 |
# iterate through the streamed output
|
54 |
report = []
|
55 |
res_box = st.empty()
|
56 |
for chunk in chat_completion:
|
57 |
# extract the object containing the text (totally different structure when streaming)
|
58 |
+
if chunk.choices is not None: # sometimes returns None - probably the prompt needs work
|
59 |
+
chunk_message = chunk.choices[0].delta
|
60 |
+
# test to make sure there is text in the object (some don't have)
|
61 |
+
if 'content' in chunk_message:
|
62 |
+
report.append(chunk_message['content']) # extract the message
|
63 |
+
# add the latest text and merge it with all previous
|
64 |
+
result = "".join(report).strip()
|
65 |
+
# res_box.success(result) # output to response text box
|
66 |
+
res_box.success(result)
|
|
|
|
|
67 |
|
68 |
|
69 |
|