Commit
Β·
c0e59d3
1
Parent(s):
75cc8bf
changed default g2v clustering to contrastive, and added filtering to ensure spans show
Browse files- utils/gram2vec_feat_utils.py +11 -3
- utils/interp_space_utils.py +20 -5
- utils/visualizations.py +26 -2
utils/gram2vec_feat_utils.py
CHANGED
@@ -198,10 +198,18 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
198 |
)
|
199 |
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
html_background_authors = create_html(
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
selected_feature_llm,
|
206 |
selected_feature_g2v,
|
207 |
short,
|
|
|
198 |
)
|
199 |
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
200 |
|
201 |
+
# Filter background authors to those with at least one Gram2Vec span
|
202 |
+
bg_start = 4
|
203 |
+
bg_indices = list(range(bg_start, len(texts)))
|
204 |
+
kept_indices = [i for i in bg_indices if gram_spans_list[i]]
|
205 |
+
filtered_texts_bg = [texts[i] for i in kept_indices]
|
206 |
+
filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
|
207 |
+
filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
|
208 |
+
|
209 |
html_background_authors = create_html(
|
210 |
+
filtered_texts_bg,
|
211 |
+
filtered_llm_bg,
|
212 |
+
filtered_gram_bg,
|
213 |
selected_feature_llm,
|
214 |
selected_feature_g2v,
|
215 |
short,
|
utils/interp_space_utils.py
CHANGED
@@ -528,7 +528,7 @@ def compute_clusters_g2v_representation(
|
|
528 |
other_author_ids: List[Any],
|
529 |
features_clm_name: str,
|
530 |
top_n: int = 10,
|
531 |
-
mode: str = "
|
532 |
sharedness_method: str = "mean_minus_alpha_std",
|
533 |
alpha: float = 0.5
|
534 |
) -> List[str]:
|
@@ -569,14 +569,29 @@ def compute_clusters_g2v_representation(
|
|
569 |
# Contrastive mode (default): compute target mean and subtract contrast mean
|
570 |
all_g2v_values = np.array([list(x.values()) for x in selected_feats]).mean(axis=0)
|
571 |
|
572 |
-
|
573 |
-
|
574 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
|
576 |
final_g2v_feats_values = all_g2v_values - all_g2v_other_values
|
577 |
|
578 |
|
579 |
-
|
|
|
|
|
|
|
|
|
580 |
|
581 |
# Filter out features that are not present in any of the authors
|
582 |
selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
|
|
|
528 |
other_author_ids: List[Any],
|
529 |
features_clm_name: str,
|
530 |
top_n: int = 10,
|
531 |
+
mode: str = "contrastive",
|
532 |
sharedness_method: str = "mean_minus_alpha_std",
|
533 |
alpha: float = 0.5
|
534 |
) -> List[str]:
|
|
|
569 |
# Contrastive mode (default): compute target mean and subtract contrast mean
|
570 |
all_g2v_values = np.array([list(x.values()) for x in selected_feats]).mean(axis=0)
|
571 |
|
572 |
+
# If an explicit contrast set is provided, use it; otherwise use everyone outside selection
|
573 |
+
if other_author_ids:
|
574 |
+
explicit_mask = background_corpus_df['authorID'].isin(other_author_ids).to_numpy()
|
575 |
+
# Ensure contrast set is disjoint from the selected set
|
576 |
+
contrast_mask = np.logical_and(explicit_mask, ~selected_mask)
|
577 |
+
else:
|
578 |
+
contrast_mask = ~selected_mask
|
579 |
+
|
580 |
+
other_selected_feats = background_corpus_df[contrast_mask][features_clm_name].tolist()
|
581 |
+
if len(other_selected_feats) > 0:
|
582 |
+
all_g2v_other_values = np.array([list(x.values()) for x in other_selected_feats]).mean(axis=0)
|
583 |
+
else:
|
584 |
+
# No contrast docs β treat contrast mean as zeros
|
585 |
+
all_g2v_other_values = np.zeros_like(all_g2v_values)
|
586 |
|
587 |
final_g2v_feats_values = all_g2v_values - all_g2v_other_values
|
588 |
|
589 |
|
590 |
+
# Keep only features that have a positive contrastive score
|
591 |
+
top_g2v_feats = sorted(
|
592 |
+
[(feat, val) for feat, val in zip(all_g2v_feats, final_g2v_feats_values) if val > 0],
|
593 |
+
key=lambda x: -x[1]
|
594 |
+
)
|
595 |
|
596 |
# Filter out features that are not present in any of the authors
|
597 |
selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
|
utils/visualizations.py
CHANGED
@@ -13,6 +13,7 @@ import re
|
|
13 |
from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation
|
14 |
from utils.llm_feat_utils import split_features
|
15 |
from utils.gram2vec_feat_utils import get_shorthand, get_fullform
|
|
|
16 |
|
17 |
import plotly.io as pio
|
18 |
|
@@ -251,9 +252,32 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
|
|
251 |
features_clm_name='g2v_vector'
|
252 |
)
|
253 |
|
254 |
-
#
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
for feat in g2v_feats:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
HR_g2v = get_fullform(feat)
|
258 |
print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
|
259 |
if HR_g2v is None:
|
|
|
13 |
from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation
|
14 |
from utils.llm_feat_utils import split_features
|
15 |
from utils.gram2vec_feat_utils import get_shorthand, get_fullform
|
16 |
+
from gram2vec.feature_locator import find_feature_spans
|
17 |
|
18 |
import plotly.io as pio
|
19 |
|
|
|
252 |
features_clm_name='g2v_vector'
|
253 |
)
|
254 |
|
255 |
+
# ββ Span-existence filter on task authors in the zoom βββββββββββββββββββ
|
256 |
+
# Keep only features that have at least one detected span in any of the
|
257 |
+
# visible task authors' texts
|
258 |
+
visible_task_authors = task_authors_df[task_authors_df['authorID'].isin(visible_authors)]
|
259 |
+
if visible_task_authors.empty:
|
260 |
+
visible_task_authors = task_authors_df
|
261 |
+
|
262 |
+
def _to_text(x):
|
263 |
+
return '\n\n =========== \n\n'.join(x) if isinstance(x, list) else x
|
264 |
+
|
265 |
+
task_texts = [_to_text(x) for x in visible_task_authors['fullText'].tolist()]
|
266 |
+
|
267 |
+
filtered_g2v_feats = []
|
268 |
for feat in g2v_feats:
|
269 |
+
try:
|
270 |
+
# `feat` is shorthand already (e.g., 'pos_bigrams:NOUN PROPN')
|
271 |
+
if any(find_feature_spans(txt, feat) for txt in task_texts):
|
272 |
+
filtered_g2v_feats.append(feat)
|
273 |
+
else:
|
274 |
+
print(f"[INFO] Dropping G2V feature with no spans in task texts: {feat}")
|
275 |
+
except Exception as e:
|
276 |
+
print(f"[WARN] Error while checking spans for {feat}: {e}")
|
277 |
+
|
278 |
+
# Convert to human readable for display
|
279 |
+
HR_g2v_list = []
|
280 |
+
for feat in filtered_g2v_feats:
|
281 |
HR_g2v = get_fullform(feat)
|
282 |
print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
|
283 |
if HR_g2v is None:
|