Spaces:
Runtime error
Runtime error
Trent
commited on
Commit
·
a41bdbc
1
Parent(s):
49438d6
Multi model select and local model loading
Browse files- __init__.py +0 -0
- app.py +12 -30
- backend/__init__.py +0 -0
- backend/config.py +1 -0
- backend/inference.py +9 -20
- backend/main.py +0 -19
- backend/utils.py +11 -0
- requirements.txt +1 -1
__init__.py
ADDED
File without changes
|
app.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
-
|
4 |
-
import
|
|
|
5 |
|
6 |
st.title('Demo using Flax-Sentence-Tranformers')
|
7 |
|
@@ -20,12 +21,12 @@ For more cool information on sentence embeddings, see the [sBert project](https:
|
|
20 |
Please enjoy!!
|
21 |
''')
|
22 |
|
23 |
-
|
24 |
anchor = st.text_input(
|
25 |
'Please enter here the main text you want to compare:'
|
26 |
)
|
27 |
|
28 |
if anchor:
|
|
|
29 |
n_texts = st.sidebar.number_input(
|
30 |
f'''How many texts you want to compare with: '{anchor}'?''',
|
31 |
value=2,
|
@@ -34,40 +35,21 @@ if anchor:
|
|
34 |
inputs = []
|
35 |
|
36 |
for i in range(n_texts):
|
37 |
-
|
38 |
-
input = st.sidebar.text_input(f'Text {i+1}:')
|
39 |
|
40 |
inputs.append(input)
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
api_base_url = 'http://127.0.0.1:8000/similarity'
|
45 |
-
|
46 |
if anchor:
|
47 |
if st.sidebar.button('Tell me the similarity.'):
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
inputs = inputs,
|
53 |
-
model = 'mpnet'))
|
54 |
-
res_minilm_l6 = requests.get(url = api_base_url, params = dict(anchor = anchor,
|
55 |
-
inputs = inputs,
|
56 |
-
model = 'minilm_l6'))
|
57 |
-
|
58 |
-
d_distilroberta = res_distilroberta.json()['dataframe']
|
59 |
-
d_mpnet = res_mpnet.json()['dataframe']
|
60 |
-
d_minilm_l6 = res_minilm_l6.json()['dataframe']
|
61 |
-
|
62 |
-
index = list(d_distilroberta['inputs'].values())
|
63 |
df_total = pd.DataFrame(index=index)
|
64 |
-
|
65 |
-
|
66 |
-
df_total['minilm_l6'] = list(d_minilm_l6['score'].values())
|
67 |
|
68 |
-
st.write('Here are the results for
|
69 |
st.write(df_total)
|
70 |
st.write('Visualize the results of each model:')
|
71 |
st.area_chart(df_total)
|
72 |
-
|
73 |
-
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
+
|
4 |
+
from backend import inference
|
5 |
+
from backend.config import MODELS_ID
|
6 |
|
7 |
st.title('Demo using Flax-Sentence-Tranformers')
|
8 |
|
|
|
21 |
Please enjoy!!
|
22 |
''')
|
23 |
|
|
|
24 |
anchor = st.text_input(
|
25 |
'Please enter here the main text you want to compare:'
|
26 |
)
|
27 |
|
28 |
if anchor:
|
29 |
+
select_models = st.sidebar.multiselect("Choose models", options=MODELS_ID.keys())
|
30 |
n_texts = st.sidebar.number_input(
|
31 |
f'''How many texts you want to compare with: '{anchor}'?''',
|
32 |
value=2,
|
|
|
35 |
inputs = []
|
36 |
|
37 |
for i in range(n_texts):
|
38 |
+
input = st.sidebar.text_input(f'Text {i + 1}:')
|
|
|
39 |
|
40 |
inputs.append(input)
|
41 |
|
|
|
|
|
|
|
|
|
42 |
if anchor:
|
43 |
if st.sidebar.button('Tell me the similarity.'):
|
44 |
+
results = {model: inference.text_similarity(anchor, inputs, model) for model in select_models}
|
45 |
+
df_results = {model: results[model] for model in results}
|
46 |
+
|
47 |
+
index = inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
df_total = pd.DataFrame(index=index)
|
49 |
+
for key, value in df_results.items():
|
50 |
+
df_total[key] = list(value['score'].values)
|
|
|
51 |
|
52 |
+
st.write('Here are the results for selected models:')
|
53 |
st.write(df_total)
|
54 |
st.write('Visualize the results of each model:')
|
55 |
st.area_chart(df_total)
|
|
|
|
backend/__init__.py
ADDED
File without changes
|
backend/config.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
|
2 |
mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
|
|
|
3 |
minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
|
|
|
1 |
MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
|
2 |
mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
|
3 |
+
mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
|
4 |
minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
|
backend/inference.py
CHANGED
@@ -1,41 +1,30 @@
|
|
1 |
-
from sentence_transformers import SentenceTransformer
|
2 |
import pandas as pd
|
3 |
import jax.numpy as jnp
|
4 |
|
5 |
from typing import List
|
6 |
-
import config
|
7 |
-
|
8 |
-
# We download the models we will be using.
|
9 |
-
# If you do not want to use all, you can comment the unused ones.
|
10 |
-
distilroberta_model = SentenceTransformer(config.MODELS_ID['distilroberta'])
|
11 |
-
mpnet_model = SentenceTransformer(config.MODELS_ID['mpnet'])
|
12 |
-
minilm_l6_model = SentenceTransformer(config.MODELS_ID['minilm_l6'])
|
13 |
|
14 |
# Defining cosine similarity using flax.
|
|
|
|
|
|
|
15 |
def cos_sim(a, b):
|
16 |
-
return jnp.matmul(a, jnp.transpose(b))/(jnp.linalg.norm(a)*jnp.linalg.norm(b))
|
17 |
|
18 |
|
19 |
# We get similarity between embeddings.
|
20 |
-
def text_similarity(anchor: str, inputs: List[str],
|
|
|
21 |
|
22 |
# Creating embeddings
|
23 |
-
|
24 |
-
|
25 |
-
inputs_emb = distilroberta_model.encode([input for input in inputs])
|
26 |
-
elif model == 'mpnet':
|
27 |
-
anchor_emb = mpnet_model.encode(anchor)[None, :]
|
28 |
-
inputs_emb = mpnet_model.encode([input for input in inputs])
|
29 |
-
elif model == 'minilm_l6':
|
30 |
-
anchor_emb = minilm_l6_model.encode(anchor)[None, :]
|
31 |
-
inputs_emb = minilm_l6_model.encode([input for input in inputs])
|
32 |
|
33 |
# Obtaining similarity
|
34 |
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
|
35 |
|
36 |
# Returning a Pandas' dataframe
|
37 |
d = {'inputs': [input for input in inputs],
|
38 |
-
'score': [round(similarity[i],3) for i in range(len(similarity))]}
|
39 |
df = pd.DataFrame(d, columns=['inputs', 'score'])
|
40 |
|
41 |
return df.sort_values('score', ascending=False)
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import jax.numpy as jnp
|
3 |
|
4 |
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Defining cosine similarity using flax.
|
7 |
+
from backend.utils import load_model
|
8 |
+
|
9 |
+
|
10 |
def cos_sim(a, b):
|
11 |
+
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
|
12 |
|
13 |
|
14 |
# We get similarity between embeddings.
|
15 |
+
def text_similarity(anchor: str, inputs: List[str], model_name: str):
|
16 |
+
model = load_model(model_name)
|
17 |
|
18 |
# Creating embeddings
|
19 |
+
anchor_emb = model.encode(anchor)[None, :]
|
20 |
+
inputs_emb = model.encode([input for input in inputs])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Obtaining similarity
|
23 |
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
|
24 |
|
25 |
# Returning a Pandas' dataframe
|
26 |
d = {'inputs': [input for input in inputs],
|
27 |
+
'score': [round(similarity[i], 3) for i in range(len(similarity))]}
|
28 |
df = pd.DataFrame(d, columns=['inputs', 'score'])
|
29 |
|
30 |
return df.sort_values('score', ascending=False)
|
backend/main.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
from fastapi import Query, FastAPI
|
2 |
-
|
3 |
-
import config
|
4 |
-
import inference
|
5 |
-
from typing import List
|
6 |
-
|
7 |
-
app = FastAPI()
|
8 |
-
|
9 |
-
@app.get("/")
|
10 |
-
def read_root():
|
11 |
-
return {"message": "Welcome to the API of flax-sentence-embeddings."}
|
12 |
-
|
13 |
-
@app.get('/similarity')
|
14 |
-
def get_similarity(anchor: str, inputs: List[str] = Query([]), model: str = 'distilroberta'):
|
15 |
-
return {'dataframe': inference.text_similarity(anchor, inputs, model)}
|
16 |
-
|
17 |
-
|
18 |
-
#if __name__ == "__main__":
|
19 |
-
# uvicorn.run("main:app", host="0.0.0.0", port=8080)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/utils.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from .config import MODELS_ID
|
4 |
+
|
5 |
+
|
6 |
+
@st.cache(allow_output_mutation=True)
|
7 |
+
def load_model(model_name):
|
8 |
+
assert model_name in MODELS_ID.keys()
|
9 |
+
# Lazy downloading
|
10 |
+
model = SentenceTransformer(MODELS_ID[model_name])
|
11 |
+
return model
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
fastapi
|
2 |
sentence_transformers
|
3 |
pandas
|
4 |
jax
|
|
|
5 |
streamlit
|
|
|
|
|
1 |
sentence_transformers
|
2 |
pandas
|
3 |
jax
|
4 |
+
jaxlib
|
5 |
streamlit
|