Kevin Hu commited on
Commit
1dc3f10
·
1 Parent(s): cd19d72

Refactor for total_tokens. (#4652)

Browse files

### What problem does this PR solve?

#4567
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

rag/llm/chat_model.py CHANGED
@@ -53,7 +53,7 @@ class Base(ABC):
53
  ans += LENGTH_NOTIFICATION_CN
54
  else:
55
  ans += LENGTH_NOTIFICATION_EN
56
- return ans, response.usage.total_tokens
57
  except openai.APIError as e:
58
  return "**ERROR**: " + str(e), 0
59
 
@@ -75,15 +75,11 @@ class Base(ABC):
75
  resp.choices[0].delta.content = ""
76
  ans += resp.choices[0].delta.content
77
 
78
- if not hasattr(resp, "usage") or not resp.usage:
79
- total_tokens = (
80
- total_tokens
81
- + num_tokens_from_string(resp.choices[0].delta.content)
82
- )
83
- elif isinstance(resp.usage, dict):
84
- total_tokens = resp.usage.get("total_tokens", total_tokens)
85
  else:
86
- total_tokens = resp.usage.total_tokens
87
 
88
  if resp.choices[0].finish_reason == "length":
89
  if is_chinese(ans):
@@ -97,6 +93,17 @@ class Base(ABC):
97
 
98
  yield total_tokens
99
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  class GptTurbo(Base):
102
  def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
@@ -182,7 +189,7 @@ class BaiChuanChat(Base):
182
  ans += LENGTH_NOTIFICATION_CN
183
  else:
184
  ans += LENGTH_NOTIFICATION_EN
185
- return ans, response.usage.total_tokens
186
  except openai.APIError as e:
187
  return "**ERROR**: " + str(e), 0
188
 
@@ -212,14 +219,11 @@ class BaiChuanChat(Base):
212
  if not resp.choices[0].delta.content:
213
  resp.choices[0].delta.content = ""
214
  ans += resp.choices[0].delta.content
215
- total_tokens = (
216
- (
217
- total_tokens
218
- + num_tokens_from_string(resp.choices[0].delta.content)
219
- )
220
- if not hasattr(resp, "usage")
221
- else resp.usage["total_tokens"]
222
- )
223
  if resp.choices[0].finish_reason == "length":
224
  if is_chinese([ans]):
225
  ans += LENGTH_NOTIFICATION_CN
@@ -256,7 +260,7 @@ class QWenChat(Base):
256
  tk_count = 0
257
  if response.status_code == HTTPStatus.OK:
258
  ans += response.output.choices[0]['message']['content']
259
- tk_count += response.usage.total_tokens
260
  if response.output.choices[0].get("finish_reason", "") == "length":
261
  if is_chinese([ans]):
262
  ans += LENGTH_NOTIFICATION_CN
@@ -292,7 +296,7 @@ class QWenChat(Base):
292
  for resp in response:
293
  if resp.status_code == HTTPStatus.OK:
294
  ans = resp.output.choices[0]['message']['content']
295
- tk_count = resp.usage.total_tokens
296
  if resp.output.choices[0].get("finish_reason", "") == "length":
297
  if is_chinese(ans):
298
  ans += LENGTH_NOTIFICATION_CN
@@ -334,7 +338,7 @@ class ZhipuChat(Base):
334
  ans += LENGTH_NOTIFICATION_CN
335
  else:
336
  ans += LENGTH_NOTIFICATION_EN
337
- return ans, response.usage.total_tokens
338
  except Exception as e:
339
  return "**ERROR**: " + str(e), 0
340
 
@@ -364,9 +368,9 @@ class ZhipuChat(Base):
364
  ans += LENGTH_NOTIFICATION_CN
365
  else:
366
  ans += LENGTH_NOTIFICATION_EN
367
- tk_count = resp.usage.total_tokens
368
  if resp.choices[0].finish_reason == "stop":
369
- tk_count = resp.usage.total_tokens
370
  yield ans
371
  except Exception as e:
372
  yield ans + "\n**ERROR**: " + str(e)
@@ -569,7 +573,7 @@ class MiniMaxChat(Base):
569
  ans += LENGTH_NOTIFICATION_CN
570
  else:
571
  ans += LENGTH_NOTIFICATION_EN
572
- return ans, response["usage"]["total_tokens"]
573
  except Exception as e:
574
  return "**ERROR**: " + str(e), 0
575
 
@@ -603,11 +607,11 @@ class MiniMaxChat(Base):
603
  if "choices" in resp and "delta" in resp["choices"][0]:
604
  text = resp["choices"][0]["delta"]["content"]
605
  ans += text
606
- total_tokens = (
607
- total_tokens + num_tokens_from_string(text)
608
- if "usage" not in resp
609
- else resp["usage"]["total_tokens"]
610
- )
611
  yield ans
612
 
613
  except Exception as e:
@@ -640,7 +644,7 @@ class MistralChat(Base):
640
  ans += LENGTH_NOTIFICATION_CN
641
  else:
642
  ans += LENGTH_NOTIFICATION_EN
643
- return ans, response.usage.total_tokens
644
  except openai.APIError as e:
645
  return "**ERROR**: " + str(e), 0
646
 
@@ -838,7 +842,7 @@ class GeminiChat(Base):
838
  yield 0
839
 
840
 
841
- class GroqChat:
842
  def __init__(self, key, model_name, base_url=''):
843
  from groq import Groq
844
  self.client = Groq(api_key=key)
@@ -863,7 +867,7 @@ class GroqChat:
863
  ans += LENGTH_NOTIFICATION_CN
864
  else:
865
  ans += LENGTH_NOTIFICATION_EN
866
- return ans, response.usage.total_tokens
867
  except Exception as e:
868
  return ans + "\n**ERROR**: " + str(e), 0
869
 
@@ -1255,7 +1259,7 @@ class BaiduYiyanChat(Base):
1255
  **gen_conf
1256
  ).body
