from fastapi import FastAPI | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer | |
import os | |
os.environ['TRANSFORMERS_CACHE'] = '/app/cache' | |
app = FastAPI( | |
title="Text Embedding API", | |
description="Dùng mô hình Vietnamese_Embedding từ AITeamVN", | |
version="1.0" | |
) | |
# Load mô hình | |
model = SentenceTransformer("AITeamVN/Vietnamese_Embedding") | |
model.max_seq_length = 2048 | |
# Định nghĩa schema input | |
class TextInput(BaseModel): | |
texts: list[str] | |
def embed_text(data: TextInput): | |
inputs = [t.strip() for t in data.texts] | |
embs = model.encode(inputs, convert_to_numpy=True) | |
return {"embeddings": [e.tolist() for e in embs]} | |