LennardZuendorf commited on
Commit
30049a9
·
1 Parent(s): d4dd3c5

fix: fixing mistral answering and prompt formatting

Browse files
backend/controller.py CHANGED
@@ -59,13 +59,15 @@ def interference(
59
  raise RuntimeError("There was an error in the selected XAI approach.")
60
 
61
  # call the explained chat function with the model instance
62
- prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
63
- model=model,
64
- xai=xai,
65
- message=prompt,
66
- history=history,
67
- system_prompt=system_prompt,
68
- knowledge=knowledge,
 
 
69
  )
70
  # if no XAI approach is selected call the vanilla chat function
71
  else:
@@ -78,16 +80,17 @@ def interference(
78
  knowledge=knowledge,
79
  )
80
  # set XAI outputs to disclaimer html/none
81
- xai_graphic, xai_markup = (
82
  """
83
  <div style="text-align: center"><h4>Without Selected XAI Approach,
84
  no graphic will be displayed</h4></div>
85
  """,
86
  [("", "")],
 
87
  )
88
 
89
  # return the outputs
90
- return prompt_output, history_output, xai_graphic, xai_markup
91
 
92
 
93
  # simple chat function that calls the model
@@ -121,10 +124,10 @@ def explained_chat(
121
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
122
 
123
  # generating an answer using the methods chat function
124
- answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
125
 
126
  # updating the chat history with the new answer
127
  history.append((message, answer))
128
 
129
  # returning the updated history, xai graphic and xai plot elements
130
- return "", history, xai_graphic, xai_markup
 
59
  raise RuntimeError("There was an error in the selected XAI approach.")
60
 
61
  # call the explained chat function with the model instance
62
+ prompt_output, history_output, xai_interactive, xai_markup, xai_plot = (
63
+ explained_chat(
64
+ model=model,
65
+ xai=xai,
66
+ message=prompt,
67
+ history=history,
68
+ system_prompt=system_prompt,
69
+ knowledge=knowledge,
70
+ )
71
  )
72
  # if no XAI approach is selected call the vanilla chat function
73
  else:
 
80
  knowledge=knowledge,
81
  )
82
  # set XAI outputs to disclaimer html/none
83
+ xai_interactive, xai_markup, xai_plot = (
84
  """
85
  <div style="text-align: center"><h4>Without Selected XAI Approach,
86
  no graphic will be displayed</h4></div>
87
  """,
88
  [("", "")],
89
+ None,
90
  )
91
 
92
  # return the outputs
93
+ return prompt_output, history_output, xai_interactive, xai_markup, xai_plot
94
 
95
 
96
  # simple chat function that calls the model
 
124
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
125
 
126
  # generating an answer using the methods chat function
127
+ answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
128
 
129
  # updating the chat history with the new answer
130
  history.append((message, answer))
131
 
132
  # returning the updated history, xai graphic and xai plot elements
133
+ return "", history, xai_graphic, xai_markup, xai_plot
explanation/interpret_captum.py CHANGED
@@ -52,4 +52,4 @@ def chat_explained(model, prompt):
52
  marked_text = markup_text(input_tokens, values, variant="captum")
53
 
54
  # return response, graphic and marked_text array
55
- return response_text, graphic, marked_text
 
52
  marked_text = markup_text(input_tokens, values, variant="captum")
53
 
54
  # return response, graphic and marked_text array
55
+ return response_text, graphic, marked_text, None
explanation/interpret_shap.py CHANGED
@@ -23,29 +23,6 @@ def extract_seq_att(shap_values):
23
  return list(zip(shap_values.data[0], values))
24
 
25
 
26
- # main explain function that returns a chat with explanations
27
- def chat_explained(model, prompt):
28
- model.set_config({})
29
-
30
- # create the shap explainer
31
- shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
32
-
33
- # get the shap values for the prompt
34
- shap_values = shap_explainer([prompt])
35
-
36
- # create the explanation graphic and marked text array
37
- graphic = create_graphic(shap_values)
38
- marked_text = markup_text(
39
- shap_values.data[0], shap_values.values[0], variant="shap"
40
- )
41
-
42
- # create the response text
43
- response_text = fmt.format_output_text(shap_values.output_names)
44
-
45
- # return response, graphic and marked_text array
46
- return response_text, graphic, marked_text
47
-
48
-
49
  # function used to wrap the model with a shap model