1257
  ans = response['result']
1258
- return ans, response["usage"]["total_tokens"]
1259
 
1260
  except Exception as e:
1261
  return ans + "\n**ERROR**: " + str(e), 0
@@ -1283,7 +1287,7 @@ class BaiduYiyanChat(Base):
1283
  for resp in response:
1284
  resp = resp.body
1285
  ans += resp['result']
1286
- total_tokens = resp["usage"]["total_tokens"]
1287
 
1288
  yield ans
1289
 
 
53
  ans += LENGTH_NOTIFICATION_CN
54
  else:
55
  ans += LENGTH_NOTIFICATION_EN
56
+ return ans, self.total_token_count(response)
57
  except openai.APIError as e:
58
  return "**ERROR**: " + str(e), 0
59
 
 
75
  resp.choices[0].delta.content = ""
76
  ans += resp.choices[0].delta.content
77
 
78
+ tol = self.total_token_count(resp)
79
+ if not tol:
80
+ total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
 
 
 
 
81
  else:
82
+ total_tokens = tol
83
 
84
  if resp.choices[0].finish_reason == "length":
85
  if is_chinese(ans):
 
93
 
94
  yield total_tokens
95
 
96
+ def total_token_count(self, resp):
97
+ try:
98
+ return resp.usage.total_tokens
99
+ except Exception:
100
+ pass
101
+ try:
102
+ return resp["usage"]["total_tokens"]
103
+ except Exception:
104
+ pass
105
+ return 0
106
+
107
 
108
  class GptTurbo(Base):
109
  def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
 
189
  ans += LENGTH_NOTIFICATION_CN
190
  else:
191
  ans += LENGTH_NOTIFICATION_EN
192
+ return ans, self.total_token_count(response)
193
  except openai.APIError as e:
194
  return "**ERROR**: " + str(e), 0
195
 
 
219
  if not resp.choices[0].delta.content:
220
  resp.choices[0].delta.content = ""
221
  ans += resp.choices[0].delta.content
