KevinHuSh commited on
Commit
c60dccb
·
1 Parent(s): d42f535

fix #994 (#1006)

Browse files

### What problem does this PR solve?

#994

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (1) hide show
  1. rag/llm/embedding_model.py +29 -21
rag/llm/embedding_model.py CHANGED
@@ -123,30 +123,38 @@ class QWenEmbed(Base):
123
 
124
  def encode(self, texts: list, batch_size=10):
125
  import dashscope
126
- res = []
127
- token_count = 0
128
- texts = [truncate(t, 2048) for t in texts]
129
- for i in range(0, len(texts), batch_size):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  resp = dashscope.TextEmbedding.call(
131
  model=self.model_name,
132
- input=texts[i:i + batch_size],
133
- text_type="document"
134
  )
135
- embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
136
- for e in resp["output"]["embeddings"]:
137
- embds[e["text_index"]] = e["embedding"]
138
- res.extend(embds)
139
- token_count += resp["usage"]["total_tokens"]
140
- return np.array(res), token_count
141
-
142
- def encode_queries(self, text):
143
- resp = dashscope.TextEmbedding.call(
144
- model=self.model_name,
145
- input=text[:2048],
146
- text_type="query"
147
- )
148
- return np.array(resp["output"]["embeddings"][0]
149
- ["embedding"]), resp["usage"]["total_tokens"]
150
 
151
 
152
  class ZhipuEmbed(Base):
 
123
 
124
  def encode(self, texts: list, batch_size=10):
125
  import dashscope
126
+ try:
127
+ res = []
128
+ token_count = 0
129
+ texts = [truncate(t, 2048) for t in texts]
130
+ for i in range(0, len(texts), batch_size):
131
+ resp = dashscope.TextEmbedding.call(
132
+ model=self.model_name,
133
+ input=texts[i:i + batch_size],
134
+ text_type="document"
135
+ )
136
+ embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
137
+ for e in resp["output"]["embeddings"]:
138
+ embds[e["text_index"]] = e["embedding"]
139
+ res.extend(embds)
140
+ token_count += resp["usage"]["total_tokens"]
141
+ return np.array(res), token_count
142
+ except Exception as e:
143
+ raise Exception("Account abnormal. Please ensure it's on good standing.")
144
+ return np.array([]), 0
145
+
146
+ def encode_queries(self, text):
147
+ try:
148
  resp = dashscope.TextEmbedding.call(
149
  model=self.model_name,
150
+ input=text[:2048],
151
+ text_type="query"
152
  )
153
+ return np.array(resp["output"]["embeddings"][0]
154
+ ["embedding"]), resp["usage"]["total_tokens"]
155
+ except Exception as e:
156
+ raise Exception("Account abnormal. Please ensure it's on good standing.")
157
+ return np.array([]), 0
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  class ZhipuEmbed(Base):