Refactor embedding batch_size (#3825)
Browse files### What problem does this PR solve?
Refactor embedding batch_size. Close #3657
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
- api/db/services/llm_service.py +4 -4
- rag/benchmark.py +5 -8
- rag/llm/embedding_model.py +151 -97
api/db/services/llm_service.py
CHANGED
@@ -232,13 +232,13 @@ class LLMBundle(object):
|
|
232 |
self.max_length = lm.max_tokens
|
233 |
break
|
234 |
|
235 |
-
def encode(self, texts: list
|
236 |
-
|
237 |
if not TenantLLMService.increase_usage(
|
238 |
self.tenant_id, self.llm_type, used_tokens):
|
239 |
logging.error(
|
240 |
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
241 |
-
return
|
242 |
|
243 |
def encode_queries(self, query: str):
|
244 |
emd, used_tokens = self.mdl.encode_queries(query)
|
@@ -280,7 +280,7 @@ class LLMBundle(object):
|
|
280 |
logging.error(
|
281 |
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
282 |
return
|
283 |
-
yield chunk
|
284 |
|
285 |
def chat(self, system, history, gen_conf):
|
286 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
|
|
232 |
self.max_length = lm.max_tokens
|
233 |
break
|
234 |
|
235 |
+
def encode(self, texts: list):
|
236 |
+
embeddings, used_tokens = self.mdl.encode(texts)
|
237 |
if not TenantLLMService.increase_usage(
|
238 |
self.tenant_id, self.llm_type, used_tokens):
|
239 |
logging.error(
|
240 |
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
241 |
+
return embeddings, used_tokens
|
242 |
|
243 |
def encode_queries(self, query: str):
|
244 |
emd, used_tokens = self.mdl.encode_queries(query)
|
|
|
280 |
logging.error(
|
281 |
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
282 |
return
|
283 |
+
yield chunk
|
284 |
|
285 |
def chat(self, system, history, gen_conf):
|
286 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
rag/benchmark.py
CHANGED
@@ -63,16 +63,13 @@ class Benchmark:
|
|
63 |
run[query][c["chunk_id"]] = c["similarity"]
|
64 |
return run
|
65 |
|
66 |
-
def embedding(self, docs
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
|
71 |
-
vects.extend(vts.tolist())
|
72 |
-
assert len(docs) == len(vects)
|
73 |
vector_size = 0
|
74 |
for i, d in enumerate(docs):
|
75 |
-
v =
|
76 |
vector_size = len(v)
|
77 |
d["q_%d_vec" % len(v)] = v
|
78 |
return docs, vector_size
|
|
|
63 |
run[query][c["chunk_id"]] = c["similarity"]
|
64 |
return run
|
65 |
|
66 |
+
def embedding(self, docs):
|
67 |
+
texts = [d["content_with_weight"] for d in docs]
|
68 |
+
embeddings, _ = self.embd_mdl.encode(texts)
|
69 |
+
assert len(docs) == len(embeddings)
|
|
|
|
|
|
|
70 |
vector_size = 0
|
71 |
for i, d in enumerate(docs):
|
72 |
+
v = embeddings[i]
|
73 |
vector_size = len(v)
|
74 |
d["q_%d_vec" % len(v)] = v
|
75 |
return docs, vector_size
|
rag/llm/embedding_model.py
CHANGED
@@ -38,7 +38,7 @@ class Base(ABC):
|
|
38 |
def __init__(self, key, model_name):
|
39 |
pass
|
40 |
|
41 |
-
def encode(self, texts: list
|
42 |
raise NotImplementedError("Please implement encode method!")
|
43 |
|
44 |
def encode_queries(self, text: str):
|
@@ -78,15 +78,16 @@ class DefaultEmbedding(Base):
|
|
78 |
use_fp16=torch.cuda.is_available())
|
79 |
self._model = DefaultEmbedding._model
|
80 |
|
81 |
-
def encode(self, texts: list
|
|
|
82 |
texts = [truncate(t, 2048) for t in texts]
|
83 |
token_count = 0
|
84 |
for t in texts:
|
85 |
token_count += num_tokens_from_string(t)
|
86 |
-
|
87 |
for i in range(0, len(texts), batch_size):
|
88 |
-
|
89 |
-
return np.array(
|
90 |
|
91 |
def encode_queries(self, text: str):
|
92 |
token_count = num_tokens_from_string(text)
|
@@ -101,12 +102,18 @@ class OpenAIEmbed(Base):
|
|
101 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
102 |
self.model_name = model_name
|
103 |
|
104 |
-
def encode(self, texts: list
|
|
|
|
|
105 |
texts = [truncate(t, 8191) for t in texts]
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
110 |
|
111 |
def encode_queries(self, text):
|
112 |
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
@@ -123,12 +130,14 @@ class LocalAIEmbed(Base):
|
|
123 |
self.client = OpenAI(api_key="empty", base_url=base_url)
|
124 |
self.model_name = model_name.split("___")[0]
|
125 |
|
126 |
-
def encode(self, texts: list
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
132 |
|
133 |
def encode_queries(self, text):
|
134 |
embds, cnt = self.encode([text])
|
@@ -155,12 +164,12 @@ class BaiChuanEmbed(OpenAIEmbed):
|
|
155 |
|
156 |
class QWenEmbed(Base):
|
157 |
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
158 |
-
|
159 |
self.model_name = model_name
|
160 |
|
161 |
-
def encode(self, texts: list
|
162 |
import dashscope
|
163 |
-
batch_size =
|
164 |
try:
|
165 |
res = []
|
166 |
token_count = 0
|
@@ -169,6 +178,7 @@ class QWenEmbed(Base):
|
|
169 |
resp = dashscope.TextEmbedding.call(
|
170 |
model=self.model_name,
|
171 |
input=texts[i:i + batch_size],
|
|
|
172 |
text_type="document"
|
173 |
)
|
174 |
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
@@ -186,6 +196,7 @@ class QWenEmbed(Base):
|
|
186 |
resp = dashscope.TextEmbedding.call(
|
187 |
model=self.model_name,
|
188 |
input=text[:2048],
|
|
|
189 |
text_type="query"
|
190 |
)
|
191 |
return np.array(resp["output"]["embeddings"][0]
|
@@ -200,7 +211,7 @@ class ZhipuEmbed(Base):
|
|
200 |
self.client = ZhipuAI(api_key=key)
|
201 |
self.model_name = model_name
|
202 |
|
203 |
-
def encode(self, texts: list
|
204 |
arr = []
|
205 |
tks_num = 0
|
206 |
for txt in texts:
|
@@ -221,7 +232,7 @@ class OllamaEmbed(Base):
|
|
221 |
self.client = Client(host=kwargs["base_url"])
|
222 |
self.model_name = model_name
|
223 |
|
224 |
-
def encode(self, texts: list
|
225 |
arr = []
|
226 |
tks_num = 0
|
227 |
for txt in texts:
|
@@ -252,13 +263,13 @@ class FastEmbed(Base):
|
|
252 |
from fastembed import TextEmbedding
|
253 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
254 |
|
255 |
-
def encode(self, texts: list
|
256 |
# Using the internal tokenizer to encode the texts and get the total
|
257 |
# number of tokens
|
258 |
encodings = self._model.model.tokenizer.encode_batch(texts)
|
259 |
total_tokens = sum(len(e) for e in encodings)
|
260 |
|
261 |
-
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
|
262 |
|
263 |
return np.array(embeddings), total_tokens
|
264 |
|
@@ -278,11 +289,15 @@ class XinferenceEmbed(Base):
|
|
278 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
279 |
self.model_name = model_name
|
280 |
|
281 |
-
def encode(self, texts: list
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
|
|
286 |
|
287 |
def encode_queries(self, text):
|
288 |
res = self.client.embeddings.create(input=[text],
|
@@ -306,7 +321,8 @@ class YoudaoEmbed(Base):
|
|
306 |
model_name_or_path=model_name.replace(
|
307 |
"maidalun1020", "InfiniFlow"))
|
308 |
|
309 |
-
def encode(self, texts: list
|
|
|
310 |
res = []
|
311 |
token_count = 0
|
312 |
for t in texts:
|
@@ -332,15 +348,21 @@ class JinaEmbed(Base):
|
|
332 |
}
|
333 |
self.model_name = model_name
|
334 |
|
335 |
-
def encode(self, texts: list
|
336 |
texts = [truncate(t, 8196) for t in texts]
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
def encode_queries(self, text):
|
346 |
embds, cnt = self.encode([text])
|
@@ -394,12 +416,17 @@ class MistralEmbed(Base):
|
|
394 |
self.client = MistralClient(api_key=key)
|
395 |
self.model_name = model_name
|
396 |
|
397 |
-
def encode(self, texts: list
|
398 |
texts = [truncate(t, 8196) for t in texts]
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
|
|
|
|
|
|
|
|
403 |
|
404 |
def encode_queries(self, text):
|
405 |
res = self.client.embeddings(input=[truncate(text, 8196)],
|
@@ -418,7 +445,7 @@ class BedrockEmbed(Base):
|
|
418 |
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
419 |
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
420 |
|
421 |
-
def encode(self, texts: list
|
422 |
texts = [truncate(t, 8196) for t in texts]
|
423 |
embeddings = []
|
424 |
token_count = 0
|
@@ -436,7 +463,6 @@ class BedrockEmbed(Base):
|
|
436 |
return np.array(embeddings), token_count
|
437 |
|
438 |
def encode_queries(self, text):
|
439 |
-
|
440 |
embeddings = []
|
441 |
token_count = num_tokens_from_string(text)
|
442 |
if self.model_name.split('.')[0] == 'amazon':
|
@@ -453,20 +479,26 @@ class BedrockEmbed(Base):
|
|
453 |
class GeminiEmbed(Base):
|
454 |
def __init__(self, key, model_name='models/text-embedding-004',
|
455 |
**kwargs):
|
456 |
-
|
457 |
self.model_name = 'models/' + model_name
|
458 |
|
459 |
-
def encode(self, texts: list
|
460 |
texts = [truncate(t, 2048) for t in texts]
|
461 |
token_count = sum(num_tokens_from_string(text) for text in texts)
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
def encode_queries(self, text):
|
|
|
470 |
result = genai.embed_content(
|
471 |
model=self.model_name,
|
472 |
content=truncate(text,2048),
|
@@ -495,19 +527,22 @@ class NvidiaEmbed(Base):
|
|
495 |
if model_name == "snowflake/arctic-embed-l":
|
496 |
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
497 |
|
498 |
-
def encode(self, texts: list
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
|
|
|
|
|
|
511 |
|
512 |
def encode_queries(self, text):
|
513 |
embds, cnt = self.encode([text])
|
@@ -541,16 +576,20 @@ class CoHereEmbed(Base):
|
|
541 |
self.client = Client(api_key=key)
|
542 |
self.model_name = model_name
|
543 |
|
544 |
-
def encode(self, texts: list
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
|
|
|
|
|
|
|
|
554 |
|
555 |
def encode_queries(self, text):
|
556 |
res = self.client.embed(
|
@@ -599,19 +638,23 @@ class SILICONFLOWEmbed(Base):
|
|
599 |
self.base_url = base_url
|
600 |
self.model_name = model_name
|
601 |
|
602 |
-
def encode(self, texts: list
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
res
|
614 |
-
|
|
|
|
|
|
|
|
|
615 |
|
616 |
def encode_queries(self, text):
|
617 |
payload = {
|
@@ -632,9 +675,14 @@ class ReplicateEmbed(Base):
|
|
632 |
self.model_name = model_name
|
633 |
self.client = Client(api_token=key)
|
634 |
|
635 |
-
def encode(self, texts: list
|
636 |
-
|
637 |
-
|
|
|
|
|
|
|
|
|
|
|
638 |
|
639 |
def encode_queries(self, text):
|
640 |
res = self.client.embed(self.model_name, input={"texts": [text]})
|
@@ -673,11 +721,17 @@ class VoyageEmbed(Base):
|
|
673 |
self.client = voyageai.Client(api_key=key)
|
674 |
self.model_name = model_name
|
675 |
|
676 |
-
def encode(self, texts: list
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
681 |
|
682 |
def encode_queries(self, text):
|
683 |
res = self.client.embed(
|
@@ -694,7 +748,7 @@ class HuggingFaceEmbed(Base):
|
|
694 |
self.model_name = model_name
|
695 |
self.base_url = base_url or "http://127.0.0.1:8080"
|
696 |
|
697 |
-
def encode(self, texts: list
|
698 |
embeddings = []
|
699 |
for text in texts:
|
700 |
response = requests.post(
|
|
|
38 |
def __init__(self, key, model_name):
|
39 |
pass
|
40 |
|
41 |
+
def encode(self, texts: list):
|
42 |
raise NotImplementedError("Please implement encode method!")
|
43 |
|
44 |
def encode_queries(self, text: str):
|
|
|
78 |
use_fp16=torch.cuda.is_available())
|
79 |
self._model = DefaultEmbedding._model
|
80 |
|
81 |
+
def encode(self, texts: list):
|
82 |
+
batch_size = 16
|
83 |
texts = [truncate(t, 2048) for t in texts]
|
84 |
token_count = 0
|
85 |
for t in texts:
|
86 |
token_count += num_tokens_from_string(t)
|
87 |
+
ress = []
|
88 |
for i in range(0, len(texts), batch_size):
|
89 |
+
ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
90 |
+
return np.array(ress), token_count
|
91 |
|
92 |
def encode_queries(self, text: str):
|
93 |
token_count = num_tokens_from_string(text)
|
|
|
102 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
103 |
self.model_name = model_name
|
104 |
|
105 |
+
def encode(self, texts: list):
|
106 |
+
# OpenAI requires batch size <=16
|
107 |
+
batch_size = 16
|
108 |
texts = [truncate(t, 8191) for t in texts]
|
109 |
+
ress = []
|
110 |
+
total_tokens = 0
|
111 |
+
for i in range(0, len(texts), batch_size):
|
112 |
+
res = self.client.embeddings.create(input=texts[i:i + batch_size],
|
113 |
+
model=self.model_name)
|
114 |
+
ress.extend([d.embedding for d in res.data])
|
115 |
+
total_tokens += res.usage.total_tokens
|
116 |
+
return np.array(ress), total_tokens
|
117 |
|
118 |
def encode_queries(self, text):
|
119 |
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
|
|
130 |
self.client = OpenAI(api_key="empty", base_url=base_url)
|
131 |
self.model_name = model_name.split("___")[0]
|
132 |
|
133 |
+
def encode(self, texts: list):
|
134 |
+
batch_size = 16
|
135 |
+
ress = []
|
136 |
+
for i in range(0, len(texts), batch_size):
|
137 |
+
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
138 |
+
ress.extend([d.embedding for d in res.data])
|
139 |
+
# local embedding for LmStudio donot count tokens
|
140 |
+
return np.array(ress), 1024
|
141 |
|
142 |
def encode_queries(self, text):
|
143 |
embds, cnt = self.encode([text])
|
|
|
164 |
|
165 |
class QWenEmbed(Base):
|
166 |
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
167 |
+
self.key = key
|
168 |
self.model_name = model_name
|
169 |
|
170 |
+
def encode(self, texts: list):
|
171 |
import dashscope
|
172 |
+
batch_size = 4
|
173 |
try:
|
174 |
res = []
|
175 |
token_count = 0
|
|
|
178 |
resp = dashscope.TextEmbedding.call(
|
179 |
model=self.model_name,
|
180 |
input=texts[i:i + batch_size],
|
181 |
+
api_key=self.key,
|
182 |
text_type="document"
|
183 |
)
|
184 |
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
|
|
196 |
resp = dashscope.TextEmbedding.call(
|
197 |
model=self.model_name,
|
198 |
input=text[:2048],
|
199 |
+
api_key=self.key,
|
200 |
text_type="query"
|
201 |
)
|
202 |
return np.array(resp["output"]["embeddings"][0]
|
|
|
211 |
self.client = ZhipuAI(api_key=key)
|
212 |
self.model_name = model_name
|
213 |
|
214 |
+
def encode(self, texts: list):
|
215 |
arr = []
|
216 |
tks_num = 0
|
217 |
for txt in texts:
|
|
|
232 |
self.client = Client(host=kwargs["base_url"])
|
233 |
self.model_name = model_name
|
234 |
|
235 |
+
def encode(self, texts: list):
|
236 |
arr = []
|
237 |
tks_num = 0
|
238 |
for txt in texts:
|
|
|
263 |
from fastembed import TextEmbedding
|
264 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
265 |
|
266 |
+
def encode(self, texts: list):
|
267 |
# Using the internal tokenizer to encode the texts and get the total
|
268 |
# number of tokens
|
269 |
encodings = self._model.model.tokenizer.encode_batch(texts)
|
270 |
total_tokens = sum(len(e) for e in encodings)
|
271 |
|
272 |
+
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
|
273 |
|
274 |
return np.array(embeddings), total_tokens
|
275 |
|
|
|
289 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
290 |
self.model_name = model_name
|
291 |
|
292 |
+
def encode(self, texts: list):
|
293 |
+
batch_size = 16
|
294 |
+
ress = []
|
295 |
+
total_tokens = 0
|
296 |
+
for i in range(0, len(texts), batch_size):
|
297 |
+
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
298 |
+
ress.extend([d.embedding for d in res.data])
|
299 |
+
total_tokens += res.usage.total_tokens
|
300 |
+
return np.array(ress), total_tokens
|
301 |
|
302 |
def encode_queries(self, text):
|
303 |
res = self.client.embeddings.create(input=[text],
|
|
|
321 |
model_name_or_path=model_name.replace(
|
322 |
"maidalun1020", "InfiniFlow"))
|
323 |
|
324 |
+
def encode(self, texts: list):
|
325 |
+
batch_size = 10
|
326 |
res = []
|
327 |
token_count = 0
|
328 |
for t in texts:
|
|
|
348 |
}
|
349 |
self.model_name = model_name
|
350 |
|
351 |
+
def encode(self, texts: list):
|
352 |
texts = [truncate(t, 8196) for t in texts]
|
353 |
+
batch_size = 16
|
354 |
+
ress = []
|
355 |
+
token_count = 0
|
356 |
+
for i in range(0, len(texts), batch_size):
|
357 |
+
data = {
|
358 |
+
"model": self.model_name,
|
359 |
+
"input": texts[i:i + batch_size],
|
360 |
+
'encoding_type': 'float'
|
361 |
+
}
|
362 |
+
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
363 |
+
ress.extend([d["embedding"] for d in res["data"]])
|
364 |
+
token_count += res["usage"]["total_tokens"]
|
365 |
+
return np.array(ress), token_count
|
366 |
|
367 |
def encode_queries(self, text):
|
368 |
embds, cnt = self.encode([text])
|
|
|
416 |
self.client = MistralClient(api_key=key)
|
417 |
self.model_name = model_name
|
418 |
|
419 |
+
def encode(self, texts: list):
|
420 |
texts = [truncate(t, 8196) for t in texts]
|
421 |
+
batch_size = 16
|
422 |
+
ress = []
|
423 |
+
token_count = 0
|
424 |
+
for i in range(0, len(texts), batch_size):
|
425 |
+
res = self.client.embeddings(input=texts[i:i + batch_size],
|
426 |
+
model=self.model_name)
|
427 |
+
ress.extend([d.embedding for d in res.data])
|
428 |
+
token_count += res.usage.total_tokens
|
429 |
+
return np.array(ress), token_count
|
430 |
|
431 |
def encode_queries(self, text):
|
432 |
res = self.client.embeddings(input=[truncate(text, 8196)],
|
|
|
445 |
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
446 |
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
447 |
|
448 |
+
def encode(self, texts: list):
|
449 |
texts = [truncate(t, 8196) for t in texts]
|
450 |
embeddings = []
|
451 |
token_count = 0
|
|
|
463 |
return np.array(embeddings), token_count
|
464 |
|
465 |
def encode_queries(self, text):
|
|
|
466 |
embeddings = []
|
467 |
token_count = num_tokens_from_string(text)
|
468 |
if self.model_name.split('.')[0] == 'amazon':
|
|
|
479 |
class GeminiEmbed(Base):
|
480 |
def __init__(self, key, model_name='models/text-embedding-004',
|
481 |
**kwargs):
|
482 |
+
self.key = key
|
483 |
self.model_name = 'models/' + model_name
|
484 |
|
485 |
+
def encode(self, texts: list):
|
486 |
texts = [truncate(t, 2048) for t in texts]
|
487 |
token_count = sum(num_tokens_from_string(text) for text in texts)
|
488 |
+
genai.configure(api_key=self.key)
|
489 |
+
batch_size = 16
|
490 |
+
ress = []
|
491 |
+
for i in range(0, len(texts), batch_size):
|
492 |
+
result = genai.embed_content(
|
493 |
+
model=self.model_name,
|
494 |
+
content=texts[i, i + batch_size],
|
495 |
+
task_type="retrieval_document",
|
496 |
+
title="Embedding of single string")
|
497 |
+
ress.extend(result['embedding'])
|
498 |
+
return np.array(ress),token_count
|
499 |
|
500 |
def encode_queries(self, text):
|
501 |
+
genai.configure(api_key=self.key)
|
502 |
result = genai.embed_content(
|
503 |
model=self.model_name,
|
504 |
content=truncate(text,2048),
|
|
|
527 |
if model_name == "snowflake/arctic-embed-l":
|
528 |
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
529 |
|
530 |
+
def encode(self, texts: list):
|
531 |
+
batch_size = 16
|
532 |
+
ress = []
|
533 |
+
token_count = 0
|
534 |
+
for i in range(0, len(texts), batch_size):
|
535 |
+
payload = {
|
536 |
+
"input": texts[i : i + batch_size],
|
537 |
+
"input_type": "query",
|
538 |
+
"model": self.model_name,
|
539 |
+
"encoding_format": "float",
|
540 |
+
"truncate": "END",
|
541 |
+
}
|
542 |
+
res = requests.post(self.base_url, headers=self.headers, json=payload).json()
|
543 |
+
ress.extend([d["embedding"] for d in res["data"]])
|
544 |
+
token_count += res["usage"]["total_tokens"]
|
545 |
+
return np.array(ress), token_count
|
546 |
|
547 |
def encode_queries(self, text):
|
548 |
embds, cnt = self.encode([text])
|
|
|
576 |
self.client = Client(api_key=key)
|
577 |
self.model_name = model_name
|
578 |
|
579 |
+
def encode(self, texts: list):
|
580 |
+
batch_size = 16
|
581 |
+
ress = []
|
582 |
+
token_count = 0
|
583 |
+
for i in range(0, len(texts), batch_size):
|
584 |
+
res = self.client.embed(
|
585 |
+
texts=texts[i : i + batch_size],
|
586 |
+
model=self.model_name,
|
587 |
+
input_type="search_document",
|
588 |
+
embedding_types=["float"],
|
589 |
+
)
|
590 |
+
ress.extend([d for d in res.embeddings.float])
|
591 |
+
token_count += res.meta.billed_units.input_tokens
|
592 |
+
return np.array(ress), token_count
|
593 |
|
594 |
def encode_queries(self, text):
|
595 |
res = self.client.embed(
|
|
|
638 |
self.base_url = base_url
|
639 |
self.model_name = model_name
|
640 |
|
641 |
+
def encode(self, texts: list):
|
642 |
+
batch_size = 16
|
643 |
+
ress = []
|
644 |
+
token_count = 0
|
645 |
+
for i in range(0, len(texts), batch_size):
|
646 |
+
texts_batch = texts[i : i + batch_size]
|
647 |
+
payload = {
|
648 |
+
"model": self.model_name,
|
649 |
+
"input": texts_batch,
|
650 |
+
"encoding_format": "float",
|
651 |
+
}
|
652 |
+
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
653 |
+
if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
|
654 |
+
raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
|
655 |
+
ress.extend([d["embedding"] for d in res["data"]])
|
656 |
+
token_count += res["usage"]["total_tokens"]
|
657 |
+
return np.array(ress), token_count
|
658 |
|
659 |
def encode_queries(self, text):
|
660 |
payload = {
|
|
|
675 |
self.model_name = model_name
|
676 |
self.client = Client(api_token=key)
|
677 |
|
678 |
+
def encode(self, texts: list):
|
679 |
+
batch_size = 16
|
680 |
+
token_count = sum([num_tokens_from_string(text) for text in texts])
|
681 |
+
ress = []
|
682 |
+
for i in range(0, len(texts), batch_size):
|
683 |
+
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
|
684 |
+
ress.extend(res)
|
685 |
+
return np.array(ress), token_count
|
686 |
|
687 |
def encode_queries(self, text):
|
688 |
res = self.client.embed(self.model_name, input={"texts": [text]})
|
|
|
721 |
self.client = voyageai.Client(api_key=key)
|
722 |
self.model_name = model_name
|
723 |
|
724 |
+
def encode(self, texts: list):
|
725 |
+
batch_size = 16
|
726 |
+
ress = []
|
727 |
+
token_count = 0
|
728 |
+
for i in range(0, len(texts), batch_size):
|
729 |
+
res = self.client.embed(
|
730 |
+
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
|
731 |
+
)
|
732 |
+
ress.extend(res.embeddings)
|
733 |
+
token_count += res.total_tokens
|
734 |
+
return np.array(ress), token_count
|
735 |
|
736 |
def encode_queries(self, text):
|
737 |
res = self.client.embed(
|
|
|
748 |
self.model_name = model_name
|
749 |
self.base_url = base_url or "http://127.0.0.1:8080"
|
750 |
|
751 |
+
def encode(self, texts: list):
|
752 |
embeddings = []
|
753 |
for text in texts:
|
754 |
response = requests.post(
|