Chidam Gopal commited on
Commit
4749985
Β·
unverified Β·
1 Parent(s): c8ad9d9

iab classifier app

Browse files
Files changed (3) hide show
  1. .gitignore +166 -0
  2. requirements.txt +11 -3
  3. src/streamlit_app.py +186 -38
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+ eval_venv/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+ .python-version
165
+ models/
166
+ data/
requirements.txt CHANGED
@@ -1,3 +1,11 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ numpy
4
+ scikit-learn
5
+ model2vec
6
+ mohtml
7
+ streamlit
8
+ matplotlib
9
+ transformers-interpret
10
+ datasets
11
+ huggingface_hub
src/streamlit_app.py CHANGED
@@ -1,40 +1,188 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ from datasets import load_dataset
2
+ import requests
3
+ import gzip
4
+ import json
5
  import streamlit as st
6
+ import pandas as pd
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from model2vec import StaticModel
9
+ import matplotlib.pyplot as plt
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
+ import torch
12
+ from torch.nn.functional import sigmoid
13
+ from transformers_interpret import SequenceClassificationExplainer
14
+
15
+
16
+ # -- SETTINGS --
17
+ LABELS = [
18
+ 'inconclusive',
19
+ 'animals',
20
+ 'arts',
21
+ 'autos',
22
+ 'business',
23
+ 'career',
24
+ 'education',
25
+ 'fashion',
26
+ 'finance',
27
+ 'food',
28
+ 'government',
29
+ 'health',
30
+ 'hobbies',
31
+ 'home',
32
+ 'news',
33
+ 'realestate',
34
+ 'society',
35
+ 'sports',
36
+ 'tech',
37
+ 'travel'
38
+ ]
39
+
40
+ label2id = {label: idx for idx, label in enumerate(LABELS)}
41
+ id2label = {idx: label for label, idx in label2id.items()}
42
+
43
+ REPO_ID = "chidamnat2002/iab_training_dataset"
44
+
45
+ @st.cache_data
46
+ def load_csv_data():
47
+ dataset = load_dataset(REPO_ID, split="train", data_files="train_df_simple.csv")
48
+ df = pd.DataFrame(dataset)
49
+ return df
50
+
51
+ @st.cache_resource
52
+ def get_model_and_tokenizer():
53
+ tokenizer = AutoTokenizer.from_pretrained("chidamnat2002/content-multilabel-iab-classifier")
54
+ model = AutoModelForSequenceClassification.from_pretrained("chidamnat2002/content-multilabel-iab-classifier")
55
+ return model, tokenizer
56
+
57
+ @st.cache_resource
58
+ def get_explainer():
59
+ model, tokenizer = get_model_and_tokenizer()
60
+ return SequenceClassificationExplainer(model, tokenizer)
61
+
62
+ # -- LOAD MODEL & EMBEDDINGS --
63
+ @st.cache_resource
64
+ def load_model():
65
+ return StaticModel.from_pretrained("minishlab/potion-retrieval-32M")
66
+
67
+ # st.markdown("### ✨ Encode all examples")
68
+ @st.cache_resource
69
+ def encode_texts_cached(corpus):
70
+ model = load_model() # use cached model
71
+ return model.encode(corpus)
72
+
73
+ @st.cache_data(show_spinner="Embedding reference", max_entries=50)
74
+ def encode_reference(text: str):
75
+ model = load_model()
76
+ return model.encode([text])[0]
77
+
78
+ @st.cache_resource
79
+ def get_data_and_embeddings():
80
+ df = load_csv_data()
81
+ texts = df["text"].to_list()
82
+ prior_labels = df['label'].to_list()
83
+ X = encode_texts_cached(texts)
84
+ return texts, prior_labels, X
85
+
86
+
87
+ st.set_page_config(page_title="IAB Classifier App", layout="wide")
88
+ st.title("🧠 IAB Classifier App")
89
+
90
+ # Load data
91
+ texts, prior_labels, X = get_data_and_embeddings()
92
+ st.markdown("### 🧭 Reference sentence for similarity")
93
+ reference = st.text_area("Type something like 'business related'")
94
+ prediction_choice = st.checkbox("try our iab model prediction for this")
95
+
96
+
97
+ def predict_content_multilabel(text, threshold=0.5, verbose=False):
98
+ model, tokenizer = get_model_and_tokenizer()
99
+ model.eval()
100
+ text = text.replace("-", " ")
101
+ with torch.no_grad():
102
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
103
+ logits = model(**inputs).logits
104
+ probs = sigmoid(logits).squeeze().cpu().numpy()
105
+
106
+ predicted_labels = [(id2label[i], round(float(p), 3)) for i, p in enumerate(probs) if p >= threshold]
107
+ probs_res = [prob for prob in probs if prob >= threshold]
108
+
109
+ if verbose:
110
+ st.write(f"Text: {text}")
111
+ st.write("Predicted Labels:")
112
+
113
+ return predicted_labels
114
+
115
+ if reference:
116
+ st.write("Labels loaded:", len(prior_labels))
117
+
118
+ query = encode_reference(reference)
119
+ similarity = cosine_similarity([query], X)[0]
120
+
121
+ df_emb = pd.DataFrame({
122
+ "text": texts,
123
+ "sim": similarity,
124
+ "label": prior_labels,
125
+ }).sort_values("sim", ascending=False)
126
+
127
+ top_size = st.slider("number of similar items", 1, 100, 5)
128
+ top_candidates = [(row["text"], row["sim"], row["label"]) for row in df_emb.to_dict(orient="records")][:top_size]
129
+
130
+ st.markdown("### πŸ§ͺ Similar example(s)")
131
+ if not top_candidates:
132
+ st.info("No more similar examples.")
133
+ else:
134
+ st.write(pd.DataFrame(top_candidates, columns=['text', 'similarity_score', 'label']))
135
+ top_labelled_df = pd.DataFrame(top_candidates, columns=['text', 'similarity_score', 'label'])
136
+
137
+ preds = dict(predict_content_multilabel(reference, threshold=0.2))
138
+ st.write(f"preds = {preds}")
139
+
140
+ col1, col2 = st.columns(2)
141
+
142
+ # Left: What training data says
143
+ with col1:
144
+ st.markdown("#### πŸ“š What Training Data Says")
145
+ fig1, ax1 = plt.subplots()
146
+ top_labelled_df['label'].value_counts(normalize=True).sort_values().plot(kind='barh', ax=ax1, color="lightcoral")
147
+ ax1.set_title("Label Distribution")
148
+ ax1.set_xlabel("Proportion")
149
+ ax1.grid(True, axis='x', linestyle='--', alpha=0.5)
150
+ st.pyplot(fig1)
151
+
152
+ # Right: What model predicts
153
+ with col2:
154
+ st.markdown("#### πŸ€– Model Predictions")
155
+ if len(preds) == 0 or not prediction_choice:
156
+ st.write("Model is unsure")
157
+ else:
158
+ fig2, ax2 = plt.subplots()
159
+ pd.Series(preds).sort_values().plot.barh(color="skyblue", ax=ax2)
160
+ ax2.set_title("Predicted Probabilities")
161
+ ax2.set_xlabel("Probability")
162
+ ax2.grid(True, axis='x', linestyle='--', alpha=0.5)
163
+ st.pyplot(fig2)
164
+
165
+ if prediction_choice and reference:
166
+ st.markdown("### πŸ” Model Explanation (Top Predicted Class)")
167
+
168
+ explainer = get_explainer()
169
+ attributions = explainer(reference)
170
+
171
+ st.markdown(f"**Predicted label:** `{explainer.predicted_class_name}`")
172
+
173
+ # Token importance bar chart
174
+ fig, ax = plt.subplots(figsize=(12, 1.5))
175
+ tokens, scores = zip(*attributions)
176
+ ax.bar(range(len(scores)), scores)
177
+ ax.set_xticks(range(len(tokens)))
178
+ ax.set_xticklabels(tokens, rotation=90)
179
+ ax.set_ylabel("Attribution Score")
180
+ ax.set_title("Token Attribution (Integrated Gradients)")
181
+ st.pyplot(fig)
182
+
183
+ # HTML Highlighted Text
184
+ st.markdown("#### πŸ”Ž Highlighted Text Importance")
185
+ html_output = explainer.visualize().data
186
 
187
+ # Render in Streamlit
188
+ st.markdown(html_output, unsafe_allow_html=True)