Spaces:
Runtime error
Runtime error
# app.py (FastAPI server to host the Jina Embedding model) | |
import os | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from typing import List, Optional | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
app = FastAPI() | |
# ----------------------------- | |
# Load model once on startup | |
# ----------------------------- | |
MODEL_NAME = "jinaai/jina-embeddings-v4" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
model = AutoModel.from_pretrained( | |
MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16 | |
).to(device) | |
model.eval() | |
# ----------------------------- | |
# Request / Response Models | |
# ----------------------------- | |
class EmbedRequest(BaseModel): | |
text: str | |
task: str = "retrieval" | |
prompt_name: Optional[str] = None | |
return_token_embeddings: bool = True | |
truncate_dim: Optional[int] = None # for matryoshka embeddings | |
class EmbedResponse(BaseModel): | |
embeddings: List[List[float]] | |
class EmbedImageRequest(BaseModel): | |
image: str | |
task: str = "retrieval" | |
return_multivector: bool = True | |
truncate_dim: Optional[int] = None | |
class EmbedImageResponse(BaseModel): | |
embeddings: List[List[float]] | |
class TokenizeRequest(BaseModel): | |
text: str | |
class TokenizeResponse(BaseModel): | |
input_ids: List[int] | |
class DecodeRequest(BaseModel): | |
input_ids: List[int] | |
class DecodeResponse(BaseModel): | |
text: str | |
# ----------------------------- | |
# Embedding Endpoint (text) | |
# ----------------------------- | |
def embed(req: EmbedRequest): | |
text = req.text | |
# Case 1: Query β pooled mean of multivectors | |
if not req.return_token_embeddings: | |
with torch.no_grad(): | |
outputs = model.encode_text( | |
texts=[text], | |
task=req.task, | |
prompt_name=req.prompt_name or "query", | |
return_multivector=True, | |
truncate_dim=req.truncate_dim, | |
) | |
# outputs[0] = (num_vectors, hidden_dim) | |
pooled = outputs[0].mean(dim=0).cpu() | |
return {"embeddings": [pooled]} | |
# Case 2: Passage β sliding window, token-level embeddings | |
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt") | |
input_ids = enc["input_ids"].squeeze(0).to(device) | |
total_tokens = input_ids.size(0) | |
max_len = model.config.max_position_embeddings # ~32k | |
stride = 50 | |
embeddings = [] | |
position = 0 | |
while position < total_tokens: | |
end = min(position + max_len, total_tokens) | |
window_ids = input_ids[position:end].unsqueeze(0).to(device) | |
with torch.no_grad(): | |
outputs = model.encode_text( | |
texts=[tokenizer.decode(window_ids[0])], | |
task=req.task, | |
prompt_name=req.prompt_name or "passage", | |
return_multivector=True, | |
truncate_dim=req.truncate_dim, | |
) | |
window_embeds = outputs[0].cpu() | |
if position > 0: | |
window_embeds = window_embeds[stride:] | |
embeddings.append(window_embeds) | |
position += max_len - stride | |
full_embeddings = torch.cat(embeddings, dim=0) | |
return {"embeddings": full_embeddings} | |
# ----------------------------- | |
# Embedding Endpoint (image) | |
# ----------------------------- | |
def embed_image(req: EmbedImageRequest): | |
with torch.no_grad(): | |
outputs = model.encode_image( | |
images=[req.image], | |
task=req.task, | |
return_multivector=req.return_multivector, | |
truncate_dim=req.truncate_dim, | |
) | |
pooled = outputs[0].mean(dim=0).cpu() | |
return {"embeddings": [pooled]} | |
# ----------------------------- | |
# Tokenize Endpoint | |
# ----------------------------- | |
def tokenize(req: TokenizeRequest): | |
enc = tokenizer(req.text, add_special_tokens=False) | |
return {"input_ids": enc["input_ids"]} | |
# ----------------------------- | |
# Decode Endpoint | |
# ----------------------------- | |
def decode(req: DecodeRequest): | |
decoded = tokenizer.decode(req.input_ids) | |
return {"text": decoded} | |