Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import colorsys | |
| import json | |
| import os | |
| import random | |
| from concurrent.futures import ThreadPoolExecutor | |
| from dataclasses import dataclass, make_dataclass | |
| from datetime import datetime | |
| from io import BytesIO | |
| import aiohttp | |
| import evaluate | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from pydub import AudioSegment | |
| from constants import WHISPER_OPEN_AI_LINK | |
| # Load the Word Error Rate (WER) metric from the evaluate library | |
| wer_metric = evaluate.load("wer") | |
| def compute_average_wer(results): | |
| """ | |
| Compute the average Word Error Rate (WER) for a list of transcription results. | |
| :param results: List of dictionaries, each containing 'reference' and 'prediction' keys | |
| :return: Average WER as a percentage, rounded to 2 decimal places | |
| This function calculates the WER for each reference-prediction pair and returns | |
| the average. If no predictions are provided, it returns 100% WER. | |
| """ | |
| references = [result["reference"] for result in results] | |
| predictions = [result["prediction"] for result in results] | |
| if len(predictions) == 0: | |
| return 1 | |
| return round( | |
| wer_metric.compute(references=references, predictions=predictions) * 100.0, | |
| 2, | |
| ) | |
| def read_json_line_by_line(file_path): | |
| """ | |
| Read a JSON file line by line, parsing each line as a separate JSON object. | |
| :param file_path: Path to the JSON file | |
| :return: List of parsed JSON objects | |
| This function is useful for reading large JSON files that contain one JSON object | |
| per line. It handles JSON parsing errors gracefully, skipping invalid lines. | |
| """ | |
| data = [] | |
| with open(file_path, "r") as f: | |
| for line in f: | |
| try: | |
| item = json.loads(line.strip()) | |
| data.append(item) | |
| except json.JSONDecodeError: | |
| print(f"Skipping invalid JSON in {file_path}: {line}") | |
| return data | |
| def group_wer(group): | |
| """ | |
| Calculate the Word Error Rate (WER) for a group of transcriptions. | |
| :param group: DataFrame group containing 'normalized_reference' and 'normalized_prediction' columns | |
| :return: Average WER for the group | |
| This function is typically used with DataFrame groupby operations to calculate | |
| WER for specific groups of transcriptions. | |
| """ | |
| return compute_average_wer( | |
| group[["normalized_reference", "normalized_prediction"]] | |
| .rename( | |
| columns={ | |
| "normalized_reference": "reference", | |
| "normalized_prediction": "prediction", | |
| } | |
| ) | |
| .to_dict("records") | |
| ) | |
| def load_multilingual_results(csv_file): | |
| """ | |
| Load multilingual results from a CSV file into a pandas DataFrame. | |
| :param csv_file: Path to the CSV file containing multilingual results | |
| :return: DataFrame with the loaded results, or None if the file is not found | |
| This function attempts to load a CSV file using pandas, handling potential | |
| FileNotFoundError exceptions. | |
| """ | |
| try: | |
| df = pd.json_normalize(csv_file) | |
| return df | |
| except FileNotFoundError: | |
| return None | |
| def download_dataset(repo_id, local_dir, remote_dir, path_includes=""): | |
| """ | |
| Download benchmark result files from a specified Hugging Face repository to a local directory. | |
| :param repo_id: ID of the Hugging Face repository | |
| :param local_dir: Local directory where downloaded files will be saved | |
| :param remote_dir: Remote directory within the repository to download from | |
| This function uses the Hugging Face Hub API to list and download files from a | |
| specific directory in a repository. It forces the download to ensure up-to-date files. | |
| """ | |
| files = list_repo_files(repo_id, repo_type="dataset") | |
| directory_files = [ | |
| file for file in files if file.startswith(remote_dir) and path_includes in file | |
| ] | |
| with ThreadPoolExecutor() as executor: | |
| executor.map( | |
| lambda file: hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=file, | |
| local_dir=local_dir, | |
| force_download=True, | |
| ), | |
| directory_files, | |
| ) | |
| def process_file(file_path): | |
| """ | |
| Process a file containing JSON objects delimited by new lines. | |
| :param file_path: Path to the file to be processed | |
| :return: List of dictionaries, each representing a parsed JSON object | |
| This function reads the file line by line, parsing each line as a JSON object. | |
| It handles potential JSON decoding errors, printing error messages for invalid lines. | |
| """ | |
| data = [] | |
| with open(file_path, "r") as file: | |
| for line in file: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| json_obj = json.loads(line) | |
| data.append(json_obj) | |
| except json.JSONDecodeError as e: | |
| print(f"Error decoding JSON in line: {line}") | |
| print(f"Error message: {str(e)}") | |
| return data | |
| def dir_to_json(root_dir, output_file): | |
| """ | |
| Convert a directory of benchmark result files to a single JSON file. | |
| :param root_dir: Root directory containing the benchmark result files | |
| :param output_file: Output file where the JSON data will be saved | |
| This function walks through the directory structure, processes each file, | |
| and writes the combined data to a single JSON file. It extracts metadata | |
| from the file path and includes it in the JSON output. | |
| """ | |
| with open(output_file, "w") as outfile: | |
| for subdir, _, files in os.walk(root_dir): | |
| for file in files: | |
| file_path = os.path.join(subdir, file) | |
| # ignore .DS_Store and summary files | |
| if file_path.endswith(".DS_Store") or "summary" in file_path: | |
| continue | |
| parts = file_path.split(os.sep) | |
| model_version = parts[2] | |
| device_name = parts[3].replace("_", " ") | |
| os_type_version = parts[4] | |
| dataset_name = parts[5] | |
| timestamp_commit = parts[6].replace(".json", "") | |
| timestamp, commit_hash, commit_timestamp = timestamp_commit.split("_") | |
| data_list = process_file(file_path) | |
| for data in data_list: | |
| original_entry = { | |
| "model": model_version.replace("_", "/"), | |
| "device": device_name, | |
| "os": os_type_version.replace("_", " "), | |
| "wer": data["wer"], | |
| "dataset_name": dataset_name, | |
| "reference_transcription": data["reference_transcription"], | |
| "prediction_transcription": data["prediction_transcription"], | |
| "difference_transcription": data["difference_transcription"], | |
| "audio_file_url": data["audio_file_url"], | |
| "timestamp": timestamp.replace("-", ":").replace(":", "-", 2), | |
| "commit_hash": commit_hash, | |
| "commit_timestamp": commit_timestamp, | |
| } | |
| outfile.write(json.dumps(original_entry) + "\n") | |
| async def download_audio_to_ndarray(url): | |
| """ | |
| Downloads an audio file from a URL and converts it to a NumPy array. | |
| :param url: The URL of the audio file to download | |
| :return: A tuple containing the sample rate and audio data as a NumPy array | |
| This asynchronous function uses aiohttp to download the audio file, | |
| converts it to an AudioSegment, and then to a NumPy array. It handles | |
| both mono and stereo audio files. | |
| """ | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url) as response: | |
| if response.status == 200: | |
| audio_bytes = BytesIO(await response.read()) | |
| audio = AudioSegment.from_file(audio_bytes, format="mp3") | |
| audio_data = np.array(audio.get_array_of_samples()) | |
| if audio.channels == 2: | |
| audio_data = audio_data.reshape((-1, 2)) | |
| return audio.frame_rate, audio_data | |
| else: | |
| return None, None | |
| async def play_audio(url): | |
| """ | |
| Wrapper function for Gradio to play audio from a URL. | |
| :param url: The URL of the audio file to play | |
| :return: A tuple of sample rate and audio data, or an error message | |
| This function uses download_audio_to_ndarray to get the audio data | |
| and returns it in a format suitable for Gradio's audio player. | |
| """ | |
| sample_rate, audio_data = await download_audio_to_ndarray(url) | |
| if audio_data is None: | |
| return "Error downloading the file" | |
| else: | |
| return sample_rate, audio_data | |
| def get_filter_cond(df, model, device, os, dataset, timestamp=None): | |
| """ | |
| Creates a filter condition for a DataFrame based on specified parameters. | |
| :param df: DataFrame containing the transcription data | |
| :param model: String representing the model name | |
| :param device: String representing the device name | |
| :param os: String representing the OS name | |
| :param dataset: String representing the dataset name | |
| :param timestamp: Optional timestamp for filtering (default: None) | |
| :return: A boolean mask for filtering the DataFrame | |
| This function constructs a complex boolean condition for filtering | |
| the DataFrame based on the provided parameters. | |
| """ | |
| filter_cond = ( | |
| (df["model"] == model) | |
| & (df["device"] == device) | |
| & (df["os"] == os) | |
| & (df["dataset_name"] == dataset) | |
| ) | |
| return filter_cond & (df["timestamp"] == timestamp) if timestamp else filter_cond | |
| def get_filtered_transcript(df, model, device, os, dataset, timestamp): | |
| """ | |
| Retrieves filtered transcription data from a DataFrame. | |
| :param df: DataFrame containing the transcription data | |
| :param model: String representing the model name | |
| :param device: String representing the device name | |
| :param os: String representing the OS name | |
| :param dataset: String representing the dataset name | |
| :param timestamp: String representing the timestamp | |
| :return: A filtered DataFrame with transcription data | |
| This function applies a filter to the input DataFrame and returns | |
| relevant columns for transcription analysis. | |
| """ | |
| filter_cond = get_filter_cond(df, model, device, os, dataset, timestamp) | |
| df = df[filter_cond][ | |
| [ | |
| "reference_transcription", | |
| "prediction_transcription", | |
| "difference_transcription", | |
| "audio_file_url", | |
| ] | |
| ] | |
| return df | |
| def get_filtered_timestamps(df, model, device, os, dataset): | |
| """ | |
| Retrieves unique timestamps for a specific model, device, OS, and dataset combination. | |
| :param df: DataFrame containing the transcription data | |
| :param model: String representing the model name | |
| :param device: String representing the device name | |
| :param os: String representing the OS name | |
| :param dataset: String representing the dataset name | |
| :return: A filtered DataFrame containing unique timestamps | |
| This function is useful for getting a list of available timestamps | |
| for a specific configuration, which can be used for further analysis or UI elements. | |
| """ | |
| filter_cond = get_filter_cond(df, model, device, os, dataset) | |
| df = df[filter_cond][["timestamp"]].drop_duplicates() | |
| return df | |
| def make_model_name_clickable_link(model): | |
| """ | |
| Creates an HTML link to the Hugging Face model page. | |
| :param model: String representing the model name | |
| :return: An HTML string containing a clickable link to the model page | |
| This function generates a formatted HTML link that can be used in | |
| web interfaces to provide direct access to the model's page on Hugging Face. | |
| """ | |
| return f"""<a style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" href="https://huggingface.co/argmaxinc/whisperkit-coreml/tree/main/{model.replace('/', '_')}" target="_blank">{model}</a>""" | |
| def make_dataset_wer_clickable_link(row, dataset): | |
| """ | |
| Creates a clickable link for the WER value of a dataset. | |
| :param row: Row containing the dataset WER value | |
| :param dataset: String representing the dataset name | |
| :return: An HTML string containing a clickable link to the dataset's WER details | |
| This function generates a formatted HTML link that can be used in | |
| web interfaces to provide access to detailed WER information for a specific dataset. | |
| """ | |
| dataset_column = f"{dataset}" | |
| href = WHISPER_OPEN_AI_LINK.format( | |
| row["Model"].replace("/", "_"), | |
| dataset, | |
| ) | |
| return f'<a style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" href="{href}">{row[dataset_column]}</a>' | |
| def make_timestamp_clickable_link(model, dataset, timestamp): | |
| """ | |
| Creates a clickable link for a timestamp. | |
| :param model: String representing the model name | |
| :param dataset: String representing the dataset name | |
| :param timestamp: Timestamp to be displayed and used in the link | |
| :return: An HTML string containing a clickable div for the timestamp | |
| This function generates a formatted HTML div that can be used as a clickable | |
| element in web interfaces, typically for displaying and interacting with specific timestamps. | |
| """ | |
| elem_id = ( | |
| f"{dataset}-{model}-{timestamp}".replace(" ", "_") | |
| .replace('"', "") | |
| .replace("'", "") | |
| .replace(",", "") | |
| ) | |
| onclick = f"onclick=\"document.getElementById('{elem_id}').click();\"" | |
| return f'<div style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" {onclick} href="#">{timestamp}</div>' | |
| def make_multilingual_model_clickable_link(model): | |
| """ | |
| Creates a clickable link for a multilingual model name. | |
| :param model: String representing the model name | |
| :return: An HTML string containing a clickable div for the model name | |
| This function generates a formatted HTML div that can be used as a clickable | |
| element in web interfaces, typically for displaying and interacting with multilingual model names. | |
| """ | |
| elem_id = ( | |
| f"{model}".replace(" ", "_").replace('"', "").replace("'", "").replace(",", "") | |
| ) | |
| onclick = f"onclick=\"document.getElementById('{elem_id}').click();console.log('hello');\"" | |
| return f'<div style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" {onclick} href="#">{model}</div>' | |
| def plot_metric( | |
| df, y_axis_col, y_axis_title, fig_title, filter_input=None, exclude_input=None | |
| ): | |
| """ | |
| Plots a metric for each model-device-OS group in a DataFrame. | |
| :param df: DataFrame containing the benchmark data | |
| :param y_axis_col: DataFrame column to use as the y-axis | |
| :param y_axis_title: Display name for the y-axis | |
| :param fig_title: Display title for the figure | |
| :param filter_input: Optional string to filter the model-device-OS combinations | |
| :param exclude_input: Optional string to exclude model-device-OS combinations | |
| :return: A Plotly figure object | |
| """ | |
| with open("dashboard_data/version.json", "r") as f: | |
| version = json.load(f) | |
| releases = set(version["releases"]) | |
| df = df[df["commit_hash"].isin(releases)] | |
| grouped = df.groupby(["model", "device", "os"]) | |
| sorted_groups = [ | |
| group.sort_values("commit_timestamp") | |
| for _, group in grouped | |
| ] | |
| if filter_input: | |
| filters = [f.strip().lower() for f in filter_input.split(";")] | |
| sorted_groups = [ | |
| group | |
| for group in sorted_groups | |
| if any( | |
| f | |
| in f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}".lower() | |
| for f in filters | |
| ) | |
| ] | |
| if exclude_input: | |
| excludes = [e.strip().lower() for e in exclude_input.split(";")] | |
| sorted_groups = [ | |
| group | |
| for group in sorted_groups | |
| if not any( | |
| e | |
| in f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}".lower() | |
| for e in excludes | |
| ) | |
| ] | |
| base_colors = ["#4542f4", "#0e0c06", "#ccf0a7", "#ff7f4e", "#ffd15a"] | |
| num_colors = len(sorted_groups) | |
| random_colors = generate_random_colors(base_colors, num_colors) | |
| fig = go.Figure() | |
| for i, group in enumerate(sorted_groups): | |
| model_device_os = ( | |
| f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}" | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=group["commit_timestamp"].apply( | |
| lambda x: datetime.strptime(x, "%Y-%m-%dT%H%M%S").strftime( | |
| "%Y-%m-%d %H:%M:%S" | |
| ) | |
| ), | |
| y=group[y_axis_col], | |
| mode="lines+markers", | |
| name=model_device_os, | |
| line=dict(color=random_colors[i % len(random_colors)]), | |
| marker=dict(color=random_colors[i % len(random_colors)]), | |
| hovertemplate=( | |
| f"<b>{model_device_os}</b><br>" | |
| "Timestamp: %{x}<br>" | |
| f"{y_axis_title}: %{{y:.2f}}<br>" | |
| "<extra></extra>" | |
| ), | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=fig_title, | |
| xaxis_title="Commit Timestamp", | |
| yaxis_title=y_axis_title, | |
| legend_title="Model-Device-OS", | |
| width=1100, | |
| height=600, | |
| plot_bgcolor="rgb(250,249,244)", | |
| ) | |
| return fig | |
| def fields(raw_class): | |
| """ | |
| Returns the fields of a dataclass. | |
| :param raw_class: The dataclass to inspect | |
| :return: List of fields in the dataclass | |
| This utility function extracts and returns all the fields defined in a dataclass, | |
| excluding special methods and attributes. | |
| """ | |
| return [ | |
| v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__" | |
| ] | |
| def get_os_name_and_version(os_string): | |
| """ | |
| Extracts the OS name and major version from a string. | |
| :param os_string: String representing the OS name and version | |
| :return: Formatted string with OS name and major version | |
| This function splits the input string into OS name and version, | |
| then returns a formatted string with just the major version number. | |
| """ | |
| os_name, os_version = os_string.split() | |
| os_version = os_version.split(".")[0] | |
| return f"{os_name} {os_version}" | |
| def create_initial_quality_column_dict(): | |
| """ | |
| Creates the initial column dictionary for the quality table. | |
| :return: A list of column dictionaries | |
| This function defines the basic structure of the quality table, | |
| including columns for model, average WER, and QoI (Quality of Implementation). | |
| """ | |
| return [ | |
| [ | |
| "model", | |
| ColumnContent, | |
| ColumnContent("Model", "html", True, never_hidden=True), | |
| ], | |
| ["average_wer", ColumnContent, ColumnContent("Average WER", "html", True)], | |
| ["qoi", ColumnContent, ColumnContent("QoI", "html", True)], | |
| ] | |
| def calculate_parity(m2_ultra_wer, row): | |
| """ | |
| Calculates the WER parity between M2 Ultra and the current model. | |
| :param m2_ultra_wer: DataFrame containing WER values for M2 Ultra | |
| :param row: Current row being processed | |
| :return: WER difference between M2 Ultra and current model, or None if not applicable | |
| This function computes the percentage difference in WER between the M2 Ultra model | |
| and the current model, providing a measure of relative performance. | |
| """ | |
| if row["Model"] in m2_ultra_wer.index: | |
| return round(m2_ultra_wer[row["Model"]] - row["Average WER"], 2) | |
| return None | |
| def create_initial_performance_column_dict(): | |
| """ | |
| Creates the initial column dictionary for the performance table. | |
| :return: A list of column dictionaries | |
| This function defines the basic structure of the performance table, | |
| including columns for model, device, OS, parity, average WER, QoI, speed, and tokens per second. | |
| """ | |
| return [ | |
| [ | |
| "model", | |
| ColumnContent, | |
| ColumnContent("Model", "html", True, never_hidden=True), | |
| ], | |
| [ | |
| "device", | |
| ColumnContent, | |
| ColumnContent("Device", "html", True, never_hidden=True), | |
| ], | |
| ["os", ColumnContent, ColumnContent("OS", "html", True, never_hidden=True)], | |
| ["english_wer", ColumnContent, ColumnContent("English WER", "html", True)], | |
| ["multilingual_wer", ColumnContent, ColumnContent("Multilingual WER", "str", True)], | |
| ["qoi", ColumnContent, ColumnContent("QoI", "html", False)], | |
| ["speed", ColumnContent, ColumnContent("Speed", "html", False)], | |
| ["toks", ColumnContent, ColumnContent("Tok / s", "html", False)], | |
| ] | |
| def add_datasets_to_quality_columns(column_dict, datasets): | |
| """ | |
| Adds dataset-specific columns to the quality table column dictionary. | |
| :param column_dict: The initial column dictionary | |
| :param datasets: List of dataset names to add | |
| :return: A dictionary containing the updated column dictionary and related metadata | |
| This function extends the quality table structure with columns for each dataset, | |
| and creates a dataclass to represent the table structure. It also generates | |
| metadata about the columns for use in the UI. | |
| """ | |
| updated_column_dict = column_dict.copy() | |
| for dataset in datasets: | |
| field_name = dataset.replace("-", "") | |
| updated_column_dict.append( | |
| [field_name, ColumnContent, ColumnContent(dataset, "html", True)] | |
| ) | |
| AutoEvalColumn = make_dataclass("AutoEvalColumn", updated_column_dict, frozen=True) | |
| COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden] | |
| TYPES = [c.type for c in fields(AutoEvalColumn) if not c.hidden] | |
| ALWAYS_HERE_COLS = [c.name for c in fields(AutoEvalColumn) if c.never_hidden] | |
| TOGGLE_COLS = [c.name for c in fields(AutoEvalColumn) if not c.never_hidden] | |
| SELECTED_COLS = [ | |
| c.name | |
| for c in fields(AutoEvalColumn) | |
| if not c.never_hidden and c.displayed_by_default | |
| ] | |
| return { | |
| "column_dict": updated_column_dict, | |
| "AutoEvalColumn": AutoEvalColumn, | |
| "COLS": COLS, | |
| "TYPES": TYPES, | |
| "ALWAYS_HERE_COLS": ALWAYS_HERE_COLS, | |
| "TOGGLE_COLS": TOGGLE_COLS, | |
| "SELECTED_COLS": SELECTED_COLS, | |
| } | |
| def add_datasets_to_performance_columns(column_dict, datasets): | |
| """ | |
| Adds dataset-specific columns to the performance table column dictionary. | |
| :param column_dict: The initial column dictionary | |
| :param datasets: List of dataset names to add | |
| :return: A dictionary containing the updated column dictionary and related metadata | |
| This function extends the performance table structure with columns for each dataset, | |
| adding both speed and tokens per second metrics. It also creates a dataclass to | |
| represent the table structure and generates metadata about the columns for use in the UI. | |
| """ | |
| updated_column_dict = column_dict.copy() | |
| for dataset in datasets: | |
| field_name = dataset.replace("-", "") | |
| updated_column_dict.append( | |
| [ | |
| f"{field_name}_speed", | |
| ColumnContent, | |
| ColumnContent( | |
| f"{'Short-Form' if dataset == 'librispeech-10mins' else 'Long-Form'} Speed", | |
| "html", | |
| True, | |
| ), | |
| ] | |
| ) | |
| updated_column_dict.append( | |
| [ | |
| f"{field_name}_toks", | |
| ColumnContent, | |
| ColumnContent( | |
| f"{'Short-Form' if dataset == 'librispeech-10mins' else 'Long-Form'} Tok/s", | |
| "html", | |
| True, | |
| ), | |
| ] | |
| ) | |
| AutoEvalColumn = make_dataclass("AutoEvalColumn", updated_column_dict, frozen=True) | |
| COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden] | |
| TYPES = [c.type for c in fields(AutoEvalColumn) if not c.hidden] | |
| ALWAYS_HERE_COLS = [c.name for c in fields(AutoEvalColumn) if c.never_hidden] | |
| TOGGLE_COLS = [c.name for c in fields(AutoEvalColumn) if not c.never_hidden] | |
| SELECTED_COLS = [ | |
| c.name | |
| for c in fields(AutoEvalColumn) | |
| if not c.never_hidden and c.displayed_by_default | |
| ] | |
| return { | |
| "column_dict": updated_column_dict, | |
| "AutoEvalColumn": AutoEvalColumn, | |
| "COLS": COLS, | |
| "TYPES": TYPES, | |
| "ALWAYS_HERE_COLS": ALWAYS_HERE_COLS, | |
| "TOGGLE_COLS": TOGGLE_COLS, | |
| "SELECTED_COLS": SELECTED_COLS, | |
| } | |
| def create_confusion_matrix_plot(matrix, labels, is_forced): | |
| """ | |
| Creates a confusion matrix plot for language detection. | |
| :param matrix: 2D numpy array representing the confusion matrix | |
| :param labels: List of language labels | |
| :param is_forced: Boolean indicating whether language hint was used | |
| :return: A Plotly figure object representing the confusion matrix | |
| This function generates a heatmap visualization of the confusion matrix | |
| for language detection, with customized layout and hover information. | |
| """ | |
| fig = go.Figure( | |
| data=go.Heatmap( | |
| z=matrix, | |
| x=labels, | |
| y=labels, | |
| colorscale=[ | |
| [0, "rgb(250,249,244)"], | |
| [0.5, "rgb(69,66,244)"], | |
| [1.0, "rgb(14,12,6)"], | |
| ], | |
| hoverongaps=False, | |
| hovertemplate="True: %{y}<br>Predicted: %{x}<br>Value: %{z}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=f'Language Detection Confusion Matrix with {"Language Hint" if is_forced else "Language Prediction by Model"}', | |
| xaxis_title="Predicted Language", | |
| yaxis_title="True Language", | |
| xaxis=dict(tickangle=-45), | |
| width=600, | |
| height=600, | |
| margin=dict(l=50, r=50, t=50, b=50), | |
| ) | |
| return fig | |
| def hex_to_rgb(hex_color): | |
| """ | |
| Converts a hexadecimal color code to RGB values. | |
| :param hex_color: String representing a color in hexadecimal format | |
| :return: Tuple of three integers representing RGB values | |
| This function takes a hex color code and returns the corresponding | |
| RGB values as a tuple of integers. | |
| """ | |
| hex_color = hex_color.lstrip("#") | |
| return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) | |
| def rgb_to_hex(rgb): | |
| """ | |
| Converts RGB values to a hexadecimal color code. | |
| :param rgb: Tuple of three integers representing RGB values | |
| :return: String representing the color in hexadecimal format | |
| This function takes RGB values as a tuple and returns the corresponding | |
| hex color code as a string. | |
| """ | |
| return "#{:02x}{:02x}{:02x}".format(*rgb) | |
| def interpolate_colors(color1, color2, factor): | |
| """ | |
| Interpolates between two colors in HSV space. | |
| :param color1: First color in hexadecimal format | |
| :param color2: Second color in hexadecimal format | |
| :param factor: Float between 0 and 1, representing the interpolation factor | |
| :return: Interpolated color in hexadecimal format | |
| This function performs color interpolation in HSV color space, which can | |
| produce more visually pleasing results than simple RGB interpolation. | |
| """ | |
| rgb1 = hex_to_rgb(color1) | |
| rgb2 = hex_to_rgb(color2) | |
| hsv1 = colorsys.rgb_to_hsv(*[x / 255.0 for x in rgb1]) | |
| hsv2 = colorsys.rgb_to_hsv(*[x / 255.0 for x in rgb2]) | |
| h = (hsv1[0] + factor * (hsv2[0] - hsv1[0])) % 1.0 | |
| s = hsv1[1] + factor * (hsv2[1] - hsv1[1]) | |
| v = hsv1[2] + factor * (hsv2[2] - hsv1[2]) | |
| rgb = colorsys.hsv_to_rgb(h, s, v) | |
| return rgb_to_hex(tuple(int(x * 255) for x in rgb)) | |
| def color_distance(color1, color2): | |
| """ | |
| Calculates the Euclidean distance between two colors in RGB space. | |
| :param color1: First color in hexadecimal format | |
| :param color2: Second color in hexadecimal format | |
| :return: Float representing the distance between the two colors | |
| This function computes the Euclidean distance between two colors in RGB space, | |
| which can be used as a measure of color similarity. | |
| """ | |
| rgb1 = hex_to_rgb(color1) | |
| rgb2 = hex_to_rgb(color2) | |
| return sum((a - b) ** 2 for a, b in zip(rgb1, rgb2)) ** 0.5 | |
| def generate_random_colors(base_colors, num_colors, min_distance=30): | |
| """ | |
| Generates a list of random colors based on a set of base colors. | |
| :param base_colors: List of base colors in hexadecimal format | |
| :param num_colors: Number of colors to generate | |
| :param min_distance: Minimum distance between generated colors (default: 30) | |
| :return: List of generated colors in hexadecimal format | |
| This function creates a list of random colors by interpolating between | |
| the provided base colors. It attempts to maintain a minimum distance | |
| between colors to ensure visual distinctiveness. | |
| """ | |
| generated_colors = [] | |
| attempts = 0 | |
| max_attempts = 1000 | |
| while len(generated_colors) < num_colors and attempts < max_attempts: | |
| color1, color2 = random.sample(base_colors, 2) | |
| factor = random.random() | |
| new_color = interpolate_colors(color1, color2, factor) | |
| if all(color_distance(new_color, c) >= min_distance for c in generated_colors): | |
| generated_colors.append(new_color) | |
| attempts = 0 | |
| else: | |
| attempts += 1 | |
| if attempts > 100: | |
| if random.random() < 0.1: | |
| generated_colors.append(new_color) | |
| attempts = 0 | |
| return generated_colors | |
| class Task: | |
| """ | |
| Dataclass representing a benchmark task. | |
| :param benchmark: String representing the benchmark name | |
| :param metric: String representing the metric used for evaluation | |
| :param col_name: String representing the column name in the results DataFrame | |
| """ | |
| benchmark: str | |
| metric: str | |
| col_name: str | |
| class ColumnContent: | |
| """ | |
| Dataclass representing a column in the results table. | |
| :param name: String representing the column name | |
| :param type: String representing the data type of the column | |
| :param displayed_by_default: Boolean indicating if the column should be displayed by default | |
| :param hidden: Boolean indicating if the column should be hidden (default: False) | |
| :param never_hidden: Boolean indicating if the column should never be hidden (default: False) | |
| :param dummy: Boolean indicating if this is a dummy column (default: False) | |
| """ | |
| name: str | |
| type: str | |
| displayed_by_default: bool | |
| hidden: bool = False | |
| never_hidden: bool = False | |
| dummy: bool = False | |
| css = """ | |
| @font-face { | |
| font-family: 'Zwizz Regular'; | |
| font-style: normal; | |
| font-weight: normal; | |
| src: local('Zwizz Regular'), url('static/Zwizz-Regular.woff') format('woff'); | |
| } | |
| @font-face { | |
| font-family: 'Zwizz Medium'; | |
| font-style: normal; | |
| font-weight: normal; | |
| src: local('Zwizz Medium'), url('static/Zwizz-Medium.woff') format('woff'); | |
| } | |
| @font-face { | |
| font-family: 'Zwizz SemiBold'; | |
| font-style: normal; | |
| font-weight: normal; | |
| src: local('Zwizz SemiBold'), url('static/Zwizz-SemiBold.woff') format('woff'); | |
| } | |
| @import url('https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&display=swap'); | |
| @import url('https://fonts.googleapis.com/css2?family=Sora:[email protected]&display=swap'); | |
| /* Typography Scale */ | |
| h1, .h1 { | |
| font-family: 'Sora', sans-serif; | |
| font-weight: 300; | |
| font-size: 2em; | |
| letter-spacing: -0.05em; | |
| } | |
| h2, .h2 { | |
| font-family: 'Sora', sans-serif; | |
| font-weight: 400; | |
| letter-spacing: -0.05em; | |
| } | |
| h3, h4, h5, .h3, .h4, .h5 { | |
| font-family: 'Sora', sans-serif; | |
| font-weight: 400; | |
| letter-spacing: -0.05em; | |
| } | |
| h6, .h6, pre, code, .monospace { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-weight: 400; | |
| letter-spacing: 0.01em; | |
| } | |
| /* Add strong tag styling */ | |
| strong, b { | |
| font-family: 'Zwizz SemiBold', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; | |
| letter-spacing: -0.02em; | |
| } | |
| /* Global Zwizz styles */ | |
| :root { | |
| --zwizz-spacing: -0.02em; | |
| } | |
| /* All Gradio elements should have Zwizz spacing */ | |
| .gradio-container * { | |
| letter-spacing: var(--zwizz-spacing); | |
| line-height: 1.7; | |
| } | |
| /* UI Elements */ | |
| .tab-buttons button, #models-to-add-text, .gradio-button { | |
| font-family: 'Sora', sans-serif; | |
| font-weight: 400; | |
| letter-spacing: -0.05em; | |
| } | |
| /* Specific Table Styling */ | |
| table, .table, th, td { | |
| font-family: 'IBM Plex Mono', 'Noto Color Emoji', sans-serif, monospace !important; | |
| font-weight: 400; | |
| letter-spacing: 0.01em; | |
| } | |
| /* Technical/Code Elements */ | |
| .code-block, .technical-text { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-weight: 400; | |
| letter-spacing: 0.01em; | |
| } | |
| /* Additional Elements */ | |
| #methodology-text p, #methodology-text li, .markdown-text { | |
| font-family: 'Zwizz Regular', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; | |
| font-size: 16px !important; | |
| letter-spacing: var(--zwizz-spacing); | |
| line-height: 1.7; | |
| } | |
| /* Font weight utilities */ | |
| .zwizz-medium { | |
| font-family: 'Zwizz Medium', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; | |
| } | |
| .zwizz-semibold { | |
| font-family: 'Zwizz SemiBold', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; | |
| } | |
| /* Maintaining Original Layout Rules */ | |
| .gradio-container { | |
| max-width: 95% !important; | |
| } | |
| /* Table Layouts */ | |
| .large-table, | |
| .large-table .table-wrap, | |
| #multilingual-model-table .table-wrap, | |
| #lookup-table .table-wrap { | |
| height: 35em !important; | |
| overflow-y: scroll !important; | |
| } | |
| /* SVG Container Rules */ | |
| .svg-container, | |
| .main-svg { | |
| width: 100% !important; | |
| } | |
| .large-table, .large-table .table-wrap, #multilingual-model-table .table-wrap, #lookup-table .table-wrap { | |
| height: 35em !important; | |
| overflow-y: scroll !important; | |
| } | |
| .left-side-table .table-wrap { | |
| height: 15em !important; | |
| overflow-y: scroll !important; | |
| } | |
| #average-wer-table .table-wrap { | |
| height: 8em !important; | |
| overflow-y: scroll !important; | |
| } | |
| #general-wer-table .table-wrap { | |
| height: 35em !important; | |
| overflow-y: scroll !important; | |
| } | |
| """ | |
