|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import marimo |
|
|
|
__generated_with = "0.9.33" |
|
app = marimo.App(width="medium") |
|
|
|
|
|
@app.cell |
|
def __(): |
|
import marimo as mo |
|
|
|
return (mo,) |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(mo): |
|
mo.md( |
|
r""" |
|
# Visualizing text embeddings using MotherDuck and marimo |
|
|
|
> Text embeddings have become a crucial tool in AI/ML applications, allowing us to convert text into numerical vectors that capture semantic meaning. These vectors are often used for semantic search, but in ~~this blog post~~ marimo app, we'll explore how to visualize and explore text embeddings interactively using MotherDuck and marimo. |
|
|
|
This app lets you visualize and explore text embeddings from Hacker News posts about **databases**. You can: |
|
|
|
- See how different posts cluster together based on semantic similarity |
|
- Adjust clustering parameters in real-time |
|
- Explore relationships between posts through an interactive visualization |
|
|
|
!!! Info |
|
**This marimo application based on [this blog](https://motherduck.com/blog/MotherDuck-Visualize-Embeddings-Marimo/).** We recommend looking through the blog first. |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(mo): |
|
mo.md( |
|
""" |
|
## Connecting to MotherDuck and Loading Sample Data |
|
|
|
This data has already been pre-computed, but you can fork and edit this notebook to run with your own data! |
|
|
|
```sql |
|
ATTACH IF NOT EXISTS 'md:my_db' |
|
SELECT * FROM my_db.demo_with_embeddings; |
|
``` |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
_df = mo.sql( |
|
""" |
|
ATTACH IF NOT EXISTS 'md:my_db' |
|
""" |
|
) |
|
return (my_db,) |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
_df = mo.sql( |
|
""" |
|
-- Commented out as we have already run the embeddings for showcasing purposes. |
|
|
|
-- CREATE OR REPLACE TABLE my_db.demo_embedding_data AS |
|
-- SELECT DISTINCT ON (url) * -- Remove duplicate URLs |
|
-- FROM 'hf://datasets/julien040/hacker-news-posts/story.parquet' |
|
-- WHERE contains(title, 'database') -- Filter for posts about databases |
|
-- AND score > 5 -- Only include popular posts |
|
-- LIMIT 50000; |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(demo_with_embeddings, mo, my_db): |
|
embeddings = mo.sql( |
|
f""" |
|
-- Commented out as we have already run the embeddings for showcasing purposes. |
|
-- CREATE TABLE my_db.demo_with_embeddings AS |
|
-- SELECT *, embedding(title) as text_embedding |
|
-- FROM my_db.demo_embedding_data |
|
-- LIMIT 1500; |
|
|
|
SELECT title, text_embedding, * EXCLUDE(id, title, text_embedding, comments) FROM my_db.demo_with_embeddings; |
|
""" |
|
) |
|
return (embeddings,) |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
mo.md( |
|
""" |
|
## Making Sense of High-Dimensional Data |
|
|
|
Text embeddings typically have hundreds of dimensions (512 in our case), making them impossible to visualize directly. We'll use two techniques to make them interpretable: |
|
|
|
1. **Dimensionality Reduction**: Convert our 512D vectors into 2D points while preserving relationships between texts |
|
2. **Clustering**: Group similar texts together into clusters |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(cluster_points, mo, reduce_dimensions): |
|
def md_help(cls): |
|
import inspect |
|
|
|
return f"def {cls.__name__} {inspect.signature(cls)}:\n {cls.__doc__}" |
|
|
|
mo.accordion( |
|
{ |
|
"`reduce_dimensions`": md_help(reduce_dimensions), |
|
"`cluster_points`": md_help(cluster_points), |
|
} |
|
) |
|
return (md_help,) |
|
|
|
|
|
@app.cell |
|
def __(np): |
|
def reduce_dimensions(np_array, metric="cosine"): |
|
""" |
|
Reduce the dimensions of embeddings to a 2D space. |
|
|
|
Here we use the UMAP algorithm. UMAP preserves both local and |
|
global structure of the high-dimensional data. |
|
""" |
|
import umap |
|
|
|
reducer = umap.UMAP( |
|
n_components=2, |
|
metric=metric, |
|
n_neighbors=80, |
|
min_dist=0.1, |
|
) |
|
return reducer.fit_transform(np_array) |
|
|
|
def cluster_points(np_array, min_cluster_size=4, max_cluster_size=50): |
|
""" |
|
Cluster the embeddings. |
|
|
|
|
|
Here we use the HDBSCAN algorithm. We first reduce dimensionality to 50D with |
|
PCA to speed up clustering, while still preserving most of the important information. |
|
""" |
|
import hdbscan |
|
from sklearn.decomposition import PCA |
|
|
|
pca = PCA(n_components=50) |
|
np_array = pca.fit_transform(np_array) |
|
|
|
hdb = hdbscan.HDBSCAN( |
|
min_samples=3, |
|
min_cluster_size=min_cluster_size, |
|
max_cluster_size=max_cluster_size, |
|
).fit(np_array) |
|
|
|
return np.where( |
|
hdb.labels_ == -1, "outlier", "cluster_" + hdb.labels_.astype(str) |
|
) |
|
|
|
return cluster_points, reduce_dimensions |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
cluster_size_slider = mo.ui.range_slider( |
|
start=1, |
|
stop=80, |
|
value=(4, 50), |
|
step=1, |
|
show_value=True, |
|
debounce=True, |
|
label="Cluster Size (min, max)", |
|
) |
|
metric_dropdown = mo.ui.dropdown( |
|
["cosine", "euclidean", "manhattan"], |
|
value="cosine", |
|
label="Distance Metric", |
|
) |
|
return cluster_size_slider, metric_dropdown |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
mo.md( |
|
r""" |
|
## Processing the Data |
|
|
|
Now we'll transform our high-dimensional embeddings into something we can visualize, using `reduce_dimensions` and `cluster_points`. More details on this step [in the blog](https://motherduck.com/blog/MotherDuck-Visualize-Embeddings-Marimo/). |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __( |
|
cluster_points, |
|
cluster_size_slider, |
|
embeddings, |
|
metric_dropdown, |
|
mo, |
|
reduce_dimensions, |
|
): |
|
with mo.status.spinner("Clustering points...") as _s: |
|
import numba |
|
|
|
embeddings_array = embeddings["text_embedding"].to_numpy() |
|
hdb_labels = cluster_points( |
|
embeddings_array, |
|
min_cluster_size=cluster_size_slider.value[0], |
|
max_cluster_size=cluster_size_slider.value[1], |
|
) |
|
_s.update("Reducing dimensionality...") |
|
embeddings_2d = reduce_dimensions( |
|
embeddings_array, metric=metric_dropdown.value |
|
) |
|
mo.show_code() |
|
return embeddings_2d, embeddings_array, hdb_labels, numba |
|
|
|
|
|
@app.cell |
|
def __(cluster_size_slider, metric_dropdown, mo): |
|
mo.hstack([cluster_size_slider, metric_dropdown]) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(embeddings, embeddings_2d, hdb_labels, pl): |
|
data = embeddings.lazy() |
|
data = data.with_columns( |
|
text_embedding_2d_1=embeddings_2d[:, 0], |
|
text_embedding_2d_2=embeddings_2d[:, 1], |
|
cluster=hdb_labels, |
|
) |
|
data = data.unique(subset=["url"], maintain_order=True) |
|
data = data.drop(["text_embedding"]) |
|
data = data.filter(pl.col("cluster") != "outlier") |
|
data = data.collect() |
|
return (data,) |
|
|
|
|
|
@app.cell |
|
def __(data): |
|
data.select( |
|
"title", "cluster", "text_embedding_2d_1", "text_embedding_2d_2", "score" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(alt, data, mo): |
|
chart = ( |
|
alt.Chart(data) |
|
.mark_point() |
|
.encode( |
|
x=alt.X("text_embedding_2d_1").scale(zero=False), |
|
y=alt.Y("text_embedding_2d_2").scale(zero=False), |
|
color="cluster", |
|
tooltip=["title", "score", "cluster"], |
|
) |
|
) |
|
chart = mo.ui.altair_chart(chart) |
|
mo.show_code() |
|
return (chart,) |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(mo): |
|
mo.md( |
|
r""" |
|
## Creating an Interactive Visualization |
|
|
|
We will plot the 2D representation of the text embeddings, colored by the clusters identified by HDBSCAN. You can select points on the chart to explore the text embeddings further. 👇 |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(chart): |
|
chart |
|
return |
|
|
|
|
|
@app.cell |
|
def __(chart): |
|
chart.value |
|
return |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
|
|
mo.Html("<div style='height: 400px;'></div>") |
|
return |
|
|
|
|
|
@app.cell |
|
def __(): |
|
|
|
import polars as pl |
|
import duckdb |
|
import pyarrow |
|
|
|
|
|
import altair as alt |
|
|
|
|
|
import numpy as np |
|
|
|
return alt, duckdb, np, pl, pyarrow |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run() |
|
|