Spaces:
Sleeping
Sleeping
File size: 2,732 Bytes
b805057 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
from datetime import datetime
from config import TEST_MODE, device, log
router = APIRouter()
class SentenceEmbeddingsInput(BaseModel):
inputs: list[str]
model: str
parameters: dict
class SentenceEmbeddingsOutput(BaseModel):
embeddings: Optional[list[list[float]]] = None
error: Optional[str] = None
@router.post('/sentence-embeddings')
def sentence_embeddings(inputs: SentenceEmbeddingsInput):
start_time = datetime.now()
fn = sentence_embeddings_mapping.get(inputs.model)
if not fn:
return SentenceEmbeddingsOutput(
error=f'No sentence embeddings model found for {inputs.model}'
)
try:
embeddings = fn(inputs.inputs, inputs.parameters)
log({
"task": "sentence_embeddings",
"model": inputs.model,
"start_time": start_time.isoformat(),
"time_taken": (datetime.now() - start_time).total_seconds(),
"inputs": inputs.inputs,
"outputs": embeddings,
"parameters": inputs.parameters,
})
loaded_models_last_updated[inputs.model] = datetime.now()
return SentenceEmbeddingsOutput(
embeddings=embeddings
)
except Exception as e:
return SentenceEmbeddingsOutput(
error=str(e)
)
def generic_sentence_embeddings(model_name: str):
global loaded_models
def process_texts(texts: list[str], parameters: dict):
if TEST_MODE:
return [[0.1,0.2]] * len(texts)
if model_name in loaded_models:
tokenizer, model = loaded_models[model_name]
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
loaded_models[model] = (tokenizer, model)
# Tokenize sentences
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.tolist()
return process_texts
# Polling every X minutes to
loaded_models = {}
loaded_models_last_updated = {}
sentence_embeddings_mapping = {
'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'),
'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'),
} |