222
+ tol = self.total_token_count(resp)
223
+ if not tol:
224
+ total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
225
+ else:
226
+ total_tokens = tol
 
 
 
227
  if resp.choices[0].finish_reason == "length":
228
  if is_chinese([ans]):
229
  ans += LENGTH_NOTIFICATION_CN
 
260
  tk_count = 0
261
  if response.status_code == HTTPStatus.OK:
262
  ans += response.output.choices[0]['message']['content']
263
+ tk_count += self.total_token_count(response)
264
  if response.output.choices[0].get("finish_reason", "") == "length":
265
  if is_chinese([ans]):
266
  ans += LENGTH_NOTIFICATION_CN
 
296
  for resp in response:
297
  if resp.status_code == HTTPStatus.OK:
298
  ans = resp.output.choices[0]['message']['content']
299
+ tk_count = self.total_token_count(resp)
300
  if resp.output.choices[0].get("finish_reason", "") == "length":
301
  if is_chinese(ans):
302
  ans += LENGTH_NOTIFICATION_CN
 
338
  ans += LENGTH_NOTIFICATION_CN
339
  else:
340
  ans += LENGTH_NOTIFICATION_EN
341
+ return ans, self.total_token_count(response)
342
  except Exception as e:
343
  return "**ERROR**: " + str(e), 0
344
 
 
368
  ans += LENGTH_NOTIFICATION_CN
369
  else:
370
  ans += LENGTH_NOTIFICATION_EN
371
+ tk_count = self.total_token_count(resp)
372
  if resp.choices[0].finish_reason == "stop":
373
+ tk_count = self.total_token_count(resp)
374
  yield ans
375
  except Exception as e:
376
  yield ans + "\n**ERROR**: " + str(e)
 
573
  ans += LENGTH_NOTIFICATION_CN
574
  else:
575
  ans += LENGTH_NOTIFICATION_EN
576
+ return ans, self.total_token_count(response)
577
  except Exception as e:
578
  return "**ERROR**: " + str(e), 0
579
 
 
607
  if "choices" in resp and "delta" in resp["choices"][0]:
608
  text = resp["choices"][0]["delta"]["content"]
609
  ans += text
610
+ tol = self.total_token_count(resp)
611
+ if not tol:
612
+ total_tokens += num_tokens_from_string(text)
613
+ else:
614
+ total_tokens = tol
615
  yield ans
616
 
617
  except Exception as e:
 
644
  ans += LENGTH_NOTIFICATION_CN
645
  else:
646
  ans += LENGTH_NOTIFICATION_EN
647
+ return ans, self.total_token_count(response)
648
  except openai.APIError as e:
649
  return "**ERROR**: " + str(e), 0
650
 
 
842
  yield 0
843
 
844
 
845
+ class GroqChat(Base):
846
  def __init__(self, key, model_name, base_url=''):
847
  from groq import Groq
848
  self.client = Groq(api_key=key)
 
867
  ans += LENGTH_NOTIFICATION_CN
868
  else:
869
  ans += LENGTH_NOTIFICATION_EN
870
+ return ans, self.total_token_count(response)
871
  except Exception as e:
872
  return ans + "\n**ERROR**: " + str(e), 0
873
 
 
1259
  **gen_conf
1260
  ).body
1261
  ans = response['result']
1262
+ return ans, self.total_token_count(response)
1263
 
1264
  except Exception as e:
1265
  return ans + "\n**ERROR**: " + str(e), 0
 
1287
  for resp in response:
1288
  resp = resp.body
1289
  ans += resp['result']
1290
+ total_tokens = self.total_token_count(resp)
1291
 
1292
  yield ans
1293
 
rag/llm/embedding_model.py CHANGED
@@ -44,11 +44,23 @@ class Base(ABC):
44
  def encode_queries(self, text: str):
45
  raise NotImplementedError("Please implement encode method!")
46
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  class DefaultEmbedding(Base):
49
  _model = None
50
  _model_name = ""
51
  _model_lock = threading.Lock()
 
52
  def __init__(self, key, model_name, **kwargs):
