Anisha Bhatnagar commited on
Commit
74947b9
·
1 Parent(s): 8367823

show span in background authors issue fixed

Browse files
utils/gram2vec_feat_utils.py CHANGED
@@ -126,7 +126,7 @@ def highlight_both_spans(text, llm_spans, gram_spans):
126
 
127
 
128
  def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
129
- llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=4):
130
  """
131
  For mystery + 3 candidates:
132
  1. get llm spans via your existing cache+API
@@ -152,9 +152,11 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
152
 
153
  if selected_feature_llm and selected_feature_llm != "None":
154
  # print(llm_style_feats_analysis)
 
155
  author_list = list(llm_style_feats_analysis['spans'].values())
156
  llm_spans_list = []
157
  for i, (_, txt) in enumerate(texts):
 
158
  author_spans_list = []
159
  for txt_span in author_list[i][selected_feature_llm]:
160
  author_spans_list.append(Span(txt.find(txt_span), txt.find(txt_span) + len(txt_span)))
@@ -167,6 +169,8 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
167
  if selected_feature_g2v and selected_feature_g2v != "None":
168
  # get gram2vec spans
169
  gram_spans_list = []
 
 
170
  print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
171
  short = get_shorthand(selected_feature_g2v)
172
  print(f"short hand: {short}")
@@ -199,14 +203,19 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
199
  )
200
  combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
201
 
 
 
202
  # Filter background authors to those with at least one Gram2Vec span
203
  bg_start = 4
204
  bg_indices = list(range(bg_start, len(texts)))
205
  kept_indices = [i for i in bg_indices if gram_spans_list[i]]
 
206
  filtered_texts_bg = [texts[i] for i in kept_indices]
207
  filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
208
  filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
209
 
 
 
210
  html_background_authors = create_html(
211
  filtered_texts_bg,
212
  filtered_llm_bg,
@@ -219,6 +228,7 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
219
  ground_truth_author=ground_truth_author
220
  )
221
  background_html = "<div>" + "\n<hr>\n".join(html_background_authors) + "</div>"
 
222
  return combined_html, background_html
223
 
224
  def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id: int=0) -> str:
@@ -230,26 +240,27 @@ def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id
230
  return "Mystery Author"
231
  elif label.startswith("a0_author") or label.startswith("a1_author") or label.startswith("a2_author") or label.startswith("Candidate"):
232
  if label.startswith("Candidate"):
233
- id = int(label.split(" ")[2]) # Get the number after 'Candidate Author'
234
  else:
235
  id = label.split("_")[0][-1] # Get the last character of the first part (a0, a1, a2)
236
  if predicted_author is not None and ground_truth_author is not None:
237
  if int(id) == predicted_author and int(id) == ground_truth_author:
238
- return f"Candidate {int(id)} (Predicted & Ground Truth)"
239
  elif int(id) == predicted_author:
240
- return f"Candidate {int(id)} (Predicted)"
241
  elif int(id) == ground_truth_author:
242
- return f"Candidate {int(id)} (Ground Truth)"
243
  else:
244
- return f"Candidate {int(id)}"
245
  else:
246
- return f"Candidate {int(id)}"
247
  else:
248
  return f"Background Author {bg_id+1}"
249
 
250
  def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
251
  html = []
252
  for i, (label, txt) in enumerate(texts):
 
253
  label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
254
  combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
255
  notice = ""
 
126
 
127
 
128
  def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
129
+ llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=7):
130
  """
131
  For mystery + 3 candidates:
132
  1. get llm spans via your existing cache+API
 
152
 
153
  if selected_feature_llm and selected_feature_llm != "None":
154
  # print(llm_style_feats_analysis)
155
+ print(f"{len(llm_style_feats_analysis['spans'].values())}")
156
  author_list = list(llm_style_feats_analysis['spans'].values())
157
  llm_spans_list = []
158
  for i, (_, txt) in enumerate(texts):
159
+ print(f"{i}/{len(texts)}")
160
  author_spans_list = []
161
  for txt_span in author_list[i][selected_feature_llm]:
162
  author_spans_list.append(Span(txt.find(txt_span), txt.find(txt_span) + len(txt_span)))
 
169
  if selected_feature_g2v and selected_feature_g2v != "None":
170
  # get gram2vec spans
171
  gram_spans_list = []
172
+ # clean the display string and get the feature name without the zscore
173
+ selected_feature_g2v = selected_feature_g2v.split(" | [Z=")[0].strip()
174
  print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
175
  short = get_shorthand(selected_feature_g2v)
176
  print(f"short hand: {short}")
 
203
  )
204
  combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
205
 
206
+ # print(f"\n\n\n\n{texts[4:]}")
207
+
208
  # Filter background authors to those with at least one Gram2Vec span
209
  bg_start = 4
210
  bg_indices = list(range(bg_start, len(texts)))
211
  kept_indices = [i for i in bg_indices if gram_spans_list[i]]
212
+ print(f"\n---> {kept_indices}")
213
  filtered_texts_bg = [texts[i] for i in kept_indices]
214
  filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
215
  filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
216
 
217
+ print(filtered_texts_bg)
218
+
219
  html_background_authors = create_html(
220
  filtered_texts_bg,
221
  filtered_llm_bg,
 
228
  ground_truth_author=ground_truth_author
229
  )
