File size: 1,726 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
"""
This is a simple application for sentence embeddings: clustering
Sentences are mapped to sentence embeddings and then agglomerative clustering with a threshold is applied.
"""
from sentence_transformers import SentenceTransformer
from sklearn.cluster import AgglomerativeClustering
import numpy as np
embedder = SentenceTransformer('all-MiniLM-L6-v2')
# Corpus with example sentences
corpus = ['A man is eating food.',
'A man is eating a piece of bread.',
'A man is eating pasta.',
'The girl is carrying a baby.',
'The baby is carried by the woman',
'A man is riding a horse.',
'A man is riding a white horse on an enclosed ground.',
'A monkey is playing drums.',
'Someone in a gorilla costume is playing a set of drums.',
'A cheetah is running behind its prey.',
'A cheetah chases prey on across a field.'
]
corpus_embeddings = embedder.encode(corpus)
# Normalize the embeddings to unit length
corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1, keepdims=True)
# Perform kmean clustering
clustering_model = AgglomerativeClustering(n_clusters=None, distance_threshold=1.5) #, affinity='cosine', linkage='average', distance_threshold=0.4)
clustering_model.fit(corpus_embeddings)
cluster_assignment = clustering_model.labels_
clustered_sentences = {}
for sentence_id, cluster_id in enumerate(cluster_assignment):
if cluster_id not in clustered_sentences:
clustered_sentences[cluster_id] = []
clustered_sentences[cluster_id].append(corpus[sentence_id])
for i, cluster in clustered_sentences.items():
print("Cluster ", i+1)
print(cluster)
print("")
|