Spaces:
Runtime error
Runtime error
Trent
commited on
Commit
·
883e41e
1
Parent(s):
75c3a89
Clustering function
Browse files- app.py +5 -1
- backend/inference.py +68 -1
- requirements.txt +2 -0
app.py
CHANGED
@@ -118,4 +118,8 @@ For more cool information on sentence embeddings, see the [sBert project](https:
|
|
118 |
|
119 |
if st.button('Give me my search.'):
|
120 |
results = {model: inference.text_search(anchor, n_texts, model, QA_MODELS_ID) for model in select_models}
|
121 |
-
st.table(pd.DataFrame(results[select_models[0]]).T)
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
if st.button('Give me my search.'):
|
120 |
results = {model: inference.text_search(anchor, n_texts, model, QA_MODELS_ID) for model in select_models}
|
121 |
+
st.table(pd.DataFrame(results[select_models[0]]).T)
|
122 |
+
|
123 |
+
if st.button('3D Clustering of search result (new window)'):
|
124 |
+
fig = inference.text_cluster(anchor, 1000, select_models[0], QA_MODELS_ID)
|
125 |
+
fig.show()
|
backend/inference.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gzip
|
2 |
import json
|
|
|
3 |
|
4 |
import pandas as pd
|
5 |
import numpy as np
|
@@ -11,7 +12,7 @@ from typing import List, Union
|
|
11 |
import torch
|
12 |
|
13 |
from backend.utils import load_model, filter_questions, load_embeddings
|
14 |
-
|
15 |
|
16 |
def cos_sim(a, b):
|
17 |
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
|
@@ -71,3 +72,69 @@ def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
|
71 |
urls.append(f"https://stackoverflow.com/q/{post['id']}")
|
72 |
|
73 |
return hits_titles, hits_scores, urls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gzip
|
2 |
import json
|
3 |
+
from collections import Counter
|
4 |
|
5 |
import pandas as pd
|
6 |
import numpy as np
|
|
|
12 |
import torch
|
13 |
|
14 |
from backend.utils import load_model, filter_questions, load_embeddings
|
15 |
+
from MulticoreTSNE import MulticoreTSNE as TSNE
|
16 |
|
17 |
def cos_sim(a, b):
|
18 |
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
|
|
|
72 |
urls.append(f"https://stackoverflow.com/q/{post['id']}")
|
73 |
|
74 |
return hits_titles, hits_scores, urls
|
75 |
+
|
76 |
+
|
77 |
+
def text_cluster(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
78 |
+
# Proceeding with model
|
79 |
+
print(model_name)
|
80 |
+
assert model_name == "mpnet_qa"
|
81 |
+
model = load_model(model_name, model_dict)
|
82 |
+
|
83 |
+
# Creating embeddings
|
84 |
+
query_emb = model.encode(anchor, convert_to_tensor=True)[None, :]
|
85 |
+
|
86 |
+
print("loading embeddings")
|
87 |
+
corpus_emb = load_embeddings()
|
88 |
+
|
89 |
+
# Getting hits
|
90 |
+
hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0]
|
91 |
+
|
92 |
+
filtered_posts = filter_questions("python")
|
93 |
+
|
94 |
+
hits_dict = [filtered_posts[hit['corpus_id']] for hit in hits]
|
95 |
+
hits_dict.append(dict(id = '1', title = anchor, tags = ['']))
|
96 |
+
|
97 |
+
hits_emb = torch.stack([corpus_emb[hit['corpus_id']] for hit in hits])
|
98 |
+
hits_emb = torch.cat((hits_emb, query_emb))
|
99 |
+
|
100 |
+
# Dimensionality reduction with t-SNE
|
101 |
+
tsne = TSNE(n_components=3, verbose=1, perplexity=15, n_iter=1000)
|
102 |
+
tsne_results = tsne.fit_transform(hits_emb.cpu())
|
103 |
+
df = pd.DataFrame(hits_dict)
|
104 |
+
tags = list(df['tags'])
|
105 |
+
|
106 |
+
counter = Counter(tags[0])
|
107 |
+
for i in tags[1:]:
|
108 |
+
counter.update(i)
|
109 |
+
|
110 |
+
df_tags = pd.DataFrame(counter.most_common(), columns=['Tag', 'Mentions'])
|
111 |
+
most_common_tags = list(df_tags['Tag'])[1:5]
|
112 |
+
|
113 |
+
labels = []
|
114 |
+
|
115 |
+
for tags_list in list(df['tags']):
|
116 |
+
for common_tag in most_common_tags:
|
117 |
+
if common_tag in tags_list:
|
118 |
+
labels.append(common_tag)
|
119 |
+
break
|
120 |
+
elif common_tag != most_common_tags[-1]:
|
121 |
+
continue
|
122 |
+
else:
|
123 |
+
labels.append('others')
|
124 |
+
|
125 |
+
df['title'] = [post['title'] for post in hits_dict]
|
126 |
+
df['labels'] = labels
|
127 |
+
df['tsne_x'] = tsne_results[:, 0]
|
128 |
+
df['tsne_y'] = tsne_results[:, 1]
|
129 |
+
df['tsne_z'] = tsne_results[:, 2]
|
130 |
+
|
131 |
+
df['size'] = [2 for i in range(len(df))]
|
132 |
+
|
133 |
+
# Making the query bigger than the rest of the observations
|
134 |
+
df['size'][len(df) - 1] = 10
|
135 |
+
df['labels'][len(df) - 1] = 'QUERY'
|
136 |
+
import plotly.express as px
|
137 |
+
|
138 |
+
fig = px.scatter_3d(df, x='tsne_x', y='tsne_y', z='tsne_z', color='labels', size='size',
|
139 |
+
color_discrete_sequence=px.colors.qualitative.D3, hover_data=[df.title])
|
140 |
+
return fig
|
requirements.txt
CHANGED
@@ -5,3 +5,5 @@ jaxlib
|
|
5 |
streamlit
|
6 |
numpy
|
7 |
torch
|
|
|
|
|
|
5 |
streamlit
|
6 |
numpy
|
7 |
torch
|
8 |
+
MulticoreTSNE
|
9 |
+
plotly
|