KevinHuSh commited on
Commit
c87ddd7
·
1 Parent(s): 8f65b41

truncate text to fitin embedding model (#692)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Refactoring

rag/llm/embedding_model.py CHANGED
@@ -27,8 +27,7 @@ import torch
27
  import numpy as np
28
 
29
  from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
30
- from rag.utils import num_tokens_from_string
31
-
32
 
33
  try:
34
  flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
@@ -70,7 +69,7 @@ class DefaultEmbedding(Base):
70
  self.model = flag_model
71
 
72
  def encode(self, texts: list, batch_size=32):
73
- texts = [t[:2000] for t in texts]
74
  token_count = 0
75
  for t in texts:
76
  token_count += num_tokens_from_string(t)
@@ -93,12 +92,14 @@ class OpenAIEmbed(Base):
93
  self.model_name = model_name
94
 
95
  def encode(self, texts: list, batch_size=32):
 
96
  res = self.client.embeddings.create(input=texts,
97
  model=self.model_name)
98
- return np.array([d.embedding for d in res.data]), res.usage.total_tokens
 
99
 
100
  def encode_queries(self, text):
101
- res = self.client.embeddings.create(input=[text],
102
  model=self.model_name)
103
  return np.array(res.data[0].embedding), res.usage.total_tokens
104
 
@@ -112,7 +113,7 @@ class QWenEmbed(Base):
112
  import dashscope
113
  res = []
114
  token_count = 0
115
- texts = [txt[:2048] for txt in texts]
116
  for i in range(0, len(texts), batch_size):
117
  resp = dashscope.TextEmbedding.call(
118
  model=self.model_name,
 
27
  import numpy as np
28
 
29
  from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
30
+ from rag.utils import num_tokens_from_string, truncate
 
31
 
32
  try:
33
  flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
 
69
  self.model = flag_model
70
 
71
  def encode(self, texts: list, batch_size=32):
72
+ texts = [truncate(t, 2048) for t in texts]
73
  token_count = 0
74
  for t in texts:
75
  token_count += num_tokens_from_string(t)
 
92
  self.model_name = model_name
93
 
94
  def encode(self, texts: list, batch_size=32):
95
+ texts = [truncate(t, 8196) for t in texts]
96
  res = self.client.embeddings.create(input=texts,
97
  model=self.model_name)
98
+ return np.array([d.embedding for d in res.data]
99
+ ), res.usage.total_tokens
100
 
101
  def encode_queries(self, text):
102
+ res = self.client.embeddings.create(input=[truncate(text, 8196)],
103
  model=self.model_name)
104
  return np.array(res.data[0].embedding), res.usage.total_tokens
105
 
 
113
  import dashscope
114
  res = []
115
  token_count = 0
116
+ texts = [truncate(t, 2048) for t in texts]
117
  for i in range(0, len(texts), batch_size):
118
  resp = dashscope.TextEmbedding.call(
119
  model=self.model_name,
rag/utils/__init__.py CHANGED
@@ -63,3 +63,7 @@ def num_tokens_from_string(string: str) -> int:
63
  num_tokens = len(encoder.encode(string))
64
  return num_tokens
65
 
 
 
 
 
 
63
  num_tokens = len(encoder.encode(string))
64
  return num_tokens
65
 
66
+
67
+ def truncate(string: str, max_len: int) -> int:
68
+ """Returns truncated text if the length of text exceed max_len."""
69
+ return encoder.decode(encoder.encode(string)[:max_len])