File size: 484 Bytes
a41bdbc
 
 
 
 
 
 
 
 
31f3439
 
 
 
 
6e03e5d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import streamlit as st
from sentence_transformers import SentenceTransformer
from .config import MODELS_ID


@st.cache(allow_output_mutation=True)
def load_model(model_name):
    assert model_name in MODELS_ID.keys()
    # Lazy downloading
    model_ids = MODELS_ID[model_name]
    if type(model_ids) == str:
        output = SentenceTransformer(model_ids)
    elif hasattr(model_ids, '__iter__'):
        output = [SentenceTransformer(name) for name in model_ids]

    return output