Spaces:
Runtime error
Runtime error
Commit
·
30049a9
1
Parent(s):
d4dd3c5
fix: fixing mistral answering and prompt formatting
Browse files- backend/controller.py +14 -11
- explanation/interpret_captum.py +1 -1
- explanation/interpret_shap.py +23 -23
- model/mistral.py +23 -12
backend/controller.py
CHANGED
@@ -59,13 +59,15 @@ def interference(
|
|
59 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
60 |
|
61 |
# call the explained chat function with the model instance
|
62 |
-
prompt_output, history_output,
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
)
|
70 |
# if no XAI approach is selected call the vanilla chat function
|
71 |
else:
|
@@ -78,16 +80,17 @@ def interference(
|
|
78 |
knowledge=knowledge,
|
79 |
)
|
80 |
# set XAI outputs to disclaimer html/none
|
81 |
-
|
82 |
"""
|
83 |
<div style="text-align: center"><h4>Without Selected XAI Approach,
|
84 |
no graphic will be displayed</h4></div>
|
85 |
""",
|
86 |
[("", "")],
|
|
|
87 |
)
|
88 |
|
89 |
# return the outputs
|
90 |
-
return prompt_output, history_output,
|
91 |
|
92 |
|
93 |
# simple chat function that calls the model
|
@@ -121,10 +124,10 @@ def explained_chat(
|
|
121 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
122 |
|
123 |
# generating an answer using the methods chat function
|
124 |
-
answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
|
125 |
|
126 |
# updating the chat history with the new answer
|
127 |
history.append((message, answer))
|
128 |
|
129 |
# returning the updated history, xai graphic and xai plot elements
|
130 |
-
return "", history, xai_graphic, xai_markup
|
|
|
59 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
60 |
|
61 |
# call the explained chat function with the model instance
|
62 |
+
prompt_output, history_output, xai_interactive, xai_markup, xai_plot = (
|
63 |
+
explained_chat(
|
64 |
+
model=model,
|
65 |
+
xai=xai,
|
66 |
+
message=prompt,
|
67 |
+
history=history,
|
68 |
+
system_prompt=system_prompt,
|
69 |
+
knowledge=knowledge,
|
70 |
+
)
|
71 |
)
|
72 |
# if no XAI approach is selected call the vanilla chat function
|
73 |
else:
|
|
|
80 |
knowledge=knowledge,
|
81 |
)
|
82 |
# set XAI outputs to disclaimer html/none
|
83 |
+
xai_interactive, xai_markup, xai_plot = (
|
84 |
"""
|
85 |
<div style="text-align: center"><h4>Without Selected XAI Approach,
|
86 |
no graphic will be displayed</h4></div>
|
87 |
""",
|
88 |
[("", "")],
|
89 |
+
None,
|
90 |
)
|
91 |
|
92 |
# return the outputs
|
93 |
+
return prompt_output, history_output, xai_interactive, xai_markup, xai_plot
|
94 |
|
95 |
|
96 |
# simple chat function that calls the model
|
|
|
124 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
125 |
|
126 |
# generating an answer using the methods chat function
|
127 |
+
answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
|
128 |
|
129 |
# updating the chat history with the new answer
|
130 |
history.append((message, answer))
|
131 |
|
132 |
# returning the updated history, xai graphic and xai plot elements
|
133 |
+
return "", history, xai_graphic, xai_markup, xai_plot
|
explanation/interpret_captum.py
CHANGED
@@ -52,4 +52,4 @@ def chat_explained(model, prompt):
|
|
52 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
53 |
|
54 |
# return response, graphic and marked_text array
|
55 |
-
return response_text, graphic, marked_text
|
|
|
52 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
53 |
|
54 |
# return response, graphic and marked_text array
|
55 |
+
return response_text, graphic, marked_text, None
|
explanation/interpret_shap.py
CHANGED
@@ -23,29 +23,6 @@ def extract_seq_att(shap_values):
|
|
23 |
return list(zip(shap_values.data[0], values))
|
24 |
|
25 |
|
26 |
-
# main explain function that returns a chat with explanations
|
27 |
-
def chat_explained(model, prompt):
|
28 |
-
model.set_config({})
|
29 |
-
|
30 |
-
# create the shap explainer
|
31 |
-
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
32 |
-
|
33 |
-
# get the shap values for the prompt
|
34 |
-
shap_values = shap_explainer([prompt])
|
35 |
-
|
36 |
-
# create the explanation graphic and marked text array
|
37 |
-
graphic = create_graphic(shap_values)
|
38 |
-
marked_text = markup_text(
|
39 |
-
shap_values.data[0], shap_values.values[0], variant="shap"
|
40 |
-
)
|
41 |
-
|
42 |
-
# create the response text
|
43 |
-
response_text = fmt.format_output_text(shap_values.output_names)
|
44 |
-
|
45 |
-
# return response, graphic and marked_text array
|
46 |
-
return response_text, graphic, marked_text
|
47 |
-
|
48 |
-
|
49 |
# function used to wrap the model with a shap model
|
50 |
def wrap_shap(model):
|
51 |
# calling global variants
|
@@ -80,3 +57,26 @@ def create_graphic(shap_values):
|
|
80 |
|
81 |
# return the html graphic as string to display in iFrame
|
82 |
return str(graphic_html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
return list(zip(shap_values.data[0], values))
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# function used to wrap the model with a shap model
|
27 |
def wrap_shap(model):
|
28 |
# calling global variants
|
|
|
57 |
|
58 |
# return the html graphic as string to display in iFrame
|
59 |
return str(graphic_html)
|
60 |
+
|
61 |
+
|
62 |
+
# main explain function that returns a chat with explanations
|
63 |
+
def chat_explained(model, prompt):
|
64 |
+
model.set_config({})
|
65 |
+
|
66 |
+
# create the shap explainer
|
67 |
+
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
68 |
+
|
69 |
+
# get the shap values for the prompt
|
70 |
+
shap_values = shap_explainer([prompt])
|
71 |
+
|
72 |
+
# create the explanation graphic and marked text array
|
73 |
+
graphic = create_graphic(shap_values)
|
74 |
+
marked_text = markup_text(
|
75 |
+
shap_values.data[0], shap_values.values[0], variant="shap"
|
76 |
+
)
|
77 |
+
|
78 |
+
# create the response text
|
79 |
+
response_text = fmt.format_output_text(shap_values.output_names)
|
80 |
+
|
81 |
+
# return response, graphic and marked_text array
|
82 |
+
return response_text, graphic, marked_text, None
|
model/mistral.py
CHANGED
@@ -58,8 +58,8 @@ def set_config(config_dict: dict):
|
|
58 |
|
59 |
|
60 |
# advanced formatting function that takes into a account a conversation history
|
61 |
-
# CREDIT:
|
62 |
-
|
63 |
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
64 |
prompt = ""
|
65 |
|
@@ -83,8 +83,13 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
|
|
83 |
# adds conversation history to the prompt
|
84 |
for conversation in history[1:]:
|
85 |
# takes all the following conversations and adds them as context
|
86 |
-
prompt += "".join(
|
|
|
|
|
87 |
|
|
|
|
|
|
|
88 |
return prompt
|
89 |
|
90 |
|
@@ -93,16 +98,22 @@ def format_answer(answer: str):
|
|
93 |
# empty answer string
|
94 |
formatted_answer = ""
|
95 |
|
96 |
-
#
|
97 |
-
|
98 |
-
if len(parts) >= 3:
|
99 |
-
# Return the text after the second occurrence of [/INST]
|
100 |
-
formatted_answer = parts[2].strip()
|
101 |
-
else:
|
102 |
-
# Return an empty string if there are fewer than two occurrences of [/INST]
|
103 |
-
formatted_answer = ""
|
104 |
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
return formatted_answer
|
107 |
|
108 |
|
|
|
58 |
|
59 |
|
60 |
# advanced formatting function that takes into a account a conversation history
|
61 |
+
# CREDIT: adapated from the Mistral AI Instruct chat template
|
62 |
+
# see https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja
|
63 |
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
64 |
prompt = ""
|
65 |
|
|
|
83 |
# adds conversation history to the prompt
|
84 |
for conversation in history[1:]:
|
85 |
# takes all the following conversations and adds them as context
|
86 |
+
prompt += "".join(
|
87 |
+
f"\n[INST] {conversation[0]} [/INST] {conversation[1]}</s>"
|
88 |
+
)
|
89 |
|
90 |
+
prompt += """\n[INST] {message} [/INST]"""
|
91 |
+
|
92 |
+
# returns full prompt
|
93 |
return prompt
|
94 |
|
95 |
|
|
|
98 |
# empty answer string
|
99 |
formatted_answer = ""
|
100 |
|
101 |
+
# splitting answer by instruction tokens
|
102 |
+
segments = answer.split("[/INST]")
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
# checking if proper history got returned
|
105 |
+
if len(segments) > 1:
|
106 |
+
# return text after the last ['/INST'] - reponse to last message
|
107 |
+
formatted_answer = segments[-1].strip()
|
108 |
+
else:
|
109 |
+
# return warning and full answer if not enough [/INST] tokens found
|
110 |
+
gr.Warning("""
|
111 |
+
There was an issue with answer formatting...\n
|
112 |
+
returning the full answer.
|
113 |
+
""")
|
114 |
+
formatted_answer = answer
|
115 |
+
|
116 |
+
print(f"CUT:\n {answer}\nINTO:\n{formatted_answer}")
|
117 |
return formatted_answer
|
118 |
|
119 |
|