File size: 5,238 Bytes
00a6112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import lime
from lime.lime_text import LimeTextExplainer
import numpy as np
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
import shap
import matplotlib.pyplot as plt
import io
from PIL import Image
import pandas as pd


# Load the IMDB dataset using Hugging Face datasets
dataset = load_dataset('imdb')

# Extract the training and test sets
text_train = [review['text'] for review in dataset['train']]
y_train = [review['label'] for review in dataset['train']]
text_test = [review['text'] for review in dataset['test']]
y_test = [review['label'] for review in dataset['test']]

# Convert the text data into a TF-IDF representation
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
X_train = vectorizer.fit_transform(text_train)
X_test = vectorizer.transform(text_test)

# Split the training data into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

# Train a logistic regression model
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)

# Initialize LIME explainer
lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive'])

# Create a SHAP explainer object
shap_explainer = shap.LinearExplainer(model, X_train)

def explain_text(input_text):
    # Predict label
    input_vector = vectorizer.transform([input_text])
    predicted_label = model.predict(input_vector)[0]
    label_name = 'Positive' if predicted_label == 1 else 'Negative'

    # LIME explanation
    def predict_proba_for_lime(texts):
        return model.predict_proba(vectorizer.transform(texts))

    lime_exp = lime_explainer.explain_instance(input_text, predict_proba_for_lime, num_features=10)
    lime_fig = lime_exp.as_pyplot_figure()
    lime_img = fig_to_nparray(lime_fig)

    # Get the complete HTML for LIME explanation
    lime_html = lime_exp.as_html()

    # SHAP explanation
    shap_values = shap_explainer.shap_values(input_vector)[0]
    feature_names = vectorizer.get_feature_names_out()

    # Create a SHAP explanation object for the selected instance
    shap_explanation = shap.Explanation(
        values=shap_values,
        base_values=shap_explainer.expected_value,
        feature_names=feature_names,
        data=input_vector.toarray()[0]
    )

    # Function to highlight text based on SHAP values
    def highlight_text_shap(text, word_importances, feature_names, max_num_features):
        words = text.split()
        word_to_importance = {}
        for idx, word in enumerate(feature_names):
            if word in text.lower():
                word_to_importance[word] = word_importances[idx]

        sorted_word_importance = sorted(word_to_importance.items(), key=lambda x: abs(x[1]), reverse=True)[:max_num_features]
        top_words = {word: importance for word, importance in sorted_word_importance}

        highlighted_text = []
        for word in words:
            cleaned_word = ''.join(filter(str.isalnum, word)).lower()
            if cleaned_word in top_words:
                importance = top_words[cleaned_word]
                color = 'red' if importance > 0 else 'blue'
                highlighted_text.append(f'<span style="color:{color}">{word}</span>')
            else:
                highlighted_text.append(word)

        return ' '.join(highlighted_text)

    # Set the maximum number of features to display
    max_num_features = 10

    # Create a DataFrame for SHAP values
    shap_df = pd.DataFrame({
        'Feature': shap_explanation.feature_names,
        'SHAP Value': shap_explanation.values
    }).sort_values(by='SHAP Value', ascending=False).head(max_num_features)

    # Plot the SHAP values
    plt.figure(figsize=(10, 6))
    plt.barh(shap_df['Feature'], shap_df['SHAP Value'], color=['red' if val > 0 else 'blue' for val in shap_df['SHAP Value']])
    plt.xlabel('SHAP Value')
    plt.title('Top 10 Feature Importance')
    plt.tight_layout()
    shap_fig = fig_to_nparray(plt.gcf())

    # Highlight the text based on SHAP values
    shap_highlighted_text = highlight_text_shap(input_text, shap_values, feature_names, max_num_features)

    return label_name, lime_img, shap_fig, lime_html, shap_highlighted_text

def fig_to_nparray(fig):
    """Convert a matplotlib figure to a NumPy array."""
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    return np.array(img)


# Create Gradio interface
iface = gr.Interface(
    fn=explain_text,
    inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
    outputs=[
        gr.Label(label="Predicted Label"),
        gr.Image(type="numpy", label="LIME Explanation"),
        gr.Image(type="numpy", label="SHAP Explanation"),
        gr.HTML(label="LIME Highlighted Text Explanation"),
        gr.HTML(label="SHAP Highlighted Text Explanation"),
    ],
    title="LIME and SHAP Explanations",
    description="Enter a text sample to see its prediction and explanations using LIME and SHAP."
)

# Launch the interface
iface.launch()