Anisha Bhatnagar
commited on
Commit
·
74947b9
1
Parent(s):
8367823
show span in background authors issue fixed
Browse files- utils/gram2vec_feat_utils.py +18 -7
- utils/interp_space_utils.py +4 -3
- utils/llm_feat_utils.py +3 -0
- utils/visualizations.py +1 -1
utils/gram2vec_feat_utils.py
CHANGED
@@ -126,7 +126,7 @@ def highlight_both_spans(text, llm_spans, gram_spans):
|
|
126 |
|
127 |
|
128 |
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
129 |
-
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=
|
130 |
"""
|
131 |
For mystery + 3 candidates:
|
132 |
1. get llm spans via your existing cache+API
|
@@ -152,9 +152,11 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
152 |
|
153 |
if selected_feature_llm and selected_feature_llm != "None":
|
154 |
# print(llm_style_feats_analysis)
|
|
|
155 |
author_list = list(llm_style_feats_analysis['spans'].values())
|
156 |
llm_spans_list = []
|
157 |
for i, (_, txt) in enumerate(texts):
|
|
|
158 |
author_spans_list = []
|
159 |
for txt_span in author_list[i][selected_feature_llm]:
|
160 |
author_spans_list.append(Span(txt.find(txt_span), txt.find(txt_span) + len(txt_span)))
|
@@ -167,6 +169,8 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
167 |
if selected_feature_g2v and selected_feature_g2v != "None":
|
168 |
# get gram2vec spans
|
169 |
gram_spans_list = []
|
|
|
|
|
170 |
print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
|
171 |
short = get_shorthand(selected_feature_g2v)
|
172 |
print(f"short hand: {short}")
|
@@ -199,14 +203,19 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
199 |
)
|
200 |
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
201 |
|
|
|
|
|
202 |
# Filter background authors to those with at least one Gram2Vec span
|
203 |
bg_start = 4
|
204 |
bg_indices = list(range(bg_start, len(texts)))
|
205 |
kept_indices = [i for i in bg_indices if gram_spans_list[i]]
|
|
|
206 |
filtered_texts_bg = [texts[i] for i in kept_indices]
|
207 |
filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
|
208 |
filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
|
209 |
|
|
|
|
|
210 |
html_background_authors = create_html(
|
211 |
filtered_texts_bg,
|
212 |
filtered_llm_bg,
|
@@ -219,6 +228,7 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
219 |
ground_truth_author=ground_truth_author
|
220 |
)
|
221 |
background_html = "<div>" + "\n<hr>\n".join(html_background_authors) + "</div>"
|
|
|
222 |
return combined_html, background_html
|
223 |
|
224 |
def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id: int=0) -> str:
|
@@ -230,26 +240,27 @@ def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id
|
|
230 |
return "Mystery Author"
|
231 |
elif label.startswith("a0_author") or label.startswith("a1_author") or label.startswith("a2_author") or label.startswith("Candidate"):
|
232 |
if label.startswith("Candidate"):
|
233 |
-
id = int(label.split(" ")[2]) # Get the number after 'Candidate Author'
|
234 |
else:
|
235 |
id = label.split("_")[0][-1] # Get the last character of the first part (a0, a1, a2)
|
236 |
if predicted_author is not None and ground_truth_author is not None:
|
237 |
if int(id) == predicted_author and int(id) == ground_truth_author:
|
238 |
-
return f"Candidate {int(id)} (Predicted & Ground Truth)"
|
239 |
elif int(id) == predicted_author:
|
240 |
-
return f"Candidate {int(id)} (Predicted)"
|
241 |
elif int(id) == ground_truth_author:
|
242 |
-
return f"Candidate {int(id)} (Ground Truth)"
|
243 |
else:
|
244 |
-
return f"Candidate {int(id)}"
|
245 |
else:
|
246 |
-
return f"Candidate {int(id)}"
|
247 |
else:
|
248 |
return f"Background Author {bg_id+1}"
|
249 |
|
250 |
def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
|
251 |
html = []
|
252 |
for i, (label, txt) in enumerate(texts):
|
|
|
253 |
label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
|
254 |
combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
|
255 |
notice = ""
|
|
|
126 |
|
127 |
|
128 |
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
129 |
+
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=7):
|
130 |
"""
|
131 |
For mystery + 3 candidates:
|
132 |
1. get llm spans via your existing cache+API
|
|
|
152 |
|
153 |
if selected_feature_llm and selected_feature_llm != "None":
|
154 |
# print(llm_style_feats_analysis)
|
155 |
+
print(f"{len(llm_style_feats_analysis['spans'].values())}")
|
156 |
author_list = list(llm_style_feats_analysis['spans'].values())
|
157 |
llm_spans_list = []
|
158 |
for i, (_, txt) in enumerate(texts):
|
159 |
+
print(f"{i}/{len(texts)}")
|
160 |
author_spans_list = []
|
161 |
for txt_span in author_list[i][selected_feature_llm]:
|
162 |
author_spans_list.append(Span(txt.find(txt_span), txt.find(txt_span) + len(txt_span)))
|
|
|
169 |
if selected_feature_g2v and selected_feature_g2v != "None":
|
170 |
# get gram2vec spans
|
171 |
gram_spans_list = []
|
172 |
+
# clean the display string and get the feature name without the zscore
|
173 |
+
selected_feature_g2v = selected_feature_g2v.split(" | [Z=")[0].strip()
|
174 |
print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
|
175 |
short = get_shorthand(selected_feature_g2v)
|
176 |
print(f"short hand: {short}")
|
|
|
203 |
)
|
204 |
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
205 |
|
206 |
+
# print(f"\n\n\n\n{texts[4:]}")
|
207 |
+
|
208 |
# Filter background authors to those with at least one Gram2Vec span
|
209 |
bg_start = 4
|
210 |
bg_indices = list(range(bg_start, len(texts)))
|
211 |
kept_indices = [i for i in bg_indices if gram_spans_list[i]]
|
212 |
+
print(f"\n---> {kept_indices}")
|
213 |
filtered_texts_bg = [texts[i] for i in kept_indices]
|
214 |
filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
|
215 |
filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
|
216 |
|
217 |
+
print(filtered_texts_bg)
|
218 |
+
|
219 |
html_background_authors = create_html(
|
220 |
filtered_texts_bg,
|
221 |
filtered_llm_bg,
|
|
|
228 |
ground_truth_author=ground_truth_author
|
229 |
)
|
230 |
background_html = "<div>" + "\n<hr>\n".join(html_background_authors) + "</div>"
|
231 |
+
# print(f"Background HTML: {background_html}")
|
232 |
return combined_html, background_html
|
233 |
|
234 |
def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id: int=0) -> str:
|
|
|
240 |
return "Mystery Author"
|
241 |
elif label.startswith("a0_author") or label.startswith("a1_author") or label.startswith("a2_author") or label.startswith("Candidate"):
|
242 |
if label.startswith("Candidate"):
|
243 |
+
id = int(label.split(" ")[2])-1 # Get the number after 'Candidate Author'; convert to 0 index
|
244 |
else:
|
245 |
id = label.split("_")[0][-1] # Get the last character of the first part (a0, a1, a2)
|
246 |
if predicted_author is not None and ground_truth_author is not None:
|
247 |
if int(id) == predicted_author and int(id) == ground_truth_author:
|
248 |
+
return f"Candidate {int(id)+1} (Predicted & Ground Truth)"
|
249 |
elif int(id) == predicted_author:
|
250 |
+
return f"Candidate {int(id)+1} (Predicted)"
|
251 |
elif int(id) == ground_truth_author:
|
252 |
+
return f"Candidate {int(id)+1} (Ground Truth)"
|
253 |
else:
|
254 |
+
return f"Candidate {int(id)+1}"
|
255 |
else:
|
256 |
+
return f"Candidate {int(id)+1}"
|
257 |
else:
|
258 |
return f"Background Author {bg_id+1}"
|
259 |
|
260 |
def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
|
261 |
html = []
|
262 |
for i, (label, txt) in enumerate(texts):
|
263 |
+
print(i, label, txt[:30])
|
264 |
label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
|
265 |
combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
|
266 |
notice = ""
|
utils/interp_space_utils.py
CHANGED
@@ -521,7 +521,8 @@ def compute_clusters_style_representation_3(
|
|
521 |
cluster_label_clm_name: str = 'authorID',
|
522 |
max_num_feats: int = 10,
|
523 |
max_num_documents_per_author=3,
|
524 |
-
max_num_authors=5
|
|
|
525 |
):
|
526 |
|
527 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
@@ -537,8 +538,8 @@ def compute_clusters_style_representation_3(
|
|
537 |
features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
|
538 |
|
539 |
# STEP 2: Prepare author pool for span extraction
|
540 |
-
span_df = background_corpus_df.iloc[:
|
541 |
-
author_names = span_df[cluster_label_clm_name].tolist()[:
|
542 |
print(f"Number of authors for span detection : {len(span_df)}")
|
543 |
print(author_names)
|
544 |
spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
|
|
|
521 |
cluster_label_clm_name: str = 'authorID',
|
522 |
max_num_feats: int = 10,
|
523 |
max_num_documents_per_author=3,
|
524 |
+
max_num_authors=5,
|
525 |
+
max_authors_for_span_extraction=7
|
526 |
):
|
527 |
|
528 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
|
538 |
features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
|
539 |
|
540 |
# STEP 2: Prepare author pool for span extraction
|
541 |
+
span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
|
542 |
+
author_names = span_df[cluster_label_clm_name].tolist()[:max_authors_for_span_extraction]
|
543 |
print(f"Number of authors for span detection : {len(span_df)}")
|
544 |
print(author_names)
|
545 |
spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
|
utils/llm_feat_utils.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
import hashlib
|
4 |
import time
|
5 |
from json import JSONDecodeError
|
|
|
6 |
|
7 |
CACHE_DIR = "datasets/feature_spans_cache"
|
8 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
@@ -66,10 +67,12 @@ def generate_feature_spans_with_retries(client, text: str, features: list[str])
|
|
66 |
for attempt in range(MAX_ATTEMPTS):
|
67 |
try:
|
68 |
response_str = generate_feature_spans(client, text, features)
|
|
|
69 |
result = json.loads(response_str)
|
70 |
return result
|
71 |
except (JSONDecodeError, ValueError) as e:
|
72 |
print(f"Attempt {attempt+1} failed: {e}")
|
|
|
73 |
if attempt < MAX_ATTEMPTS - 1:
|
74 |
wait_sec = WAIT_SECONDS * (2 ** attempt)
|
75 |
print(f"Retrying after {wait_sec} seconds...")
|
|
|
3 |
import hashlib
|
4 |
import time
|
5 |
from json import JSONDecodeError
|
6 |
+
import traceback
|
7 |
|
8 |
CACHE_DIR = "datasets/feature_spans_cache"
|
9 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
67 |
for attempt in range(MAX_ATTEMPTS):
|
68 |
try:
|
69 |
response_str = generate_feature_spans(client, text, features)
|
70 |
+
print(response_str)
|
71 |
result = json.loads(response_str)
|
72 |
return result
|
73 |
except (JSONDecodeError, ValueError) as e:
|
74 |
print(f"Attempt {attempt+1} failed: {e}")
|
75 |
+
traceback.print_exc()
|
76 |
if attempt < MAX_ATTEMPTS - 1:
|
77 |
wait_sec = WAIT_SECONDS * (2 ** attempt)
|
78 |
print(f"Retrying after {wait_sec} seconds...")
|
utils/visualizations.py
CHANGED
@@ -225,7 +225,7 @@ def format_g2v_features_for_display(g2v_features_with_scores):
|
|
225 |
z_score = float(z_score)
|
226 |
|
227 |
# Create display string with z-score
|
228 |
-
display_string = f"{feature_name} | Z={z_score:.2f}]"
|
229 |
display_choices.append(display_string)
|
230 |
original_values.append(feature_name)
|
231 |
else:
|
|
|
225 |
z_score = float(z_score)
|
226 |
|
227 |
# Create display string with z-score
|
228 |
+
display_string = f"{feature_name} | [Z={z_score:.2f}]"
|
229 |
display_choices.append(display_string)
|
230 |
original_values.append(feature_name)
|
231 |
else:
|