LennardZuendorf commited on
Commit
5d99c07
·
1 Parent(s): c5c1df2

feat: adding mistral model again

Browse files
backend/controller.py CHANGED
@@ -6,6 +6,8 @@ import gradio as gr
6
 
7
  # internal imports
8
  from model import godel
 
 
9
  from explanation import interpret_shap as shap_int, visualize as viz
10
 
11
 
@@ -17,14 +19,20 @@ def interference(
17
  knowledge: str,
18
  system_prompt: str,
19
  xai_selection: str,
 
20
  ):
21
  # if no proper system prompt is given, use a default one
22
- if system_prompt in ('', ' '):
23
  system_prompt = """
24
  You are a helpful, respectful and honest assistant.
25
  Always answer as helpfully as possible, while being safe.
26
  """
27
 
 
 
 
 
 
28
  # if a XAI approach is selected, grab the XAI module instance
29
  if xai_selection in ("SHAP", "Attention"):
30
  # matching selection
@@ -44,7 +52,7 @@ def interference(
44
 
45
  # call the explained chat function with the model instance
46
  prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
47
- model=godel,
48
  xai=xai,
49
  message=prompt,
50
  history=history,
@@ -55,7 +63,7 @@ def interference(
55
  else:
56
  # call the vanilla chat function
57
  prompt_output, history_output = vanilla_chat(
58
- model=godel,
59
  message=prompt,
60
  history=history,
61
  system_prompt=system_prompt,
@@ -95,6 +103,9 @@ def explained_chat(
95
  model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
96
  ):
97
  # formatting the prompt using the model's format_prompt function
 
 
 
98
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
99
 
100
  # generating an answer using the methods chat function
 
6
 
7
  # internal imports
8
  from model import godel
9
+ from model import mistral
10
+ from utils import modelling as mdl
11
  from explanation import interpret_shap as shap_int, visualize as viz
12
 
13
 
 
19
  knowledge: str,
20
  system_prompt: str,
21
  xai_selection: str,
22
+ model_selection: str,
23
  ):
24
  # if no proper system prompt is given, use a default one
25
+ if system_prompt in ("", " "):
26
  system_prompt = """
27
  You are a helpful, respectful and honest assistant.
28
  Always answer as helpfully as possible, while being safe.
29
  """
30
 
31
+ if model_selection.lower == "mistral":
32
+ model = mistral
33
+ else:
34
+ model = godel
35
+
36
  # if a XAI approach is selected, grab the XAI module instance
37
  if xai_selection in ("SHAP", "Attention"):
38
  # matching selection
 
52
 
53
  # call the explained chat function with the model instance
54
  prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
55
+ model=model,
56
  xai=xai,
57
  message=prompt,
58
  history=history,
 
63
  else:
64
  # call the vanilla chat function
65
  prompt_output, history_output = vanilla_chat(
66
+ model=model,
67
  message=prompt,
68
  history=history,
69
  system_prompt=system_prompt,
 
103
  model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
104
  ):
105
  # formatting the prompt using the model's format_prompt function
106
+ message, history, system_prompt, knowledge = mdl.prompt_limiter(
107
+ message, history, system_prompt, knowledge
108
+ )
109
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
110
 
111
  # generating an answer using the methods chat function
explanation/interpret_captum.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # external imports
2
+ from captum.attr import LLMAttribution, TextTokenInput, KernelShap
3
+ import torch
4
+
5
+ # internal imports
6
+ from utils import formatting as fmt
7
+ from .markup import markup_text
8
+
9
+
10
+ # main explain function that returns a chat with explanations
11
+ def chat_explained(model, prompt):
12
+ model.set_config({})
13
+
14
+ # creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
15
+ llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
16
+
17
+ # generation attribution
18
+ attribution_input = TextTokenInput(prompt, model.TOKENIZER)
19
+ attribution_result = llm_attribution.attribute(attribution_input)
20
+
21
+ # extracting values and input tokens
22
+ values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
23
+ input_tokens = fmt.format_tokens(attribution_result.input_tokens)
24
+
25
+ # raising error if mismatch occurs
26
+ if len(attribution_result.input_tokens) != len(values):
27
+ raise RuntimeError("values and input len mismatch")
28
+
29
+ # getting response text, graphic placeholder and marked text object
30
+ response_text = fmt.format_output_text(attribution_result.output_tokens)
31
+ graphic = (
32
+ "<div style='text-align: center; font-family:arial;'><h4>Attention"
33
+ "Intepretation with Captum doesn't support an interactive graphic.</h4></div>"
34
+ )
35
+ marked_text = markup_text(input_tokens, values, variant="captum")
36
+
37
+ # return response, graphic and marked_text array
38
+ return response_text, graphic, marked_text
explanation/markup.py CHANGED
@@ -18,7 +18,7 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
18
  if variant == "shap":
19
  text_values = np.transpose(text_values)
20
  text_values = fmt.flatten_attribution(text_values)
21
- else:
22
  text_values = fmt.flatten_attention(text_values)
23
 
24
  # Determine the minimum and maximum values
 
18
  if variant == "shap":
19
  text_values = np.transpose(text_values)
20
  text_values = fmt.flatten_attribution(text_values)
21
+ elif variant == "visualizer":
22
  text_values = fmt.flatten_attention(text_values)
23
 
24
  # Determine the minimum and maximum values
explanation/visualize.py CHANGED
@@ -3,7 +3,7 @@
3
 
4
  # internal imports
5
  from utils import formatting as fmt
6
- from model.godel import CONFIG
7
  from .markup import markup_text
8
 
9
 
 
3
 
4
  # internal imports
5
  from utils import formatting as fmt
6
+ from model.model import CONFIG
7
  from .markup import markup_text
8
 
9
 
main.py CHANGED
@@ -97,31 +97,40 @@ with gr.Blocks(
97
  """)
98
  # row with columns for the different settings
99
  with gr.Row(equal_height=True):
100
- # accordion that extends if clicked
101
- with gr.Accordion(label="Application Settings", open=False):
102
- # column that takes up 3/4 of the row
103
- with gr.Column(scale=3):
104
- # textbox to enter the system prompt
105
- system_prompt = gr.Textbox(
106
- label="System Prompt",
107
- info="Set the models system prompt, dictating how it answers.",
108
- # default system prompt is set to this in the backend
109
- placeholder=(
110
- "You are a helpful, respectful and honest assistant. Always"
111
- " answer as helpfully as possible, while being safe."
112
- ),
113
- )
114
- # column that takes up 1/4 of the row
115
- with gr.Column(scale=1):
116
- # checkbox group to select the xai method
117
- xai_selection = gr.Radio(
118
- ["None", "SHAP", "Attention"],
119
- label="Interpretability Settings",
120
- info="Select a Interpretability Implementation to use.",
121
- value="None",
122
- interactive=True,
123
- show_label=True,
124
- )
 
 
 
 
 
 
 
 
 
125
 
126
  # calling info functions on inputs/submits for different settings
127
  system_prompt.submit(system_prompt_info, [system_prompt])
@@ -247,13 +256,27 @@ with gr.Blocks(
247
  ## see backend/controller.py for more information
248
  submit_btn.click(
249
  interference,
250
- [user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
 
 
 
 
 
 
 
251
  [user_prompt, chatbot, xai_interactive, xai_text],
252
  )
253
  # function triggered by the enter key
254
  user_prompt.submit(
255
  interference,
256
- [user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
 
 
 
 
 
 
 
257
  [user_prompt, chatbot, xai_interactive, xai_text],
258
  )
259
 
 
97
  """)
98
  # row with columns for the different settings
99
  with gr.Row(equal_height=True):
100
+ # column that takes up 3/4 of the row
101
+ with gr.Column(scale=2):
102
+ # textbox to enter the system prompt
103
+ system_prompt = gr.Textbox(
104
+ label="System Prompt",
105
+ info="Set the models system prompt, dictating how it answers.",
106
+ # default system prompt is set to this in the backend
107
+ placeholder=(
108
+ "You are a helpful, respectful and honest assistant. Always"
109
+ " answer as helpfully as possible, while being safe."
110
+ ),
111
+ )
112
+ # column that takes up 1/4 of the row
113
+ with gr.Column(scale=1):
114
+ # checkbox group to select the xai method
115
+ xai_selection = gr.Radio(
116
+ ["None", "SHAP", "Attention"],
117
+ label="Interpretability Settings",
118
+ info="Select a Interpretability Implementation to use.",
119
+ value="None",
120
+ interactive=True,
121
+ show_label=True,
122
+ )
123
+ # column that takes up 1/4 of the row
124
+ with gr.Column(scale=1):
125
+ # checkbox group to select the xai method
126
+ model_selection = gr.Radio(
127
+ ["GODEL", "Mistral"],
128
+ label="Model Settings",
129
+ info="Select a Model to use.",
130
+ value="GODEL",
131
+ interactive=True,
132
+ show_label=True,
133
+ )
134
 
135
  # calling info functions on inputs/submits for different settings
136
  system_prompt.submit(system_prompt_info, [system_prompt])
 
256
  ## see backend/controller.py for more information
257
  submit_btn.click(
258
  interference,
259
+ [
260
+ user_prompt,
261
+ chatbot,
262
+ knowledge_input,
263
+ system_prompt,
264
+ xai_selection,
265
+ model_selection,
266
+ ],
267
  [user_prompt, chatbot, xai_interactive, xai_text],
268
  )
269
  # function triggered by the enter key
270
  user_prompt.submit(
271
  interference,
272
+ [
273
+ user_prompt,
274
+ chatbot,
275
+ knowledge_input,
276
+ system_prompt,
277
+ xai_selection,
278
+ model_selection,
279
+ ],
280
  [user_prompt, chatbot, xai_interactive, xai_text],
281
  )
282
 
model/mistral.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mistral model module for chat interaction and model instance control
2
+
3
+ # external imports
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+ import gradio as gr
7
+
8
+ # internal imports
9
+ from utils import modelling as mdl
10
+
11
+ # global model and tokenizer instance (created on inital build)
12
+ device = mdl.get_device()
13
+ if device == torch.device("cuda"):
14
+ n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()
15
+
16
+ MODEL = AutoModelForCausalLM.from_pretrained(
17
+ "mistralai/Mistral-7B-Instruct-v0.2",
18
+ quantization_config=bnb_config,
19
+ device_map="auto", # dispatch efficiently the model on the available ressources
20
+ max_memory={i: max_memory for i in range(n_gpus)},
21
+ )
22
+
23
+ else:
24
+ MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
25
+ MODEL.to(device)
26
+ TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
27
+
28
+ # default model config
29
+ CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
30
+
31
+
32
+ # function to (re) set config
33
+ def set_config(config: dict):
34
+ global CONFIG
35
+
36
+ # if config dict is given, update it
37
+ if config != {}:
38
+ CONFIG = config
39
+ else:
40
+ # hard setting model config to default
41
+ # needed for shap
42
+ MODEL.config.max_new_tokens = 50
43
+ MODEL.config.min_length = 8
44
+ MODEL.config.top_p = 0.9
45
+ MODEL.config.do_sample = True
46
+
47
+
48
+ # advanced formatting function that takes into a account a conversation history
49
+ # CREDIT: adapted from Venkata Bhanu Teja Pallakonda in Huggingface discussions
50
+ ## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/
51
+ def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
52
+ prompt = ""
53
+
54
+ if knowledge != "":
55
+ gr.Info("""
56
+ Mistral doesn't support additional knowledge, it's gonna be ignored.
57
+ """)
58
+
59
+ # if no history, use system prompt and example message
60
+ if len(history) == 0:
61
+ prompt = f"""<s>[INST] {system_prompt} [/INST] How can I help you today? </s>
62
+ [INST] {message} [/INST]"""
63
+ else:
64
+ # takes the very first exchange and the system prompt as base
65
+ for user_prompt, bot_response in history[0]:
66
+ prompt = (
67
+ f"<s>[INST] {system_prompt} {user_prompt} [/INST] {bot_response}</s>"
68
+ )
69
+
70
+ # takes all the following conversations and adds them as context
71
+ prompt += "".join(
72
+ f"[INST] {user_prompt} [/INST] {bot_response}</s>"
73
+ for user_prompt, bot_response in history[1:]
74
+ )
75
+
76
+ return prompt
77
+
78
+
79
+ # function to extract real answer because mistral always returns the full prompt
80
+ def format_answer(answer: str):
81
+ # empty answer string
82
+ formatted_answer = ""
83
+
84
+ # extracting text after INST tokens
85
+ parts = answer.split("[/INST]")
86
+ if len(parts) >= 3:
87
+ # Return the text after the second occurrence of [/INST]
88
+ formatted_answer = parts[2].strip()
89
+ else:
90
+ # Return an empty string if there are fewer than two occurrences of [/INST]
91
+ formatted_answer = ""
92
+
93
+ return formatted_answer
94
+
95
+
96
+ def respond(prompt: str):
97
+
98
+ # tokenizing inputs and configuring model
99
+ input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"]
100
+
101
+ # generating text with tokenized input, returning output
102
+ output_ids = MODEL.generate(input_ids, max_new_tokens=50, generation_config=CONFIG)
103
+ output_text = TOKENIZER.batch_decode(output_ids)
104
+
105
+ return format_answer(output_text)
requirements.txt CHANGED
@@ -2,6 +2,7 @@ gradio~=4.7.1
2
  transformers~=4.35.2
3
  torch~=2.1.1
4
  shap~=0.44.0
 
5
  bertviz~=1.4.0
6
  accelerate~=0.24.1
7
  markdown~=3.5.1
 
2
  transformers~=4.35.2
3
  torch~=2.1.1
4
  shap~=0.44.0
5
+ captum
6
  bertviz~=1.4.0
7
  accelerate~=0.24.1
8
  markdown~=3.5.1
utils/modelling.py CHANGED
@@ -1,7 +1,9 @@
1
  # modelling util module providing formatting functions for model functionalities
2
 
3
  # external imports
 
4
  import gradio as gr
 
5
 
6
 
7
  # function that limits the prompt to contain model runtime
@@ -72,3 +74,26 @@ def token_counter(tokenizer, text: str):
72
  tokens = tokenizer(text, return_tensors="pt").input_ids
73
  # return the token count
74
  return len(tokens[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # modelling util module providing formatting functions for model functionalities
2
 
3
  # external imports
4
+ import torch
5
  import gradio as gr
6
+ from transformers import BitsAndBytesConfig
7
 
8
 
9
  # function that limits the prompt to contain model runtime
 
74
  tokens = tokenizer(text, return_tensors="pt").input_ids
75
  # return the token count
76
  return len(tokens[0])
77
+
78
+
79
+ def get_device():
80
+ if torch.cuda.is_available():
81
+ device = torch.device("cuda")
82
+ else:
83
+ device = torch.device("cpu")
84
+
85
+ return device
86
+
87
+
88
+ # setting device based on available hardware
89
+ def gpu_loading_config(max_memory: str = "15000MB"):
90
+ n_gpus = torch.cuda.device_count()
91
+
92
+ bnb_config = BitsAndBytesConfig(
93
+ load_in_4bit=True,
94
+ bnb_4bit_use_double_quant=True,
95
+ bnb_4bit_quant_type="nf4",
96
+ bnb_4bit_compute_dtype=torch.bfloat16,
97
+ )
98
+
99
+ return n_gpus, max_memory, bnb_config