zhichyu commited on
Commit
08913be
·
1 Parent(s): ef2346f

Refactor embedding batch_size (#3825)

Browse files

### What problem does this PR solve?

Refactor embedding batch_size. Close #3657

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

api/db/services/llm_service.py CHANGED
@@ -232,13 +232,13 @@ class LLMBundle(object):
232
  self.max_length = lm.max_tokens
233
  break
234
 
235
- def encode(self, texts: list, batch_size=32):
236
- emd, used_tokens = self.mdl.encode(texts, batch_size)
237
  if not TenantLLMService.increase_usage(
238
  self.tenant_id, self.llm_type, used_tokens):
239
  logging.error(
240
  "LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
241
- return emd, used_tokens
242
 
243
  def encode_queries(self, query: str):
244
  emd, used_tokens = self.mdl.encode_queries(query)
@@ -280,7 +280,7 @@ class LLMBundle(object):
280
  logging.error(
281
  "LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
282
  return
283
- yield chunk
284
 
285
  def chat(self, system, history, gen_conf):
286
  txt, used_tokens = self.mdl.chat(system, history, gen_conf)
 
232
  self.max_length = lm.max_tokens
233
  break
234
 
235
+ def encode(self, texts: list):
236
+ embeddings, used_tokens = self.mdl.encode(texts)
237
  if not TenantLLMService.increase_usage(
238
  self.tenant_id, self.llm_type, used_tokens):
239
  logging.error(
240
  "LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
241
+ return embeddings, used_tokens
242
 
243
  def encode_queries(self, query: str):
244
  emd, used_tokens = self.mdl.encode_queries(query)
 
280
  logging.error(
281
  "LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
282
  return
283
+ yield chunk
284
 
285
  def chat(self, system, history, gen_conf):
286
  txt, used_tokens = self.mdl.chat(system, history, gen_conf)
rag/benchmark.py CHANGED
@@ -63,16 +63,13 @@ class Benchmark:
63
  run[query][c["chunk_id"]] = c["similarity"]
64
  return run
65
 
66
- def embedding(self, docs, batch_size=16):
67
- vects = []
68
- cnts = [d["content_with_weight"] for d in docs]
69
- for i in range(0, len(cnts), batch_size):
70
- vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
71
- vects.extend(vts.tolist())
72
- assert len(docs) == len(vects)
73
  vector_size = 0
74
  for i, d in enumerate(docs):
75
- v = vects[i]
76
  vector_size = len(v)
77
  d["q_%d_vec" % len(v)] = v
78
  return docs, vector_size
 
63
  run[query][c["chunk_id"]] = c["similarity"]
64
  return run
65
 
66
+ def embedding(self, docs):
67
+ texts = [d["content_with_weight"] for d in docs]
68
+ embeddings, _ = self.embd_mdl.encode(texts)
69
+ assert len(docs) == len(embeddings)
 
 
 
70
  vector_size = 0
71
  for i, d in enumerate(docs):
72
+ v = embeddings[i]
73
  vector_size = len(v)
74
  d["q_%d_vec" % len(v)] = v
75
  return docs, vector_size
rag/llm/embedding_model.py CHANGED
@@ -38,7 +38,7 @@ class Base(ABC):
38
  def __init__(self, key, model_name):
39
  pass
40
 
41
- def encode(self, texts: list, batch_size=16):
42
  raise NotImplementedError("Please implement encode method!")
43
 
44
  def encode_queries(self, text: str):
@@ -78,15 +78,16 @@ class DefaultEmbedding(Base):
78
  use_fp16=torch.cuda.is_available())
79
  self._model = DefaultEmbedding._model
80
 
81
- def encode(self, texts: list, batch_size=16):
 
82
  texts = [truncate(t, 2048) for t in texts]
83
  token_count = 0
84
  for t in texts:
85
  token_count += num_tokens_from_string(t)
86
- res = []
87
  for i in range(0, len(texts), batch_size):
88
- res.extend(self._model.encode(texts[i:i + batch_size]).tolist())
89
- return np.array(res), token_count
90
 
91
  def encode_queries(self, text: str):
92
  token_count = num_tokens_from_string(text)
@@ -101,12 +102,18 @@ class OpenAIEmbed(Base):
101
  self.client = OpenAI(api_key=key, base_url=base_url)
102
  self.model_name = model_name
103
 
104
- def encode(self, texts: list, batch_size=16):
 
 
105
  texts = [truncate(t, 8191) for t in texts]
106
- res = self.client.embeddings.create(input=texts,
107
- model=self.model_name)
108
- return np.array([d.embedding for d in res.data]
109
- ), res.usage.total_tokens
 
 
 
 
110
 
111
  def encode_queries(self, text):
112
  res = self.client.embeddings.create(input=[truncate(text, 8191)],
@@ -123,12 +130,14 @@ class LocalAIEmbed(Base):
123
  self.client = OpenAI(api_key="empty", base_url=base_url)
124
  self.model_name = model_name.split("___")[0]
125
 
126
- def encode(self, texts: list, batch_size=16):
127
- res = self.client.embeddings.create(input=texts, model=self.model_name)
128
- return (
129
- np.array([d.embedding for d in res.data]),
130
- 1024,
131
- ) # local embedding for LmStudio donot count tokens
 
 
132
 
133
  def encode_queries(self, text):
134
  embds, cnt = self.encode([text])
@@ -155,12 +164,12 @@ class BaiChuanEmbed(OpenAIEmbed):
155
 
156
  class QWenEmbed(Base):
157
  def __init__(self, key, model_name="text_embedding_v2", **kwargs):
158
- dashscope.api_key = key
159
  self.model_name = model_name
160
 
161
- def encode(self, texts: list, batch_size=10):
162
  import dashscope
163
- batch_size = min(batch_size, 4)
164
  try:
165
  res = []
166
  token_count = 0
@@ -169,6 +178,7 @@ class QWenEmbed(Base):
169
  resp = dashscope.TextEmbedding.call(
170
  model=self.model_name,
171
  input=texts[i:i + batch_size],
 
172
  text_type="document"
173
  )
174
  embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
@@ -186,6 +196,7 @@ class QWenEmbed(Base):
186
  resp = dashscope.TextEmbedding.call(
187
  model=self.model_name,
188
  input=text[:2048],
 
189
  text_type="query"
190
  )
191
  return np.array(resp["output"]["embeddings"][0]
@@ -200,7 +211,7 @@ class ZhipuEmbed(Base):
200
  self.client = ZhipuAI(api_key=key)
201
  self.model_name = model_name
202
 
203
- def encode(self, texts: list, batch_size=16):
204
  arr = []
205
  tks_num = 0
206
  for txt in texts:
@@ -221,7 +232,7 @@ class OllamaEmbed(Base):
221
  self.client = Client(host=kwargs["base_url"])
222
  self.model_name = model_name
223
 
224
- def encode(self, texts: list, batch_size=16):
225
  arr = []
226
  tks_num = 0
227
  for txt in texts:
@@ -252,13 +263,13 @@ class FastEmbed(Base):
252
  from fastembed import TextEmbedding
253
  self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
254
 
255
- def encode(self, texts: list, batch_size=16):
256
  # Using the internal tokenizer to encode the texts and get the total
257
  # number of tokens
258
  encodings = self._model.model.tokenizer.encode_batch(texts)
259
  total_tokens = sum(len(e) for e in encodings)
260
 
261
- embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
262
 
263
  return np.array(embeddings), total_tokens
264
 
@@ -278,11 +289,15 @@ class XinferenceEmbed(Base):
278
  self.client = OpenAI(api_key=key, base_url=base_url)
279
  self.model_name = model_name
280
 
281
- def encode(self, texts: list, batch_size=16):
282
- res = self.client.embeddings.create(input=texts,
283
- model=self.model_name)
284
- return np.array([d.embedding for d in res.data]
285
- ), res.usage.total_tokens
 
 
 
 
286
 
287
  def encode_queries(self, text):
288
  res = self.client.embeddings.create(input=[text],
@@ -306,7 +321,8 @@ class YoudaoEmbed(Base):
306
  model_name_or_path=model_name.replace(
307
  "maidalun1020", "InfiniFlow"))
308
 
309
- def encode(self, texts: list, batch_size=10):
 
310
  res = []
311
  token_count = 0
312
  for t in texts:
@@ -332,15 +348,21 @@ class JinaEmbed(Base):
332
  }
333
  self.model_name = model_name
334
 
335
- def encode(self, texts: list, batch_size=None):
336
  texts = [truncate(t, 8196) for t in texts]
337
- data = {
338
- "model": self.model_name,
339
- "input": texts,
340
- 'encoding_type': 'float'
341
- }
342
- res = requests.post(self.base_url, headers=self.headers, json=data).json()
343
- return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
 
 
 
 
 
 
344
 
345
  def encode_queries(self, text):
346
  embds, cnt = self.encode([text])
@@ -394,12 +416,17 @@ class MistralEmbed(Base):
394
  self.client = MistralClient(api_key=key)
395
  self.model_name = model_name
396
 
397
- def encode(self, texts: list, batch_size=16):
398
  texts = [truncate(t, 8196) for t in texts]
399
- res = self.client.embeddings(input=texts,
400
- model=self.model_name)
401
- return np.array([d.embedding for d in res.data]
402
- ), res.usage.total_tokens
 
 
 
 
 
403
 
404
  def encode_queries(self, text):
405
  res = self.client.embeddings(input=[truncate(text, 8196)],
@@ -418,7 +445,7 @@ class BedrockEmbed(Base):
418
  self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
419
  aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
420
 
421
- def encode(self, texts: list, batch_size=16):
422
  texts = [truncate(t, 8196) for t in texts]
423
  embeddings = []
424
  token_count = 0
@@ -436,7 +463,6 @@ class BedrockEmbed(Base):
436
  return np.array(embeddings), token_count
437
 
438
  def encode_queries(self, text):
439
-
440
  embeddings = []
441
  token_count = num_tokens_from_string(text)
442
  if self.model_name.split('.')[0] == 'amazon':
@@ -453,20 +479,26 @@ class BedrockEmbed(Base):
453
  class GeminiEmbed(Base):
454
  def __init__(self, key, model_name='models/text-embedding-004',
455
  **kwargs):
456
- genai.configure(api_key=key)
457
  self.model_name = 'models/' + model_name
458
 
459
- def encode(self, texts: list, batch_size=16):
460
  texts = [truncate(t, 2048) for t in texts]
461
  token_count = sum(num_tokens_from_string(text) for text in texts)
462
- result = genai.embed_content(
463
- model=self.model_name,
464
- content=texts,
465
- task_type="retrieval_document",
466
- title="Embedding of list of strings")
467
- return np.array(result['embedding']),token_count
 
 
 
 
 
468
 
469
  def encode_queries(self, text):
 
470
  result = genai.embed_content(
471
  model=self.model_name,
472
  content=truncate(text,2048),
@@ -495,19 +527,22 @@ class NvidiaEmbed(Base):
495
  if model_name == "snowflake/arctic-embed-l":
496
  self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
497
 
498
- def encode(self, texts: list, batch_size=None):
499
- payload = {
500
- "input": texts,
501
- "input_type": "query",
502
- "model": self.model_name,
503
- "encoding_format": "float",
504
- "truncate": "END",
505
- }
506
- res = requests.post(self.base_url, headers=self.headers, json=payload).json()
507
- return (
508
- np.array([d["embedding"] for d in res["data"]]),
509
- res["usage"]["total_tokens"],
510
- )
 
 
 
511
 
512
  def encode_queries(self, text):
513
  embds, cnt = self.encode([text])
@@ -541,16 +576,20 @@ class CoHereEmbed(Base):
541
  self.client = Client(api_key=key)
542
  self.model_name = model_name
543
 
544
- def encode(self, texts: list, batch_size=16):
545
- res = self.client.embed(
546
- texts=texts,
547
- model=self.model_name,
548
- input_type="search_query",
549
- embedding_types=["float"],
550
- )
551
- return np.array([d for d in res.embeddings.float]), int(
552
- res.meta.billed_units.input_tokens
553
- )
 
 
 
 
554
 
555
  def encode_queries(self, text):
556
  res = self.client.embed(
@@ -599,19 +638,23 @@ class SILICONFLOWEmbed(Base):
599
  self.base_url = base_url
600
  self.model_name = model_name
601
 
602
- def encode(self, texts: list, batch_size=16):
603
- payload = {
604
- "model": self.model_name,
605
- "input": texts,
606
- "encoding_format": "float",
607
- }
608
- res = requests.post(self.base_url, json=payload, headers=self.headers).json()
609
- if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= len(texts):
610
- raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
611
- return (
612
- np.array([d["embedding"] for d in res["data"]]),
613
- res["usage"]["total_tokens"],
614
- )
 
 
 
 
615
 
616
  def encode_queries(self, text):
617
  payload = {
@@ -632,9 +675,14 @@ class ReplicateEmbed(Base):
632
  self.model_name = model_name
633
  self.client = Client(api_token=key)
634
 
635
- def encode(self, texts: list, batch_size=16):
636
- res = self.client.run(self.model_name, input={"texts": json.dumps(texts)})
637
- return np.array(res), sum([num_tokens_from_string(text) for text in texts])
 
 
 
 
 
638
 
639
  def encode_queries(self, text):
640
  res = self.client.embed(self.model_name, input={"texts": [text]})
@@ -673,11 +721,17 @@ class VoyageEmbed(Base):
673
  self.client = voyageai.Client(api_key=key)
674
  self.model_name = model_name
675
 
676
- def encode(self, texts: list, batch_size=16):
677
- res = self.client.embed(
678
- texts=texts, model=self.model_name, input_type="document"
679
- )
680
- return np.array(res.embeddings), res.total_tokens
 
 
 
 
 
 
681
 
682
  def encode_queries(self, text):
683
  res = self.client.embed(
@@ -694,7 +748,7 @@ class HuggingFaceEmbed(Base):
694
  self.model_name = model_name
695
  self.base_url = base_url or "http://127.0.0.1:8080"
696
 
697
- def encode(self, texts: list, batch_size=16):
698
  embeddings = []
699
  for text in texts:
700
  response = requests.post(
 
38
  def __init__(self, key, model_name):
39
  pass
40
 
41
+ def encode(self, texts: list):
42
  raise NotImplementedError("Please implement encode method!")
43
 
44
  def encode_queries(self, text: str):
 
78
  use_fp16=torch.cuda.is_available())
79
  self._model = DefaultEmbedding._model
80
 
81
+ def encode(self, texts: list):
82
+ batch_size = 16
83
  texts = [truncate(t, 2048) for t in texts]
84
  token_count = 0
85
  for t in texts:
86
  token_count += num_tokens_from_string(t)
87
+ ress = []
88
  for i in range(0, len(texts), batch_size):
89
+ ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
90
+ return np.array(ress), token_count
91
 
92
  def encode_queries(self, text: str):
93
  token_count = num_tokens_from_string(text)
 
102
  self.client = OpenAI(api_key=key, base_url=base_url)
103
  self.model_name = model_name
104
 
105
+ def encode(self, texts: list):
106
+ # OpenAI requires batch size <=16
107
+ batch_size = 16
108
  texts = [truncate(t, 8191) for t in texts]
109
+ ress = []
110
+ total_tokens = 0
111
+ for i in range(0, len(texts), batch_size):
112
+ res = self.client.embeddings.create(input=texts[i:i + batch_size],
113
+ model=self.model_name)
114
+ ress.extend([d.embedding for d in res.data])
115
+ total_tokens += res.usage.total_tokens
116
+ return np.array(ress), total_tokens
117
 
118
  def encode_queries(self, text):
119
  res = self.client.embeddings.create(input=[truncate(text, 8191)],
 
130
  self.client = OpenAI(api_key="empty", base_url=base_url)
131
  self.model_name = model_name.split("___")[0]
132
 
133
+ def encode(self, texts: list):
134
+ batch_size = 16
135
+ ress = []
136
+ for i in range(0, len(texts), batch_size):
137
+ res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
138
+ ress.extend([d.embedding for d in res.data])
139
+ # local embedding for LmStudio donot count tokens
140
+ return np.array(ress), 1024
141
 
142
  def encode_queries(self, text):
143
  embds, cnt = self.encode([text])
 
164
 
165
  class QWenEmbed(Base):
166
  def __init__(self, key, model_name="text_embedding_v2", **kwargs):
167
+ self.key = key
168
  self.model_name = model_name
169
 
170
+ def encode(self, texts: list):
171
  import dashscope
172
+ batch_size = 4
173
  try:
174
  res = []
175
  token_count = 0
 
178
  resp = dashscope.TextEmbedding.call(
179
  model=self.model_name,
180
  input=texts[i:i + batch_size],
181
+ api_key=self.key,
182
  text_type="document"
183
  )
184
  embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
 
196
  resp = dashscope.TextEmbedding.call(
197
  model=self.model_name,
198
  input=text[:2048],
199
+ api_key=self.key,
200
  text_type="query"
201
  )
202
  return np.array(resp["output"]["embeddings"][0]
 
211
  self.client = ZhipuAI(api_key=key)
212
  self.model_name = model_name
213
 
214
+ def encode(self, texts: list):
215
  arr = []
216
  tks_num = 0
217
  for txt in texts:
 
232
  self.client = Client(host=kwargs["base_url"])
233
  self.model_name = model_name
234
 
235
+ def encode(self, texts: list):
236
  arr = []
237
  tks_num = 0
238
  for txt in texts:
 
263
  from fastembed import TextEmbedding
264
  self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
265
 
266
+ def encode(self, texts: list):
267
  # Using the internal tokenizer to encode the texts and get the total
268
  # number of tokens
269
  encodings = self._model.model.tokenizer.encode_batch(texts)
270
  total_tokens = sum(len(e) for e in encodings)
271
 
272
+ embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
273
 
274
  return np.array(embeddings), total_tokens
275
 
 
289
  self.client = OpenAI(api_key=key, base_url=base_url)
290
  self.model_name = model_name
291
 
292
+ def encode(self, texts: list):
293
+ batch_size = 16
294
+ ress = []
295
+ total_tokens = 0
296
+ for i in range(0, len(texts), batch_size):
297
+ res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
298
+ ress.extend([d.embedding for d in res.data])
299
+ total_tokens += res.usage.total_tokens
300
+ return np.array(ress), total_tokens
301
 
302
  def encode_queries(self, text):
303
  res = self.client.embeddings.create(input=[text],
 
321
  model_name_or_path=model_name.replace(
322
  "maidalun1020", "InfiniFlow"))
323
 
324
+ def encode(self, texts: list):
325
+ batch_size = 10
326
  res = []
327
  token_count = 0
328
  for t in texts:
 
348
  }
349
  self.model_name = model_name
350
 
351
+ def encode(self, texts: list):
352
  texts = [truncate(t, 8196) for t in texts]
353
+ batch_size = 16
354
+ ress = []
355
+ token_count = 0
356
+ for i in range(0, len(texts), batch_size):
357
+ data = {
358
+ "model": self.model_name,
359
+ "input": texts[i:i + batch_size],
360
+ 'encoding_type': 'float'
361
+ }
362
+ res = requests.post(self.base_url, headers=self.headers, json=data).json()
363
+ ress.extend([d["embedding"] for d in res["data"]])
364
+ token_count += res["usage"]["total_tokens"]
365
+ return np.array(ress), token_count
366
 
367
  def encode_queries(self, text):
368
  embds, cnt = self.encode([text])
 
416
  self.client = MistralClient(api_key=key)
417
  self.model_name = model_name
418
 
419
+ def encode(self, texts: list):
420
  texts = [truncate(t, 8196) for t in texts]
421
+ batch_size = 16
422
+ ress = []
423
+ token_count = 0
424
+ for i in range(0, len(texts), batch_size):
425
+ res = self.client.embeddings(input=texts[i:i + batch_size],
426
+ model=self.model_name)
427
+ ress.extend([d.embedding for d in res.data])
428
+ token_count += res.usage.total_tokens
429
+ return np.array(ress), token_count
430
 
431
  def encode_queries(self, text):
432
  res = self.client.embeddings(input=[truncate(text, 8196)],
 
445
  self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
446
  aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
447
 
448
+ def encode(self, texts: list):
449
  texts = [truncate(t, 8196) for t in texts]
450
  embeddings = []
451
  token_count = 0
 
463
  return np.array(embeddings), token_count
464
 
465
  def encode_queries(self, text):
 
466
  embeddings = []
467
  token_count = num_tokens_from_string(text)
468
  if self.model_name.split('.')[0] == 'amazon':
 
479
  class GeminiEmbed(Base):
480
  def __init__(self, key, model_name='models/text-embedding-004',
481
  **kwargs):
482
+ self.key = key
483
  self.model_name = 'models/' + model_name
484
 
485
+ def encode(self, texts: list):
486
  texts = [truncate(t, 2048) for t in texts]
487
  token_count = sum(num_tokens_from_string(text) for text in texts)
488
+ genai.configure(api_key=self.key)
489
+ batch_size = 16
490
+ ress = []
491
+ for i in range(0, len(texts), batch_size):
492
+ result = genai.embed_content(
493
+ model=self.model_name,
494
+ content=texts[i, i + batch_size],
495
+ task_type="retrieval_document",
496
+ title="Embedding of single string")
497
+ ress.extend(result['embedding'])
498
+ return np.array(ress),token_count
499
 
500
  def encode_queries(self, text):
501
+ genai.configure(api_key=self.key)
502
  result = genai.embed_content(
503
  model=self.model_name,
504
  content=truncate(text,2048),
 
527
  if model_name == "snowflake/arctic-embed-l":
528
  self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
529
 
530
+ def encode(self, texts: list):
531
+ batch_size = 16
532
+ ress = []
533
+ token_count = 0
534
+ for i in range(0, len(texts), batch_size):
535
+ payload = {
536
+ "input": texts[i : i + batch_size],
537
+ "input_type": "query",
538
+ "model": self.model_name,
539
+ "encoding_format": "float",
540
+ "truncate": "END",
541
+ }
542
+ res = requests.post(self.base_url, headers=self.headers, json=payload).json()
543
+ ress.extend([d["embedding"] for d in res["data"]])
544
+ token_count += res["usage"]["total_tokens"]
545
+ return np.array(ress), token_count
546
 
547
  def encode_queries(self, text):
548
  embds, cnt = self.encode([text])
 
576
  self.client = Client(api_key=key)
577
  self.model_name = model_name
578
 
579
+ def encode(self, texts: list):
580
+ batch_size = 16
581
+ ress = []
582
+ token_count = 0
583
+ for i in range(0, len(texts), batch_size):
584
+ res = self.client.embed(
585
+ texts=texts[i : i + batch_size],
586
+ model=self.model_name,
587
+ input_type="search_document",
588
+ embedding_types=["float"],
589
+ )
590
+ ress.extend([d for d in res.embeddings.float])
591
+ token_count += res.meta.billed_units.input_tokens
592
+ return np.array(ress), token_count
593
 
594
  def encode_queries(self, text):
595
  res = self.client.embed(
 
638
  self.base_url = base_url
639
  self.model_name = model_name
640
 
641
+ def encode(self, texts: list):
642
+ batch_size = 16
643
+ ress = []
644
+ token_count = 0
645
+ for i in range(0, len(texts), batch_size):
646
+ texts_batch = texts[i : i + batch_size]
647
+ payload = {
648
+ "model": self.model_name,
649
+ "input": texts_batch,
650
+ "encoding_format": "float",
651
+ }
652
+ res = requests.post(self.base_url, json=payload, headers=self.headers).json()
653
+ if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
654
+ raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
655
+ ress.extend([d["embedding"] for d in res["data"]])
656
+ token_count += res["usage"]["total_tokens"]
657
+ return np.array(ress), token_count
658
 
659
  def encode_queries(self, text):
660
  payload = {
 
675
  self.model_name = model_name
676
  self.client = Client(api_token=key)
677
 
678
+ def encode(self, texts: list):
679
+ batch_size = 16
680
+ token_count = sum([num_tokens_from_string(text) for text in texts])
681
+ ress = []
682
+ for i in range(0, len(texts), batch_size):
683
+ res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
684
+ ress.extend(res)
685
+ return np.array(ress), token_count
686
 
687
  def encode_queries(self, text):
688
  res = self.client.embed(self.model_name, input={"texts": [text]})
 
721
  self.client = voyageai.Client(api_key=key)
722
  self.model_name = model_name
723
 
724
+ def encode(self, texts: list):
725
+ batch_size = 16
726
+ ress = []
727
+ token_count = 0
728
+ for i in range(0, len(texts), batch_size):
729
+ res = self.client.embed(
730
+ texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
731
+ )
732
+ ress.extend(res.embeddings)
733
+ token_count += res.total_tokens
734
+ return np.array(ress), token_count
735
 
736
  def encode_queries(self, text):
737
  res = self.client.embed(
 
748
  self.model_name = model_name
749
  self.base_url = base_url or "http://127.0.0.1:8080"
750
 
751
+ def encode(self, texts: list):
752
  embeddings = []
753
  for text in texts:
754
  response = requests.post(