File size: 4,485 Bytes
648b3f3
8454732
 
 
 
18cc8d2
648b3f3
 
 
 
 
8454732
648b3f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18cc8d2
648b3f3
18cc8d2
 
648b3f3
 
 
18cc8d2
 
 
 
 
 
 
 
 
 
 
 
648b3f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18cc8d2
648b3f3
 
 
 
 
18cc8d2
648b3f3
 
18cc8d2
648b3f3
 
 
18cc8d2
 
648b3f3
18cc8d2
f163c78
18cc8d2
648b3f3
18cc8d2
648b3f3
18cc8d2
648b3f3
 
18cc8d2
 
648b3f3
 
 
 
 
 
 
 
 
18cc8d2
 
 
 
 
 
648b3f3
18cc8d2
648b3f3
 
 
 
 
 
 
18cc8d2
63a5a82
648b3f3
 
18cc8d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648b3f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18cc8d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# app.py (FastAPI server to host the Jina Embedding model)
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"

from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional
import torch
from transformers import AutoModel, AutoTokenizer

app = FastAPI()

# -----------------------------
# Load model once on startup
# -----------------------------
MODEL_NAME = "jinaai/jina-embeddings-v4"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModel.from_pretrained(
    MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16
).to(device)
model.eval()


# -----------------------------
# Request / Response Models
# -----------------------------
class EmbedRequest(BaseModel):
    text: str
    task: str = "retrieval"
    prompt_name: Optional[str] = None
    return_token_embeddings: bool = True
    truncate_dim: Optional[int] = None  # for matryoshka embeddings


class EmbedResponse(BaseModel):
    embeddings: List[List[float]]


class EmbedImageRequest(BaseModel):
    image: str
    task: str = "retrieval"
    return_multivector: bool = True
    truncate_dim: Optional[int] = None


class EmbedImageResponse(BaseModel):
    embeddings: List[List[float]]


class TokenizeRequest(BaseModel):
    text: str


class TokenizeResponse(BaseModel):
    input_ids: List[int]


class DecodeRequest(BaseModel):
    input_ids: List[int]


class DecodeResponse(BaseModel):
    text: str


# -----------------------------
# Embedding Endpoint (text)
# -----------------------------
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
    text = req.text

    # Case 1: Query → pooled mean of multivectors
    if not req.return_token_embeddings:
        with torch.no_grad():
            outputs = model.encode_text(
                texts=[text],
                task=req.task,
                prompt_name=req.prompt_name or "query",
                return_multivector=True,
                truncate_dim=req.truncate_dim,
            )
        # outputs[0] = (num_vectors, hidden_dim)
        pooled = outputs[0].mean(dim=0).cpu()
        return {"embeddings": [pooled]}

    # Case 2: Passage → sliding window, token-level embeddings
    enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
    input_ids = enc["input_ids"].squeeze(0).to(device)
    total_tokens = input_ids.size(0)

    max_len = model.config.max_position_embeddings  # ~32k
    stride = 50
    embeddings = []
    position = 0

    while position < total_tokens:
        end = min(position + max_len, total_tokens)
        window_ids = input_ids[position:end].unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model.encode_text(
                texts=[tokenizer.decode(window_ids[0])],
                task=req.task,
                prompt_name=req.prompt_name or "passage",
                return_multivector=True,
                truncate_dim=req.truncate_dim,
            )

        window_embeds = outputs[0].cpu()

        if position > 0:
            window_embeds = window_embeds[stride:]

        embeddings.append(window_embeds)
        position += max_len - stride

    full_embeddings = torch.cat(embeddings, dim=0)
    return {"embeddings": full_embeddings}


# -----------------------------
# Embedding Endpoint (image)
# -----------------------------
@app.post("/embed_image", response_model=EmbedImageResponse)
def embed_image(req: EmbedImageRequest):
    with torch.no_grad():
        outputs = model.encode_image(
            images=[req.image],
            task=req.task,
            return_multivector=req.return_multivector,
            truncate_dim=req.truncate_dim,
        )
    pooled = outputs[0].mean(dim=0).cpu()
    return {"embeddings": [pooled]}


# -----------------------------
# Tokenize Endpoint
# -----------------------------
@app.post("/tokenize", response_model=TokenizeResponse)
def tokenize(req: TokenizeRequest):
    enc = tokenizer(req.text, add_special_tokens=False)
    return {"input_ids": enc["input_ids"]}


# -----------------------------
# Decode Endpoint
# -----------------------------
@app.post("/decode", response_model=DecodeResponse)
def decode(req: DecodeRequest):
    decoded = tokenizer.decode(req.input_ids)
    return {"text": decoded}