fastapi_ai_endpoints / tasks /sentence_embeddings.py
jxtan's picture
Added Translation Endpoint
b805057
raw
history blame
2.73 kB
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
from datetime import datetime
from config import TEST_MODE, device, log
router = APIRouter()
class SentenceEmbeddingsInput(BaseModel):
inputs: list[str]
model: str
parameters: dict
class SentenceEmbeddingsOutput(BaseModel):
embeddings: Optional[list[list[float]]] = None
error: Optional[str] = None
@router.post('/sentence-embeddings')
def sentence_embeddings(inputs: SentenceEmbeddingsInput):
start_time = datetime.now()
fn = sentence_embeddings_mapping.get(inputs.model)
if not fn:
return SentenceEmbeddingsOutput(
error=f'No sentence embeddings model found for {inputs.model}'
)
try:
embeddings = fn(inputs.inputs, inputs.parameters)
log({
"task": "sentence_embeddings",
"model": inputs.model,
"start_time": start_time.isoformat(),
"time_taken": (datetime.now() - start_time).total_seconds(),
"inputs": inputs.inputs,
"outputs": embeddings,
"parameters": inputs.parameters,
})
loaded_models_last_updated[inputs.model] = datetime.now()
return SentenceEmbeddingsOutput(
embeddings=embeddings
)
except Exception as e:
return SentenceEmbeddingsOutput(
error=str(e)
)
def generic_sentence_embeddings(model_name: str):
global loaded_models
def process_texts(texts: list[str], parameters: dict):
if TEST_MODE:
return [[0.1,0.2]] * len(texts)
if model_name in loaded_models:
tokenizer, model = loaded_models[model_name]
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
loaded_models[model] = (tokenizer, model)
# Tokenize sentences
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.tolist()
return process_texts
# Polling every X minutes to
loaded_models = {}
loaded_models_last_updated = {}
sentence_embeddings_mapping = {
'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'),
'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'),
}