File size: 8,772 Bytes
a45605a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906e11b
 
 
 
a45605a
 
 
 
 
7581cc1
a45605a
 
906e11b
a45605a
7581cc1
a45605a
 
 
 
 
 
 
7581cc1
 
a45605a
7581cc1
 
a45605a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7581cc1
a45605a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7581cc1
906e11b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""XAI for Transformers Intent Classifier App."""

from collections import Counter
from itertools import count
from operator import itemgetter
from re import DOTALL, sub

import streamlit as st
from plotly.express import bar
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
                          pipeline)
from transformers_interpret import SequenceClassificationExplainer

hide_streamlit_style = """
            <style>
            #MainMenu {visibility: hidden;}
            footer {visibility: hidden;}
            </style>
            """
hide_plotly_bar = {"displayModeBar": False}
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
repo_id = "remzicam/privacy_intent"
task = "text-classification"
title = "XAI for Intent Classification and Model Interpretation"
st.markdown(
    f"<h1 style='text-align: center; color: #0068C9;'>{title}</h1>", unsafe_allow_html=True
)


@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_models():
    """
    It loads the model and tokenizer from the HuggingFace model hub, and then creates a pipeline object
    that can be used to make predictions. Also, it creates model interpretation object.
    
    Returns:
      the privacy_intent_pipe and cls_explainer.
    """
    model = AutoModelForSequenceClassification.from_pretrained(
        repo_id, low_cpu_mem_usage=True
    )
    tokenizer = AutoTokenizer.from_pretrained(repo_id)
    privacy_intent_pipe = pipeline(
        task, model=model, tokenizer=tokenizer, return_all_scores=True
    )
    cls_explainer = SequenceClassificationExplainer(model, tokenizer)
    return privacy_intent_pipe, cls_explainer


privacy_intent_pipe, cls_explainer = load_models()


def label_probs_figure_creater(input_text:str):
    """
    It takes in a string, runs it through the pipeline, and returns a figure and the label with the
    highest probability
    
    Args:
      input_text (str): The text you want to analyze
    
    Returns:
      A tuple of a figure and a string.
    """
    outputs = privacy_intent_pipe(input_text)[0]
    sorted_outputs = sorted(outputs, key=lambda k: k["score"])
    prediction_label = sorted_outputs[-1]["label"]
    fig = bar(
        sorted_outputs,
        x="score",
        y="label",
        color="score",
        color_continuous_scale="rainbow",
        width=600,
        height=400,
    )
    fig.update_layout(
        title="Model Prediction Probabilities for Each Label",
        xaxis_title="",
        yaxis_title="",
        xaxis=dict(  # attribures for x axis
            showline=True,
            showgrid=True,
            linecolor="black",
            tickfont=dict(family="Calibri"),
        ),
        yaxis=dict(  # attribures for y axis
            showline=True,
            showgrid=True,
            linecolor="black",
            tickfont=dict(
                family="Times New Roman",
            ),
        ),
        plot_bgcolor="white",
        title_x=0.5,
    )
    return fig, prediction_label


def xai_attributions_html(input_text: str):
    """
    1. The function takes in a string of text as input.
    2. It then uses the explainer to generate attributions for each word in the input text.
    3. It then uses the explainer to generate an HTML visualization of the attributions.
    4. It then cleans up the HTML visualization by removing some unnecessary HTML tags.
    5. It then returns the attributions and the HTML visualization
    
    Args:
      input_text (str): The text you want to explain.
    
    Returns:
      the word attributions and the html.
    """

    word_attributions = cls_explainer(input_text)
    #remove special tokens
    word_attributions = word_attributions[1:-1]
    # remove strings shorter than 1 chrachter
    word_attributions = [i for i in word_attributions if len(i[0]) > 1]
    html = cls_explainer.visualize().data
    html = html.replace("#s", "")
    html = html.replace("#/s", "")
    html = sub("<th.*?/th>", "", html, 4, DOTALL)
    html = sub("<td.*?/td>", "", html, 4, DOTALL)
    return word_attributions, html+"<br>"


def explanation_intro(prediction_label: str):
    """
    generates model explanaiton html markdown from prediction label of the model.

    Args:
      prediction_label (str): The label that the model predicted.
    
    Returns:
      A string
    """
    return f"""<div style="background-color: lightblue;
  color: rgb(0, 66, 128);">The model predicted the given sentence as <span style="color: black"><strong>'{prediction_label}'</strong></span>.
    The figure below shows the contribution of each token to this decision.
    <span style="color: darkgreen"><strong> Green </strong></span> tokens indicate a <strong>positive </strong> contribution, while <span style="color: red"><strong> red </strong></span> tokens indicate a <strong>negative</strong> contribution.
    The <strong>bolder</strong> the color, the greater the value.</div><br>"""


def explanation_viz(prediction_label: str, word_attributions):
    """
    It takes in a prediction label and a list of word attributions, and returns a markdown string that contains
    the word that had the highest attribution and the prediction label
    
    Args:
      prediction_label (str): The label that the model predicted.
      word_attributions: a list of tuples of the form (word, attribution score)
    
    Returns:
      A string
    """
    top_attention_word = max(word_attributions, key=itemgetter(1))[0]
    return f"""The token **_'{top_attention_word}'_** is the biggest driver for the decision of the model as **'{prediction_label}'**"""


def word_attributions_dict_creater(word_attributions):
    """
    It takes a list of tuples, reverses it, splits it into two lists, colors the scores, numerates
    duplicated strings, and returns a dictionary
    
    Args:
      word_attributions: This is the output of the model explainer.
    
    Returns:
      A dictionary with the keys "word", "score", and "colors".
    """
    word_attributions.reverse()
    words, scores = zip(*word_attributions)
    # colorize positive and negative scores
    colors = ["red" if x < 0 else "lightgreen" for x in scores]
    # darker tone for max score
    max_index = scores.index(max(scores))
    colors[max_index] = "darkgreen"
    # numerate duplicated strings
    c = Counter(words)
    iters = {k: count(1) for k, v in c.items() if v > 1}
    words_ = [x + "_" + str(next(iters[x])) if x in iters else x for x in words]
    # plotly accepts dictionaries

    return {
        "word": words_,
        "score": scores,
        "colors": colors,
    }


def attention_score_figure_creater(word_attributions_dict):
    """
    It takes a dictionary of words and their attention scores and returns a bar graph of the words and
    their attention scores with specified colors.
    
    Args:
      word_attributions_dict: a dictionary with keys "word", "score", and "colors"
    
    Returns:
      A figure object
    """
    fig = bar(word_attributions_dict, x="score", y="word", width=400, height=500)
    fig.update_traces(marker_color=word_attributions_dict["colors"])
    fig.update_layout(
        title="Word-Attention Score",
        xaxis_title="",
        yaxis_title="",
        xaxis=dict(  # attribures for x axis
            showline=True,
            showgrid=True,
            linecolor="black",
            tickfont=dict(family="Calibri"),
        ),
        yaxis=dict(  # attribures for y axis
            showline=True,
            showgrid=True,
            linecolor="black",
            tickfont=dict(
                family="Times New Roman",
            ),
        ),
        plot_bgcolor="white",
        title_x=0.5,
    )

    return fig


form = st.form(key="intent-form")
input_text = form.text_area(
    label="Text",
    value="At any time during your use of the Services, you may decide to share some information or content publicly or privately.",
)
submit = form.form_submit_button("Submit")

if submit:
    label_probs_figure, prediction_label = label_probs_figure_creater(input_text)
    st.plotly_chart(label_probs_figure, config=hide_plotly_bar)
    explanation_general = explanation_intro(prediction_label)
    st.markdown(explanation_general, unsafe_allow_html=True)
    with st.spinner():
      word_attributions, html = xai_attributions_html(input_text)
      st.markdown(html, unsafe_allow_html=True)
      explanation_specific = explanation_viz(prediction_label, word_attributions)
      st.info(explanation_specific)
      word_attributions_dict = word_attributions_dict_creater(word_attributions)
      attention_score_figure = attention_score_figure_creater(word_attributions_dict)
      st.plotly_chart(attention_score_figure, config=hide_plotly_bar)