from datasets import load_dataset import requests import gzip import json import streamlit as st import pandas as pd from sklearn.metrics.pairwise import cosine_similarity from model2vec import StaticModel import matplotlib.pyplot as plt from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from torch.nn.functional import sigmoid from transformers_interpret import SequenceClassificationExplainer import os # Force all caches to writable directories os.environ["HF_HOME"] = "/tmp/hf" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers" os.environ["TORCH_HOME"] = "/tmp/torch" os.environ["HF_DATASETS_CACHE"] = "/tmp/hf/datasets" os.environ["XDG_CACHE_HOME"] = "/tmp/xdg" # For anything else using XDG # -- SETTINGS -- LABELS = [ 'inconclusive', 'animals', 'arts', 'autos', 'business', 'career', 'education', 'fashion', 'finance', 'food', 'government', 'health', 'hobbies', 'home', 'news', 'realestate', 'society', 'sports', 'tech', 'travel' ] label2id = {label: idx for idx, label in enumerate(LABELS)} id2label = {idx: label for label, idx in label2id.items()} REPO_ID = "chidamnat2002/iab_training_dataset" @st.cache_data def load_csv_data(): dataset = load_dataset(REPO_ID, split="train", data_files="train_df_simple.csv") df = pd.DataFrame(dataset) return df @st.cache_resource def get_model_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("chidamnat2002/content-multilabel-iab-classifier") model = AutoModelForSequenceClassification.from_pretrained("chidamnat2002/content-multilabel-iab-classifier") return model, tokenizer @st.cache_resource def get_explainer(): model, tokenizer = get_model_and_tokenizer() return SequenceClassificationExplainer(model, tokenizer) # -- LOAD MODEL & EMBEDDINGS -- @st.cache_resource def load_model(): return StaticModel.from_pretrained("minishlab/potion-retrieval-32M") # st.markdown("### ✨ Encode all examples") @st.cache_resource def encode_texts_cached(corpus): model = load_model() # use cached model return model.encode(corpus) @st.cache_data(show_spinner="Embedding reference", max_entries=50) def encode_reference(text: str): model = load_model() return model.encode([text])[0] @st.cache_resource def get_data_and_embeddings(): df = load_csv_data() texts = df["text"].to_list() prior_labels = df['label'].to_list() X = encode_texts_cached(texts) return texts, prior_labels, X st.set_page_config(page_title="IAB Classifier App", layout="wide") st.title("IAB Classifier App") # Load data texts, prior_labels, X = get_data_and_embeddings() st.markdown("### Reference sentence for similarity") reference = st.text_area("Type something like 'business related'") prediction_choice = st.checkbox("try our iab model prediction for this") def predict_content_multilabel(text, threshold=0.5, verbose=False): model, tokenizer = get_model_and_tokenizer() model.eval() text = text.replace("-", " ") with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) logits = model(**inputs).logits probs = sigmoid(logits).squeeze().cpu().numpy() predicted_labels = [(id2label[i], round(float(p), 3)) for i, p in enumerate(probs) if p >= threshold] probs_res = [prob for prob in probs if prob >= threshold] if verbose: st.write(f"Text: {text}") st.write("Predicted Labels:") return predicted_labels if reference: st.write("Labels loaded:", len(prior_labels)) query = encode_reference(reference) similarity = cosine_similarity([query], X)[0] df_emb = pd.DataFrame({ "text": texts, "sim": similarity, "label": prior_labels, }).sort_values("sim", ascending=False) top_size = st.slider("number of similar items", 1, 100, 5) top_candidates = [(row["text"], row["sim"], row["label"]) for row in df_emb.to_dict(orient="records")][:top_size] st.markdown("### Similar example(s)") if not top_candidates: st.info("No more similar examples.") else: st.write(pd.DataFrame(top_candidates, columns=['text', 'similarity_score', 'label'])) top_labelled_df = pd.DataFrame(top_candidates, columns=['text', 'similarity_score', 'label']) preds = dict(predict_content_multilabel(reference, threshold=0.2)) st.write(f"preds = {preds}") col1, col2 = st.columns(2) # Left: What training data says with col1: st.markdown("#### What Training Data Says") fig1, ax1 = plt.subplots() top_labelled_df['label'].value_counts(normalize=True).sort_values().plot(kind='barh', ax=ax1, color="lightcoral") ax1.set_title("Label Distribution") ax1.set_xlabel("Proportion") ax1.grid(True, axis='x', linestyle='--', alpha=0.5) st.pyplot(fig1) # Right: What model predicts with col2: st.markdown("#### Model Predictions") if len(preds) == 0 or not prediction_choice: st.write("Model is unsure") else: fig2, ax2 = plt.subplots() pd.Series(preds).sort_values().plot.barh(color="skyblue", ax=ax2) ax2.set_title("Predicted Probabilities") ax2.set_xlabel("Probability") ax2.grid(True, axis='x', linestyle='--', alpha=0.5) st.pyplot(fig2) if prediction_choice and reference: st.markdown("### Model Explanation (Top Predicted Class)") explainer = get_explainer() attributions = explainer(reference) st.markdown(f"**Predicted label:** `{explainer.predicted_class_name}`") # Token importance bar chart fig, ax = plt.subplots(figsize=(12, 1.5)) tokens, scores = zip(*attributions) ax.bar(range(len(scores)), scores) ax.set_xticks(range(len(tokens))) ax.set_xticklabels(tokens, rotation=90) ax.set_ylabel("Attribution Score") ax.set_title("Token Attribution (Integrated Gradients)") st.pyplot(fig) # HTML Highlighted Text st.markdown("#### Highlighted Text Importance") html_output = explainer.visualize().data # Render in Streamlit st.markdown(html_output, unsafe_allow_html=True)