Spaces:
Runtime error
Runtime error
Commit
·
226ad46
1
Parent(s):
b324c38
fix: fixing various bugs
Browse files- backend/controller.py +2 -2
- explanation/attention.py +51 -0
- explanation/interpret_captum.py +3 -2
- explanation/markup.py +6 -6
- explanation/visualize_att.py +0 -0
backend/controller.py
CHANGED
@@ -8,9 +8,9 @@ import gradio as gr
|
|
8 |
from model import godel
|
9 |
from model import mistral
|
10 |
from explanation import (
|
|
|
11 |
interpret_shap as shap_int,
|
12 |
interpret_captum as cpt_int,
|
13 |
-
visualize_att as viz,
|
14 |
)
|
15 |
|
16 |
|
@@ -48,7 +48,7 @@ def interference(
|
|
48 |
else:
|
49 |
xai = shap_int
|
50 |
case "attention":
|
51 |
-
xai =
|
52 |
case _:
|
53 |
# use Gradio warning to display error message
|
54 |
gr.Warning(f"""
|
|
|
8 |
from model import godel
|
9 |
from model import mistral
|
10 |
from explanation import (
|
11 |
+
attention as attention_viz,
|
12 |
interpret_shap as shap_int,
|
13 |
interpret_captum as cpt_int,
|
|
|
14 |
)
|
15 |
|
16 |
|
|
|
48 |
else:
|
49 |
xai = shap_int
|
50 |
case "attention":
|
51 |
+
xai = attention_viz
|
52 |
case _:
|
53 |
# use Gradio warning to display error message
|
54 |
gr.Warning(f"""
|
explanation/attention.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# visualization module that creates an attention visualization
|
2 |
+
|
3 |
+
|
4 |
+
# internal imports
|
5 |
+
from utils import formatting as fmt
|
6 |
+
from .markup import markup_text
|
7 |
+
|
8 |
+
# chat function that returns an answer
|
9 |
+
# and marked text based on attention
|
10 |
+
def chat_explained(model, prompt):
|
11 |
+
|
12 |
+
# get encoded input
|
13 |
+
encoder_input_ids = model.TOKENIZER(
|
14 |
+
prompt, return_tensors="pt", add_special_tokens=True
|
15 |
+
).input_ids
|
16 |
+
# generate output together with attentions of the model
|
17 |
+
decoder_input_ids = model.MODEL.generate(
|
18 |
+
encoder_input_ids, output_attentions=True, **model.CONFIG
|
19 |
+
)
|
20 |
+
|
21 |
+
# get input and output text as list of strings
|
22 |
+
encoder_text = fmt.format_tokens(
|
23 |
+
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
|
24 |
+
)
|
25 |
+
decoder_text = fmt.format_tokens(
|
26 |
+
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
|
27 |
+
)
|
28 |
+
|
29 |
+
# get attention values for the input and output vectors
|
30 |
+
# using already generated input and output
|
31 |
+
attention_output = model.MODEL(
|
32 |
+
input_ids=encoder_input_ids,
|
33 |
+
decoder_input_ids=decoder_input_ids,
|
34 |
+
output_attentions=True,
|
35 |
+
)
|
36 |
+
|
37 |
+
# averaging attention across layers
|
38 |
+
averaged_attention = fmt.avg_attention(attention_output)
|
39 |
+
|
40 |
+
# format response text for clean output
|
41 |
+
response_text = fmt.format_output_text(decoder_text)
|
42 |
+
# setting placeholder for iFrame graphic
|
43 |
+
graphic = (
|
44 |
+
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
45 |
+
" Visualization doesn't support an interactive graphic.</h4></div>"
|
46 |
+
)
|
47 |
+
# creating marked text using markup_text function and attention
|
48 |
+
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
49 |
+
|
50 |
+
# returning response, graphic and marked text array
|
51 |
+
return response_text, graphic, marked_text, None
|
explanation/interpret_captum.py
CHANGED
@@ -46,8 +46,9 @@ def chat_explained(model, prompt):
|
|
46 |
# getting response text, graphic placeholder and marked text object
|
47 |
response_text = fmt.format_output_text(attribution_result.output_tokens)
|
48 |
graphic = (
|
49 |
-
"<div style='text-align: center; font-family:arial;'><h4>
|
50 |
-
|
|
|
51 |
)
|
52 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
53 |
|
|
|
46 |
# getting response text, graphic placeholder and marked text object
|
47 |
response_text = fmt.format_output_text(attribution_result.output_tokens)
|
48 |
graphic = (
|
49 |
+
"""<div style='text-align: center; font-family:arial;'><h4>
|
50 |
+
Intepretation with Captum doesn't support an interactive graphic.</h4></div>
|
51 |
+
"""
|
52 |
)
|
53 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
54 |
|
explanation/markup.py
CHANGED
@@ -71,11 +71,11 @@ def color_codes():
|
|
71 |
"-4": "#68a1fd",
|
72 |
"-3": "#96b7fe",
|
73 |
"-2": "#bcceff",
|
74 |
-
"-1
|
75 |
"0": "#ffffff",
|
76 |
-
"1": "#ffd9d9",
|
77 |
-
"2": "#ffb3b5",
|
78 |
-
"3": "#ff8b92",
|
79 |
-
"4": "#ff5c71",
|
80 |
-
"5": "#ff0051",
|
81 |
}
|
|
|
71 |
"-4": "#68a1fd",
|
72 |
"-3": "#96b7fe",
|
73 |
"-2": "#bcceff",
|
74 |
+
"-1": "#dee6ff",
|
75 |
"0": "#ffffff",
|
76 |
+
"+1": "#ffd9d9",
|
77 |
+
"+2": "#ffb3b5",
|
78 |
+
"+3": "#ff8b92",
|
79 |
+
"+4": "#ff5c71",
|
80 |
+
"+5": "#ff0051",
|
81 |
}
|
explanation/visualize_att.py
DELETED
File without changes
|