import os import streamlit as st import logging import pandas as pd import numpy as np import pickle import numpy as np from bertopic import BERTopic from umap import UMAP from cryptography.fernet import Fernet from sentence_transformers import SentenceTransformer, util from pdb import set_trace as trace def load_db(): with st.spinner('Loading pre-computed embeddings...'): file_path = "./surveybot3000.enc" logging.info(f"Loading data from {file_path}!") with open(file_path, 'rb') as f: encrypted_data = f.read() encryption_key = st.session_state['encryption_key'] if isinstance(encryption_key, str): encryption_key = encryption_key.encode('utf-8') cipher = Fernet(encryption_key) decrypted_df = cipher.decrypt(encrypted_data) st.session_state['db'] = pickle.loads(decrypted_df) st.session_state['is_authenticated'] = True logging.info(f"Loaded {file_path}!") def load_model(): with st.spinner('Loading the model...'): model_path = os.environ.get('surveybot3000_remote_path') or os.getenv('surveybot3000_path') logging.info(f"Loading model from {model_path}!") auth_token = os.environ.get('read_models') or True logging.info(f"Using auth token: {'***' if auth_token else 'None'}") st.session_state['model'] = SentenceTransformer( model_name_or_path=model_path, token=auth_token, trust_remote_code=True, tokenizer_kwargs={ 'use_fast': False, 'legacy': True } ) logging.info(f"Loaded model!") def search(): with st.spinner('Searching the synthetic net...'): query_embeddings = st.session_state['model'].encode( sentences=st.session_state['search_query']).mean(axis=0) item_embeddings = np.vstack( st.session_state['db']['ItemStemEmbeddings']) query_scores = util.cos_sim( a=np.array(query_embeddings), b=item_embeddings ).squeeze() st.session_state['search_results'] = pd.DataFrame({ 'Match': query_scores, 'Scale': st.session_state['db']['ScaleName'], 'Instrument': st.session_state['db']['InstrumentName'], 'Reference': st.session_state['db']['psyctest_doi'], }).sort_values(by='Match', ascending=False)