Spaces:
Runtime error
Runtime error
Commit
·
5d99c07
1
Parent(s):
c5c1df2
feat: adding mistral model again
Browse files- backend/controller.py +14 -3
- explanation/interpret_captum.py +38 -0
- explanation/markup.py +1 -1
- explanation/visualize.py +1 -1
- main.py +50 -27
- model/mistral.py +105 -0
- requirements.txt +1 -0
- utils/modelling.py +25 -0
backend/controller.py
CHANGED
@@ -6,6 +6,8 @@ import gradio as gr
|
|
6 |
|
7 |
# internal imports
|
8 |
from model import godel
|
|
|
|
|
9 |
from explanation import interpret_shap as shap_int, visualize as viz
|
10 |
|
11 |
|
@@ -17,14 +19,20 @@ def interference(
|
|
17 |
knowledge: str,
|
18 |
system_prompt: str,
|
19 |
xai_selection: str,
|
|
|
20 |
):
|
21 |
# if no proper system prompt is given, use a default one
|
22 |
-
if system_prompt in (
|
23 |
system_prompt = """
|
24 |
You are a helpful, respectful and honest assistant.
|
25 |
Always answer as helpfully as possible, while being safe.
|
26 |
"""
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
# if a XAI approach is selected, grab the XAI module instance
|
29 |
if xai_selection in ("SHAP", "Attention"):
|
30 |
# matching selection
|
@@ -44,7 +52,7 @@ def interference(
|
|
44 |
|
45 |
# call the explained chat function with the model instance
|
46 |
prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
|
47 |
-
model=
|
48 |
xai=xai,
|
49 |
message=prompt,
|
50 |
history=history,
|
@@ -55,7 +63,7 @@ def interference(
|
|
55 |
else:
|
56 |
# call the vanilla chat function
|
57 |
prompt_output, history_output = vanilla_chat(
|
58 |
-
model=
|
59 |
message=prompt,
|
60 |
history=history,
|
61 |
system_prompt=system_prompt,
|
@@ -95,6 +103,9 @@ def explained_chat(
|
|
95 |
model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
|
96 |
):
|
97 |
# formatting the prompt using the model's format_prompt function
|
|
|
|
|
|
|
98 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
99 |
|
100 |
# generating an answer using the methods chat function
|
|
|
6 |
|
7 |
# internal imports
|
8 |
from model import godel
|
9 |
+
from model import mistral
|
10 |
+
from utils import modelling as mdl
|
11 |
from explanation import interpret_shap as shap_int, visualize as viz
|
12 |
|
13 |
|
|
|
19 |
knowledge: str,
|
20 |
system_prompt: str,
|
21 |
xai_selection: str,
|
22 |
+
model_selection: str,
|
23 |
):
|
24 |
# if no proper system prompt is given, use a default one
|
25 |
+
if system_prompt in ("", " "):
|
26 |
system_prompt = """
|
27 |
You are a helpful, respectful and honest assistant.
|
28 |
Always answer as helpfully as possible, while being safe.
|
29 |
"""
|
30 |
|
31 |
+
if model_selection.lower == "mistral":
|
32 |
+
model = mistral
|
33 |
+
else:
|
34 |
+
model = godel
|
35 |
+
|
36 |
# if a XAI approach is selected, grab the XAI module instance
|
37 |
if xai_selection in ("SHAP", "Attention"):
|
38 |
# matching selection
|
|
|
52 |
|
53 |
# call the explained chat function with the model instance
|
54 |
prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
|
55 |
+
model=model,
|
56 |
xai=xai,
|
57 |
message=prompt,
|
58 |
history=history,
|
|
|
63 |
else:
|
64 |
# call the vanilla chat function
|
65 |
prompt_output, history_output = vanilla_chat(
|
66 |
+
model=model,
|
67 |
message=prompt,
|
68 |
history=history,
|
69 |
system_prompt=system_prompt,
|
|
|
103 |
model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
|
104 |
):
|
105 |
# formatting the prompt using the model's format_prompt function
|
106 |
+
message, history, system_prompt, knowledge = mdl.prompt_limiter(
|
107 |
+
message, history, system_prompt, knowledge
|
108 |
+
)
|
109 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
110 |
|
111 |
# generating an answer using the methods chat function
|
explanation/interpret_captum.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# external imports
|
2 |
+
from captum.attr import LLMAttribution, TextTokenInput, KernelShap
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# internal imports
|
6 |
+
from utils import formatting as fmt
|
7 |
+
from .markup import markup_text
|
8 |
+
|
9 |
+
|
10 |
+
# main explain function that returns a chat with explanations
|
11 |
+
def chat_explained(model, prompt):
|
12 |
+
model.set_config({})
|
13 |
+
|
14 |
+
# creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
|
15 |
+
llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
|
16 |
+
|
17 |
+
# generation attribution
|
18 |
+
attribution_input = TextTokenInput(prompt, model.TOKENIZER)
|
19 |
+
attribution_result = llm_attribution.attribute(attribution_input)
|
20 |
+
|
21 |
+
# extracting values and input tokens
|
22 |
+
values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
|
23 |
+
input_tokens = fmt.format_tokens(attribution_result.input_tokens)
|
24 |
+
|
25 |
+
# raising error if mismatch occurs
|
26 |
+
if len(attribution_result.input_tokens) != len(values):
|
27 |
+
raise RuntimeError("values and input len mismatch")
|
28 |
+
|
29 |
+
# getting response text, graphic placeholder and marked text object
|
30 |
+
response_text = fmt.format_output_text(attribution_result.output_tokens)
|
31 |
+
graphic = (
|
32 |
+
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
33 |
+
"Intepretation with Captum doesn't support an interactive graphic.</h4></div>"
|
34 |
+
)
|
35 |
+
marked_text = markup_text(input_tokens, values, variant="captum")
|
36 |
+
|
37 |
+
# return response, graphic and marked_text array
|
38 |
+
return response_text, graphic, marked_text
|
explanation/markup.py
CHANGED
@@ -18,7 +18,7 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
18 |
if variant == "shap":
|
19 |
text_values = np.transpose(text_values)
|
20 |
text_values = fmt.flatten_attribution(text_values)
|
21 |
-
|
22 |
text_values = fmt.flatten_attention(text_values)
|
23 |
|
24 |
# Determine the minimum and maximum values
|
|
|
18 |
if variant == "shap":
|
19 |
text_values = np.transpose(text_values)
|
20 |
text_values = fmt.flatten_attribution(text_values)
|
21 |
+
elif variant == "visualizer":
|
22 |
text_values = fmt.flatten_attention(text_values)
|
23 |
|
24 |
# Determine the minimum and maximum values
|
explanation/visualize.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
|
4 |
# internal imports
|
5 |
from utils import formatting as fmt
|
6 |
-
from model.
|
7 |
from .markup import markup_text
|
8 |
|
9 |
|
|
|
3 |
|
4 |
# internal imports
|
5 |
from utils import formatting as fmt
|
6 |
+
from model.model import CONFIG
|
7 |
from .markup import markup_text
|
8 |
|
9 |
|
main.py
CHANGED
@@ -97,31 +97,40 @@ with gr.Blocks(
|
|
97 |
""")
|
98 |
# row with columns for the different settings
|
99 |
with gr.Row(equal_height=True):
|
100 |
-
#
|
101 |
-
with gr.
|
102 |
-
#
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
#
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
# calling info functions on inputs/submits for different settings
|
127 |
system_prompt.submit(system_prompt_info, [system_prompt])
|
@@ -247,13 +256,27 @@ with gr.Blocks(
|
|
247 |
## see backend/controller.py for more information
|
248 |
submit_btn.click(
|
249 |
interference,
|
250 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
[user_prompt, chatbot, xai_interactive, xai_text],
|
252 |
)
|
253 |
# function triggered by the enter key
|
254 |
user_prompt.submit(
|
255 |
interference,
|
256 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
[user_prompt, chatbot, xai_interactive, xai_text],
|
258 |
)
|
259 |
|
|
|
97 |
""")
|
98 |
# row with columns for the different settings
|
99 |
with gr.Row(equal_height=True):
|
100 |
+
# column that takes up 3/4 of the row
|
101 |
+
with gr.Column(scale=2):
|
102 |
+
# textbox to enter the system prompt
|
103 |
+
system_prompt = gr.Textbox(
|
104 |
+
label="System Prompt",
|
105 |
+
info="Set the models system prompt, dictating how it answers.",
|
106 |
+
# default system prompt is set to this in the backend
|
107 |
+
placeholder=(
|
108 |
+
"You are a helpful, respectful and honest assistant. Always"
|
109 |
+
" answer as helpfully as possible, while being safe."
|
110 |
+
),
|
111 |
+
)
|
112 |
+
# column that takes up 1/4 of the row
|
113 |
+
with gr.Column(scale=1):
|
114 |
+
# checkbox group to select the xai method
|
115 |
+
xai_selection = gr.Radio(
|
116 |
+
["None", "SHAP", "Attention"],
|
117 |
+
label="Interpretability Settings",
|
118 |
+
info="Select a Interpretability Implementation to use.",
|
119 |
+
value="None",
|
120 |
+
interactive=True,
|
121 |
+
show_label=True,
|
122 |
+
)
|
123 |
+
# column that takes up 1/4 of the row
|
124 |
+
with gr.Column(scale=1):
|
125 |
+
# checkbox group to select the xai method
|
126 |
+
model_selection = gr.Radio(
|
127 |
+
["GODEL", "Mistral"],
|
128 |
+
label="Model Settings",
|
129 |
+
info="Select a Model to use.",
|
130 |
+
value="GODEL",
|
131 |
+
interactive=True,
|
132 |
+
show_label=True,
|
133 |
+
)
|
134 |
|
135 |
# calling info functions on inputs/submits for different settings
|
136 |
system_prompt.submit(system_prompt_info, [system_prompt])
|
|
|
256 |
## see backend/controller.py for more information
|
257 |
submit_btn.click(
|
258 |
interference,
|
259 |
+
[
|
260 |
+
user_prompt,
|
261 |
+
chatbot,
|
262 |
+
knowledge_input,
|
263 |
+
system_prompt,
|
264 |
+
xai_selection,
|
265 |
+
model_selection,
|
266 |
+
],
|
267 |
[user_prompt, chatbot, xai_interactive, xai_text],
|
268 |
)
|
269 |
# function triggered by the enter key
|
270 |
user_prompt.submit(
|
271 |
interference,
|
272 |
+
[
|
273 |
+
user_prompt,
|
274 |
+
chatbot,
|
275 |
+
knowledge_input,
|
276 |
+
system_prompt,
|
277 |
+
xai_selection,
|
278 |
+
model_selection,
|
279 |
+
],
|
280 |
[user_prompt, chatbot, xai_interactive, xai_text],
|
281 |
)
|
282 |
|
model/mistral.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mistral model module for chat interaction and model instance control
|
2 |
+
|
3 |
+
# external imports
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# internal imports
|
9 |
+
from utils import modelling as mdl
|
10 |
+
|
11 |
+
# global model and tokenizer instance (created on inital build)
|
12 |
+
device = mdl.get_device()
|
13 |
+
if device == torch.device("cuda"):
|
14 |
+
n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()
|
15 |
+
|
16 |
+
MODEL = AutoModelForCausalLM.from_pretrained(
|
17 |
+
"mistralai/Mistral-7B-Instruct-v0.2",
|
18 |
+
quantization_config=bnb_config,
|
19 |
+
device_map="auto", # dispatch efficiently the model on the available ressources
|
20 |
+
max_memory={i: max_memory for i in range(n_gpus)},
|
21 |
+
)
|
22 |
+
|
23 |
+
else:
|
24 |
+
MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
25 |
+
MODEL.to(device)
|
26 |
+
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
27 |
+
|
28 |
+
# default model config
|
29 |
+
CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
|
30 |
+
|
31 |
+
|
32 |
+
# function to (re) set config
|
33 |
+
def set_config(config: dict):
|
34 |
+
global CONFIG
|
35 |
+
|
36 |
+
# if config dict is given, update it
|
37 |
+
if config != {}:
|
38 |
+
CONFIG = config
|
39 |
+
else:
|
40 |
+
# hard setting model config to default
|
41 |
+
# needed for shap
|
42 |
+
MODEL.config.max_new_tokens = 50
|
43 |
+
MODEL.config.min_length = 8
|
44 |
+
MODEL.config.top_p = 0.9
|
45 |
+
MODEL.config.do_sample = True
|
46 |
+
|
47 |
+
|
48 |
+
# advanced formatting function that takes into a account a conversation history
|
49 |
+
# CREDIT: adapted from Venkata Bhanu Teja Pallakonda in Huggingface discussions
|
50 |
+
## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/
|
51 |
+
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
52 |
+
prompt = ""
|
53 |
+
|
54 |
+
if knowledge != "":
|
55 |
+
gr.Info("""
|
56 |
+
Mistral doesn't support additional knowledge, it's gonna be ignored.
|
57 |
+
""")
|
58 |
+
|
59 |
+
# if no history, use system prompt and example message
|
60 |
+
if len(history) == 0:
|
61 |
+
prompt = f"""<s>[INST] {system_prompt} [/INST] How can I help you today? </s>
|
62 |
+
[INST] {message} [/INST]"""
|
63 |
+
else:
|
64 |
+
# takes the very first exchange and the system prompt as base
|
65 |
+
for user_prompt, bot_response in history[0]:
|
66 |
+
prompt = (
|
67 |
+
f"<s>[INST] {system_prompt} {user_prompt} [/INST] {bot_response}</s>"
|
68 |
+
)
|
69 |
+
|
70 |
+
# takes all the following conversations and adds them as context
|
71 |
+
prompt += "".join(
|
72 |
+
f"[INST] {user_prompt} [/INST] {bot_response}</s>"
|
73 |
+
for user_prompt, bot_response in history[1:]
|
74 |
+
)
|
75 |
+
|
76 |
+
return prompt
|
77 |
+
|
78 |
+
|
79 |
+
# function to extract real answer because mistral always returns the full prompt
|
80 |
+
def format_answer(answer: str):
|
81 |
+
# empty answer string
|
82 |
+
formatted_answer = ""
|
83 |
+
|
84 |
+
# extracting text after INST tokens
|
85 |
+
parts = answer.split("[/INST]")
|
86 |
+
if len(parts) >= 3:
|
87 |
+
# Return the text after the second occurrence of [/INST]
|
88 |
+
formatted_answer = parts[2].strip()
|
89 |
+
else:
|
90 |
+
# Return an empty string if there are fewer than two occurrences of [/INST]
|
91 |
+
formatted_answer = ""
|
92 |
+
|
93 |
+
return formatted_answer
|
94 |
+
|
95 |
+
|
96 |
+
def respond(prompt: str):
|
97 |
+
|
98 |
+
# tokenizing inputs and configuring model
|
99 |
+
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"]
|
100 |
+
|
101 |
+
# generating text with tokenized input, returning output
|
102 |
+
output_ids = MODEL.generate(input_ids, max_new_tokens=50, generation_config=CONFIG)
|
103 |
+
output_text = TOKENIZER.batch_decode(output_ids)
|
104 |
+
|
105 |
+
return format_answer(output_text)
|
requirements.txt
CHANGED
@@ -2,6 +2,7 @@ gradio~=4.7.1
|
|
2 |
transformers~=4.35.2
|
3 |
torch~=2.1.1
|
4 |
shap~=0.44.0
|
|
|
5 |
bertviz~=1.4.0
|
6 |
accelerate~=0.24.1
|
7 |
markdown~=3.5.1
|
|
|
2 |
transformers~=4.35.2
|
3 |
torch~=2.1.1
|
4 |
shap~=0.44.0
|
5 |
+
captum
|
6 |
bertviz~=1.4.0
|
7 |
accelerate~=0.24.1
|
8 |
markdown~=3.5.1
|
utils/modelling.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
# modelling util module providing formatting functions for model functionalities
|
2 |
|
3 |
# external imports
|
|
|
4 |
import gradio as gr
|
|
|
5 |
|
6 |
|
7 |
# function that limits the prompt to contain model runtime
|
@@ -72,3 +74,26 @@ def token_counter(tokenizer, text: str):
|
|
72 |
tokens = tokenizer(text, return_tensors="pt").input_ids
|
73 |
# return the token count
|
74 |
return len(tokens[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# modelling util module providing formatting functions for model functionalities
|
2 |
|
3 |
# external imports
|
4 |
+
import torch
|
5 |
import gradio as gr
|
6 |
+
from transformers import BitsAndBytesConfig
|
7 |
|
8 |
|
9 |
# function that limits the prompt to contain model runtime
|
|
|
74 |
tokens = tokenizer(text, return_tensors="pt").input_ids
|
75 |
# return the token count
|
76 |
return len(tokens[0])
|
77 |
+
|
78 |
+
|
79 |
+
def get_device():
|
80 |
+
if torch.cuda.is_available():
|
81 |
+
device = torch.device("cuda")
|
82 |
+
else:
|
83 |
+
device = torch.device("cpu")
|
84 |
+
|
85 |
+
return device
|
86 |
+
|
87 |
+
|
88 |
+
# setting device based on available hardware
|
89 |
+
def gpu_loading_config(max_memory: str = "15000MB"):
|
90 |
+
n_gpus = torch.cuda.device_count()
|
91 |
+
|
92 |
+
bnb_config = BitsAndBytesConfig(
|
93 |
+
load_in_4bit=True,
|
94 |
+
bnb_4bit_use_double_quant=True,
|
95 |
+
bnb_4bit_quant_type="nf4",
|
96 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
97 |
+
)
|
98 |
+
|
99 |
+
return n_gpus, max_memory, bnb_config
|