50
  def wrap_shap(model):
51
  # calling global variants
@@ -80,3 +57,26 @@ def create_graphic(shap_values):
80
 
81
  # return the html graphic as string to display in iFrame
82
  return str(graphic_html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  return list(zip(shap_values.data[0], values))
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # function used to wrap the model with a shap model
27
  def wrap_shap(model):
28
  # calling global variants
 
57
 
58
  # return the html graphic as string to display in iFrame
59
  return str(graphic_html)
60
+
61
+
62
+ # main explain function that returns a chat with explanations
63
+ def chat_explained(model, prompt):
64
+ model.set_config({})
65
+
66
+ # create the shap explainer
67
+ shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
68
+
69
+ # get the shap values for the prompt
70
+ shap_values = shap_explainer([prompt])
71
+
72
+ # create the explanation graphic and marked text array
73
+ graphic = create_graphic(shap_values)
74
+ marked_text = markup_text(
75
+ shap_values.data[0], shap_values.values[0], variant="shap"
76
+ )
77
+
78
+ # create the response text
79
+ response_text = fmt.format_output_text(shap_values.output_names)
80
+
81
+ # return response, graphic and marked_text array
82
+ return response_text, graphic, marked_text, None
model/mistral.py CHANGED
@@ -58,8 +58,8 @@ def set_config(config_dict: dict):
58
 
59
 
60
  # advanced formatting function that takes into a account a conversation history
61
- # CREDIT: adapted from Venkata Bhanu Teja Pallakonda in Huggingface discussions
62
- ## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/
63
  def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
64
  prompt = ""
65
 
@@ -83,8 +83,13 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
83
  # adds conversation history to the prompt
84
  for conversation in history[1:]:
85
  # takes all the following conversations and adds them as context
86
- prompt += "".join(f"[INST] {conversation[0]} [/INST] {conversation[1]}</s>")
 
 
87
 
 
 
 
88
  return prompt
89
 
90
 
@@ -93,16 +98,22 @@ def format_answer(answer: str):
93
  # empty answer string
94
  formatted_answer = ""
95
 
96
- # extracting text after INST tokens
97
- parts = answer.split("[/INST]")
98
- if len(parts) >= 3:
99
- # Return the text after the second occurrence of [/INST]
100
- formatted_answer = parts[2].strip()
101
- else:
102
- # Return an empty string if there are fewer than two occurrences of [/INST]
103
- formatted_answer = ""
104
 
105
- print(f"Cut {answer} into {formatted_answer}.")
 
 
 
 
 
 
 
 
 
 
 
 
106
  return formatted_answer
107
 
108
 
 
58
 
59
 
60
  # advanced formatting function that takes into a account a conversation history
61
+ # CREDIT: adapated from the Mistral AI Instruct chat template
62
+ # see https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja
63
  def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
64
  prompt = ""
65
 
 
83
  # adds conversation history to the prompt
84
  for conversation in history[1:]:
85
  # takes all the following conversations and adds them as context
86
+ prompt += "".join(
87
+ f"\n[INST] {conversation[0]} [/INST] {conversation[1]}</s>"
88
+ )
89
 
90
+ prompt += """\n[INST] {message} [/INST]"""
91
+
92
+ # returns full prompt
93
  return prompt
94
 
95
 
 
98
  # empty answer string
99
  formatted_answer = ""
100
 
101
+ # splitting answer by instruction tokens
102
+ segments = answer.split("[/INST]")
 
 
 
 
 
 
103
 
104
+ # checking if proper history got returned
105
+ if len(segments) > 1:
106
+ # return text after the last ['/INST'] - reponse to last message
107
+ formatted_answer = segments[-1].strip()
108
+ else:
109
+ # return warning and full answer if not enough [/INST] tokens found
110
+ gr.Warning("""
111
+ There was an issue with answer formatting...\n
112
+ returning the full answer.
113
+ """)
114
+ formatted_answer = answer
115
+
116
+ print(f"CUT:\n {answer}\nINTO:\n{formatted_answer}")
117
  return formatted_answer
118
 
119