|
""" |
|
This is a more complex example on performing clustering on large scale dataset. |
|
|
|
This examples find in a large set of sentences local communities, i.e., groups of sentences that are highly |
|
similar. You can freely configure the threshold what is considered as similar. A high threshold will |
|
only find extremely similar sentences, a lower threshold will find more sentence that are less similar. |
|
|
|
A second parameter is 'min_community_size': Only communities with at least a certain number of sentences will be returned. |
|
|
|
The method for finding the communities is extremely fast, for clustering 50k sentences it requires only 5 seconds (plus embedding comuptation). |
|
|
|
In this example, we download a large set of questions from Quora and then find similar questions in this set. |
|
""" |
|
from sentence_transformers import SentenceTransformer, util |
|
import os |
|
import csv |
|
import time |
|
|
|
|
|
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
|
|
url = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv" |
|
dataset_path = "quora_duplicate_questions.tsv" |
|
max_corpus_size = 50000 |
|
|
|
|
|
|
|
|
|
if not os.path.exists(dataset_path): |
|
print("Download dataset") |
|
util.http_get(url, dataset_path) |
|
|
|
|
|
corpus_sentences = set() |
|
with open(dataset_path, encoding='utf8') as fIn: |
|
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_MINIMAL) |
|
for row in reader: |
|
corpus_sentences.add(row['question1']) |
|
corpus_sentences.add(row['question2']) |
|
if len(corpus_sentences) >= max_corpus_size: |
|
break |
|
|
|
corpus_sentences = list(corpus_sentences) |
|
print("Encode the corpus. This might take a while") |
|
corpus_embeddings = model.encode(corpus_sentences, batch_size=64, show_progress_bar=True, convert_to_tensor=True) |
|
|
|
|
|
print("Start clustering") |
|
start_time = time.time() |
|
|
|
|
|
|
|
|
|
clusters = util.community_detection(corpus_embeddings, min_community_size=25, threshold=0.75) |
|
|
|
print("Clustering done after {:.2f} sec".format(time.time() - start_time)) |
|
|
|
|
|
for i, cluster in enumerate(clusters): |
|
print("\nCluster {}, #{} Elements ".format(i+1, len(cluster))) |
|
for sentence_id in cluster[0:3]: |
|
print("\t", corpus_sentences[sentence_id]) |
|
print("\t", "...") |
|
for sentence_id in cluster[-3:]: |
|
print("\t", corpus_sentences[sentence_id]) |
|
|
|
|
|
|
|
|