KevinHuSh
commited on
Commit
·
c60dccb
1
Parent(s):
d42f535
fix #994 (#1006)
Browse files### What problem does this PR solve?
#994
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/llm/embedding_model.py +29 -21
rag/llm/embedding_model.py
CHANGED
@@ -123,30 +123,38 @@ class QWenEmbed(Base):
|
|
123 |
|
124 |
def encode(self, texts: list, batch_size=10):
|
125 |
import dashscope
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
resp = dashscope.TextEmbedding.call(
|
131 |
model=self.model_name,
|
132 |
-
input=
|
133 |
-
text_type="
|
134 |
)
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
return np.array(res), token_count
|
141 |
-
|
142 |
-
def encode_queries(self, text):
|
143 |
-
resp = dashscope.TextEmbedding.call(
|
144 |
-
model=self.model_name,
|
145 |
-
input=text[:2048],
|
146 |
-
text_type="query"
|
147 |
-
)
|
148 |
-
return np.array(resp["output"]["embeddings"][0]
|
149 |
-
["embedding"]), resp["usage"]["total_tokens"]
|
150 |
|
151 |
|
152 |
class ZhipuEmbed(Base):
|
|
|
123 |
|
124 |
def encode(self, texts: list, batch_size=10):
|
125 |
import dashscope
|
126 |
+
try:
|
127 |
+
res = []
|
128 |
+
token_count = 0
|
129 |
+
texts = [truncate(t, 2048) for t in texts]
|
130 |
+
for i in range(0, len(texts), batch_size):
|
131 |
+
resp = dashscope.TextEmbedding.call(
|
132 |
+
model=self.model_name,
|
133 |
+
input=texts[i:i + batch_size],
|
134 |
+
text_type="document"
|
135 |
+
)
|
136 |
+
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
137 |
+
for e in resp["output"]["embeddings"]:
|
138 |
+
embds[e["text_index"]] = e["embedding"]
|
139 |
+
res.extend(embds)
|
140 |
+
token_count += resp["usage"]["total_tokens"]
|
141 |
+
return np.array(res), token_count
|
142 |
+
except Exception as e:
|
143 |
+
raise Exception("Account abnormal. Please ensure it's on good standing.")
|
144 |
+
return np.array([]), 0
|
145 |
+
|
146 |
+
def encode_queries(self, text):
|
147 |
+
try:
|
148 |
resp = dashscope.TextEmbedding.call(
|
149 |
model=self.model_name,
|
150 |
+
input=text[:2048],
|
151 |
+
text_type="query"
|
152 |
)
|
153 |
+
return np.array(resp["output"]["embeddings"][0]
|
154 |
+
["embedding"]), resp["usage"]["total_tokens"]
|
155 |
+
except Exception as e:
|
156 |
+
raise Exception("Account abnormal. Please ensure it's on good standing.")
|
157 |
+
return np.array([]), 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
class ZhipuEmbed(Base):
|