53
  """
54
  If you have trouble downloading HuggingFace models, -_^ this might help!!
@@ -115,13 +127,13 @@ class OpenAIEmbed(Base):
115
  res = self.client.embeddings.create(input=texts[i:i + batch_size],
116
  model=self.model_name)
117
  ress.extend([d.embedding for d in res.data])
118
- total_tokens += res.usage.total_tokens
119
  return np.array(ress), total_tokens
120
 
121
  def encode_queries(self, text):
122
  res = self.client.embeddings.create(input=[truncate(text, 8191)],
123
  model=self.model_name)
124
- return np.array(res.data[0].embedding), res.usage.total_tokens
125
 
126
 
127
  class LocalAIEmbed(Base):
@@ -188,7 +200,7 @@ class QWenEmbed(Base):
188
  for e in resp["output"]["embeddings"]:
189
  embds[e["text_index"]] = e["embedding"]
190
  res.extend(embds)
191
- token_count += resp["usage"]["total_tokens"]
192
  return np.array(res), token_count
193
  except Exception as e:
194
  raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
@@ -203,7 +215,7 @@ class QWenEmbed(Base):
203
  text_type="query"
204
  )
205
  return np.array(resp["output"]["embeddings"][0]
206
- ["embedding"]), resp["usage"]["total_tokens"]
207
  except Exception:
208
  raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
209
  return np.array([]), 0
@@ -229,13 +241,13 @@ class ZhipuEmbed(Base):
229
  res = self.client.embeddings.create(input=txt,
230
  model=self.model_name)
231
  arr.append(res.data[0].embedding)
232
- tks_num += res.usage.total_tokens
233
  return np.array(arr), tks_num
234
 
235
  def encode_queries(self, text):
236
  res = self.client.embeddings.create(input=text,
237
  model=self.model_name)
238
- return np.array(res.data[0].embedding), res.usage.total_tokens
239
 
240
 
241
  class OllamaEmbed(Base):
@@ -318,13 +330,13 @@ class XinferenceEmbed(Base):
318
  for i in range(0, len(texts), batch_size):
319
  res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
320
  ress.extend([d.embedding for d in res.data])
321
- total_tokens += res.usage.total_tokens
322
  return np.array(ress), total_tokens
323
 
324
  def encode_queries(self, text):
325
  res = self.client.embeddings.create(input=[text],
326
  model=self.model_name)
327
- return np.array(res.data[0].embedding), res.usage.total_tokens
328
 
329
 
330
  class YoudaoEmbed(Base):
@@ -383,7 +395,7 @@ class JinaEmbed(Base):
383
  }
384
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
385
  ress.extend([d["embedding"] for d in res["data"]])
386
- token_count += res["usage"]["total_tokens"]
387
  return np.array(ress), token_count
388
 
389
  def encode_queries(self, text):
@@ -447,13 +459,13 @@ class MistralEmbed(Base):
447
  res = self.client.embeddings(input=texts[i:i + batch_size],
448
  model=self.model_name)
449
  ress.extend([d.embedding for d in res.data])
450
- token_count += res.usage.total_tokens
451
  return np.array(ress), token_count
452
 
453
  def encode_queries(self, text):
454
  res = self.client.embeddings(input=[truncate(text, 8196)],
455
  model=self.model_name)
456
- return np.array(res.data[0].embedding), res.usage.total_tokens
457
 
458
 
459
  class BedrockEmbed(Base):
@@ -565,7 +577,7 @@ class NvidiaEmbed(Base):
565
  }
566
  res = requests.post(self.base_url, headers=self.headers, json=payload).json()
567
  ress.extend([d["embedding"] for d in res["data"]])
568
- token_count += res["usage"]["total_tokens"]
569
  return np.array(ress), token_count
570
 
571
  def encode_queries(self, text):
@@ -677,7 +689,7 @@ class SILICONFLOWEmbed(Base):
677
  if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
678
  raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
679
  ress.extend([d["embedding"] for d in res["data"]])
680
- token_count += res["usage"]["total_tokens"]
681
  return np.array(ress), token_count
682
 
683
  def encode_queries(self, text):
@@ -689,7 +701,7 @@ class SILICONFLOWEmbed(Base):
689
  res = requests.post(self.base_url, json=payload, headers=self.headers).json()
690
  if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= 1:
691
  raise ValueError(f"SILICONFLOWEmbed.encode_queries got invalid response from {self.base_url}")
692
- return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"]
693
 
694
 
695
  class ReplicateEmbed(Base):
@@ -727,14 +739,14 @@ class BaiduYiyanEmbed(Base):
727
  res = self.client.do(model=self.model_name, texts=texts).body
728
  return (
729
  np.array([r["embedding"] for r in res["data"]]),
730
- res["usage"]["total_tokens"],
731
  )
732
 
733
  def encode_queries(self, text):
734
  res = self.client.do(model=self.model_name, texts=[text]).body
735
  return (
736
  np.array([r["embedding"] for r in res["data"]]),
737
- res["usage"]["total_tokens"],
738
  )
739
 
740
 
 
44
  def encode_queries(self, text: str):
45
  raise NotImplementedError("Please implement encode method!")
46
 
47
+ def total_token_count(self, resp):
48
+ try:
49
+ return resp.usage.total_tokens
50
+ except Exception:
51
+ pass
52
+ try:
53
+ return resp["usage"]["total_tokens"]
54
+ except Exception:
55
+ pass
56
+ return 0
57
+
58
 
59
  class DefaultEmbedding(Base):
60
  _model = None
61
  _model_name = ""
62
  _model_lock = threading.Lock()
63
+
64
  def __init__(self, key, model_name, **kwargs):
65
  """
