peter-zeng commited on
Commit
c0e59d3
Β·
1 Parent(s): 75cc8bf

changed default g2v clustering to contrastive, and added filtering to ensure spans show

Browse files
utils/gram2vec_feat_utils.py CHANGED
@@ -198,10 +198,18 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
198
  )
199
  combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
200
 
 
 
 
 
 
 
 
 
201
  html_background_authors = create_html(
202
- texts[4:], #last three are background
203
- llm_spans_list,
204
- gram_spans_list,
205
  selected_feature_llm,
206
  selected_feature_g2v,
207
  short,
 
198
  )
199
  combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
200
 
201
+ # Filter background authors to those with at least one Gram2Vec span
202
+ bg_start = 4
203
+ bg_indices = list(range(bg_start, len(texts)))
204
+ kept_indices = [i for i in bg_indices if gram_spans_list[i]]
205
+ filtered_texts_bg = [texts[i] for i in kept_indices]
206
+ filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
207
+ filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
208
+
209
  html_background_authors = create_html(
210
+ filtered_texts_bg,
211
+ filtered_llm_bg,
212
+ filtered_gram_bg,
213
  selected_feature_llm,
214
  selected_feature_g2v,
215
  short,
utils/interp_space_utils.py CHANGED
@@ -528,7 +528,7 @@ def compute_clusters_g2v_representation(
528
  other_author_ids: List[Any],
529
  features_clm_name: str,
530
  top_n: int = 10,
531
- mode: str = "sharedness",
532
  sharedness_method: str = "mean_minus_alpha_std",
533
  alpha: float = 0.5
534
  ) -> List[str]:
@@ -569,14 +569,29 @@ def compute_clusters_g2v_representation(
569
  # Contrastive mode (default): compute target mean and subtract contrast mean
570
  all_g2v_values = np.array([list(x.values()) for x in selected_feats]).mean(axis=0)
571
 
572
- other_selected_feats = background_corpus_df[~selected_mask][features_clm_name].tolist()
573
- all_g2v_other_feats = list(other_selected_feats[0].keys())
574
- all_g2v_other_values = np.array([list(x.values()) for x in other_selected_feats]).mean(axis=0)
 
 
 
 
 
 
 
 
 
 
 
575
 
576
  final_g2v_feats_values = all_g2v_values - all_g2v_other_values
577
 
578
 
579
- top_g2v_feats = sorted(list(zip(all_g2v_feats, final_g2v_feats_values)), key=lambda x: -x[1])
 
 
 
 
580
 
581
  # Filter out features that are not present in any of the authors
582
  selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
 
528
  other_author_ids: List[Any],
529
  features_clm_name: str,
530
  top_n: int = 10,
531
+ mode: str = "contrastive",
532
  sharedness_method: str = "mean_minus_alpha_std",
533
  alpha: float = 0.5
534
  ) -> List[str]:
 
569
  # Contrastive mode (default): compute target mean and subtract contrast mean
570
  all_g2v_values = np.array([list(x.values()) for x in selected_feats]).mean(axis=0)
571
 
572
+ # If an explicit contrast set is provided, use it; otherwise use everyone outside selection
573
+ if other_author_ids:
574
+ explicit_mask = background_corpus_df['authorID'].isin(other_author_ids).to_numpy()
575
+ # Ensure contrast set is disjoint from the selected set
576
+ contrast_mask = np.logical_and(explicit_mask, ~selected_mask)
577
+ else:
578
+ contrast_mask = ~selected_mask
579
+
580
+ other_selected_feats = background_corpus_df[contrast_mask][features_clm_name].tolist()
581
+ if len(other_selected_feats) > 0:
582
+ all_g2v_other_values = np.array([list(x.values()) for x in other_selected_feats]).mean(axis=0)
583
+ else:
584
+ # No contrast docs β†’ treat contrast mean as zeros
585
+ all_g2v_other_values = np.zeros_like(all_g2v_values)
586
 
587
  final_g2v_feats_values = all_g2v_values - all_g2v_other_values
588
 
589
 
590
+ # Keep only features that have a positive contrastive score
591
+ top_g2v_feats = sorted(
592
+ [(feat, val) for feat, val in zip(all_g2v_feats, final_g2v_feats_values) if val > 0],
593
+ key=lambda x: -x[1]
594
+ )
595
 
596
  # Filter out features that are not present in any of the authors
597
  selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
utils/visualizations.py CHANGED
@@ -13,6 +13,7 @@ import re
13
  from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation
14
  from utils.llm_feat_utils import split_features
15
  from utils.gram2vec_feat_utils import get_shorthand, get_fullform
 
16
 
17
  import plotly.io as pio
18
 
@@ -251,9 +252,32 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
251
  features_clm_name='g2v_vector'
252
  )
253
 
254
- # Gram2vec features are already in shorthand. convert to human readable for display
255
- HR_g2v_list = []
 
 
 
 
 
 
 
 
 
 
 
256
  for feat in g2v_feats:
 
 
 
 
 
 
 
 
 
 
 
 
257
  HR_g2v = get_fullform(feat)
258
  print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
259
  if HR_g2v is None:
 
13
  from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation
14
  from utils.llm_feat_utils import split_features
15
  from utils.gram2vec_feat_utils import get_shorthand, get_fullform
16
+ from gram2vec.feature_locator import find_feature_spans
17
 
18
  import plotly.io as pio
19
 
 
252
  features_clm_name='g2v_vector'
253
  )
254
 
255
+ # ── Span-existence filter on task authors in the zoom ───────────────────
256
+ # Keep only features that have at least one detected span in any of the
257
+ # visible task authors' texts
258
+ visible_task_authors = task_authors_df[task_authors_df['authorID'].isin(visible_authors)]
259
+ if visible_task_authors.empty:
260
+ visible_task_authors = task_authors_df
261
+
262
+ def _to_text(x):
263
+ return '\n\n =========== \n\n'.join(x) if isinstance(x, list) else x
264
+
265
+ task_texts = [_to_text(x) for x in visible_task_authors['fullText'].tolist()]
266
+
267
+ filtered_g2v_feats = []
268
  for feat in g2v_feats:
269
+ try:
270
+ # `feat` is shorthand already (e.g., 'pos_bigrams:NOUN PROPN')
271
+ if any(find_feature_spans(txt, feat) for txt in task_texts):
272
+ filtered_g2v_feats.append(feat)
273
+ else:
274
+ print(f"[INFO] Dropping G2V feature with no spans in task texts: {feat}")
275
+ except Exception as e:
276
+ print(f"[WARN] Error while checking spans for {feat}: {e}")
277
+
278
+ # Convert to human readable for display
279
+ HR_g2v_list = []
280
+ for feat in filtered_g2v_feats:
281
  HR_g2v = get_fullform(feat)
282
  print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
283
  if HR_g2v is None: