|
import os |
|
import streamlit as st |
|
import logging |
|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
from cryptography.fernet import Fernet |
|
from sentence_transformers import SentenceTransformer, util |
|
|
|
def load_model(): |
|
|
|
no_model = st.session_state.get('model') is None |
|
swap_model = st.session_state.get('input_model_name') != st.session_state['loaded_model_name'] |
|
|
|
if swap_model or no_model: |
|
|
|
with st.spinner('Loading the model might take a couple of seconds...'): |
|
env_local = st.session_state['input_model_name'].lower() + '_path' |
|
env_remote = st.session_state['input_model_name'].lower() + '_remote_path' |
|
|
|
if os.environ.get(env_remote): |
|
model_path = os.environ.get(env_remote) |
|
else: |
|
model_path = os.getenv(env_local) |
|
|
|
auth_token = os.environ.get('read_models') or True |
|
|
|
st.session_state['model'] = SentenceTransformer( |
|
model_name_or_path=model_path, |
|
use_auth_token=auth_token |
|
) |
|
|
|
st.session_state['loaded_model_name'] = st.session_state['input_model_name'] |
|
|
|
logging.info(f"Loaded {st.session_state['input_model_name']}!") |
|
|
|
with st.spinner('Loading embeddings...'): |
|
|
|
file_path = f"./{st.session_state['input_model_name'].lower()}.enc" |
|
|
|
with open(file_path, 'rb') as f: |
|
encrypted_data = f.read() |
|
|
|
try: |
|
cipher = Fernet(st.session_state['decrypt_key']) |
|
decrypted_df = cipher.decrypt(encrypted_data) |
|
st.session_state['db'] = pickle.loads(decrypted_df) |
|
st.session_state['valid_decrypt_key'] = True |
|
except Exception as e: |
|
st.error(body="Error: No valid encryption key!", icon="🔑") |
|
logging.error(e) |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search(): |
|
|
|
with st.spinner('Searching...'): |
|
query_embeddings = st.session_state['model'].encode(sentences=st.session_state['search_query']).mean(axis=0) |
|
|
|
query_scores = util.cos_sim( |
|
a=np.array(query_embeddings), |
|
b=st.session_state['db']['ItemStemEmbeddings'] |
|
).squeeze() |
|
|
|
st.session_state['results'] = pd.DataFrame({ |
|
'Match': query_scores, |
|
'Scale': st.session_state['db']['ScaleName'], |
|
'Instrument': st.session_state['db']['InstrumentName'], |
|
'Reference': st.session_state['db']['InstrumentApaReference'], |
|
}).sort_values(by='Match', ascending=False) |