66
  If you have trouble downloading HuggingFace models, -_^ this might help!!
 
127
  res = self.client.embeddings.create(input=texts[i:i + batch_size],
128
  model=self.model_name)
129
  ress.extend([d.embedding for d in res.data])
130
+ total_tokens += self.total_token_count(res)
131
  return np.array(ress), total_tokens
132
 
133
  def encode_queries(self, text):
134
  res = self.client.embeddings.create(input=[truncate(text, 8191)],
135
  model=self.model_name)
136
+ return np.array(res.data[0].embedding), self.total_token_count(res)
137
 
138
 
139
  class LocalAIEmbed(Base):
 
200
  for e in resp["output"]["embeddings"]:
201
  embds[e["text_index"]] = e["embedding"]
202
  res.extend(embds)
203
+ token_count += self.total_token_count(resp)
204
  return np.array(res), token_count
205
  except Exception as e:
206
  raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
 
215
  text_type="query"
216
  )
217
  return np.array(resp["output"]["embeddings"][0]
218
+ ["embedding"]), self.total_token_count(resp)
219
  except Exception:
220
  raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
221
  return np.array([]), 0
 
241
  res = self.client.embeddings.create(input=txt,
242
  model=self.model_name)
243
  arr.append(res.data[0].embedding)
244
+ tks_num += self.total_token_count(res)
245
  return np.array(arr), tks_num
246
 
247
  def encode_queries(self, text):
248
  res = self.client.embeddings.create(input=text,
249
  model=self.model_name)
250
+ return np.array(res.data[0].embedding), self.total_token_count(res)
251
 
252
 
253
  class OllamaEmbed(Base):
 
330
  for i in range(0, len(texts), batch_size):
331
  res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
332
  ress.extend([d.embedding for d in res.data])
333
+ total_tokens += self.total_token_count(res)
334
  return np.array(ress), total_tokens
335
 
336
  def encode_queries(self, text):
337
  res = self.client.embeddings.create(input=[text],
338
  model=self.model_name)
339
+ return np.array(res.data[0].embedding), self.total_token_count(res)
340
 
341
 
342
  class YoudaoEmbed(Base):
 
395
  }
396
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
397
  ress.extend([d["embedding"] for d in res["data"]])
398
+ token_count += self.total_token_count(res)
399
  return np.array(ress), token_count
