Spaces:
Sleeping
Sleeping
chore: init demo
Browse files- .gitignore +3 -0
- app.py +248 -0
- constants.py +217 -0
- requirements.txt +3 -0
- utils.py +270 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.gradio
|
| 2 |
+
__pycache__
|
| 3 |
+
plots
|
app.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pickle
|
| 5 |
+
import os
|
| 6 |
+
from sklearn.manifold import TSNE
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from utils import (plot_distances_tsne,
|
| 9 |
+
plot_distances_umap,
|
| 10 |
+
cluster_languages_hdbscan,
|
| 11 |
+
cluster_languages_kmeans,
|
| 12 |
+
plot_mst,
|
| 13 |
+
cluster_languages_by_families,
|
| 14 |
+
cluster_languages_by_subfamilies,
|
| 15 |
+
filter_languages_by_families)
|
| 16 |
+
from functools import partial
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
with open("../../results/languages_list.pkl", "rb") as f:
|
| 20 |
+
languages = pickle.load(f)
|
| 21 |
+
|
| 22 |
+
DATASETS = ["wikimedia/wikipedia", "uonlp/CulturaX", "HuggingFaceFW/fineweb-2"]
|
| 23 |
+
MODELS = ["mistralai/Mistral-7B-v0.1", "google/gemma-3-4b-pt", "meta-llama/Llama-3.2-1B"]
|
| 24 |
+
|
| 25 |
+
distance_matrices = {
|
| 26 |
+
dataset: {
|
| 27 |
+
model: np.load(os.path.join("../../results", dataset, model, "distances_matrix.npy"))
|
| 28 |
+
for model in MODELS
|
| 29 |
+
}
|
| 30 |
+
for dataset in DATASETS
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
average_distances_matrix = np.load("../../results/average_distances_matrix.npy")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def filter_languages_nan(model, dataset, use_average):
|
| 37 |
+
if use_average:
|
| 38 |
+
matrix = average_distances_matrix
|
| 39 |
+
else:
|
| 40 |
+
matrix = distance_matrices[dataset][model]
|
| 41 |
+
|
| 42 |
+
vector = matrix[0]
|
| 43 |
+
updated_languages = np.array(languages)[~np.isnan(vector)]
|
| 44 |
+
updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)]
|
| 45 |
+
|
| 46 |
+
return updated_matrix, updated_languages
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_similar_languages(model, dataset, selected_language, use_average, n):
|
| 50 |
+
"""
|
| 51 |
+
Retrieves the distances for the selected language from the chosen model and dataset,
|
| 52 |
+
sorts them by similarity (lowest distance first), and returns a DataFrame.
|
| 53 |
+
"""
|
| 54 |
+
if use_average:
|
| 55 |
+
matrix = average_distances_matrix
|
| 56 |
+
else:
|
| 57 |
+
matrix = distance_matrices[dataset][model]
|
| 58 |
+
selected_language_index = languages.index(selected_language)
|
| 59 |
+
distances = matrix[selected_language_index]
|
| 60 |
+
df = pd.DataFrame({"Language": languages, "Distance": distances})
|
| 61 |
+
sorted_distances = df.sort_values(by="Distance")
|
| 62 |
+
sorted_distances.drop(index=selected_language_index, inplace=True)
|
| 63 |
+
sorted_distances.reset_index(drop=True, inplace=True)
|
| 64 |
+
sorted_distances.reset_index(inplace=True)
|
| 65 |
+
sorted_distances["Distance"] = sorted_distances["Distance"].round(4)
|
| 66 |
+
return sorted_distances.head(n)
|
| 67 |
+
|
| 68 |
+
def update_languages(model, dataset):
|
| 69 |
+
"""
|
| 70 |
+
Returns the language list based on the given model and dataset.
|
| 71 |
+
"""
|
| 72 |
+
matrix = distance_matrices[dataset][model]
|
| 73 |
+
vector = matrix[0]
|
| 74 |
+
updated_languages = np.array(languages)[~np.isnan(vector)]
|
| 75 |
+
return list(updated_languages)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def update_language_options(model, dataset, language, use_average):
|
| 79 |
+
if use_average:
|
| 80 |
+
updated_languages = languages
|
| 81 |
+
else:
|
| 82 |
+
updated_languages = update_languages(model, dataset)
|
| 83 |
+
if language not in updated_languages:
|
| 84 |
+
language = updated_languages[0]
|
| 85 |
+
return gr.Dropdown(label="Language", choices=updated_languages, value=language)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def toggle_inputs(use_average):
|
| 89 |
+
if use_average:
|
| 90 |
+
return gr.update(interactive=False, visible=False), gr.update(interactive=False, visible=False)
|
| 91 |
+
else:
|
| 92 |
+
return gr.update(interactive=True, visible=True), gr.update(interactive=True, visible=True)
|
| 93 |
+
|
| 94 |
+
i = 0
|
| 95 |
+
|
| 96 |
+
def plot_distances(model, dataset, use_average, cluster_method, cluster_method_param, plot_fn):
|
| 97 |
+
"""
|
| 98 |
+
Plots all languages from the distances matrix using t-SNE.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
global i
|
| 102 |
+
|
| 103 |
+
updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
|
| 104 |
+
|
| 105 |
+
if cluster_method == "HDBSCAN":
|
| 106 |
+
filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan(
|
| 107 |
+
updated_matrix, updated_languages, min_cluster_size=cluster_method_param
|
| 108 |
+
)
|
| 109 |
+
legends = None
|
| 110 |
+
elif cluster_method == "KMeans":
|
| 111 |
+
filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans(
|
| 112 |
+
updated_matrix, updated_languages, n_clusters=cluster_method_param
|
| 113 |
+
)
|
| 114 |
+
legends = None
|
| 115 |
+
elif cluster_method == "Family":
|
| 116 |
+
clusters, legends = cluster_languages_by_families(updated_languages)
|
| 117 |
+
filtered_matrix = updated_matrix
|
| 118 |
+
filtered_languages = updated_languages
|
| 119 |
+
elif cluster_method == "Subfamily":
|
| 120 |
+
clusters, legends = cluster_languages_by_subfamilies(updated_languages)
|
| 121 |
+
filtered_matrix = updated_matrix
|
| 122 |
+
filtered_languages = updated_languages
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError("Invalid cluster method")
|
| 125 |
+
|
| 126 |
+
fig = plot_fn(model, dataset, use_average, filtered_matrix, filtered_languages, clusters, legends)
|
| 127 |
+
fig.tight_layout()
|
| 128 |
+
fig.savefig(f"plots/plot_{i}.pdf", format="pdf")
|
| 129 |
+
i += 1
|
| 130 |
+
return fig
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
with gr.Blocks() as demo:
|
| 134 |
+
gr.Markdown("## Language Distance Explorer")
|
| 135 |
+
average_checkbox = gr.Checkbox(label="Use Average Distances", value=False)
|
| 136 |
+
with gr.Row():
|
| 137 |
+
model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0])
|
| 138 |
+
dataset_input = gr.Dropdown(
|
| 139 |
+
label="Dataset",
|
| 140 |
+
choices=DATASETS,
|
| 141 |
+
value=DATASETS[0]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
with gr.Tab(label="Closest Languages Table"):
|
| 145 |
+
with gr.Row():
|
| 146 |
+
language_input = gr.Dropdown(label="Language", choices=languages, value=languages[0])
|
| 147 |
+
top_n_input = gr.Slider(label="Top N", minimum=1, maximum=30, step=1, value=10)
|
| 148 |
+
|
| 149 |
+
output_table = gr.Dataframe(label="Similar Languages")
|
| 150 |
+
|
| 151 |
+
model_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
|
| 152 |
+
dataset_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
|
| 153 |
+
language_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
| 154 |
+
model_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
| 155 |
+
dataset_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
| 156 |
+
top_n_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
| 157 |
+
|
| 158 |
+
average_checkbox.change(
|
| 159 |
+
fn=toggle_inputs,
|
| 160 |
+
inputs=[average_checkbox],
|
| 161 |
+
outputs=[model_input, dataset_input]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
average_checkbox.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
|
| 165 |
+
average_checkbox.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
with gr.Tab(label="Distance Plot"):
|
| 169 |
+
with gr.Row():
|
| 170 |
+
cluster_method_input = gr.Dropdown(label="Cluster Method", choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], value="HDBSCAN")
|
| 171 |
+
clusters_input = gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2)
|
| 172 |
+
|
| 173 |
+
def update_clusters_input_option(cluster_method):
|
| 174 |
+
if cluster_method == "HDBSCAN":
|
| 175 |
+
return gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2, visible=True, interactive=True)
|
| 176 |
+
elif cluster_method == "KMeans":
|
| 177 |
+
return gr.Slider(label="Number of Clusters", minimum=2, maximum=20, step=1, value=2, visible=True, interactive=True)
|
| 178 |
+
else:
|
| 179 |
+
return gr.update(interactive=False, visible=False)
|
| 180 |
+
|
| 181 |
+
cluster_method_input.change(fn=update_clusters_input_option, inputs=[cluster_method_input], outputs=clusters_input)
|
| 182 |
+
|
| 183 |
+
with gr.Row():
|
| 184 |
+
plot_tsne_button = gr.Button("Plot t-SNE")
|
| 185 |
+
plot_umap_button = gr.Button("Plot UMAP")
|
| 186 |
+
plot_mst_button = gr.Button("Plot MST")
|
| 187 |
+
|
| 188 |
+
with gr.Row():
|
| 189 |
+
plot_output = gr.Plot(label="Distance Plot")
|
| 190 |
+
|
| 191 |
+
plot_tsne_button.click(fn=partial(plot_distances, plot_fn=plot_distances_tsne),
|
| 192 |
+
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
|
| 193 |
+
outputs=plot_output)
|
| 194 |
+
plot_umap_button.click(fn=partial(plot_distances, plot_fn=plot_distances_umap),
|
| 195 |
+
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
|
| 196 |
+
outputs=plot_output)
|
| 197 |
+
plot_mst_button.click(fn=partial(plot_distances, plot_fn=plot_mst),
|
| 198 |
+
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
|
| 199 |
+
outputs=plot_output)
|
| 200 |
+
|
| 201 |
+
with gr.Tab(label="Language Families Subplot"):
|
| 202 |
+
|
| 203 |
+
checked_families_input = gr.CheckboxGroup(label="Language Families",
|
| 204 |
+
choices=[
|
| 205 |
+
'Afroasiatic',
|
| 206 |
+
'Austroasiatic',
|
| 207 |
+
'Austronesian',
|
| 208 |
+
'Constructed',
|
| 209 |
+
'Creole',
|
| 210 |
+
'Dravidian',
|
| 211 |
+
'Germanic',
|
| 212 |
+
'Indo-European',
|
| 213 |
+
'Japonic',
|
| 214 |
+
'Kartvelian',
|
| 215 |
+
'Koreanic',
|
| 216 |
+
'Language Isolate',
|
| 217 |
+
'Niger-Congo',
|
| 218 |
+
'Northeast Caucasian',
|
| 219 |
+
'Romance',
|
| 220 |
+
'Sino-Tibetan',
|
| 221 |
+
'Turkic',
|
| 222 |
+
'Uralic'
|
| 223 |
+
],
|
| 224 |
+
value=["Indo-European"])
|
| 225 |
+
with gr.Row():
|
| 226 |
+
plot_family_button = gr.Button("Plot Families")
|
| 227 |
+
plot_figsize_h_input = gr.Slider(label="Figure Height", minimum=5, maximum=30, step=1, value=15)
|
| 228 |
+
plot_figsize_w_input = gr.Slider(label="Figure Width", minimum=5, maximum=30, step=1, value=15)
|
| 229 |
+
plot_family_output = gr.Plot(label="Families Plot")
|
| 230 |
+
def plot_families_subfamilies(families, model, dataset, use_average, figsize_h, figsize_w):
|
| 231 |
+
global i
|
| 232 |
+
|
| 233 |
+
updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
|
| 234 |
+
updated_matrix, updated_languages = filter_languages_by_families(updated_matrix, updated_languages, families)
|
| 235 |
+
|
| 236 |
+
clusters, legends = cluster_languages_by_subfamilies(updated_languages)
|
| 237 |
+
fig = plot_mst(model, dataset, use_average, updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h))
|
| 238 |
+
fig.tight_layout()
|
| 239 |
+
fig.savefig(f"plots/plot_{i}.pdf", format="pdf")
|
| 240 |
+
i += 1
|
| 241 |
+
return fig
|
| 242 |
+
|
| 243 |
+
plot_family_button.click(fn=plot_families_subfamilies,
|
| 244 |
+
inputs=[checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input],
|
| 245 |
+
outputs=plot_family_output)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
demo.launch(share=True)
|
constants.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
language_subfamilies = {
|
| 2 |
+
"Afrikaans": "West Germanic",
|
| 3 |
+
"Albanian": "Albanian",
|
| 4 |
+
"Arabic": "Semitic",
|
| 5 |
+
"Egyptian Arabic": "Semitic",
|
| 6 |
+
"Aragonese": "Romance",
|
| 7 |
+
"Armenian": "Armenian",
|
| 8 |
+
"Asturian": "Romance",
|
| 9 |
+
"Azerbaijani": "Oghuz",
|
| 10 |
+
"Bashkir": "Kypchak",
|
| 11 |
+
"Basque": "Language Isolate",
|
| 12 |
+
"Bavarian": "Austro-Bavarian",
|
| 13 |
+
"Belarusian": "East Slavic",
|
| 14 |
+
"Bengali": "Eastern Indo-Aryan",
|
| 15 |
+
"Bishnupriya Manipuri": "Eastern Indo-Aryan",
|
| 16 |
+
"Bosnian": "South Slavic",
|
| 17 |
+
"Breton": "Brythonic",
|
| 18 |
+
"Bulgarian": "South Slavic",
|
| 19 |
+
"Burmese": "Burmish",
|
| 20 |
+
"Catalan": "Romance",
|
| 21 |
+
"Cebuano": "Central Philippine",
|
| 22 |
+
"Chechen": "Nakh-Daghestanian",
|
| 23 |
+
"Chinese (Simplified)": "Sinitic",
|
| 24 |
+
"Chinese (Traditional)": "Sinitic",
|
| 25 |
+
"Min Nan Chinese": "Sinitic",
|
| 26 |
+
"Chuvash": "Oghur",
|
| 27 |
+
"Croatian": "South Slavic",
|
| 28 |
+
"Czech": "West Slavic",
|
| 29 |
+
"Danish": "North Germanic",
|
| 30 |
+
"Dutch": "West Germanic",
|
| 31 |
+
"English": "West Germanic",
|
| 32 |
+
"Estonian": "Finnic",
|
| 33 |
+
"Finnish": "Finnic",
|
| 34 |
+
"French": "Gallo-Romance",
|
| 35 |
+
"Galician": "Gallo-Romance",
|
| 36 |
+
"Georgian": "Kartvelian",
|
| 37 |
+
"German": "West Germanic",
|
| 38 |
+
"Greek": "Hellenic",
|
| 39 |
+
"Gujarati": "Gujarati",
|
| 40 |
+
"Haitian": "French-based Creole",
|
| 41 |
+
"Hebrew": "Semitic",
|
| 42 |
+
"Hindi": "Central Indo-Aryan",
|
| 43 |
+
"Hungarian": "Ugric",
|
| 44 |
+
"Icelandic": "North Germanic",
|
| 45 |
+
"Ido": "Constructed",
|
| 46 |
+
"Indonesian": "Malayic",
|
| 47 |
+
"Irish": "Goidelic",
|
| 48 |
+
"Italian": "Italo-Dalmatian",
|
| 49 |
+
"Japanese": "Japonic",
|
| 50 |
+
"Javanese": "Javanic",
|
| 51 |
+
"Kannada": "Southern Dravidian",
|
| 52 |
+
"Kazakh": "Kypchak",
|
| 53 |
+
"Kirghiz": "Kypchak",
|
| 54 |
+
"Korean": "Koreanic",
|
| 55 |
+
"Latin": "Italic",
|
| 56 |
+
"Latvian": "Baltic",
|
| 57 |
+
"Lithuanian": "Baltic",
|
| 58 |
+
"Lombard": "Gallo-Italic",
|
| 59 |
+
"Low Saxon": "West Germanic",
|
| 60 |
+
"Luxembourgish": "West Germanic",
|
| 61 |
+
"Macedonian": "South Slavic",
|
| 62 |
+
"Malagasy": "Malayic",
|
| 63 |
+
"Malay": "Malayic",
|
| 64 |
+
"Malayalam": "Southern Dravidian",
|
| 65 |
+
"Marathi": "Central Indo-Aryan",
|
| 66 |
+
"Minangkabau": "Malayic",
|
| 67 |
+
"Nepali": "Eastern Indo-Aryan",
|
| 68 |
+
"Newar": "Newaric",
|
| 69 |
+
"Norwegian (Bokmal)": "North Germanic",
|
| 70 |
+
"Norwegian (Nynorsk)": "North Germanic",
|
| 71 |
+
"Occitan": "Gallo-Romance",
|
| 72 |
+
"Persian (Farsi)": "Iranian",
|
| 73 |
+
"Piedmontese": "Gallo-Italic",
|
| 74 |
+
"Polish": "West Slavic",
|
| 75 |
+
"Portuguese": "Iberian Romance",
|
| 76 |
+
"Punjabi": "Punjabi",
|
| 77 |
+
"Romanian": "Eastern Romance",
|
| 78 |
+
"Russian": "East Slavic",
|
| 79 |
+
"Scots": "West Germanic",
|
| 80 |
+
"Serbian": "South Slavic",
|
| 81 |
+
"Serbo-Croatian": "South Slavic",
|
| 82 |
+
"Sicilian": "Italo-Dalmatian",
|
| 83 |
+
"Slovak": "West Slavic",
|
| 84 |
+
"Slovenian": "South Slavic",
|
| 85 |
+
"South Azerbaijani": "Oghuz",
|
| 86 |
+
"Spanish": "Iberian Romance",
|
| 87 |
+
"Sundanese": "Sundic",
|
| 88 |
+
"Swahili": "Bantu",
|
| 89 |
+
"Swedish": "North Germanic",
|
| 90 |
+
"Tagalog": "Central Philippine",
|
| 91 |
+
"Tajik": "Iranian",
|
| 92 |
+
"Tamil": "Southern Dravidian",
|
| 93 |
+
"Tatar": "Kypchak",
|
| 94 |
+
"Telugu": "Southern Dravidian",
|
| 95 |
+
"Turkish": "Oghuz",
|
| 96 |
+
"Ukrainian": "East Slavic",
|
| 97 |
+
"Urdu": "Central Indo-Aryan",
|
| 98 |
+
"Uzbek": "Karluk",
|
| 99 |
+
"Vietnamese": "Vietic",
|
| 100 |
+
"Volapük": "Constructed",
|
| 101 |
+
"Waray-Waray": "Central Philippine",
|
| 102 |
+
"Welsh": "Brythonic",
|
| 103 |
+
"West Frisian": "West Germanic",
|
| 104 |
+
"Western Punjabi": "Punjabi",
|
| 105 |
+
"Yoruba": "Yoruboid",
|
| 106 |
+
"Esperanto": "Constructed",
|
| 107 |
+
"Crimean Tatar": "Kypchak"
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
language_families = {
|
| 111 |
+
"Afrikaans": "Germanic",
|
| 112 |
+
"Albanian": "Indo-European",
|
| 113 |
+
"Arabic": "Afroasiatic",
|
| 114 |
+
"Egyptian Arabic": "Afroasiatic",
|
| 115 |
+
"Aragonese": "Romance",
|
| 116 |
+
"Armenian": "Indo-European",
|
| 117 |
+
"Asturian": "Romance",
|
| 118 |
+
"Azerbaijani": "Turkic",
|
| 119 |
+
"Bashkir": "Turkic",
|
| 120 |
+
"Basque": "Language Isolate",
|
| 121 |
+
"Bavarian": "Germanic",
|
| 122 |
+
"Belarusian": "Indo-European",
|
| 123 |
+
"Bengali": "Indo-European",
|
| 124 |
+
"Bishnupriya Manipuri": "Indo-European",
|
| 125 |
+
"Bosnian": "Indo-European",
|
| 126 |
+
"Breton": "Indo-European",
|
| 127 |
+
"Bulgarian": "Indo-European",
|
| 128 |
+
"Burmese": "Sino-Tibetan",
|
| 129 |
+
"Catalan": "Romance",
|
| 130 |
+
"Cebuano": "Austronesian",
|
| 131 |
+
"Chechen": "Northeast Caucasian",
|
| 132 |
+
"Chinese (Simplified)": "Sino-Tibetan",
|
| 133 |
+
"Chinese (Traditional)": "Sino-Tibetan",
|
| 134 |
+
"Min Nan Chinese": "Sino-Tibetan",
|
| 135 |
+
"Chuvash": "Turkic",
|
| 136 |
+
"Croatian": "Indo-European",
|
| 137 |
+
"Czech": "Indo-European",
|
| 138 |
+
"Danish": "Germanic",
|
| 139 |
+
"Dutch": "Germanic",
|
| 140 |
+
"English": "Germanic",
|
| 141 |
+
"Estonian": "Uralic",
|
| 142 |
+
"Finnish": "Uralic",
|
| 143 |
+
"French": "Romance",
|
| 144 |
+
"Galician": "Romance",
|
| 145 |
+
"Georgian": "Kartvelian",
|
| 146 |
+
"German": "Germanic",
|
| 147 |
+
"Greek": "Indo-European",
|
| 148 |
+
"Gujarati": "Indo-European",
|
| 149 |
+
"Haitian": "Creole",
|
| 150 |
+
"Hebrew": "Afroasiatic",
|
| 151 |
+
"Hindi": "Indo-European",
|
| 152 |
+
"Hungarian": "Uralic",
|
| 153 |
+
"Icelandic": "Germanic",
|
| 154 |
+
"Ido": "Constructed",
|
| 155 |
+
"Indonesian": "Austronesian",
|
| 156 |
+
"Irish": "Indo-European",
|
| 157 |
+
"Italian": "Romance",
|
| 158 |
+
"Japanese": "Japonic",
|
| 159 |
+
"Javanese": "Austronesian",
|
| 160 |
+
"Kannada": "Dravidian",
|
| 161 |
+
"Kazakh": "Turkic",
|
| 162 |
+
"Kirghiz": "Turkic",
|
| 163 |
+
"Korean": "Koreanic",
|
| 164 |
+
"Latin": "Indo-European",
|
| 165 |
+
"Latvian": "Indo-European",
|
| 166 |
+
"Lithuanian": "Indo-European",
|
| 167 |
+
"Lombard": "Romance",
|
| 168 |
+
"Low Saxon": "Germanic",
|
| 169 |
+
"Luxembourgish": "Germanic",
|
| 170 |
+
"Macedonian": "Indo-European",
|
| 171 |
+
"Malagasy": "Austronesian",
|
| 172 |
+
"Malay": "Austronesian",
|
| 173 |
+
"Malayalam": "Dravidian",
|
| 174 |
+
"Marathi": "Indo-European",
|
| 175 |
+
"Minangkabau": "Austronesian",
|
| 176 |
+
"Nepali": "Indo-European",
|
| 177 |
+
"Newar": "Sino-Tibetan",
|
| 178 |
+
"Norwegian (Bokmal)": "Germanic",
|
| 179 |
+
"Norwegian (Nynorsk)": "Germanic",
|
| 180 |
+
"Occitan": "Romance",
|
| 181 |
+
"Persian (Farsi)": "Indo-European",
|
| 182 |
+
"Piedmontese": "Romance",
|
| 183 |
+
"Polish": "Indo-European",
|
| 184 |
+
"Portuguese": "Romance",
|
| 185 |
+
"Punjabi": "Indo-European",
|
| 186 |
+
"Romanian": "Romance",
|
| 187 |
+
"Russian": "Indo-European",
|
| 188 |
+
"Scots": "Germanic",
|
| 189 |
+
"Serbian": "Indo-European",
|
| 190 |
+
"Serbo-Croatian": "Indo-European",
|
| 191 |
+
"Sicilian": "Romance",
|
| 192 |
+
"Slovak": "Indo-European",
|
| 193 |
+
"Slovenian": "Indo-European",
|
| 194 |
+
"South Azerbaijani": "Turkic",
|
| 195 |
+
"Spanish": "Romance",
|
| 196 |
+
"Sundanese": "Austronesian",
|
| 197 |
+
"Swahili": "Niger-Congo",
|
| 198 |
+
"Swedish": "Germanic",
|
| 199 |
+
"Tagalog": "Austronesian",
|
| 200 |
+
"Tajik": "Indo-European",
|
| 201 |
+
"Tamil": "Dravidian",
|
| 202 |
+
"Tatar": "Turkic",
|
| 203 |
+
"Telugu": "Dravidian",
|
| 204 |
+
"Turkish": "Turkic",
|
| 205 |
+
"Ukrainian": "Indo-European",
|
| 206 |
+
"Urdu": "Indo-European",
|
| 207 |
+
"Uzbek": "Turkic",
|
| 208 |
+
"Vietnamese": "Austroasiatic",
|
| 209 |
+
"Volapük": "Constructed",
|
| 210 |
+
"Waray-Waray": "Austronesian",
|
| 211 |
+
"Welsh": "Indo-European",
|
| 212 |
+
"West Frisian": "Germanic",
|
| 213 |
+
"Western Punjabi": "Indo-European",
|
| 214 |
+
"Yoruba": "Niger-Congo",
|
| 215 |
+
"Esperanto": "Constructed",
|
| 216 |
+
"Crimean Tatar": "Turkic"
|
| 217 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==5.23.3
|
| 2 |
+
networkx==3.4.2
|
| 3 |
+
umap-learn==0.5.7
|
utils.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import networkx as nx
|
| 2 |
+
from sklearn.cluster import HDBSCAN
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.manifold import TSNE
|
| 6 |
+
import umap
|
| 7 |
+
from sklearn.cluster import KMeans
|
| 8 |
+
from scipy.spatial import KDTree
|
| 9 |
+
from adjustText import adjust_text
|
| 10 |
+
from constants import language_families, language_subfamilies
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def filter_languages_by_families(matrix, languages, families):
|
| 14 |
+
"""
|
| 15 |
+
Filters the languages based on their families.
|
| 16 |
+
|
| 17 |
+
Parameters:
|
| 18 |
+
- languages: list of languages to filter.
|
| 19 |
+
- families: list of families to include.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
- filtered_languages: list of languages that belong to the specified families.
|
| 23 |
+
"""
|
| 24 |
+
filtered_languages = [(i, lang) for i, lang in enumerate(languages) if language_families[lang] in families]
|
| 25 |
+
filtered_indices = [i for i, lang in filtered_languages]
|
| 26 |
+
filtered_languages = [lang for i, lang in filtered_languages]
|
| 27 |
+
filtered_matrix = matrix[np.ix_(filtered_indices, filtered_indices)]
|
| 28 |
+
return filtered_matrix, filtered_languages
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_dynamic_color_map(n_colors):
|
| 32 |
+
"""
|
| 33 |
+
Generates a dynamic color map with the specified number of colors.
|
| 34 |
+
|
| 35 |
+
Parameters:
|
| 36 |
+
- n_colors: int, the number of distinct colors required.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
- color_map: list of RGB tuples representing the colors.
|
| 40 |
+
"""
|
| 41 |
+
cmap = plt.get_cmap("tab20") if n_colors <= 20 else plt.get_cmap("hsv")
|
| 42 |
+
color_map = [cmap(i / n_colors) for i in range(n_colors)]
|
| 43 |
+
return color_map
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def cluster_languages_by_families(languages):
|
| 47 |
+
lang_families = [language_families[lang] for lang in languages]
|
| 48 |
+
legend = sorted(set(lang_families))
|
| 49 |
+
clusters = [legend.index(family) for family in lang_families]
|
| 50 |
+
return clusters, legend
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def cluster_languages_by_subfamilies(languages):
|
| 54 |
+
labels = [language_families[lang] + f" ({language_subfamilies[lang]})" for lang in languages]
|
| 55 |
+
legend = sorted(set(labels))
|
| 56 |
+
clusters = [legend.index(family) for family in labels]
|
| 57 |
+
return clusters, legend
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def plot_mst(model, dataset, use_average, matrix, languages, clusters, legend=None, fig_size=(20,20)):
|
| 61 |
+
"""
|
| 62 |
+
Plots a Minimum Spanning Tree (MST) from a given distance matrix, node labels, and cluster assignments.
|
| 63 |
+
|
| 64 |
+
Parameters:
|
| 65 |
+
- dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between nodes.
|
| 66 |
+
- labels: list of length N containing the labels for each node.
|
| 67 |
+
- clusters: list of length N containing the cluster assignment (or ID) for each node.
|
| 68 |
+
"""
|
| 69 |
+
# Create an empty undirected graph
|
| 70 |
+
G = nx.Graph()
|
| 71 |
+
|
| 72 |
+
# Number of nodes
|
| 73 |
+
N = len(languages)
|
| 74 |
+
|
| 75 |
+
# Add edges to the graph from the distance matrix.
|
| 76 |
+
# Only iterate over the upper triangle of the matrix (i < j)
|
| 77 |
+
for i in range(N):
|
| 78 |
+
for j in range(i + 1, N):
|
| 79 |
+
G.add_edge(i, j, weight=matrix[i, j])
|
| 80 |
+
|
| 81 |
+
# Compute the Minimum Spanning Tree using NetworkX's built-in function.
|
| 82 |
+
mst = nx.minimum_spanning_tree(G)
|
| 83 |
+
|
| 84 |
+
# Choose a layout for the MST. Here we use Kamada-Kawai layout which considers edge weights.
|
| 85 |
+
pos = nx.kamada_kawai_layout(mst, weight='weight')
|
| 86 |
+
|
| 87 |
+
# Map each cluster to a color
|
| 88 |
+
unique_clusters = sorted(set(clusters))
|
| 89 |
+
cmap = get_dynamic_color_map(len(unique_clusters))
|
| 90 |
+
cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
|
| 91 |
+
|
| 92 |
+
node_colors = [cluster_colors.get(cluster) for cluster in clusters]
|
| 93 |
+
|
| 94 |
+
# Create a figure for plotting.
|
| 95 |
+
fig, ax = plt.subplots(figsize=fig_size)
|
| 96 |
+
|
| 97 |
+
# Draw the MST edges.
|
| 98 |
+
nx.draw_networkx_edges(mst, pos, edge_color='gray', ax=ax)
|
| 99 |
+
|
| 100 |
+
# Draw the nodes with colors corresponding to their clusters.
|
| 101 |
+
nx.draw_networkx_nodes(mst, pos, node_color=node_colors, node_size=100, ax=ax, alpha=0.7)
|
| 102 |
+
|
| 103 |
+
# Instead of directly drawing labels, we create text objects to adjust them later
|
| 104 |
+
texts = []
|
| 105 |
+
for i, label in enumerate(languages):
|
| 106 |
+
x, y = pos[i]
|
| 107 |
+
texts.append(ax.text(x, y, label, fontsize=10))
|
| 108 |
+
|
| 109 |
+
# Adjust text labels to minimize overlap.
|
| 110 |
+
# The arrowprops argument can draw arrows from labels to nodes if desired.
|
| 111 |
+
adjust_text(texts, expand_text=(1.05, 1.2))
|
| 112 |
+
|
| 113 |
+
# Add a legend for clusters
|
| 114 |
+
if legend is None:
|
| 115 |
+
legend = {cluster: str(cluster) for cluster in unique_clusters}
|
| 116 |
+
legend_handles = [
|
| 117 |
+
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, alpha=0.7, label=legend[cluster])
|
| 118 |
+
for cluster in unique_clusters
|
| 119 |
+
]
|
| 120 |
+
ax.legend(handles=legend_handles, title="Clusters", loc="best")
|
| 121 |
+
|
| 122 |
+
# Remove axis for clarity.
|
| 123 |
+
ax.axis('off')
|
| 124 |
+
# ax.set_title(f"Minimum Spanning Tree of Languages ({'Average' if use_average else f'{model}, {dataset}'})")
|
| 125 |
+
|
| 126 |
+
return fig
|
| 127 |
+
|
| 128 |
+
def cluster_languages_kmeans(dist_matrix, languages, n_clusters=5):
|
| 129 |
+
"""
|
| 130 |
+
Clusters languages using a distance matrix and KMeans.
|
| 131 |
+
|
| 132 |
+
Parameters:
|
| 133 |
+
- dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between languages.
|
| 134 |
+
- n_clusters: int, the number of clusters to form.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
- filtered_matrix: 2D NumPy array of the filtered distance matrix.
|
| 138 |
+
- filtered_languages: list of filtered languages.
|
| 139 |
+
- filtered_clusters: list of filtered cluster assignments.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
# Perform clustering using KMeans
|
| 143 |
+
kmeans_model = KMeans(n_clusters=n_clusters, random_state=23)
|
| 144 |
+
clusters = kmeans_model.fit_predict(dist_matrix)
|
| 145 |
+
|
| 146 |
+
# Count the number of elements in each cluster
|
| 147 |
+
cluster_counts = np.bincount(clusters)
|
| 148 |
+
|
| 149 |
+
# Identify clusters with more than 1 element
|
| 150 |
+
valid_clusters = np.where(cluster_counts > 1)[0]
|
| 151 |
+
|
| 152 |
+
# Filter out points belonging to clusters with only 1 element
|
| 153 |
+
valid_indices = np.isin(clusters, valid_clusters)
|
| 154 |
+
filtered_matrix = dist_matrix[np.ix_(valid_indices, valid_indices)]
|
| 155 |
+
filtered_languages = np.array(languages)[valid_indices]
|
| 156 |
+
filtered_clusters = np.array(clusters)[valid_indices]
|
| 157 |
+
|
| 158 |
+
return filtered_matrix, filtered_languages, filtered_clusters
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def cluster_languages_hdbscan(dist_matrix, languages, min_cluster_size=2):
|
| 162 |
+
"""
|
| 163 |
+
Clusters languages using a distance matrix and HDBSCAN.
|
| 164 |
+
|
| 165 |
+
Parameters:
|
| 166 |
+
- dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between languages.
|
| 167 |
+
- min_cluster_size: int, the minimum size of clusters.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
- clusters: list of length N containing the cluster assignment (or ID) for each language.
|
| 171 |
+
"""
|
| 172 |
+
# Perform clustering using HDBSCAN with the precomputed distance matrix
|
| 173 |
+
clustering_model = HDBSCAN(
|
| 174 |
+
metric='precomputed', min_cluster_size=min_cluster_size
|
| 175 |
+
)
|
| 176 |
+
clusters = clustering_model.fit_predict(dist_matrix)
|
| 177 |
+
|
| 178 |
+
# Filter out points belonging to cluster -1 using NumPy
|
| 179 |
+
valid_indices = np.where(clusters != -1)[0]
|
| 180 |
+
filtered_matrix = dist_matrix[np.ix_(valid_indices, valid_indices)]
|
| 181 |
+
filtered_languages = np.array(languages)[valid_indices]
|
| 182 |
+
filtered_clusters = np.array(clusters)[valid_indices]
|
| 183 |
+
return filtered_matrix, filtered_languages, filtered_clusters
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def plot_distances_tsne(model, dataset, use_average, matrix, languages, clusters, legend=None):
|
| 187 |
+
"""
|
| 188 |
+
Plots all languages from the distances matrix using t-SNE and colors them by clusters.
|
| 189 |
+
"""
|
| 190 |
+
tsne = TSNE(n_components=2, random_state=23, metric="precomputed", init="random")
|
| 191 |
+
tsne_results = tsne.fit_transform(matrix)
|
| 192 |
+
|
| 193 |
+
# Map each cluster to a color
|
| 194 |
+
unique_clusters = sorted(set(clusters))
|
| 195 |
+
cmap = get_dynamic_color_map(len(unique_clusters))
|
| 196 |
+
cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
|
| 197 |
+
|
| 198 |
+
fig, ax = plt.subplots(figsize=(16, 12))
|
| 199 |
+
scatter = ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=[cluster_colors[cluster] for cluster in clusters], alpha=0.7)
|
| 200 |
+
|
| 201 |
+
# for i, lang in enumerate(languages):
|
| 202 |
+
# ax.text(tsne_results[i, 0], tsne_results[i, 1], lang, fontsize=8, alpha=0.8)
|
| 203 |
+
|
| 204 |
+
# Instead of directly drawing labels, we create text objects to adjust them later
|
| 205 |
+
texts = []
|
| 206 |
+
for i, label in enumerate(languages):
|
| 207 |
+
x, y = tsne_results[i, 0], tsne_results[i, 1]
|
| 208 |
+
texts.append(ax.text(x, y, label, fontsize=10))
|
| 209 |
+
|
| 210 |
+
# Adjust text labels to minimize overlap.
|
| 211 |
+
# The arrowprops argument can draw arrows from labels to nodes if desired.
|
| 212 |
+
adjust_text(texts, expand_text=(1.05, 1.2))
|
| 213 |
+
|
| 214 |
+
# Add a legend for clusters
|
| 215 |
+
if legend is None:
|
| 216 |
+
legend = {cluster: str(cluster) for cluster in unique_clusters}
|
| 217 |
+
legend_handles = [
|
| 218 |
+
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, label=legend[cluster])
|
| 219 |
+
for cluster in unique_clusters
|
| 220 |
+
]
|
| 221 |
+
ax.legend(handles=legend_handles, title="Clusters", loc="best")
|
| 222 |
+
|
| 223 |
+
ax.set_title(f"t-SNE Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})")
|
| 224 |
+
ax.set_xlabel("t-SNE Dimension 1")
|
| 225 |
+
ax.set_ylabel("t-SNE Dimension 2")
|
| 226 |
+
return fig
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def plot_distances_umap(model, dataset, use_average, matrix, languages, clusters, legend=None):
|
| 230 |
+
"""
|
| 231 |
+
Plots all languages from the distances matrix using UMAP and colors them by clusters.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
umap_model = umap.UMAP(metric="precomputed", random_state=23)
|
| 235 |
+
umap_results = umap_model.fit_transform(matrix)
|
| 236 |
+
|
| 237 |
+
# Map each cluster to a color
|
| 238 |
+
unique_clusters = sorted(set(clusters))
|
| 239 |
+
cmap = get_dynamic_color_map(len(unique_clusters))
|
| 240 |
+
cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
|
| 241 |
+
|
| 242 |
+
fig, ax = plt.subplots(figsize=(16, 12))
|
| 243 |
+
scatter = ax.scatter(umap_results[:, 0], umap_results[:, 1], c=[cluster_colors[cluster] for cluster in clusters], alpha=0.7)
|
| 244 |
+
|
| 245 |
+
# for i, lang in enumerate(languages):
|
| 246 |
+
# ax.text(umap_results[i, 0], umap_results[i, 1], lang, fontsize=8, alpha=0.8)
|
| 247 |
+
|
| 248 |
+
# Instead of directly drawing labels, we create text objects to adjust them later
|
| 249 |
+
texts = []
|
| 250 |
+
for i, label in enumerate(languages):
|
| 251 |
+
x, y = umap_results[i, 0], umap_results[i, 1]
|
| 252 |
+
texts.append(ax.text(x, y, label, fontsize=10))
|
| 253 |
+
|
| 254 |
+
# Adjust text labels to minimize overlap.
|
| 255 |
+
# The arrowprops argument can draw arrows from labels to nodes if desired.
|
| 256 |
+
adjust_text(texts, expand_text=(1.05, 1.2))
|
| 257 |
+
|
| 258 |
+
# Add a legend for clusters
|
| 259 |
+
if legend is None:
|
| 260 |
+
legend = {cluster: str(cluster) for cluster in unique_clusters}
|
| 261 |
+
legend_handles = [
|
| 262 |
+
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, label=legend[cluster])
|
| 263 |
+
for cluster in unique_clusters
|
| 264 |
+
]
|
| 265 |
+
ax.legend(handles=legend_handles, title="Clusters", loc="best")
|
| 266 |
+
|
| 267 |
+
ax.set_title(f"UMAP Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})")
|
| 268 |
+
ax.set_xlabel("UMAP Dimension 1")
|
| 269 |
+
ax.set_ylabel("UMAP Dimension 2")
|
| 270 |
+
return fig
|