| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from sklearn.cluster import MeanShift, estimate_bandwidth | |
| from sklearn.datasets import make_blobs | |
| def get_clusters_plot(n_blobs, cluster_std): | |
| X, _, centers = make_blobs( | |
| n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True | |
| ) | |
| bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500) | |
| ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) | |
| ms.fit(X) | |
| labels = ms.labels_ | |
| cluster_centers = ms.cluster_centers_ | |
| labels_unique = np.unique(labels) | |
| n_clusters_ = len(labels_unique) | |
| colors = ["#dede00", "#377eb8", "#f781bf"] | |
| markers = ["x", "o", "^"] | |
| fig = plt.figure() | |
| for k, col in zip(range(n_clusters_), colors): | |
| my_members = labels == k | |
| cluster_center = cluster_centers[k] | |
| plt.plot(X[my_members, 0], X[my_members, 1], markers[k], color=col) | |
| plt.plot( | |
| cluster_center[0], | |
| cluster_center[1], | |
| markers[k], | |
| markerfacecolor=col, | |
| markeredgecolor="k", | |
| markersize=14, | |
| ) | |
| return fig | |
| demo = gr.Interface( | |
| get_clusters_plot, | |
| [ | |
| gr.Slider( | |
| minimum=2, maximum=10, label="Number of clusters in data", step=1, value=3 | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1, | |
| label="Cluster standard deviation", | |
| step=0.1, | |
| value=0.6, | |
| ), | |
| ], | |
| gr.Plot(), | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |