import os import streamlit as st import logging import pandas as pd import numpy as np import pickle from cryptography.fernet import Fernet from sentence_transformers import SentenceTransformer, util def load_model(): no_model = st.session_state.get('model') is None swap_model = st.session_state.get('input_model_name') != st.session_state['loaded_model_name'] if swap_model or no_model: with st.spinner('Loading the model might take a couple of seconds...'): env_local = st.session_state['input_model_name'].lower() + '_path' env_remote = st.session_state['input_model_name'].lower() + '_remote_path' if os.environ.get(env_remote): model_path = os.environ.get(env_remote) else: model_path = os.getenv(env_local) auth_token = os.environ.get('read_models') or True st.session_state['model'] = SentenceTransformer( model_name_or_path=model_path, use_auth_token=auth_token ) st.session_state['loaded_model_name'] = st.session_state['input_model_name'] logging.info(f"Loaded {st.session_state['input_model_name']}!") with st.spinner('Loading embeddings...'): file_path = f"./{st.session_state['input_model_name'].lower()}.enc" with open(file_path, 'rb') as f: encrypted_data = f.read() try: cipher = Fernet(st.session_state['decrypt_key']) decrypted_df = cipher.decrypt(encrypted_data) st.session_state['db'] = pickle.loads(decrypted_df) st.session_state['valid_decrypt_key'] = True except Exception as e: st.error(body="Error: No valid encryption key!", icon="🔑") logging.error(e) return # st.session_state['db'] = pd.read_parquet( # path=f"./{st.session_state['input_model_name'].lower()}.parquet" # ) #decrypt_key def search(): with st.spinner('Searching...'): query_embeddings = st.session_state['model'].encode(sentences=st.session_state['search_query']).mean(axis=0) query_scores = util.cos_sim( a=np.array(query_embeddings), b=st.session_state['db']['ItemStemEmbeddings'] ).squeeze() st.session_state['results'] = pd.DataFrame({ 'Match': query_scores, 'Scale': st.session_state['db']['ScaleName'], 'Instrument': st.session_state['db']['InstrumentName'], 'Reference': st.session_state['db']['InstrumentApaReference'], }).sort_values(by='Match', ascending=False)