Spaces:
Runtime error
Runtime error
File size: 4,485 Bytes
648b3f3 8454732 18cc8d2 648b3f3 8454732 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 f163c78 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 648b3f3 18cc8d2 63a5a82 648b3f3 18cc8d2 648b3f3 18cc8d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# app.py (FastAPI server to host the Jina Embedding model)
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional
import torch
from transformers import AutoModel, AutoTokenizer
app = FastAPI()
# -----------------------------
# Load model once on startup
# -----------------------------
MODEL_NAME = "jinaai/jina-embeddings-v4"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModel.from_pretrained(
MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16
).to(device)
model.eval()
# -----------------------------
# Request / Response Models
# -----------------------------
class EmbedRequest(BaseModel):
text: str
task: str = "retrieval"
prompt_name: Optional[str] = None
return_token_embeddings: bool = True
truncate_dim: Optional[int] = None # for matryoshka embeddings
class EmbedResponse(BaseModel):
embeddings: List[List[float]]
class EmbedImageRequest(BaseModel):
image: str
task: str = "retrieval"
return_multivector: bool = True
truncate_dim: Optional[int] = None
class EmbedImageResponse(BaseModel):
embeddings: List[List[float]]
class TokenizeRequest(BaseModel):
text: str
class TokenizeResponse(BaseModel):
input_ids: List[int]
class DecodeRequest(BaseModel):
input_ids: List[int]
class DecodeResponse(BaseModel):
text: str
# -----------------------------
# Embedding Endpoint (text)
# -----------------------------
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
text = req.text
# Case 1: Query → pooled mean of multivectors
if not req.return_token_embeddings:
with torch.no_grad():
outputs = model.encode_text(
texts=[text],
task=req.task,
prompt_name=req.prompt_name or "query",
return_multivector=True,
truncate_dim=req.truncate_dim,
)
# outputs[0] = (num_vectors, hidden_dim)
pooled = outputs[0].mean(dim=0).cpu()
return {"embeddings": [pooled]}
# Case 2: Passage → sliding window, token-level embeddings
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
input_ids = enc["input_ids"].squeeze(0).to(device)
total_tokens = input_ids.size(0)
max_len = model.config.max_position_embeddings # ~32k
stride = 50
embeddings = []
position = 0
while position < total_tokens:
end = min(position + max_len, total_tokens)
window_ids = input_ids[position:end].unsqueeze(0).to(device)
with torch.no_grad():
outputs = model.encode_text(
texts=[tokenizer.decode(window_ids[0])],
task=req.task,
prompt_name=req.prompt_name or "passage",
return_multivector=True,
truncate_dim=req.truncate_dim,
)
window_embeds = outputs[0].cpu()
if position > 0:
window_embeds = window_embeds[stride:]
embeddings.append(window_embeds)
position += max_len - stride
full_embeddings = torch.cat(embeddings, dim=0)
return {"embeddings": full_embeddings}
# -----------------------------
# Embedding Endpoint (image)
# -----------------------------
@app.post("/embed_image", response_model=EmbedImageResponse)
def embed_image(req: EmbedImageRequest):
with torch.no_grad():
outputs = model.encode_image(
images=[req.image],
task=req.task,
return_multivector=req.return_multivector,
truncate_dim=req.truncate_dim,
)
pooled = outputs[0].mean(dim=0).cpu()
return {"embeddings": [pooled]}
# -----------------------------
# Tokenize Endpoint
# -----------------------------
@app.post("/tokenize", response_model=TokenizeResponse)
def tokenize(req: TokenizeRequest):
enc = tokenizer(req.text, add_special_tokens=False)
return {"input_ids": enc["input_ids"]}
# -----------------------------
# Decode Endpoint
# -----------------------------
@app.post("/decode", response_model=DecodeResponse)
def decode(req: DecodeRequest):
decoded = tokenizer.decode(req.input_ids)
return {"text": decoded}
|