Spaces:
Runtime error
Runtime error
import gradio as gr | |
import lime | |
from lime.lime_text import LimeTextExplainer | |
import numpy as np | |
from datasets import load_dataset | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.model_selection import train_test_split | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.pipeline import make_pipeline | |
import shap | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
import pandas as pd | |
# Load the IMDB dataset using Hugging Face datasets | |
dataset = load_dataset('imdb') | |
# Extract the training and test sets | |
text_train = [review['text'] for review in dataset['train']] | |
y_train = [review['label'] for review in dataset['train']] | |
text_test = [review['text'] for review in dataset['test']] | |
y_test = [review['label'] for review in dataset['test']] | |
# Convert the text data into a TF-IDF representation | |
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000) | |
X_train = vectorizer.fit_transform(text_train) | |
X_test = vectorizer.transform(text_test) | |
# Split the training data into train and validation sets | |
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42) | |
# Train a logistic regression model | |
model = LogisticRegression(max_iter=1000) | |
model.fit(X_train, y_train) | |
# Initialize LIME explainer | |
lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive']) | |
# Create a SHAP explainer object | |
shap_explainer = shap.LinearExplainer(model, X_train) | |
def explain_text(input_text): | |
# Predict label | |
input_vector = vectorizer.transform([input_text]) | |
predicted_label = model.predict(input_vector)[0] | |
label_name = 'Positive' if predicted_label == 1 else 'Negative' | |
# LIME explanation | |
def predict_proba_for_lime(texts): | |
return model.predict_proba(vectorizer.transform(texts)) | |
lime_exp = lime_explainer.explain_instance(input_text, predict_proba_for_lime, num_features=10) | |
lime_fig = lime_exp.as_pyplot_figure() | |
lime_img = fig_to_nparray(lime_fig) | |
# Get the complete HTML for LIME explanation | |
lime_html = lime_exp.as_html() | |
# SHAP explanation | |
shap_values = shap_explainer.shap_values(input_vector)[0] | |
feature_names = vectorizer.get_feature_names_out() | |
# Create a SHAP explanation object for the selected instance | |
shap_explanation = shap.Explanation( | |
values=shap_values, | |
base_values=shap_explainer.expected_value, | |
feature_names=feature_names, | |
data=input_vector.toarray()[0] | |
) | |
# Function to highlight text based on SHAP values | |
def highlight_text_shap(text, word_importances, feature_names, max_num_features): | |
words = text.split() | |
word_to_importance = {} | |
for idx, word in enumerate(feature_names): | |
if word in text.lower(): | |
word_to_importance[word] = word_importances[idx] | |
sorted_word_importance = sorted(word_to_importance.items(), key=lambda x: abs(x[1]), reverse=True)[:max_num_features] | |
top_words = {word: importance for word, importance in sorted_word_importance} | |
highlighted_text = [] | |
for word in words: | |
cleaned_word = ''.join(filter(str.isalnum, word)).lower() | |
if cleaned_word in top_words: | |
importance = top_words[cleaned_word] | |
color = 'red' if importance > 0 else 'blue' | |
highlighted_text.append(f'<span style="color:{color}">{word}</span>') | |
else: | |
highlighted_text.append(word) | |
return ' '.join(highlighted_text) | |
# Set the maximum number of features to display | |
max_num_features = 10 | |
# Create a DataFrame for SHAP values | |
shap_df = pd.DataFrame({ | |
'Feature': shap_explanation.feature_names, | |
'SHAP Value': shap_explanation.values | |
}).sort_values(by='SHAP Value', ascending=False).head(max_num_features) | |
# Plot the SHAP values | |
plt.figure(figsize=(10, 6)) | |
plt.barh(shap_df['Feature'], shap_df['SHAP Value'], color=['red' if val > 0 else 'blue' for val in shap_df['SHAP Value']]) | |
plt.xlabel('SHAP Value') | |
plt.title('Top 10 Feature Importance') | |
plt.tight_layout() | |
shap_fig = fig_to_nparray(plt.gcf()) | |
# Highlight the text based on SHAP values | |
shap_highlighted_text = highlight_text_shap(input_text, shap_values, feature_names, max_num_features) | |
return label_name, lime_img, shap_fig, lime_html, shap_highlighted_text | |
def fig_to_nparray(fig): | |
"""Convert a matplotlib figure to a NumPy array.""" | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png') | |
buf.seek(0) | |
img = Image.open(buf) | |
return np.array(img) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=explain_text, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), | |
outputs=[ | |
gr.Label(label="Predicted Label"), | |
gr.Image(type="numpy", label="LIME Explanation"), | |
gr.Image(type="numpy", label="SHAP Explanation"), | |
gr.HTML(label="LIME Highlighted Text Explanation"), | |
gr.HTML(label="SHAP Highlighted Text Explanation"), | |
], | |
title="LIME and SHAP Explanations", | |
description="Enter a text sample to see its prediction and explanations using LIME and SHAP." | |
) | |
# Launch the interface | |
iface.launch() |