|
|
|
import os |
|
os.environ["HF_HOME"] = "/tmp/huggingface" |
|
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" |
|
|
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from typing import List, Optional |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
MODEL_NAME = "jinaai/jina-embeddings-v4" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
model = AutoModel.from_pretrained( |
|
MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16 |
|
).to(device) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
class EmbedRequest(BaseModel): |
|
text: str |
|
task: str = "retrieval" |
|
prompt_name: Optional[str] = None |
|
return_token_embeddings: bool = True |
|
truncate_dim: Optional[int] = None |
|
|
|
|
|
class EmbedResponse(BaseModel): |
|
embeddings: List[List[float]] |
|
|
|
|
|
class EmbedImageRequest(BaseModel): |
|
image: str |
|
task: str = "retrieval" |
|
return_multivector: bool = True |
|
truncate_dim: Optional[int] = None |
|
|
|
|
|
class EmbedImageResponse(BaseModel): |
|
embeddings: List[List[float]] |
|
|
|
|
|
class TokenizeRequest(BaseModel): |
|
text: str |
|
|
|
|
|
class TokenizeResponse(BaseModel): |
|
input_ids: List[int] |
|
|
|
|
|
class DecodeRequest(BaseModel): |
|
input_ids: List[int] |
|
|
|
|
|
class DecodeResponse(BaseModel): |
|
text: str |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/embed", response_model=EmbedResponse) |
|
def embed(req: EmbedRequest): |
|
text = req.text |
|
|
|
|
|
if not req.return_token_embeddings: |
|
with torch.no_grad(): |
|
outputs = model.encode_text( |
|
texts=[text], |
|
task=req.task, |
|
prompt_name=req.prompt_name or "query", |
|
return_multivector=True, |
|
truncate_dim=req.truncate_dim, |
|
) |
|
|
|
pooled = outputs[0].mean(dim=0).cpu() |
|
return {"embeddings": [pooled]} |
|
|
|
|
|
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt") |
|
input_ids = enc["input_ids"].squeeze(0).to(device) |
|
total_tokens = input_ids.size(0) |
|
|
|
max_len = model.config.max_position_embeddings |
|
stride = 50 |
|
embeddings = [] |
|
position = 0 |
|
|
|
while position < total_tokens: |
|
end = min(position + max_len, total_tokens) |
|
window_ids = input_ids[position:end].unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.encode_text( |
|
texts=[tokenizer.decode(window_ids[0])], |
|
task=req.task, |
|
prompt_name=req.prompt_name or "passage", |
|
return_multivector=True, |
|
truncate_dim=req.truncate_dim, |
|
) |
|
|
|
window_embeds = outputs[0].cpu() |
|
|
|
if position > 0: |
|
window_embeds = window_embeds[stride:] |
|
|
|
embeddings.append(window_embeds) |
|
position += max_len - stride |
|
|
|
full_embeddings = torch.cat(embeddings, dim=0) |
|
return {"embeddings": full_embeddings} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/embed_image", response_model=EmbedImageResponse) |
|
def embed_image(req: EmbedImageRequest): |
|
with torch.no_grad(): |
|
outputs = model.encode_image( |
|
images=[req.image], |
|
task=req.task, |
|
return_multivector=req.return_multivector, |
|
truncate_dim=req.truncate_dim, |
|
) |
|
pooled = outputs[0].mean(dim=0).cpu() |
|
return {"embeddings": [pooled]} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/tokenize", response_model=TokenizeResponse) |
|
def tokenize(req: TokenizeRequest): |
|
enc = tokenizer(req.text, add_special_tokens=False) |
|
return {"input_ids": enc["input_ids"]} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/decode", response_model=DecodeResponse) |
|
def decode(req: DecodeRequest): |
|
decoded = tokenizer.decode(req.input_ids) |
|
return {"text": decoded} |
|
|