|
from datasets import load_dataset |
|
import requests |
|
import gzip |
|
import json |
|
import streamlit as st |
|
import pandas as pd |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from model2vec import StaticModel |
|
import matplotlib.pyplot as plt |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
from torch.nn.functional import sigmoid |
|
from transformers_interpret import SequenceClassificationExplainer |
|
import os |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/hf" |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers" |
|
os.environ["TORCH_HOME"] = "/tmp/torch" |
|
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf/datasets" |
|
os.environ["XDG_CACHE_HOME"] = "/tmp/xdg" |
|
|
|
|
|
|
|
LABELS = [ |
|
'inconclusive', |
|
'animals', |
|
'arts', |
|
'autos', |
|
'business', |
|
'career', |
|
'education', |
|
'fashion', |
|
'finance', |
|
'food', |
|
'government', |
|
'health', |
|
'hobbies', |
|
'home', |
|
'news', |
|
'realestate', |
|
'society', |
|
'sports', |
|
'tech', |
|
'travel' |
|
] |
|
|
|
label2id = {label: idx for idx, label in enumerate(LABELS)} |
|
id2label = {idx: label for label, idx in label2id.items()} |
|
|
|
REPO_ID = "chidamnat2002/iab_training_dataset" |
|
|
|
@st.cache_data |
|
def load_csv_data(): |
|
dataset = load_dataset(REPO_ID, split="train", data_files="train_df_simple.csv") |
|
df = pd.DataFrame(dataset) |
|
return df |
|
|
|
@st.cache_resource |
|
def get_model_and_tokenizer(): |
|
tokenizer = AutoTokenizer.from_pretrained("chidamnat2002/content-multilabel-iab-classifier") |
|
model = AutoModelForSequenceClassification.from_pretrained("chidamnat2002/content-multilabel-iab-classifier") |
|
return model, tokenizer |
|
|
|
@st.cache_resource |
|
def get_explainer(): |
|
model, tokenizer = get_model_and_tokenizer() |
|
return SequenceClassificationExplainer(model, tokenizer) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
return StaticModel.from_pretrained("minishlab/potion-retrieval-32M") |
|
|
|
|
|
@st.cache_resource |
|
def encode_texts_cached(corpus): |
|
model = load_model() |
|
return model.encode(corpus) |
|
|
|
@st.cache_data(show_spinner="Embedding reference", max_entries=50) |
|
def encode_reference(text: str): |
|
model = load_model() |
|
return model.encode([text])[0] |
|
|
|
@st.cache_resource |
|
def get_data_and_embeddings(): |
|
df = load_csv_data() |
|
texts = df["text"].to_list() |
|
prior_labels = df['label'].to_list() |
|
X = encode_texts_cached(texts) |
|
return texts, prior_labels, X |
|
|
|
|
|
st.set_page_config(page_title="IAB Classifier App", layout="wide") |
|
st.title("IAB Classifier App") |
|
|
|
|
|
texts, prior_labels, X = get_data_and_embeddings() |
|
st.markdown("### Reference sentence for similarity") |
|
reference = st.text_area("Type something like 'business related'") |
|
prediction_choice = st.checkbox("try our iab model prediction for this") |
|
|
|
|
|
def predict_content_multilabel(text, threshold=0.5, verbose=False): |
|
model, tokenizer = get_model_and_tokenizer() |
|
model.eval() |
|
text = text.replace("-", " ") |
|
with torch.no_grad(): |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) |
|
logits = model(**inputs).logits |
|
probs = sigmoid(logits).squeeze().cpu().numpy() |
|
|
|
predicted_labels = [(id2label[i], round(float(p), 3)) for i, p in enumerate(probs) if p >= threshold] |
|
probs_res = [prob for prob in probs if prob >= threshold] |
|
|
|
if verbose: |
|
st.write(f"Text: {text}") |
|
st.write("Predicted Labels:") |
|
|
|
return predicted_labels |
|
|
|
if reference: |
|
st.write("Labels loaded:", len(prior_labels)) |
|
|
|
query = encode_reference(reference) |
|
similarity = cosine_similarity([query], X)[0] |
|
|
|
df_emb = pd.DataFrame({ |
|
"text": texts, |
|
"sim": similarity, |
|
"label": prior_labels, |
|
}).sort_values("sim", ascending=False) |
|
|
|
top_size = st.slider("number of similar items", 1, 100, 5) |
|
top_candidates = [(row["text"], row["sim"], row["label"]) for row in df_emb.to_dict(orient="records")][:top_size] |
|
|
|
st.markdown("### Similar example(s)") |
|
if not top_candidates: |
|
st.info("No more similar examples.") |
|
else: |
|
st.write(pd.DataFrame(top_candidates, columns=['text', 'similarity_score', 'label'])) |
|
top_labelled_df = pd.DataFrame(top_candidates, columns=['text', 'similarity_score', 'label']) |
|
|
|
preds = dict(predict_content_multilabel(reference, threshold=0.2)) |
|
st.write(f"preds = {preds}") |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
st.markdown("#### What Training Data Says") |
|
fig1, ax1 = plt.subplots() |
|
top_labelled_df['label'].value_counts(normalize=True).sort_values().plot(kind='barh', ax=ax1, color="lightcoral") |
|
ax1.set_title("Label Distribution") |
|
ax1.set_xlabel("Proportion") |
|
ax1.grid(True, axis='x', linestyle='--', alpha=0.5) |
|
st.pyplot(fig1) |
|
|
|
|
|
with col2: |
|
st.markdown("#### Model Predictions") |
|
if len(preds) == 0 or not prediction_choice: |
|
st.write("Model is unsure") |
|
else: |
|
fig2, ax2 = plt.subplots() |
|
pd.Series(preds).sort_values().plot.barh(color="skyblue", ax=ax2) |
|
ax2.set_title("Predicted Probabilities") |
|
ax2.set_xlabel("Probability") |
|
ax2.grid(True, axis='x', linestyle='--', alpha=0.5) |
|
st.pyplot(fig2) |
|
|
|
if prediction_choice and reference: |
|
st.markdown("### Model Explanation (Top Predicted Class)") |
|
|
|
explainer = get_explainer() |
|
attributions = explainer(reference) |
|
|
|
st.markdown(f"**Predicted label:** `{explainer.predicted_class_name}`") |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 1.5)) |
|
tokens, scores = zip(*attributions) |
|
ax.bar(range(len(scores)), scores) |
|
ax.set_xticks(range(len(tokens))) |
|
ax.set_xticklabels(tokens, rotation=90) |
|
ax.set_ylabel("Attribution Score") |
|
ax.set_title("Token Attribution (Integrated Gradients)") |
|
st.pyplot(fig) |
|
|
|
|
|
st.markdown("#### Highlighted Text Importance") |
|
html_output = explainer.visualize().data |
|
|
|
|
|
st.markdown(html_output, unsafe_allow_html=True) |
|
|