Trent commited on
Commit
31f3439
·
1 Parent(s): 6e03e5d

List model loading support

Browse files
Files changed (4) hide show
  1. app.py +3 -3
  2. backend/config.py +2 -0
  3. backend/inference.py +9 -4
  4. backend/utils.py +5 -5
app.py CHANGED
@@ -36,7 +36,7 @@ if menu == "Sentence Similarity":
36
 
37
  inputs = []
38
 
39
- for i in range(n_texts):
40
  input = st.text_input(f'Text {i + 1}:')
41
 
42
  inputs.append(input)
@@ -45,7 +45,7 @@ if menu == "Sentence Similarity":
45
  results = {model: inference.text_similarity(anchor, inputs, model) for model in select_models}
46
  df_results = {model: results[model] for model in results}
47
 
48
- index = inputs
49
  df_total = pd.DataFrame(index=index)
50
  for key, value in df_results.items():
51
  df_total[key] = list(value['score'].values)
@@ -53,7 +53,7 @@ if menu == "Sentence Similarity":
53
  st.write('Here are the results for selected models:')
54
  st.write(df_total)
55
  st.write('Visualize the results of each model:')
56
- st.area_chart(df_total)
57
  elif menu == "Search":
58
  select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
59
 
 
36
 
37
  inputs = []
38
 
39
+ for i in range(int(n_texts)):
40
  input = st.text_input(f'Text {i + 1}:')
41
 
42
  inputs.append(input)
 
45
  results = {model: inference.text_similarity(anchor, inputs, model) for model in select_models}
46
  df_results = {model: results[model] for model in results}
47
 
48
+ index = [f"{idx}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
49
  df_total = pd.DataFrame(index=index)
50
  for key, value in df_results.items():
51
  df_total[key] = list(value['score'].values)
 
53
  st.write('Here are the results for selected models:')
54
  st.write(df_total)
55
  st.write('Visualize the results of each model:')
56
+ st.line_chart(df_total)
57
  elif menu == "Search":
58
  select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
59
 
backend/config.py CHANGED
@@ -1,6 +1,8 @@
1
  MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
2
  mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
3
  mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
 
 
4
  minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
5
 
6
  QA_MODELS_ID = dict(
 
1
  MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
2
  mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
3
  mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
4
+ mpnet_asymmetric_qa = ['flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-Q',
5
+ 'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A'],
6
  minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
7
 
8
  QA_MODELS_ID = dict(
backend/inference.py CHANGED
@@ -1,7 +1,7 @@
1
  import pandas as pd
2
  import jax.numpy as jnp
3
 
4
- from typing import List
5
 
6
  # Defining cosine similarity using flax.
7
  from backend.utils import load_model
@@ -13,12 +13,17 @@ def cos_sim(a, b):
13
 
14
  # We get similarity between embeddings.
15
  def text_similarity(anchor: str, inputs: List[str], model_name: str):
 
16
  model = load_model(model_name)
17
- assert hasattr(model, 'encode') # multiple models is not supported for similarity
18
 
19
  # Creating embeddings
20
- anchor_emb = model.encode(anchor)[None, :]
21
- inputs_emb = model.encode([input for input in inputs])
 
 
 
 
 
22
 
23
  # Obtaining similarity
24
  similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
 
1
  import pandas as pd
2
  import jax.numpy as jnp
3
 
4
+ from typing import List, Union
5
 
6
  # Defining cosine similarity using flax.
7
  from backend.utils import load_model
 
13
 
14
  # We get similarity between embeddings.
15
  def text_similarity(anchor: str, inputs: List[str], model_name: str):
16
+ print(model_name)
17
  model = load_model(model_name)
 
18
 
19
  # Creating embeddings
20
+ if hasattr(model, 'encode'):
21
+ anchor_emb = model.encode(anchor)[None, :]
22
+ inputs_emb = model.encode([input for input in inputs])
23
+ else:
24
+ assert len(model) == 2
25
+ anchor_emb = model[0].encode(anchor)[None, :]
26
+ inputs_emb = model[1].encode([input for input in inputs])
27
 
28
  # Obtaining similarity
29
  similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
backend/utils.py CHANGED
@@ -7,10 +7,10 @@ from .config import MODELS_ID
7
  def load_model(model_name):
8
  assert model_name in MODELS_ID.keys()
9
  # Lazy downloading
10
- models = MODELS_ID[model_name]
11
- if models is str:
12
- output = SentenceTransformer(models)
13
- elif hasattr(models, '__iter__') :
14
- output = [SentenceTransformer(model) for model in models]
15
 
16
  return output
 
7
  def load_model(model_name):
8
  assert model_name in MODELS_ID.keys()
9
  # Lazy downloading
10
+ model_ids = MODELS_ID[model_name]
11
+ if type(model_ids) == str:
12
+ output = SentenceTransformer(model_ids)
13
+ elif hasattr(model_ids, '__iter__'):
14
+ output = [SentenceTransformer(name) for name in model_ids]
15
 
16
  return output