|
|
|
""" |
|
Gradio app for A/B testing different translation configs |
|
""" |
|
|
|
import gradio as gr |
|
import boto3 |
|
import json |
|
import random |
|
import re |
|
import difflib |
|
from pathlib import Path |
|
from datatrove.pipeline.readers.jsonl import JsonlReader |
|
from collections import defaultdict |
|
from typing import Dict, List, Tuple, Optional |
|
|
|
|
|
s3_client = boto3.client('s3') |
|
BUCKET_NAME = "fineweb-multilingual-v1" |
|
BASE_PREFIX = "experiments/translations/vibe-checks-exps/" |
|
|
|
|
|
app_state = { |
|
'current_samples': [], |
|
'current_index': 0, |
|
'results': {'config_a': 0, 'config_b': 0, 'cant_tell': 0}, |
|
'config_names': [], |
|
'experiment': '', |
|
'language': '', |
|
'total_samples': 0, |
|
'diff_mode': True, |
|
'show_model_names': False |
|
} |
|
|
|
def list_experiments_from_s3() -> List[str]: |
|
"""List available experiments from S3 bucket""" |
|
try: |
|
response = s3_client.list_objects_v2( |
|
Bucket=BUCKET_NAME, |
|
Prefix=BASE_PREFIX, |
|
Delimiter='/' |
|
) |
|
|
|
experiments = [] |
|
if 'CommonPrefixes' in response: |
|
for prefix in response['CommonPrefixes']: |
|
experiment_folder = prefix['Prefix'].replace(BASE_PREFIX, '').rstrip('/') |
|
experiments.append(experiment_folder) |
|
|
|
return sorted(experiments) |
|
except Exception as e: |
|
print(f"Error listing experiments: {e}") |
|
return [] |
|
|
|
def list_languages_from_s3(experiment: str) -> List[str]: |
|
"""List available languages from S3 bucket for a specific experiment""" |
|
if not experiment: |
|
return [] |
|
|
|
try: |
|
experiment_prefix = f"{BASE_PREFIX}{experiment}/" |
|
response = s3_client.list_objects_v2( |
|
Bucket=BUCKET_NAME, |
|
Prefix=experiment_prefix, |
|
Delimiter='/' |
|
) |
|
|
|
languages = [] |
|
if 'CommonPrefixes' in response: |
|
for prefix in response['CommonPrefixes']: |
|
language_folder = prefix['Prefix'].replace(experiment_prefix, '').rstrip('/') |
|
languages.append(language_folder) |
|
|
|
return sorted(languages) |
|
except Exception as e: |
|
print(f"Error listing languages for experiment {experiment}: {e}") |
|
return [] |
|
|
|
def list_configs_from_s3(experiment: str, language: str) -> List[str]: |
|
"""List available config files for a specific experiment and language""" |
|
if not experiment or not language: |
|
return [] |
|
|
|
try: |
|
config_prefix = f"{BASE_PREFIX}{experiment}/{language}/" |
|
response = s3_client.list_objects_v2( |
|
Bucket=BUCKET_NAME, |
|
Prefix=config_prefix |
|
) |
|
|
|
configs = [] |
|
if 'Contents' in response: |
|
for obj in response['Contents']: |
|
key = obj['Key'] |
|
if key.endswith('.jsonl.gz'): |
|
|
|
config_name = key.replace(config_prefix, '').replace('.jsonl.gz', '') |
|
configs.append(config_name) |
|
|
|
return sorted(configs) |
|
except Exception as e: |
|
print(f"Error listing configs for {experiment}/{language}: {e}") |
|
return [] |
|
|
|
def extract_translation(inference_result: dict) -> str: |
|
"""Extract translation from inference result, removing START_TRANSLATION tags""" |
|
if not inference_result or 'text' not in inference_result: |
|
return "No translation available" |
|
|
|
text = inference_result['text'].strip() |
|
|
|
|
|
|
|
pattern = r'<(?:START_)?TRANSLATION>(.*?)(?:</(?:END_)?TRANSLATION>|<(?:END_)?TRANSLATION>|</START_TRANSLATION>)' |
|
match = re.search(pattern, text, re.DOTALL) |
|
if match: |
|
return match.group(1).strip() |
|
else: |
|
return text.strip() |
|
|
|
def generate_diff_html(text1: str, text2: str) -> Tuple[str, str]: |
|
"""Generate side-by-side HTML diff like GitHub""" |
|
|
|
lines1 = text1.splitlines(keepends=True) |
|
lines2 = text2.splitlines(keepends=True) |
|
|
|
|
|
differ = difflib.unified_diff(lines1, lines2, lineterm='', n=3) |
|
diff_lines = list(differ) |
|
|
|
|
|
if len(diff_lines) <= 2: |
|
return text1, text2 |
|
|
|
|
|
html1_lines = [] |
|
html2_lines = [] |
|
|
|
|
|
html_differ = difflib.HtmlDiff(wrapcolumn=80) |
|
|
|
|
|
words1 = text1.split() |
|
words2 = text2.split() |
|
|
|
|
|
matcher = difflib.SequenceMatcher(None, words1, words2) |
|
|
|
result1 = [] |
|
result2 = [] |
|
|
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes(): |
|
if tag == 'equal': |
|
|
|
words = words1[i1:i2] |
|
result1.extend(words) |
|
result2.extend(words) |
|
elif tag == 'delete': |
|
|
|
words = words1[i1:i2] |
|
result1.extend([f'<span style="background-color: #ffebee; color: #c62828;">{word}</span>' for word in words]) |
|
elif tag == 'insert': |
|
|
|
words = words2[j1:j2] |
|
result2.extend([f'<span style="background-color: #e8f5e8; color: #2e7d32;">{word}</span>' for word in words]) |
|
elif tag == 'replace': |
|
|
|
words1_part = words1[i1:i2] |
|
words2_part = words2[j1:j2] |
|
result1.extend([f'<span style="background-color: #ffebee; color: #c62828;">{word}</span>' for word in words1_part]) |
|
result2.extend([f'<span style="background-color: #e8f5e8; color: #2e7d32;">{word}</span>' for word in words2_part]) |
|
|
|
html1 = f'<div style="font-family: monospace; white-space: pre-wrap; line-height: 1.5;">{" ".join(result1) if result1 else text1}</div>' |
|
html2 = f'<div style="font-family: monospace; white-space: pre-wrap; line-height: 1.5;">{" ".join(result2) if result2 else text2}</div>' |
|
|
|
return html1, html2 |
|
|
|
def load_config_data(experiment: str, language: str, config_name: str) -> List[dict]: |
|
"""Load data for a specific config from S3""" |
|
try: |
|
s3_path = f"s3://{BUCKET_NAME}/{BASE_PREFIX}{experiment}/{language}/{config_name}.jsonl.gz" |
|
print(f"Loading data from: {s3_path}") |
|
|
|
|
|
reader = JsonlReader(s3_path) |
|
documents = [] |
|
|
|
for document in reader(): |
|
documents.append(document) |
|
|
|
return documents |
|
except Exception as e: |
|
print(f"Error loading {config_name} data for {experiment}/{language}: {e}") |
|
return [] |
|
|
|
def prepare_ab_test_data(experiment: str, language: str) -> Tuple[List[Tuple[dict, dict, dict]], List[str]]: |
|
"""Prepare paired samples for A/B testing""" |
|
|
|
config_names = list_configs_from_s3(experiment, language) |
|
|
|
if len(config_names) < 2: |
|
print(f"Need at least 2 configs for A/B testing, found: {config_names}") |
|
return [], [] |
|
|
|
|
|
config_a_name = config_names[0] |
|
config_b_name = config_names[1] |
|
|
|
|
|
config_a_data = load_config_data(experiment, language, config_a_name) |
|
config_b_data = load_config_data(experiment, language, config_b_name) |
|
|
|
print(f"Loaded {len(config_a_data)} samples for {config_a_name}") |
|
print(f"Loaded {len(config_b_data)} samples for {config_b_name}") |
|
|
|
|
|
config_a_by_id = {doc.id: doc for doc in config_a_data} |
|
config_b_by_id = {doc.id: doc for doc in config_b_data} |
|
|
|
|
|
common_ids = set(config_a_by_id.keys()) & set(config_b_by_id.keys()) |
|
print(f"Found {len(common_ids)} common document IDs") |
|
|
|
|
|
paired_samples = [] |
|
for doc_id in common_ids: |
|
doc_a = config_a_by_id[doc_id] |
|
doc_b = config_b_by_id[doc_id] |
|
|
|
|
|
if random.random() < 0.5: |
|
left_doc, right_doc = doc_a, doc_b |
|
left_config, right_config = config_a_name, config_b_name |
|
else: |
|
left_doc, right_doc = doc_b, doc_a |
|
left_config, right_config = config_b_name, config_a_name |
|
|
|
paired_samples.append((left_doc, right_doc, { |
|
'left_config': left_config, |
|
'right_config': right_config, |
|
'original_text': doc_a.text |
|
})) |
|
|
|
|
|
random.shuffle(paired_samples) |
|
|
|
return paired_samples, [config_a_name, config_b_name] |
|
|
|
def update_languages_dropdown(experiment: str): |
|
"""Update language dropdown based on selected experiment""" |
|
if not experiment: |
|
return gr.update(choices=[], value=None) |
|
|
|
languages = list_languages_from_s3(experiment) |
|
return gr.update(choices=languages, value=languages[0] if languages else None) |
|
|
|
def toggle_diff_mode(diff_enabled: bool): |
|
"""Toggle diff interface mode""" |
|
app_state['diff_mode'] = diff_enabled |
|
return show_current_sample() |
|
|
|
def toggle_model_names(): |
|
"""Toggle model name visibility""" |
|
app_state['show_model_names'] = not app_state['show_model_names'] |
|
return show_current_sample() |
|
|
|
def load_language_data(experiment: str, language: str): |
|
"""Load and prepare data for the selected experiment and language""" |
|
if not experiment: |
|
return "Please select an experiment", "", "", "", "0 / 0", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
if not language: |
|
return "Please select a language", "", "", "", "0 / 0", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
print(f"Loading data for experiment: {experiment}, language: {language}") |
|
|
|
|
|
samples, config_names = prepare_ab_test_data(experiment, language) |
|
|
|
if not samples: |
|
return f"No data found for experiment '{experiment}' and language '{language}'", "", "", "", "0 / 0", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
app_state['current_samples'] = samples |
|
app_state['current_index'] = 0 |
|
app_state['results'] = {'config_a': 0, 'config_b': 0, 'cant_tell': 0} |
|
app_state['config_names'] = config_names |
|
app_state['experiment'] = experiment |
|
app_state['language'] = language |
|
app_state['total_samples'] = len(samples) |
|
|
|
|
|
return show_current_sample() |
|
|
|
def show_current_sample(): |
|
"""Display the current sample""" |
|
if not app_state['current_samples'] or app_state['current_index'] >= len(app_state['current_samples']): |
|
|
|
total_votes = app_state['results']['config_a'] + app_state['results']['config_b'] + app_state['results']['cant_tell'] |
|
if total_votes == 0: |
|
results_text = "No votes recorded." |
|
else: |
|
config_a_pct = (app_state['results']['config_a'] / total_votes) * 100 |
|
config_b_pct = (app_state['results']['config_b'] / total_votes) * 100 |
|
cant_tell_pct = (app_state['results']['cant_tell'] / total_votes) * 100 |
|
config_a_name = app_state['config_names'][0] if app_state.get('config_names') else 'Config A' |
|
config_b_name = app_state['config_names'][1] if app_state.get('config_names') and len(app_state['config_names']) > 1 else 'Config B' |
|
|
|
results_text = f""" |
|
## Final Results for {app_state['experiment']} - {app_state['language']} |
|
|
|
**{config_a_name}**: {app_state['results']['config_a']} votes ({config_a_pct:.1f}%) |
|
**{config_b_name}**: {app_state['results']['config_b']} votes ({config_b_pct:.1f}%) |
|
**Can't tell**: {app_state['results']['cant_tell']} votes ({cant_tell_pct:.1f}%) |
|
|
|
Total comparisons: {total_votes} |
|
""" |
|
|
|
return ( |
|
results_text, |
|
"Testing complete!", |
|
"Testing complete!", |
|
"Click 'Load Data' to start over", |
|
f"{app_state['current_index']} / {app_state['total_samples']}", |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
"", |
|
"" |
|
) |
|
|
|
left_doc, right_doc, metadata = app_state['current_samples'][app_state['current_index']] |
|
|
|
|
|
left_results = left_doc.metadata.get('inference_results', []) |
|
right_results = right_doc.metadata.get('inference_results', []) |
|
left_translation = extract_translation(left_results[0] if left_results else {}) |
|
right_translation = extract_translation(right_results[0] if right_results else {}) |
|
|
|
|
|
if app_state['diff_mode'] and left_translation != right_translation: |
|
left_display, right_display = generate_diff_html(left_translation, right_translation) |
|
else: |
|
left_display = left_translation |
|
right_display = right_translation |
|
|
|
|
|
left_model = metadata['left_config'] |
|
right_model = metadata['right_config'] |
|
|
|
progress = f"{app_state['current_index'] + 1} / {app_state['total_samples']}" |
|
|
|
|
|
left_model_content = gr.update(value=f"**Model:** {left_model}", visible=app_state['show_model_names']) |
|
right_model_content = gr.update(value=f"**Model:** {right_model}", visible=app_state['show_model_names']) |
|
|
|
return ( |
|
metadata['original_text'], |
|
left_display, |
|
right_display, |
|
f"Experiment: {app_state['experiment']} | Language: {app_state['language']} | Progress: {progress}", |
|
progress, |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
left_model_content, |
|
right_model_content |
|
) |
|
|
|
def vote_left(): |
|
"""Record vote for left translation""" |
|
app_state['show_model_names'] = False |
|
if not app_state['current_samples'] or app_state['current_index'] >= len(app_state['current_samples']): |
|
return show_current_sample() |
|
|
|
|
|
_, _, metadata = app_state['current_samples'][app_state['current_index']] |
|
left_config = metadata['left_config'] |
|
|
|
|
|
config_a_name = app_state['config_names'][0] if app_state.get('config_names') else None |
|
|
|
if left_config == config_a_name: |
|
app_state['results']['config_a'] += 1 |
|
else: |
|
app_state['results']['config_b'] += 1 |
|
|
|
|
|
app_state['current_index'] += 1 |
|
|
|
return show_current_sample() |
|
|
|
def vote_right(): |
|
"""Record vote for right translation""" |
|
app_state['show_model_names'] = False |
|
if not app_state['current_samples'] or app_state['current_index'] >= len(app_state['current_samples']): |
|
return show_current_sample() |
|
|
|
|
|
_, _, metadata = app_state['current_samples'][app_state['current_index']] |
|
right_config = metadata['right_config'] |
|
|
|
|
|
config_a_name = app_state['config_names'][0] if app_state.get('config_names') else None |
|
|
|
if right_config == config_a_name: |
|
app_state['results']['config_a'] += 1 |
|
else: |
|
app_state['results']['config_b'] += 1 |
|
|
|
|
|
app_state['current_index'] += 1 |
|
|
|
return show_current_sample() |
|
|
|
def vote_cant_tell(): |
|
"""Record vote for can't tell""" |
|
app_state['show_model_names'] = False |
|
if not app_state['current_samples'] or app_state['current_index'] >= len(app_state['current_samples']): |
|
return show_current_sample() |
|
|
|
|
|
app_state['results']['cant_tell'] += 1 |
|
|
|
|
|
app_state['current_index'] += 1 |
|
|
|
return show_current_sample() |
|
|
|
def stop_session(): |
|
"""Stop the current session and show current results""" |
|
app_state['show_model_names'] = False |
|
if not app_state['current_samples']: |
|
return show_current_sample() |
|
|
|
|
|
app_state['current_index'] = len(app_state['current_samples']) |
|
|
|
return show_current_sample() |
|
|
|
|
|
def create_interface(): |
|
experiments = list_experiments_from_s3() |
|
|
|
|
|
initial_languages = [] |
|
if experiments: |
|
initial_languages = list_languages_from_s3(experiments[0]) |
|
|
|
with gr.Blocks(title="Translation A/B Testing", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# Translation Model A/B Testing") |
|
gr.Markdown("Compare translations from different model configurations. Choose the better translation for each sample.") |
|
|
|
with gr.Row(): |
|
experiment_dropdown = gr.Dropdown( |
|
choices=experiments, |
|
label="Select Experiment", |
|
value=experiments[0] if experiments else None |
|
) |
|
language_dropdown = gr.Dropdown( |
|
choices=initial_languages, |
|
label="Select Language", |
|
value=initial_languages[0] if initial_languages else None |
|
) |
|
load_btn = gr.Button("Load Data", variant="primary") |
|
|
|
with gr.Row(): |
|
diff_checkbox = gr.Checkbox( |
|
label="Enable Diff View", |
|
value=True, |
|
info="Show differences between translations like GitHub" |
|
) |
|
|
|
status_text = gr.Markdown("") |
|
|
|
with gr.Row(): |
|
progress_text = gr.Markdown("") |
|
stop_btn = gr.Button("Stop Session", variant="secondary", visible=False, size="sm") |
|
|
|
gr.Markdown("## Original Text") |
|
original_text = gr.Textbox(label="Text to Translate", lines=3, interactive=False) |
|
|
|
gr.Markdown("## Choose the Better Translation") |
|
|
|
|
|
with gr.Row(): |
|
left_btn = gr.Button("Choose Left", variant="secondary", visible=False) |
|
cant_tell_btn = gr.Button("Can't Tell", variant="secondary", visible=False) |
|
right_btn = gr.Button("Choose Right", variant="secondary", visible=False) |
|
|
|
gr.Markdown("*Choose 'Can't Tell' if both translations are equally good/bad or you can't decide*") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
left_translation = gr.HTML(label="Translation A") |
|
left_model_name = gr.Markdown("", visible=False) |
|
|
|
with gr.Column(): |
|
right_translation = gr.HTML(label="Translation B") |
|
right_model_name = gr.Markdown("", visible=False) |
|
|
|
|
|
with gr.Row(): |
|
show_models_btn = gr.Button("Show Model Names", variant="secondary", size="sm") |
|
|
|
|
|
experiment_dropdown.change( |
|
fn=update_languages_dropdown, |
|
inputs=[experiment_dropdown], |
|
outputs=[language_dropdown] |
|
) |
|
|
|
load_btn.click( |
|
fn=load_language_data, |
|
inputs=[experiment_dropdown, language_dropdown], |
|
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name] |
|
) |
|
|
|
left_btn.click( |
|
fn=vote_left, |
|
inputs=[], |
|
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name] |
|
) |
|
|
|
right_btn.click( |
|
fn=vote_right, |
|
inputs=[], |
|
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name] |
|
) |
|
|
|
cant_tell_btn.click( |
|
fn=vote_cant_tell, |
|
inputs=[], |
|
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name] |
|
) |
|
|
|
stop_btn.click( |
|
fn=stop_session, |
|
inputs=[], |
|
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name] |
|
) |
|
|
|
diff_checkbox.change( |
|
fn=toggle_diff_mode, |
|
inputs=[diff_checkbox], |
|
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name] |
|
) |
|
|
|
show_models_btn.click( |
|
fn=toggle_model_names, |
|
inputs=[], |
|
outputs=[original_text, left_translation, right_translation, status_text, progress_text, left_btn, right_btn, cant_tell_btn, stop_btn, left_model_name, right_model_name] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|