mylessss's picture
update
5be1c1d
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "altair==5.4.1",
# "duckdb==1.1.3",
# "hdbscan==0.8.39",
# "marimo",
# "numba==0.60.0",
# "numpy==2.0.2",
# "polars==1.17.1",
# "pyarrow==18.0.0",
# "scikit-learn==1.5.2",
# "umap-learn==0.5.7",
# ]
# ///
import marimo
__generated_with = "0.9.33"
app = marimo.App(width="medium")
@app.cell
def __():
import marimo as mo
return (mo,)
@app.cell(hide_code=True)
def __(mo):
mo.md(
r"""
# Visualizing text embeddings using MotherDuck and marimo
> Text embeddings have become a crucial tool in AI/ML applications, allowing us to convert text into numerical vectors that capture semantic meaning. These vectors are often used for semantic search, but in ~~this blog post~~ marimo app, we'll explore how to visualize and explore text embeddings interactively using MotherDuck and marimo.
This app lets you visualize and explore text embeddings from Hacker News posts about **databases**. You can:
- See how different posts cluster together based on semantic similarity
- Adjust clustering parameters in real-time
- Explore relationships between posts through an interactive visualization
!!! Info
**This marimo application based on [this blog](https://motherduck.com/blog/MotherDuck-Visualize-Embeddings-Marimo/).** We recommend looking through the blog first.
"""
)
return
@app.cell(hide_code=True)
def __(mo):
mo.md(
"""
## Connecting to MotherDuck and Loading Sample Data
This data has already been pre-computed, but you can fork and edit this notebook to run with your own data!
```sql
ATTACH IF NOT EXISTS 'md:my_db'
SELECT * FROM my_db.demo_with_embeddings;
```
"""
)
return
@app.cell
def __(mo):
_df = mo.sql(
"""
ATTACH IF NOT EXISTS 'md:my_db'
"""
)
return (my_db,)
@app.cell
def __(mo):
_df = mo.sql(
"""
-- Commented out as we have already run the embeddings for showcasing purposes.
-- CREATE OR REPLACE TABLE my_db.demo_embedding_data AS
-- SELECT DISTINCT ON (url) * -- Remove duplicate URLs
-- FROM 'hf://datasets/julien040/hacker-news-posts/story.parquet'
-- WHERE contains(title, 'database') -- Filter for posts about databases
-- AND score > 5 -- Only include popular posts
-- LIMIT 50000;
"""
)
return
@app.cell
def __(demo_with_embeddings, mo, my_db):
embeddings = mo.sql(
f"""
-- Commented out as we have already run the embeddings for showcasing purposes.
-- CREATE TABLE my_db.demo_with_embeddings AS
-- SELECT *, embedding(title) as text_embedding
-- FROM my_db.demo_embedding_data
-- LIMIT 1500;
SELECT title, text_embedding, * EXCLUDE(id, title, text_embedding, comments) FROM my_db.demo_with_embeddings;
"""
)
return (embeddings,)
@app.cell
def __(mo):
mo.md(
"""
## Making Sense of High-Dimensional Data
Text embeddings typically have hundreds of dimensions (512 in our case), making them impossible to visualize directly. We'll use two techniques to make them interpretable:
1. **Dimensionality Reduction**: Convert our 512D vectors into 2D points while preserving relationships between texts
2. **Clustering**: Group similar texts together into clusters
"""
)
return
@app.cell(hide_code=True)
def __(cluster_points, mo, reduce_dimensions):
def md_help(cls):
import inspect
return f"def {cls.__name__} {inspect.signature(cls)}:\n {cls.__doc__}"
mo.accordion(
{
"`reduce_dimensions`": md_help(reduce_dimensions),
"`cluster_points`": md_help(cluster_points),
}
)
return (md_help,)
@app.cell
def __(np):
def reduce_dimensions(np_array, metric="cosine"):
"""
Reduce the dimensions of embeddings to a 2D space.
Here we use the UMAP algorithm. UMAP preserves both local and
global structure of the high-dimensional data.
"""
import umap
reducer = umap.UMAP(
n_components=2, # Reduce to 2D for visualization
metric=metric, # Default: cosine similarity for text embeddings
n_neighbors=80, # Higher values = more global structure
min_dist=0.1, # Controls how tightly points cluster
)
return reducer.fit_transform(np_array)
def cluster_points(np_array, min_cluster_size=4, max_cluster_size=50):
"""
Cluster the embeddings.
Here we use the HDBSCAN algorithm. We first reduce dimensionality to 50D with
PCA to speed up clustering, while still preserving most of the important information.
"""
import hdbscan
from sklearn.decomposition import PCA
pca = PCA(n_components=50)
np_array = pca.fit_transform(np_array)
hdb = hdbscan.HDBSCAN(
min_samples=3, # Minimum points to form dense region
min_cluster_size=min_cluster_size, # Minimum size of a cluster
max_cluster_size=max_cluster_size, # Maximum size of a cluster
).fit(np_array)
return np.where(
hdb.labels_ == -1, "outlier", "cluster_" + hdb.labels_.astype(str)
)
return cluster_points, reduce_dimensions
@app.cell
def __(mo):
cluster_size_slider = mo.ui.range_slider(
start=1,
stop=80,
value=(4, 50),
step=1,
show_value=True,
debounce=True,
label="Cluster Size (min, max)",
)
metric_dropdown = mo.ui.dropdown(
["cosine", "euclidean", "manhattan"],
value="cosine",
label="Distance Metric",
)
return cluster_size_slider, metric_dropdown
@app.cell
def __(mo):
mo.md(
r"""
## Processing the Data
Now we'll transform our high-dimensional embeddings into something we can visualize, using `reduce_dimensions` and `cluster_points`. More details on this step [in the blog](https://motherduck.com/blog/MotherDuck-Visualize-Embeddings-Marimo/).
"""
)
return
@app.cell
def __(
cluster_points,
cluster_size_slider,
embeddings,
metric_dropdown,
mo,
reduce_dimensions,
):
with mo.status.spinner("Clustering points...") as _s:
import numba
embeddings_array = embeddings["text_embedding"].to_numpy()
hdb_labels = cluster_points(
embeddings_array,
min_cluster_size=cluster_size_slider.value[0],
max_cluster_size=cluster_size_slider.value[1],
)
_s.update("Reducing dimensionality...")
embeddings_2d = reduce_dimensions(
embeddings_array, metric=metric_dropdown.value
)
mo.show_code()
return embeddings_2d, embeddings_array, hdb_labels, numba
@app.cell
def __(cluster_size_slider, metric_dropdown, mo):
mo.hstack([cluster_size_slider, metric_dropdown])
return
@app.cell
def __(embeddings, embeddings_2d, hdb_labels, pl):
data = embeddings.lazy() # Lazy evaluation for performance
data = data.with_columns(
text_embedding_2d_1=embeddings_2d[:, 0],
text_embedding_2d_2=embeddings_2d[:, 1],
cluster=hdb_labels,
)
data = data.unique(subset=["url"], maintain_order=True) # Remove duplicate URLs
data = data.drop(["text_embedding"]) # Drop unused columns
data = data.filter(pl.col("cluster") != "outlier") # Filter out outliers
data = data.collect() # Collect the data
return (data,)
@app.cell
def __(data):
data.select(
"title", "cluster", "text_embedding_2d_1", "text_embedding_2d_2", "score"
)
return
@app.cell
def __(alt, data, mo):
chart = (
alt.Chart(data)
.mark_point()
.encode(
x=alt.X("text_embedding_2d_1").scale(zero=False),
y=alt.Y("text_embedding_2d_2").scale(zero=False),
color="cluster",
tooltip=["title", "score", "cluster"],
)
)
chart = mo.ui.altair_chart(chart)
mo.show_code()
return (chart,)
@app.cell(hide_code=True)
def __(mo):
mo.md(
r"""
## Creating an Interactive Visualization
We will plot the 2D representation of the text embeddings, colored by the clusters identified by HDBSCAN. You can select points on the chart to explore the text embeddings further. 👇
"""
)
return
@app.cell
def __(chart):
chart
return
@app.cell
def __(chart):
chart.value
return
@app.cell
def __(mo):
# Empty space for the table
mo.Html("<div style='height: 400px;'></div>")
return
@app.cell
def __():
# Data manipulation and database connections
import polars as pl
import duckdb
import pyarrow
# Visualization
import altair as alt
# ML tools for dimensionality reduction and clustering
import numpy as np
return alt, duckdb, np, pl, pyarrow
if __name__ == "__main__":
app.run()