Amlan99's picture
Update app.py
f163c78 verified
raw
history blame
4.49 kB
# app.py (FastAPI server to host the Jina Embedding model)
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional
import torch
from transformers import AutoModel, AutoTokenizer
app = FastAPI()
# -----------------------------
# Load model once on startup
# -----------------------------
MODEL_NAME = "jinaai/jina-embeddings-v4"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModel.from_pretrained(
MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16
).to(device)
model.eval()
# -----------------------------
# Request / Response Models
# -----------------------------
class EmbedRequest(BaseModel):
text: str
task: str = "retrieval"
prompt_name: Optional[str] = None
return_token_embeddings: bool = True
truncate_dim: Optional[int] = None # for matryoshka embeddings
class EmbedResponse(BaseModel):
embeddings: List[List[float]]
class EmbedImageRequest(BaseModel):
image: str
task: str = "retrieval"
return_multivector: bool = True
truncate_dim: Optional[int] = None
class EmbedImageResponse(BaseModel):
embeddings: List[List[float]]
class TokenizeRequest(BaseModel):
text: str
class TokenizeResponse(BaseModel):
input_ids: List[int]
class DecodeRequest(BaseModel):
input_ids: List[int]
class DecodeResponse(BaseModel):
text: str
# -----------------------------
# Embedding Endpoint (text)
# -----------------------------
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
text = req.text
# Case 1: Query β†’ pooled mean of multivectors
if not req.return_token_embeddings:
with torch.no_grad():
outputs = model.encode_text(
texts=[text],
task=req.task,
prompt_name=req.prompt_name or "query",
return_multivector=True,
truncate_dim=req.truncate_dim,
)
# outputs[0] = (num_vectors, hidden_dim)
pooled = outputs[0].mean(dim=0).cpu()
return {"embeddings": [pooled]}
# Case 2: Passage β†’ sliding window, token-level embeddings
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
input_ids = enc["input_ids"].squeeze(0).to(device)
total_tokens = input_ids.size(0)
max_len = model.config.max_position_embeddings # ~32k
stride = 50
embeddings = []
position = 0
while position < total_tokens:
end = min(position + max_len, total_tokens)
window_ids = input_ids[position:end].unsqueeze(0).to(device)
with torch.no_grad():
outputs = model.encode_text(
texts=[tokenizer.decode(window_ids[0])],
task=req.task,
prompt_name=req.prompt_name or "passage",
return_multivector=True,
truncate_dim=req.truncate_dim,
)
window_embeds = outputs[0].cpu()
if position > 0:
window_embeds = window_embeds[stride:]
embeddings.append(window_embeds)
position += max_len - stride
full_embeddings = torch.cat(embeddings, dim=0)
return {"embeddings": full_embeddings}
# -----------------------------
# Embedding Endpoint (image)
# -----------------------------
@app.post("/embed_image", response_model=EmbedImageResponse)
def embed_image(req: EmbedImageRequest):
with torch.no_grad():
outputs = model.encode_image(
images=[req.image],
task=req.task,
return_multivector=req.return_multivector,
truncate_dim=req.truncate_dim,
)
pooled = outputs[0].mean(dim=0).cpu()
return {"embeddings": [pooled]}
# -----------------------------
# Tokenize Endpoint
# -----------------------------
@app.post("/tokenize", response_model=TokenizeResponse)
def tokenize(req: TokenizeRequest):
enc = tokenizer(req.text, add_special_tokens=False)
return {"input_ids": enc["input_ids"]}
# -----------------------------
# Decode Endpoint
# -----------------------------
@app.post("/decode", response_model=DecodeResponse)
def decode(req: DecodeRequest):
decoded = tokenizer.decode(req.input_ids)
return {"text": decoded}