400
 
401
  def encode_queries(self, text):
 
459
  res = self.client.embeddings(input=texts[i:i + batch_size],
460
  model=self.model_name)
461
  ress.extend([d.embedding for d in res.data])
462
+ token_count += self.total_token_count(res)
463
  return np.array(ress), token_count
464
 
465
  def encode_queries(self, text):
466
  res = self.client.embeddings(input=[truncate(text, 8196)],
467
  model=self.model_name)
468
+ return np.array(res.data[0].embedding), self.total_token_count(res)
469
 
470
 
471
  class BedrockEmbed(Base):
 
577
  }
578
  res = requests.post(self.base_url, headers=self.headers, json=payload).json()
579
  ress.extend([d["embedding"] for d in res["data"]])
580
+ token_count += self.total_token_count(res)
581
  return np.array(ress), token_count
582
 
583
  def encode_queries(self, text):
 
689
  if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
690
  raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
691
  ress.extend([d["embedding"] for d in res["data"]])
692
+ token_count += self.total_token_count(res)
693
  return np.array(ress), token_count
694
 
695
  def encode_queries(self, text):
 
701
  res = requests.post(self.base_url, json=payload, headers=self.headers).json()
702
  if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= 1:
703
  raise ValueError(f"SILICONFLOWEmbed.encode_queries got invalid response from {self.base_url}")
704
+ return np.array(res["data"][0]["embedding"]), self.total_token_count(res)
705
 
706
 
707
  class ReplicateEmbed(Base):
 
739
  res = self.client.do(model=self.model_name, texts=texts).body
740
  return (
741
  np.array([r["embedding"] for r in res["data"]]),
742
+ self.total_token_count(res),
743
  )
744
 
745
  def encode_queries(self, text):
746
  res = self.client.do(model=self.model_name, texts=[text]).body
747
  return (
748
  np.array([r["embedding"] for r in res["data"]]),
749
+ self.total_token_count(res),
750
  )
751
 
752
 
rag/llm/rerank_model.py CHANGED
@@ -42,6 +42,17 @@ class Base(ABC):
42
  def similarity(self, query: str, texts: list):
43
  raise NotImplementedError("Please implement encode method!")
44
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  class DefaultRerank(Base):
47
  _model = None
@@ -115,7 +126,7 @@ class JinaRerank(Base):
115
  rank = np.zeros(len(texts), dtype=float)
116
  for d in res["results"]:
117
  rank[d["index"]] = d["relevance_score"]
118
- return rank, res["usage"]["total_tokens"]
119
 
120
 
121
  class YoudaoRerank(DefaultRerank):
@@ -417,7 +428,7 @@ class BaiduYiyanRerank(Base):
417
  rank = np.zeros(len(texts), dtype=float)
418
  for d in res["results"]:
419
  rank[d["index"]] = d["relevance_score"]
420
- return rank, res["usage"]["total_tokens"]
421
 
422
 
423
  class VoyageRerank(Base):
 
42
  def similarity(self, query: str, texts: list):
43
  raise NotImplementedError("Please implement encode method!")
44
 
45
+ def total_token_count(self, resp):
46
+ try:
47
+ return resp.usage.total_tokens
48
+ except Exception:
49
+ pass
50
+ try:
51
+ return resp["usage"]["total_tokens"]
52
+ except Exception:
53
+ pass
54
+ return 0
55
+
56
 
57
  class DefaultRerank(Base):
58
  _model = None
 
126
  rank = np.zeros(len(texts), dtype=float)
127
  for d in res["results"]:
128
  rank[d["index"]] = d["relevance_score"]
129
+ return rank, self.total_token_count(res)
130
 
131
 
132
  class YoudaoRerank(DefaultRerank):
 
428
  rank = np.zeros(len(texts), dtype=float)
429
  for d in res["results"]:
430
  rank[d["index"]] = d["relevance_score"]
431
+ return rank, self.total_token_count(res)
432
 
433
 
434
  class VoyageRerank(Base):