Spaces:
Runtime error
Runtime error
File size: 1,098 Bytes
6ae27e8 31f3439 6ae27e8 5cd1ac6 a41bdbc 6ae27e8 a41bdbc 6ae27e8 5cd1ac6 31f3439 5cd1ac6 6ae27e8 31f3439 fa5d8a4 31f3439 fa5d8a4 6ae27e8 fa5d8a4 a41bdbc 6ae27e8 fa5d8a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import pandas as pd
import jax.numpy as jnp
from typing import List, Union
# Defining cosine similarity using flax.
from backend.config import MODELS_ID
from backend.utils import load_model
def cos_sim(a, b):
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
# We get similarity between embeddings.
def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict: dict):
print(model_name)
model = load_model(model_name, model_dict)
# Creating embeddings
if hasattr(model, 'encode'):
anchor_emb = model.encode(anchor)[None, :]
inputs_emb = model.encode(inputs)
else:
assert len(model) == 2
anchor_emb = model[0].encode(anchor)[None, :]
inputs_emb = model[1].encode(inputs)
# Obtaining similarity
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
# Returning a Pandas' dataframe
d = {'inputs': inputs,
'score': [round(similarity[i], 3) for i in range(len(similarity))]}
df = pd.DataFrame(d, columns=['inputs', 'score'])
return df
|