230
  background_html = "<div>" + "\n<hr>\n".join(html_background_authors) + "</div>"
231
+ # print(f"Background HTML: {background_html}")
232
  return combined_html, background_html
233
 
234
  def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id: int=0) -> str:
 
240
  return "Mystery Author"
241
  elif label.startswith("a0_author") or label.startswith("a1_author") or label.startswith("a2_author") or label.startswith("Candidate"):
242
  if label.startswith("Candidate"):
243
+ id = int(label.split(" ")[2])-1 # Get the number after 'Candidate Author'; convert to 0 index
244
  else:
245
  id = label.split("_")[0][-1] # Get the last character of the first part (a0, a1, a2)
246
  if predicted_author is not None and ground_truth_author is not None:
247
  if int(id) == predicted_author and int(id) == ground_truth_author:
248
+ return f"Candidate {int(id)+1} (Predicted & Ground Truth)"
249
  elif int(id) == predicted_author:
250
+ return f"Candidate {int(id)+1} (Predicted)"
251
  elif int(id) == ground_truth_author:
252
+ return f"Candidate {int(id)+1} (Ground Truth)"
253
  else:
254
+ return f"Candidate {int(id)+1}"
255
  else:
256
+ return f"Candidate {int(id)+1}"
257
  else:
258
  return f"Background Author {bg_id+1}"
259
 
260
  def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
261
  html = []
262
  for i, (label, txt) in enumerate(texts):
263
+ print(i, label, txt[:30])
264
  label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
265
  combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
266
  notice = ""
utils/interp_space_utils.py CHANGED
@@ -521,7 +521,8 @@ def compute_clusters_style_representation_3(
521
  cluster_label_clm_name: str = 'authorID',
522
  max_num_feats: int = 10,
523
  max_num_documents_per_author=3,
524
- max_num_authors=5
 
525
  ):
526
 
527
  print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
@@ -537,8 +538,8 @@ def compute_clusters_style_representation_3(
537
  features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
538
 
539
  # STEP 2: Prepare author pool for span extraction
540
- span_df = background_corpus_df.iloc[:4]
541
- author_names = span_df[cluster_label_clm_name].tolist()[:4]
542
  print(f"Number of authors for span detection : {len(span_df)}")
543
  print(author_names)
544
  spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
 
521
  cluster_label_clm_name: str = 'authorID',
522
  max_num_feats: int = 10,
523
  max_num_documents_per_author=3,
524
+ max_num_authors=5,
525
+ max_authors_for_span_extraction=7
526
  ):
527
 
528
  print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
 
538
  features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
539
 
540
  # STEP 2: Prepare author pool for span extraction
541
+ span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
542
+ author_names = span_df[cluster_label_clm_name].tolist()[:max_authors_for_span_extraction]
543
  print(f"Number of authors for span detection : {len(span_df)}")
544
  print(author_names)
545
  spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
utils/llm_feat_utils.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import hashlib
4
  import time
5
  from json import JSONDecodeError
 
6
 
7
  CACHE_DIR = "datasets/feature_spans_cache"
8
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -66,10 +67,12 @@ def generate_feature_spans_with_retries(client, text: str, features: list[str])
66
  for attempt in range(MAX_ATTEMPTS):
67
  try:
68
  response_str = generate_feature_spans(client, text, features)
 
69
  result = json.loads(response_str)
70
  return result
71
  except (JSONDecodeError, ValueError) as e:
72
  print(f"Attempt {attempt+1} failed: {e}")
 
73
  if attempt < MAX_ATTEMPTS - 1:
74
  wait_sec = WAIT_SECONDS * (2 ** attempt)
75
  print(f"Retrying after {wait_sec} seconds...")
 
3
  import hashlib
4
  import time
5
  from json import JSONDecodeError
6
+ import traceback
7
 
8
  CACHE_DIR = "datasets/feature_spans_cache"
9
  os.makedirs(CACHE_DIR, exist_ok=True)
 
67
  for attempt in range(MAX_ATTEMPTS):
68
  try:
69
  response_str = generate_feature_spans(client, text, features)
70
+ print(response_str)
71
  result = json.loads(response_str)
72
  return result
73
  except (JSONDecodeError, ValueError) as e:
74
  print(f"Attempt {attempt+1} failed: {e}")
75
+ traceback.print_exc()
76
  if attempt < MAX_ATTEMPTS - 1:
77
  wait_sec = WAIT_SECONDS * (2 ** attempt)
78
  print(f"Retrying after {wait_sec} seconds...")
utils/visualizations.py CHANGED
@@ -225,7 +225,7 @@ def format_g2v_features_for_display(g2v_features_with_scores):
225
  z_score = float(z_score)
226
 
227
  # Create display string with z-score
228
- display_string = f"{feature_name} | Z={z_score:.2f}]"
229
  display_choices.append(display_string)
230
  original_values.append(feature_name)
231
  else:
 
225
  z_score = float(z_score)
226
 
227
  # Create display string with z-score
228
+ display_string = f"{feature_name} | [Z={z_score:.2f}]"
229
  display_choices.append(display_string)
230
  original_values.append(feature_name)
231
  else: