Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
from tqdm import tqdm | |
import numpy as np | |
import random | |
import torch | |
import ast | |
from difflib import HtmlDiff | |
from src.kg.main import script2kg | |
from src.summary.summarizer import Summarizer | |
from src.summary.utils import preprocess_script, chunk_script_gpt | |
from src.summary.prompt import build_summarizer_prompt | |
from src.fact.narrativefactscore import NarrativeFactScore | |
def _set_seed(seed): | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def parse_scenes(scene_text): | |
try: | |
return json.loads(scene_text) | |
except json.JSONDecodeError: | |
return ast.literal_eval(scene_text) | |
def set_name_list(dataset, data_type): | |
if dataset == "MovieSum": | |
if data_type == "train": | |
return ['8MM_1999', 'The Iron Lady_2011', 'Adventureland_2009', 'Napoleon_2023', | |
'Kubo and the Two Strings_2016', 'The Woman King_2022', 'What They Had_2018', | |
'Synecdoche, New York_2008', 'Black Christmas_2006', 'Superbad_2007'] | |
elif data_type == "validation": | |
return ['The Boondock Saints_1999', 'The House with a Clock in Its Walls_2018', | |
'The Unbelievable Truth_1989', 'Insidious_2010', 'If Beale Street Could Talk_2018', | |
'The Battle of Shaker Heights_2003', '20th Century Women_2016', | |
'Captain Phillips_2013', 'Conspiracy Theory_1997', 'Domino_2005'] | |
elif data_type == "test": | |
# Return test dataset names (shortened for brevity) | |
return ['A Nightmare on Elm Street 3: Dream Warriors_1987', 'Van Helsing_2004', | |
'Oppenheimer_2023', 'Armored_2009', 'The Martian_2015'] | |
elif dataset == "MENSA": | |
if data_type == "train": | |
return ['The_Ides_of_March_(film)', 'An_American_Werewolf_in_Paris', | |
'Batman_&_Robin_(film)', 'Airplane_II:_The_Sequel', 'Krull_(film)'] | |
elif data_type == "validation": | |
return ['Pleasantville_(film)', 'V_for_Vendetta_(film)', | |
'Mary_Shelleys_Frankenstein_(film)', 'Rapture_(1965_film)', 'Get_Out'] | |
elif data_type == "test": | |
return ['Knives_Out', 'Black_Panther', 'Pet_Sematary_(film)', | |
'Panic_Room', 'The_Village_(2004_film)'] | |
return [] | |
def update_name_list_interface(dataset, data_type): | |
if dataset in ["MovieSum", "MENSA"]: | |
return ( | |
gr.update(choices=set_name_list(dataset, data_type), value=None, visible=True), | |
gr.update(visible=False), | |
gr.update(value="") | |
) | |
else: | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(value="Click next 'Knowledge Graph' to continue") | |
) | |
def read_data(dataset, data_type): | |
file_path = f"dataset/{dataset}/{data_type}.jsonl" | |
try: | |
with open(file_path, 'r', encoding='utf8') as f: | |
data = [json.loads(line) for line in f] | |
return data | |
except FileNotFoundError: | |
return [] | |
def find_work_index(data, work_name): | |
for idx, entry in enumerate(data): | |
if entry.get("name") == work_name: | |
return idx, entry | |
return None, "Work not found in the selected dataset." | |
def get_narrative_content(dataset, data_type, work): | |
data = read_data(dataset, data_type) | |
for entry in data: | |
if entry.get("name") == work: | |
return entry['scenes'] | |
return "Work not found in the selected dataset." | |
def get_narrative_content_with_index(dataset, data_type, work): | |
data = read_data(dataset, data_type) | |
for idx, entry in enumerate(data): | |
if entry.get("name") == work: | |
# For MovieSum and MENSA datasets, only return scenes | |
if dataset in ["MovieSum", "MENSA"]: | |
return "\n".join(entry['scenes']), idx, data | |
# For other datasets or custom input, return full content | |
return entry, idx, data | |
return "Work not found in the selected dataset.", None, None | |
def show_diff(original, revised): | |
d = HtmlDiff() | |
original_lines = original.splitlines(keepends=True) | |
revised_lines = revised.splitlines(keepends=True) | |
diff_table = d.make_table(original_lines, revised_lines, fromdesc='Original Summary', todesc='Refined Summary', context=True, numlines=2) | |
return diff_table | |
def extract_initial_summary(summary_result): | |
return summary_result['summary_agg']['summaries'] | |
def extract_factuality_score_and_details(fact_score_result): | |
factuality_score = fact_score_result['fact_score'] | |
feedback_list = [] | |
for i, feedback_data in enumerate(fact_score_result['summary_feedback_pairs']): | |
feedbacks = [fb for fb in feedback_data['feedbacks'] if fb.strip()] | |
if feedbacks: | |
feedback_list.append(f"In chunk {i + 1}: {'; '.join(feedbacks)}") | |
incorrect_details = "\n".join(feedback_list) | |
return factuality_score, incorrect_details | |
def build_kg(script, idx, api_key, model_id): | |
kg = script2kg(script['scenes'], idx, script['name'], api_key, model_id) | |
return kg | |
def build_kg_custom(scenes, idx, api_key, model_id): | |
kg = script2kg(scenes, idx, "custom", api_key, model_id) | |
return kg | |
def build_kg_with_data(data, work_index, custom_scenes, api_key, model_id): | |
if data and work_index is not None: # Dataset mode | |
script = data[int(work_index)] | |
try: | |
kg = script2kg(script['scenes'], int(work_index), script['name'], api_key, model_id) | |
return kg, "Knowledge Graph built successfully!" | |
except Exception as e: | |
return None, f"Error building knowledge graph: {str(e)}" | |
elif custom_scenes: # Custom script mode | |
try: | |
scenes = parse_scenes(custom_scenes) | |
if not isinstance(scenes, list): | |
return None, "Invalid format. Please provide scenes as a list." | |
kg = build_kg_custom(scenes, 0, api_key, model_id) | |
return kg, "Knowledge Graph built successfully!" | |
except (json.JSONDecodeError, SyntaxError, ValueError) as e: | |
return None, f"Invalid format. Error: {str(e)}" | |
except Exception as e: | |
return None, f"Error building knowledge graph: {str(e)}" | |
return None, "Please select a work or input custom scenes." | |
def generate_summary(script, idx, api_key, model_id): | |
_set_seed(42) | |
scripty_summarizer = Summarizer( | |
inference_mode="org", | |
model_id=model_id, | |
api_key=api_key, | |
dtype="float16", | |
seed=42 | |
) | |
scenes = [f"s#{i}\n{s}" for i, s in enumerate(script['scenes'])] | |
script = "\n\n".join(scenes) | |
script_chunks = chunk_script_gpt(script=script, model_id=model_id, chunk_size=2048) | |
script_summaries = [] | |
for chunk in tqdm(script_chunks): | |
chunk = preprocess_script(chunk) | |
prompt = build_summarizer_prompt( | |
prompt_template="./templates/external_summary.txt", | |
input_text_list=[chunk] | |
) | |
script_summ = scripty_summarizer.inference_with_gpt(prompt=prompt) | |
script_summaries.append(script_summ.strip()) | |
elem_dict_list = [] | |
agg_dict = { | |
'script': ' '.join(script_chunks), | |
'summaries': ' '.join(script_summaries) | |
} | |
for i, (chunk, summary) in enumerate(zip(script_chunks, script_summaries)): | |
elem_dict = { | |
"chunk_index": i, | |
"chunk": chunk.strip(), | |
"summary": summary.strip() | |
} | |
elem_dict_list.append(elem_dict) | |
processed_dataset = { | |
"script": script, | |
"scenes": scenes, | |
"script_chunks": script_chunks, | |
"script_summaries": script_summaries, | |
} | |
return {"summary_sep": elem_dict_list, "summary_agg": agg_dict, "processed_dataset": processed_dataset} | |
def generate_summary_with_data(data, work_index, custom_scenes, api_key, model_id): | |
if data and work_index is not None: # Dataset mode | |
script = data[int(work_index)] | |
try: | |
summary = generate_summary(script, int(work_index), api_key, model_id) | |
return summary, extract_initial_summary(summary) | |
except Exception as e: | |
return None, f"Error generating summary: {str(e)}" | |
elif custom_scenes: # Custom script mode | |
try: | |
scenes = parse_scenes(custom_scenes) | |
if not isinstance(scenes, list): | |
return None, "Invalid format. Please provide scenes as a list." | |
script = {"name": "custom", "scenes": scenes} | |
summary = generate_summary(script, 0, api_key, model_id) | |
return summary, extract_initial_summary(summary) | |
except (json.JSONDecodeError, SyntaxError, ValueError) as e: | |
return None, f"Invalid format. Error: {str(e)}" | |
except Exception as e: | |
return None, f"Error generating summary: {str(e)}" | |
return None, "Please select a work or input custom scenes." | |
def calculate_narrative_fact_score(summary, kg_raw, api_key, model_id): | |
_set_seed(42) | |
factscorer = NarrativeFactScore(split_type='gpt', model='gptscore', api_key=api_key, model_id=model_id) | |
summary = summary['processed_dataset'] | |
chunks, summaries = summary['script_chunks'], summary['script_summaries'] | |
total_output = {'fact_score': 0, 'summary_feedback_pairs': []} | |
partial_output = {'fact_score': 0, 'summary_feedback_pairs': []} | |
total_score = 0 | |
kg = [] | |
for elem in kg_raw: | |
if elem['subject'] == elem['object']: | |
kg.append(f"{elem['subject']} {elem['predicate']}") | |
else: | |
kg.append(f"{elem['subject']} {elem['predicate']} {elem['object']}") | |
scores, scores_per_sent, relevant_scenes, summary_chunks, feedbacks = factscorer.score_src_hyp_long(chunks, summaries, kg) | |
for i, score in enumerate(scores): | |
output_elem = { | |
'src': chunks[i], | |
'summary': summaries[i], | |
'score': score, | |
'scores_per_sent': scores_per_sent[i], | |
'relevant_scenes': relevant_scenes[i], | |
'summary_chunks': summary_chunks[i], | |
'feedbacks': feedbacks[i], | |
} | |
output_elem_part = { | |
'scores_per_sent': scores_per_sent[i], | |
'summary_chunks': summary_chunks[i], | |
'feedbacks': feedbacks[i], | |
} | |
total_output['summary_feedback_pairs'].append(output_elem) | |
partial_output['summary_feedback_pairs'].append(output_elem_part) | |
total_score += score | |
total_output['fact_score'] = float(total_score / len(scores)) | |
partial_output['fact_score'] = float(total_score / len(scores)) | |
return total_output, partial_output | |
def refine_summary(summary, fact_score, api_key, model_id): | |
_set_seed(42) | |
threshold = 0.9 | |
summarizer = Summarizer( | |
inference_mode="org", | |
model_id=model_id, | |
api_key=api_key, | |
dtype="float16", | |
seed=42 | |
) | |
processed_dataset = { | |
"script": summary["script"], | |
"scenes": summary["scenes"], | |
"script_chunks": [], | |
"script_summaries": [] | |
} | |
elem_dict_list = [] | |
agg_dict = {} | |
for factscore_chunk in tqdm(fact_score['summary_feedback_pairs']): | |
src_chunk = factscore_chunk['src'] | |
original_summary = factscore_chunk['summary'] | |
if factscore_chunk['score'] >= threshold: | |
processed_dataset["script_chunks"].append(src_chunk) | |
processed_dataset["script_summaries"].append(original_summary.strip()) | |
continue | |
hallu_idxs = np.where(np.array(factscore_chunk['scores_per_sent']) == 0)[0] | |
hallu_summary_parts = np.array(factscore_chunk['summary_chunks'])[hallu_idxs] | |
feedbacks = np.array(factscore_chunk['feedbacks'])[hallu_idxs] | |
prompt = build_summarizer_prompt( | |
prompt_template="./templates/self_correction.txt", | |
input_text_list=[src_chunk, original_summary] | |
) | |
for j, (hallu_summ, feedback) in enumerate(zip(hallu_summary_parts, feedbacks)): | |
prompt += f"\n- Statement to Revise {j + 1}: {hallu_summ} (Reason for Revision: {feedback})" | |
prompt += "\n- Revised Summary: " | |
revised_summary = summarizer.inference_with_gpt(prompt=prompt) | |
if len(revised_summary.strip()) == 0: | |
revised_summary = original_summary | |
processed_dataset["script_chunks"].append(src_chunk) | |
processed_dataset["script_summaries"].append(revised_summary) | |
elem_dict = { | |
"chunk_index": len(processed_dataset["script_chunks"]) - 1, | |
"chunk": src_chunk.strip(), | |
"summary": revised_summary.strip(), | |
"org_summary": original_summary.strip(), | |
"hallu_in_summary": list(hallu_summary_parts), | |
"feedbacks": list(feedbacks), | |
} | |
elem_dict_list.append(elem_dict) | |
agg_dict['script'] = summary['script'] | |
agg_dict['summaries'] = ' '.join(processed_dataset["script_summaries"]) | |
return { | |
"summary_sep": elem_dict_list, | |
"summary_agg": agg_dict, | |
"processed_dataset": processed_dataset | |
} | |
def refine_summary_and_return_diff(summary, fact_score, api_key, model_id): | |
refined_summary = refine_summary(summary['processed_dataset'], fact_score, api_key, model_id) | |
diff = HtmlDiff().make_file( | |
summary['summary_agg']['summaries'].splitlines(), | |
refined_summary['summary_agg']['summaries'].splitlines(), | |
context=True | |
) | |
return diff | |
def open_kg(kg_data): | |
if kg_data is None: | |
return "Please build the knowledge graph first." | |
try: | |
with open('refined_kg.html', 'r', encoding='utf-8') as f: | |
html_content = f.read() | |
return f''' | |
<iframe | |
srcdoc="{html_content.replace('"', '"')}" | |
style="width: 100%; height: 500px; border: none;" | |
></iframe> | |
''' | |
except Exception as e: | |
return f'<div style="color: red;">Error reading KG file: {str(e)}</div>' | |
def format_fact_score_output(fact_score_result): | |
if not fact_score_result: | |
return "No factuality analysis available" | |
formatted_output = [] | |
# Overall score | |
formatted_output.append(f"Overall Factuality Score: {fact_score_result['fact_score']*100:.1f}%\n") | |
# Individual chunk analysis | |
for i, chunk in enumerate(fact_score_result['summary_feedback_pairs'], 1): | |
formatted_output.append(f"\nChunk {i} Analysis:") | |
formatted_output.append("Original Text:") | |
formatted_output.append(f"{' '.join(chunk['summary_chunks'])}\n") | |
if chunk['feedbacks']: | |
formatted_output.append("Feedback:") | |
feedbacks = [f"• {feedback}" for feedback in chunk['feedbacks'] if feedback.strip()] | |
formatted_output.extend(feedbacks) | |
formatted_output.append("-" * 80) | |
return "\n".join(formatted_output) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# NarrativeFactScore: Script Factuality Evaluation | |
Evaluate and refine script summaries using narrative factuality scoring. | |
""" | |
) | |
with gr.Accordion("Model Settings", open=True): | |
with gr.Row(): | |
api_key_input = gr.Textbox( | |
label="GPT API Key", | |
placeholder="Enter your GPT API key", | |
type="password", | |
scale=2 | |
) | |
model_selector = gr.Dropdown( | |
choices=[ | |
"gpt-4o-mini", | |
"gpt-4o", | |
"gpt-4-turbo", | |
"gpt-3.5-turbo-0125" | |
], | |
value="gpt-4o", | |
label="Model Selection", | |
scale=1 | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Dataset Selection"): | |
with gr.Row(): | |
dataset_selector = gr.Radio( | |
choices=["MovieSum", "MENSA", "Custom"], | |
label="Dataset", | |
info="Choose the dataset or input custom script" | |
) | |
data_type_selector = gr.Radio( | |
choices=["train", "validation", "test"], | |
label="Split Type", | |
info="Select data split", | |
visible=True | |
) | |
name_list = gr.Dropdown( | |
choices=[], | |
label="Select Script", | |
info="Choose a script to analyze", | |
visible=True | |
) | |
custom_input = gr.Textbox( | |
label="Custom Script Input", | |
info="Enter scenes as a JSON list: ['scene1', 'scene2', ...]", | |
lines=10, | |
visible=False | |
) | |
narrative_output = gr.Textbox( | |
label="Script Content", | |
interactive=False, | |
lines=10 | |
) | |
with gr.TabItem("Knowledge Graph"): | |
with gr.Row(): | |
generate_kg_button = gr.Button( | |
"Generate Knowledge Graph", | |
variant="primary" | |
) | |
open_kg_button = gr.Button("View Graph") | |
kg_status = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
kg_viewer = gr.HTML(label="Knowledge Graph Visualization") | |
with gr.TabItem("Summary Generation"): | |
generate_summary_button = gr.Button( | |
"Generate Initial Summary", | |
variant="primary" | |
) | |
summary_output = gr.Textbox( | |
label="Generated Summary", | |
interactive=False, | |
lines=5 | |
) | |
calculate_score_button = gr.Button("Calculate Factuality Score") | |
fact_score_display = gr.Textbox( | |
label="Factuality Analysis", | |
interactive=False, | |
lines=10 | |
) | |
with gr.TabItem("Summary Refinement"): | |
refine_button = gr.Button( | |
"Refine Summary", | |
variant="primary" | |
) | |
refined_output = gr.HTML(label="Refined Summary with Changes") | |
# Hidden states | |
work_index = gr.State() | |
data_state = gr.State() | |
kg_output = gr.State() | |
summary_state = gr.State() | |
fact_score_state = gr.State() | |
# Event handlers | |
dataset_selector.change( | |
fn=lambda x: gr.update(visible=x in ["MovieSum", "MENSA"]), | |
inputs=[dataset_selector], | |
outputs=data_type_selector | |
) | |
dataset_selector.change( | |
fn=update_name_list_interface, | |
inputs=[dataset_selector, data_type_selector], | |
outputs=[name_list, custom_input, narrative_output] | |
) | |
name_list.change( | |
fn=get_narrative_content_with_index, | |
inputs=[dataset_selector, data_type_selector, name_list], | |
outputs=[narrative_output, work_index, data_state] | |
) | |
generate_kg_button.click( | |
fn=build_kg_with_data, | |
inputs=[ | |
data_state, # data | |
work_index, # work_index | |
custom_input, # custom_scenes | |
api_key_input, # api_key | |
model_selector # model_id | |
], | |
outputs=[kg_output, kg_status] | |
) | |
open_kg_button.click( | |
fn=open_kg, | |
inputs=[kg_output], | |
outputs=kg_viewer | |
) | |
generate_summary_button.click( | |
fn=generate_summary_with_data, | |
inputs=[data_state, work_index, custom_input, api_key_input, model_selector], | |
outputs=[summary_state, summary_output] | |
) | |
calculate_score_button.click( | |
fn=lambda summary, kg, api_key, model: ( | |
*calculate_narrative_fact_score(summary, kg, api_key, model), | |
format_fact_score_output(calculate_narrative_fact_score(summary, kg, api_key, model)[0]) | |
), | |
inputs=[summary_state, kg_output, api_key_input, model_selector], | |
outputs=[fact_score_state, fact_score_display] | |
) | |
refine_button.click( | |
fn=refine_summary_and_return_diff, | |
inputs=[summary_state, fact_score_state, api_key_input, model_selector], | |
outputs=refined_output | |
) | |
if __name__ == "__main__": | |
demo.launch() |