synth-net / modeling.py
bjorn-hommel's picture
init commit
818f654
raw
history blame
2.78 kB
import os
import streamlit as st
import logging
import pandas as pd
import numpy as np
import pickle
from cryptography.fernet import Fernet
from sentence_transformers import SentenceTransformer, util
def load_model():
no_model = st.session_state.get('model') is None
swap_model = st.session_state.get('input_model_name') != st.session_state['loaded_model_name']
if swap_model or no_model:
with st.spinner('Loading the model might take a couple of seconds...'):
env_local = st.session_state['input_model_name'].lower() + '_path'
env_remote = st.session_state['input_model_name'].lower() + '_remote_path'
if os.environ.get(env_remote):
model_path = os.environ.get(env_remote)
else:
model_path = os.getenv(env_local)
auth_token = os.environ.get('read_models') or True
st.session_state['model'] = SentenceTransformer(
model_name_or_path=model_path,
use_auth_token=auth_token
)
st.session_state['loaded_model_name'] = st.session_state['input_model_name']
logging.info(f"Loaded {st.session_state['input_model_name']}!")
with st.spinner('Loading embeddings...'):
file_path = f"./{st.session_state['input_model_name'].lower()}.enc"
with open(file_path, 'rb') as f:
encrypted_data = f.read()
try:
cipher = Fernet(st.session_state['decrypt_key'])
decrypted_df = cipher.decrypt(encrypted_data)
st.session_state['db'] = pickle.loads(decrypted_df)
st.session_state['valid_decrypt_key'] = True
except Exception as e:
st.error(body="Error: No valid encryption key!", icon="🔑")
logging.error(e)
return
# st.session_state['db'] = pd.read_parquet(
# path=f"./{st.session_state['input_model_name'].lower()}.parquet"
# )
#decrypt_key
def search():
with st.spinner('Searching...'):
query_embeddings = st.session_state['model'].encode(sentences=st.session_state['search_query']).mean(axis=0)
query_scores = util.cos_sim(
a=np.array(query_embeddings),
b=st.session_state['db']['ItemStemEmbeddings']
).squeeze()
st.session_state['results'] = pd.DataFrame({
'Match': query_scores,
'Scale': st.session_state['db']['ScaleName'],
'Instrument': st.session_state['db']['InstrumentName'],
'Reference': st.session_state['db']['InstrumentApaReference'],
}).sort_values(by='Match', ascending=False)