File size: 2,437 Bytes
818f654
 
 
 
 
 
cf1362c
 
 
818f654
 
cf1362c
818f654
 
cf1362c
818f654
cf1362c
818f654
e113c95
cf1362c
818f654
cf1362c
 
818f654
89a0138
e113c95
89a0138
 
 
 
cf1362c
 
818f654
cf1362c
 
818f654
 
cf1362c
 
 
 
 
23d73be
cf1362c
 
 
b6858d4
faa78fa
cf1362c
 
faa78fa
 
 
 
 
 
cf1362c
818f654
e113c95
818f654
 
 
 
cf1362c
 
 
 
 
 
818f654
 
cf1362c
818f654
 
cf1362c
818f654
 
 
cf1362c
e113c95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import streamlit as st
import logging
import pandas as pd
import numpy as np
import pickle
import numpy as np
from bertopic import BERTopic
from umap import UMAP
from cryptography.fernet import Fernet
from sentence_transformers import SentenceTransformer, util
from pdb import set_trace as trace


def load_db():

    with st.spinner('Loading pre-computed embeddings...'):

        file_path = "./surveybot3000.enc"
        logging.info(f"Loading data from {file_path}!")

        with open(file_path, 'rb') as f:
            encrypted_data = f.read()

        encryption_key = st.session_state['encryption_key']

        if isinstance(encryption_key, str):
            encryption_key = encryption_key.encode('utf-8')

        cipher = Fernet(encryption_key)
        decrypted_df = cipher.decrypt(encrypted_data)
        st.session_state['db'] = pickle.loads(decrypted_df)

        st.session_state['is_authenticated'] = True
        logging.info(f"Loaded {file_path}!")


def load_model():

    with st.spinner('Loading the model...'):


        model_path = os.environ.get('surveybot3000_remote_path') or os.getenv('surveybot3000_path')
        logging.info(f"Loading model from {model_path}!")

        auth_token = os.environ.get('read_models') or True
        logging.info(f"Using auth token: {'***' if auth_token else 'None'}")

        st.session_state['model'] = SentenceTransformer(
            model_name_or_path=model_path,
            token=auth_token,
            trust_remote_code=True,
            tokenizer_kwargs={
                'use_fast': False,
                'legacy': True
            }
        )

        logging.info(f"Loaded model!")


def search():

    with st.spinner('Searching the synthetic net...'):
        query_embeddings = st.session_state['model'].encode(
            sentences=st.session_state['search_query']).mean(axis=0)

        item_embeddings = np.vstack(
            st.session_state['db']['ItemStemEmbeddings'])
        query_scores = util.cos_sim(
            a=np.array(query_embeddings),
            b=item_embeddings
        ).squeeze()

        st.session_state['search_results'] = pd.DataFrame({
            'Match': query_scores,
            'Scale': st.session_state['db']['ScaleName'],
            'Instrument': st.session_state['db']['InstrumentName'],
            'Reference': st.session_state['db']['psyctest_doi'],
        }).sort_values(by='Match', ascending=False)