Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import logging | |
import pandas as pd | |
import numpy as np | |
import pickle | |
from cryptography.fernet import Fernet | |
from sentence_transformers import SentenceTransformer, util | |
def load_model(): | |
no_model = st.session_state.get('model') is None | |
swap_model = st.session_state.get('input_model_name') != st.session_state['loaded_model_name'] | |
if swap_model or no_model: | |
with st.spinner('Loading the model might take a couple of seconds...'): | |
env_local = st.session_state['input_model_name'].lower() + '_path' | |
env_remote = st.session_state['input_model_name'].lower() + '_remote_path' | |
if os.environ.get(env_remote): | |
model_path = os.environ.get(env_remote) | |
else: | |
model_path = os.getenv(env_local) | |
auth_token = os.environ.get('read_models') or True | |
st.session_state['model'] = SentenceTransformer( | |
model_name_or_path=model_path, | |
use_auth_token=auth_token | |
) | |
st.session_state['loaded_model_name'] = st.session_state['input_model_name'] | |
logging.info(f"Loaded {st.session_state['input_model_name']}!") | |
with st.spinner('Loading embeddings...'): | |
file_path = f"./{st.session_state['input_model_name'].lower()}.enc" | |
with open(file_path, 'rb') as f: | |
encrypted_data = f.read() | |
try: | |
cipher = Fernet(st.session_state['decrypt_key']) | |
decrypted_df = cipher.decrypt(encrypted_data) | |
st.session_state['db'] = pickle.loads(decrypted_df) | |
st.session_state['valid_decrypt_key'] = True | |
except Exception as e: | |
st.error(body="Error: No valid encryption key!", icon="🔑") | |
logging.error(e) | |
return | |
# st.session_state['db'] = pd.read_parquet( | |
# path=f"./{st.session_state['input_model_name'].lower()}.parquet" | |
# ) | |
#decrypt_key | |
def search(): | |
with st.spinner('Searching...'): | |
query_embeddings = st.session_state['model'].encode(sentences=st.session_state['search_query']).mean(axis=0) | |
query_scores = util.cos_sim( | |
a=np.array(query_embeddings), | |
b=st.session_state['db']['ItemStemEmbeddings'] | |
).squeeze() | |
st.session_state['results'] = pd.DataFrame({ | |
'Match': query_scores, | |
'Scale': st.session_state['db']['ScaleName'], | |
'Instrument': st.session_state['db']['InstrumentName'], | |
'Reference': st.session_state['db']['InstrumentApaReference'], | |
}).sort_values(by='Match', ascending=False) |