iab_content_classifier / src /streamlit_app.py
Chidam Gopal
iab classification model updates
041cf4e unverified
from datasets import load_dataset
import requests
import gzip
import json
import streamlit as st
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from model2vec import StaticModel
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.nn.functional import sigmoid
from transformers_interpret import SequenceClassificationExplainer
import os
# Force all caches to writable directories
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
os.environ["TORCH_HOME"] = "/tmp/torch"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf/datasets"
os.environ["XDG_CACHE_HOME"] = "/tmp/xdg" # For anything else using XDG
# -- SETTINGS --
LABELS = [
'inconclusive',
'animals',
'arts',
'autos',
'business',
'career',
'education',
'fashion',
'finance',
'food',
'government',
'health',
'hobbies',
'home',
'news',
'realestate',
'society',
'sports',
'tech',
'travel'
]
label2id = {label: idx for idx, label in enumerate(LABELS)}
id2label = {idx: label for label, idx in label2id.items()}
REPO_ID = "chidamnat2002/iab_training_dataset"
@st.cache_data
def load_csv_data():
dataset = load_dataset(REPO_ID, split="train", data_files="train_df_simple.csv")
df = pd.DataFrame(dataset)
return df
@st.cache_resource
def get_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("chidamnat2002/content-multilabel-iab-classifier")
model = AutoModelForSequenceClassification.from_pretrained("chidamnat2002/content-multilabel-iab-classifier")
return model, tokenizer
@st.cache_resource
def get_explainer():
model, tokenizer = get_model_and_tokenizer()
return SequenceClassificationExplainer(model, tokenizer)
# -- LOAD MODEL & EMBEDDINGS --
@st.cache_resource
def load_model():
return StaticModel.from_pretrained("minishlab/potion-retrieval-32M")
# st.markdown("### ✨ Encode all examples")
@st.cache_resource
def encode_texts_cached(corpus):
model = load_model() # use cached model
return model.encode(corpus)
@st.cache_data(show_spinner="Embedding reference", max_entries=50)
def encode_reference(text: str):
model = load_model()
return model.encode([text])[0]
@st.cache_resource
def get_data_and_embeddings():
df = load_csv_data()
texts = df["text"].to_list()
prior_labels = df['label'].to_list()
X = encode_texts_cached(texts)
return texts, prior_labels, X
st.set_page_config(page_title="IAB Classifier App", layout="wide")
st.title("IAB Classifier App")
# Load data
texts, prior_labels, X = get_data_and_embeddings()
st.markdown("### Reference sentence for similarity")
reference = st.text_area("Type something like 'business related'")
prediction_choice = st.checkbox("try our iab model prediction for this")
def predict_content_multilabel(text, threshold=0.5, verbose=False):
model, tokenizer = get_model_and_tokenizer()
model.eval()
text = text.replace("-", " ")
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
logits = model(**inputs).logits
probs = sigmoid(logits).squeeze().cpu().numpy()
predicted_labels = [(id2label[i], round(float(p), 3)) for i, p in enumerate(probs) if p >= threshold]
probs_res = [prob for prob in probs if prob >= threshold]
if verbose:
st.write(f"Text: {text}")
st.write("Predicted Labels:")
return predicted_labels
if reference:
st.write("Labels loaded:", len(prior_labels))
query = encode_reference(reference)
similarity = cosine_similarity([query], X)[0]
df_emb = pd.DataFrame({
"text": texts,
"sim": similarity,
"label": prior_labels,
}).sort_values("sim", ascending=False)
top_size = st.slider("number of similar items", 1, 100, 5)
top_candidates = [(row["text"], row["sim"], row["label"]) for row in df_emb.to_dict(orient="records")][:top_size]
st.markdown("### Similar example(s)")
if not top_candidates:
st.info("No more similar examples.")
else:
st.write(pd.DataFrame(top_candidates, columns=['text', 'similarity_score', 'label']))
top_labelled_df = pd.DataFrame(top_candidates, columns=['text', 'similarity_score', 'label'])
preds = dict(predict_content_multilabel(reference, threshold=0.2))
st.write(f"preds = {preds}")
col1, col2 = st.columns(2)
# Left: What training data says
with col1:
st.markdown("#### What Training Data Says")
fig1, ax1 = plt.subplots()
top_labelled_df['label'].value_counts(normalize=True).sort_values().plot(kind='barh', ax=ax1, color="lightcoral")
ax1.set_title("Label Distribution")
ax1.set_xlabel("Proportion")
ax1.grid(True, axis='x', linestyle='--', alpha=0.5)
st.pyplot(fig1)
# Right: What model predicts
with col2:
st.markdown("#### Model Predictions")
if len(preds) == 0 or not prediction_choice:
st.write("Model is unsure")
else:
fig2, ax2 = plt.subplots()
pd.Series(preds).sort_values().plot.barh(color="skyblue", ax=ax2)
ax2.set_title("Predicted Probabilities")
ax2.set_xlabel("Probability")
ax2.grid(True, axis='x', linestyle='--', alpha=0.5)
st.pyplot(fig2)
if prediction_choice and reference:
st.markdown("### Model Explanation (Top Predicted Class)")
explainer = get_explainer()
attributions = explainer(reference)
st.markdown(f"**Predicted label:** `{explainer.predicted_class_name}`")
# Token importance bar chart
fig, ax = plt.subplots(figsize=(12, 1.5))
tokens, scores = zip(*attributions)
ax.bar(range(len(scores)), scores)
ax.set_xticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_ylabel("Attribution Score")
ax.set_title("Token Attribution (Integrated Gradients)")
st.pyplot(fig)
# HTML Highlighted Text
st.markdown("#### Highlighted Text Importance")
html_output = explainer.visualize().data
# Render in Streamlit
st.markdown(html_output, unsafe_allow_html=True)