# app.py (FastAPI server to host the Jina Embedding model) import os os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" from fastapi import FastAPI from pydantic import BaseModel from typing import List, Optional import torch from transformers import AutoModel, AutoTokenizer app = FastAPI() # ----------------------------- # Load model once on startup # ----------------------------- MODEL_NAME = "jinaai/jina-embeddings-v4" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModel.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16 ).to(device) model.eval() # ----------------------------- # Request / Response Models # ----------------------------- class EmbedRequest(BaseModel): text: str task: str = "retrieval" prompt_name: Optional[str] = None return_token_embeddings: bool = True truncate_dim: Optional[int] = None # for matryoshka embeddings class EmbedResponse(BaseModel): embeddings: List[List[float]] class EmbedImageRequest(BaseModel): image: str task: str = "retrieval" return_multivector: bool = True truncate_dim: Optional[int] = None class EmbedImageResponse(BaseModel): embeddings: List[List[float]] class TokenizeRequest(BaseModel): text: str class TokenizeResponse(BaseModel): input_ids: List[int] class DecodeRequest(BaseModel): input_ids: List[int] class DecodeResponse(BaseModel): text: str # ----------------------------- # Embedding Endpoint (text) # ----------------------------- @app.post("/embed", response_model=EmbedResponse) def embed(req: EmbedRequest): text = req.text # Case 1: Query → pooled mean of multivectors if not req.return_token_embeddings: with torch.no_grad(): outputs = model.encode_text( texts=[text], task=req.task, prompt_name=req.prompt_name or "query", return_multivector=True, truncate_dim=req.truncate_dim, ) # outputs[0] = (num_vectors, hidden_dim) pooled = outputs[0].mean(dim=0).cpu() return {"embeddings": [pooled]} # Case 2: Passage → sliding window, token-level embeddings enc = tokenizer(text, add_special_tokens=False, return_tensors="pt") input_ids = enc["input_ids"].squeeze(0).to(device) total_tokens = input_ids.size(0) max_len = model.config.max_position_embeddings # ~32k stride = 50 embeddings = [] position = 0 while position < total_tokens: end = min(position + max_len, total_tokens) window_ids = input_ids[position:end].unsqueeze(0).to(device) with torch.no_grad(): outputs = model.encode_text( texts=[tokenizer.decode(window_ids[0])], task=req.task, prompt_name=req.prompt_name or "passage", return_multivector=True, truncate_dim=req.truncate_dim, ) window_embeds = outputs[0].cpu() if position > 0: window_embeds = window_embeds[stride:] embeddings.append(window_embeds) position += max_len - stride full_embeddings = torch.cat(embeddings, dim=0) return {"embeddings": full_embeddings} # ----------------------------- # Embedding Endpoint (image) # ----------------------------- @app.post("/embed_image", response_model=EmbedImageResponse) def embed_image(req: EmbedImageRequest): with torch.no_grad(): outputs = model.encode_image( images=[req.image], task=req.task, return_multivector=req.return_multivector, truncate_dim=req.truncate_dim, ) pooled = outputs[0].mean(dim=0).cpu() return {"embeddings": [pooled]} # ----------------------------- # Tokenize Endpoint # ----------------------------- @app.post("/tokenize", response_model=TokenizeResponse) def tokenize(req: TokenizeRequest): enc = tokenizer(req.text, add_special_tokens=False) return {"input_ids": enc["input_ids"]} # ----------------------------- # Decode Endpoint # ----------------------------- @app.post("/decode", response_model=DecodeResponse) def decode(req: DecodeRequest): decoded = tokenizer.decode(req.input_ids) return {"text": decoded}