guipenedo's picture
guipenedo HF Staff
diff & reveal
25c7dfb
#!/usr/bin/env python3
"""
Gradio app for A/B testing different translation configs
"""
import gradio as gr
import boto3
import json
import random
import re
import difflib
from pathlib import Path
from datatrove.pipeline.readers.jsonl import JsonlReader
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
# Initialize S3 client
s3_client = boto3.client('s3')
BUCKET_NAME = "fineweb-multilingual-v1"
BASE_PREFIX = "experiments/translations/vibe-checks-exps/"
# Global state for the app
app_state = {
'current_samples': [],
'current_index': 0,
'results': {'config_a': 0, 'config_b': 0, 'cant_tell': 0},
'config_names': [],
'experiment': '',
'language': '',
'total_samples': 0,
'diff_mode': True, # Default enabled
'show_model_names': False
}
def list_experiments_from_s3() -> List[str]:
"""List available experiments from S3 bucket"""
try:
response = s3_client.list_objects_v2(
Bucket=BUCKET_NAME,
Prefix=BASE_PREFIX,
Delimiter='/'
)
experiments = []
if 'CommonPrefixes' in response:
for prefix in response['CommonPrefixes']:
experiment_folder = prefix['Prefix'].replace(BASE_PREFIX, '').rstrip('/')
experiments.append(experiment_folder)
return sorted(experiments)
except Exception as e:
print(f"Error listing experiments: {e}")
return []
def list_languages_from_s3(experiment: str) -> List[str]:
"""List available languages from S3 bucket for a specific experiment"""
if not experiment:
return []
try:
experiment_prefix = f"{BASE_PREFIX}{experiment}/"
response = s3_client.list_objects_v2(
Bucket=BUCKET_NAME,
Prefix=experiment_prefix,
Delimiter='/'
)
languages = []
if 'CommonPrefixes' in response:
for prefix in response['CommonPrefixes']:
language_folder = prefix['Prefix'].replace(experiment_prefix, '').rstrip('/')
languages.append(language_folder)
return sorted(languages)
except Exception as e:
print(f"Error listing languages for experiment {experiment}: {e}")
return []
def list_configs_from_s3(experiment: str, language: str) -> List[str]:
"""List available config files for a specific experiment and language"""
if not experiment or not language:
return []
try:
config_prefix = f"{BASE_PREFIX}{experiment}/{language}/"
response = s3_client.list_objects_v2(
Bucket=BUCKET_NAME,
Prefix=config_prefix
)
configs = []
if 'Contents' in response:
for obj in response['Contents']:
key = obj['Key']
if key.endswith('.jsonl.gz'):
# Extract config name from file (remove .jsonl.gz extension)
config_name = key.replace(config_prefix, '').replace('.jsonl.gz', '')
configs.append(config_name)
return sorted(configs)
except Exception as e:
print(f"Error listing configs for {experiment}/{language}: {e}")
return []
def extract_translation(inference_result: dict) -> str:
"""Extract translation from inference result, removing START_TRANSLATION tags"""
if not inference_result or 'text' not in inference_result:
return "No translation available"
text = inference_result['text'].strip()
# Extract content between START_TRANSLATION and END_TRANSLATION tags
# Support multiple closing tag formats: </END_TRANSLATION>, <END_TRANSLATION>, </START_TRANSLATION>
pattern = r'<(?:START_)?TRANSLATION>(.*?)(?:</(?:END_)?TRANSLATION>|<(?:END_)?TRANSLATION>|</START_TRANSLATION>)'
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
else:
return text.strip()
def generate_diff_html(text1: str, text2: str) -> Tuple[str, str]:
"""Generate side-by-side HTML diff like GitHub"""
# Split texts into lines for better diff visualization
lines1 = text1.splitlines(keepends=True)
lines2 = text2.splitlines(keepends=True)
# Generate diff using difflib
differ = difflib.unified_diff(lines1, lines2, lineterm='', n=3)
diff_lines = list(differ)
# If no differences, return original texts
if len(diff_lines) <= 2: # Only header lines
return text1, text2
# Generate HTML for side-by-side view
html1_lines = []
html2_lines = []
# Use HtmlDiff for better formatting
html_differ = difflib.HtmlDiff(wrapcolumn=80)
# Split by words for more granular diff
words1 = text1.split()
words2 = text2.split()
# Use sequence matcher for word-level differences
matcher = difflib.SequenceMatcher(None, words1, words2)
result1 = []
result2 = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal':
# Same words in both
words = words1[i1:i2]
result1.extend(words)
result2.extend(words)
elif tag == 'delete':
# Words only in first text (deleted)
words = words1[i1:i2]
result1.extend([f'<span style="background-color: #ffebee; color: #c62828;">{word}</span>' for word in words])
elif tag == 'insert':
# Words only in second text (inserted)
words = words2[j1:j2]
result2.extend([f'<span style="background-color: #e8f5e8; color: #2e7d32;">{word}</span>' for word in words])
elif tag == 'replace':
# Different words
words1_part = words1[i1:i2]
words2_part = words2[j1:j2]
result1.extend([f'<span style="background-color: #ffebee; color: #c62828;">{word}</span>' for word in words1_part])
result2.extend([f'<span style="background-color: #e8f5e8; color: #2e7d32;">{word}</span>' for word in words2_part])
html1 = f'<div style="font-family: monospace; white-space: pre-wrap; line-height: 1.5;">{" ".join(result1) if result1 else text1}</div>'
html2 = f'<div style="font-family: monospace; white-space: pre-wrap; line-height: 1.5;">{" ".join(result2) if result2 else text2}</div>'
return html1, html2
def load_config_data(experiment: str, language: str, config_name: str) -> List[dict]:
"""Load data for a specific config from S3"""
try:
s3_path = f"s3://{BUCKET_NAME}/{BASE_PREFIX}{experiment}/{language}/{config_name}.jsonl.gz"
print(f"Loading data from: {s3_path}")
# Use datatrove JsonlReader to read from S3
reader = JsonlReader(s3_path)
documents = []
for document in reader():
documents.append(document)
return documents
except Exception as e:
print(f"Error loading {config_name} data for {experiment}/{language}: {e}")
return []
def prepare_ab_test_data(experiment: str, language: str) -> Tuple[List[Tuple[dict, dict, dict]], List[str]]:
"""Prepare paired samples for A/B testing"""
# Dynamically discover available config files
config_names = list_configs_from_s3(experiment, language)
if len(config_names) < 2:
print(f"Need at least 2 configs for A/B testing, found: {config_names}")
return [], []
# Use first two configs found
config_a_name = config_names[0]
config_b_name = config_names[1]
# Load data for both configs
config_a_data = load_config_data(experiment, language, config_a_name)
config_b_data = load_config_data(experiment, language, config_b_name)
print(f"Loaded {len(config_a_data)} samples for {config_a_name}")
print(f"Loaded {len(config_b_data)} samples for {config_b_name}")
# Create mappings by document ID
config_a_by_id = {doc.id: doc for doc in config_a_data}
config_b_by_id = {doc.id: doc for doc in config_b_data}
# Find common IDs
common_ids = set(config_a_by_id.keys()) & set(config_b_by_id.keys())
print(f"Found {len(common_ids)} common document IDs")
# Create paired samples
paired_samples = []
for doc_id in common_ids:
doc_a = config_a_by_id[doc_id]
doc_b = config_b_by_id[doc_id]
# Randomly decide which config goes on which side
if random.random() < 0.5:
left_doc, right_doc = doc_a, doc_b
left_config, right_config = config_a_name, config_b_name
else:
left_doc, right_doc = doc_b, doc_a
left_config, right_config = config_b_name, config_a_name
paired_samples.append((left_doc, right_doc, {
'left_config': left_config,
'right_config': right_config,
'original_text': doc_a.text # Original text is the same for both
}))
# Shuffle the pairs
random.shuffle(paired_samples)
return paired_samples, [config_a_name, config_b_name]
def update_languages_dropdown(experiment: str):
"""Update language dropdown based on selected experiment"""
if not experiment:
return gr.update(choices=[], value=None)
languages = list_languages_from_s3(experiment)
return gr.update(choices=languages, value=languages[0] if languages else None)
def toggle_diff_mode(diff_enabled: bool):
"""Toggle diff interface mode"""
app_state['diff_mode'] = diff_enabled
return show_current_sample()
def toggle_model_names():
"""Toggle model name visibility"""
app_state['show_model_names'] = not app_state['show_model_names']
return show_current_sample()
def load_language_data(experiment: str, language: str):
"""Load and prepare data for the selected experiment and language"""
if not experiment:
return "Please select an experiment", "", "", "", "0 / 0", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
if not language:
return "Please select a language", "", "", "", "0 / 0", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
print(f"Loading data for experiment: {experiment}, language: {language}")
# Prepare A/B test data
samples, config_names = prepare_ab_test_data(experiment, language)
if not samples:
return f"No data found for experiment '{experiment}' and language '{language}'", "", "", "", "0 / 0", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
# Update global state
app_state['current_samples'] = samples
app_state['current_index'] = 0
app_state['results'] = {'config_a': 0, 'config_b': 0, 'cant_tell': 0}
app_state['config_names'] = config_names
app_state['experiment'] = experiment
app_state['language'] = language
app_state['total_samples'] = len(samples)
# Show first sample
return show_current_sample()
def show_current_sample():
"""Display the current sample"""
if not app_state['current_samples'] or app_state['current_index'] >= len(app_state['current_samples']):
# Show final results
total_votes = app_state['results']['config_a'] + app_state['results']['config_b'] + app_state['results']['cant_tell']
if total_votes == 0:
results_text = "No votes recorded."
else:
config_a_pct = (app_state['results']['config_a'] / total_votes) * 100
config_b_pct = (app_state['results']['config_b'] / total_votes) * 100
cant_tell_pct = (app_state['results']['cant_tell'] / total_votes) * 100
config_a_name = app_state['config_names'][0] if app_state.get('config_names') else 'Config A'
config_b_name = app_state['config_names'][1] if app_state.get('config_names') and len(app_state['config_names']) > 1 else 'Config B'
results_text = f"""
## Final Results for {app_state['experiment']} - {app_state['language']}
**{config_a_name}**: {app_state['results']['config_a']} votes ({config_a_pct:.1f}%)
**{config_b_name}**: {app_state['results']['config_b']} votes ({config_b_pct:.1f}%)
**Can't tell**: {app_state['results']['cant_tell']} votes ({cant_tell_pct:.1f}%)
Total comparisons: {total_votes}
"""
return (
results_text,
"Testing complete!",
"Testing complete!",
"Click 'Load Data' to start over",
f"{app_state['current_index']} / {app_state['total_samples']}",
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
"", # left_model_name
"" # right_model_name
)
left_doc, right_doc, metadata = app_state['current_samples'][app_state['current_index']]
# Extract translations
left_results = left_doc.metadata.get('inference_results', [])
right_results = right_doc.metadata.get('inference_results', [])
left_translation = extract_translation(left_results[0] if left_results else {})
right_translation = extract_translation(right_results[0] if right_results else {})
# Handle diff mode
if app_state['diff_mode'] and left_translation != right_translation:
left_display, right_display = generate_diff_html(left_translation, right_translation)
else:
left_display = left_translation
right_display = right_translation
# Model names
left_model = metadata['left_config']
right_model = metadata['right_config']
progress = f"{app_state['current_index'] + 1} / {app_state['total_samples']}"
# Model name content - show content if visibility is enabled, otherwise empty
left_model_content = gr.update(value=f"**Model:** {left_model}", visible=app_state['show_model_names'])
right_model_content = gr.update(value=f"**Model:** {right_model}", visible=app_state['show_model_names'])
return (
metadata['original_text'],
left_display,
right_display,
f"Experiment: {app_state['experiment']} | Language: {app_state['language']} | Progress: {progress}",
progress,
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
left_model_content, # left_model_name
right_model_content # right_model_name
)
def vote_left():
"""Record vote for left translation"""
app_state['show_model_names'] = False
if not app_state['current_samples'] or app_state['current_index'] >= len(app_state['current_samples']):
return show_current_sample()
# Determine which config the left side represents
_, _, metadata = app_state['current_samples'][app_state['current_index']]
left_config = metadata['left_config']
# Use config names from global state
config_a_name = app_state['config_names'][0] if app_state.get('config_names') else None
if left_config == config_a_name:
app_state['results']['config_a'] += 1
else:
app_state['results']['config_b'] += 1
# Move to next sample
app_state['current_index'] += 1
return show_current_sample()
def vote_right():
"""Record vote for right translation"""
app_state['show_model_names'] = False
if not app_state['current_samples'] or app_state['current_index'] >= len(app_state['current_samples']):
return show_current_sample()
# Determine which config the right side represents
_, _, metadata = app_state['current_samples'][app_state['current_index']]
right_config = metadata['right_config']
# Use config names from global state
config_a_name = app_state['config_names'][0] if app_state.get('config_names') else None
if right_config == config_a_name:
app_state['results']['config_a'] += 1
else:
app_state['results']['config_b'] += 1
# Move to next sample
app_state['current_index'] += 1
return show_current_sample()
def vote_cant_tell():
"""Record vote for can't tell"""
app_state['show_model_names'] = False
if not app_state['current_samples'] or app_state['current_index'] >= len(app_state['current_samples']):
return show_current_sample()
# Record can't tell vote
app_state['results']['cant_tell'] += 1
# Move to next sample
app_state['current_index'] += 1
return show_current_sample()
def stop_session():
"""Stop the current session and show current results"""
app_state['show_model_names'] = False
if not app_state['current_samples']:
return show_current_sample()
# Set index to end to trigger results display
app_state['current_index'] = len(app_state['current_samples'])
return show_current_sample()
# Create Gradio interface
def create_interface():
experiments = list_experiments_from_s3()
# Initialize languages for the first experiment if available
initial_languages = []
if experiments:
initial_languages = list_languages_from_s3(experiments[0])
with gr.Blocks(title="Translation A/B Testing", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Translation Model A/B Testing")
gr.Markdown("Compare translations from different model configurations. Choose the better translation for each sample.")
with gr.Row():
experiment_dropdown = gr.Dropdown(
choices=experiments,
label="Select Experiment",
value=experiments[0] if experiments else None
)
language_dropdown = gr.Dropdown(
choices=initial_languages,
label="Select Language",
value=initial_languages[0] if initial_languages else None
)
load_btn = gr.Button("Load Data", variant="primary")
with gr.Row():
diff_checkbox = gr.Checkbox(
label="Enable Diff View",
value=True,
info="Show differences between translations like GitHub"
)
status_text = gr.Markdown("")
with gr.Row():
progress_text = gr.Markdown("")
stop_btn = gr.Button("Stop Session", variant="secondary", visible=False, size="sm")
gr.Markdown("## Original Text")
original_text = gr.Textbox(label="Text to Translate", lines=3, interactive=False)
gr.Markdown("## Choose the Better Translation")
# Buttons row
with gr.Row():
left_btn = gr.Button("Choose Left", variant="secondary", visible=False)
cant_tell_btn = gr.Button("Can't Tell", variant="secondary", visible=False)
right_btn = gr.Button("Choose Right", variant="secondary", visible=False)
gr.Markdown("*Choose 'Can't Tell' if both translations are equally good/bad or you can't decide*")
# Translations row
with gr.Row():
with gr.Column():
left_translation = gr.HTML(label="Translation A")
left_model_name = gr.Markdown("", visible=False)
with gr.Column():
right_translation = gr.HTML(label="Translation B")
right_model_name = gr.Markdown("", visible=False)
# Model names toggle button
with gr.Row():
show_models_btn = gr.Button("Show Model Names", variant="secondary", size="sm")
# Event handlers
experiment_dropdown.change(
fn=update_languages_dropdown,
inputs=[experiment_dropdown],
outputs=[language_dropdown]
)
load_btn.click(
fn=load_language_data,
inputs=[experiment_dropdown, language_dropdown],
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name]
)
left_btn.click(
fn=vote_left,
inputs=[],
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name]
)
right_btn.click(
fn=vote_right,
inputs=[],
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name]
)
cant_tell_btn.click(
fn=vote_cant_tell,
inputs=[],
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name]
)
stop_btn.click(
fn=stop_session,
inputs=[],
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name]
)
diff_checkbox.change(
fn=toggle_diff_mode,
inputs=[diff_checkbox],
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name]
)
show_models_btn.click(
fn=toggle_model_names,
inputs=[],
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)