Kevin Hu
commited on
Commit
·
984f31c
1
Parent(s):
6e0d24d
fix bugs of rerank model with xinference (#1481)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/llm_app.py +11 -0
- rag/llm/rerank_model.py +9 -7
api/apps/llm_app.py
CHANGED
@@ -165,6 +165,17 @@ def add_llm():
|
|
165 |
except Exception as e:
|
166 |
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
167 |
e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
else:
|
169 |
# TODO: check other type of models
|
170 |
pass
|
|
|
165 |
except Exception as e:
|
166 |
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
167 |
e)
|
168 |
+
elif llm["model_type"] == LLMType.RERANK:
|
169 |
+
mdl = RerankModel[factory](
|
170 |
+
key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
|
171 |
+
)
|
172 |
+
try:
|
173 |
+
arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
|
174 |
+
if len(arr) == 0 or tc == 0:
|
175 |
+
raise Exception("Not known.")
|
176 |
+
except Exception as e:
|
177 |
+
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
178 |
+
e)
|
179 |
else:
|
180 |
# TODO: check other type of models
|
181 |
pass
|
rag/llm/rerank_model.py
CHANGED
@@ -136,10 +136,11 @@ class YoudaoRerank(DefaultRerank):
|
|
136 |
else: res.extend(scores)
|
137 |
return np.array(res), token_count
|
138 |
|
|
|
139 |
class XInferenceRerank(Base):
|
140 |
-
def __init__(self,model_name="",base_url=""):
|
141 |
-
self.model_name=model_name
|
142 |
-
self.base_url=base_url
|
143 |
self.headers = {
|
144 |
"Content-Type": "application/json",
|
145 |
"accept": "application/json"
|
@@ -147,11 +148,12 @@ class XInferenceRerank(Base):
|
|
147 |
|
148 |
def similarity(self, query: str, texts: list):
|
149 |
data = {
|
150 |
-
"model":self.model_name,
|
151 |
-
"query":query,
|
152 |
"return_documents": "true",
|
153 |
"return_len": "true",
|
154 |
-
"documents":texts
|
155 |
}
|
156 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
157 |
-
return np.array([d["relevance_score"] for d in res["results"]]),res["tokens"]["input_tokens"]+res["tokens"][
|
|
|
|
136 |
else: res.extend(scores)
|
137 |
return np.array(res), token_count
|
138 |
|
139 |
+
|
140 |
class XInferenceRerank(Base):
|
141 |
+
def __init__(self, key="xxxxxxx", model_name="", base_url=""):
|
142 |
+
self.model_name = model_name
|
143 |
+
self.base_url = base_url
|
144 |
self.headers = {
|
145 |
"Content-Type": "application/json",
|
146 |
"accept": "application/json"
|
|
|
148 |
|
149 |
def similarity(self, query: str, texts: list):
|
150 |
data = {
|
151 |
+
"model": self.model_name,
|
152 |
+
"query": query,
|
153 |
"return_documents": "true",
|
154 |
"return_len": "true",
|
155 |
+
"documents": texts
|
156 |
}
|
157 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
158 |
+
return np.array([d["relevance_score"] for d in res["results"]]), res["tokens"]["input_tokens"] + res["tokens"][
|
159 |
+
"output_tokens"]
|