Kevin Hu commited on
Commit
984f31c
·
1 Parent(s): 6e0d24d

fix bugs of rerank model with xinference (#1481)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (2) hide show
  1. api/apps/llm_app.py +11 -0
  2. rag/llm/rerank_model.py +9 -7
api/apps/llm_app.py CHANGED
@@ -165,6 +165,17 @@ def add_llm():
165
  except Exception as e:
166
  msg += f"\nFail to access model({llm['llm_name']})." + str(
167
  e)
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
  # TODO: check other type of models
170
  pass
 
165
  except Exception as e:
166
  msg += f"\nFail to access model({llm['llm_name']})." + str(
167
  e)
168
+ elif llm["model_type"] == LLMType.RERANK:
169
+ mdl = RerankModel[factory](
170
+ key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
171
+ )
172
+ try:
173
+ arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
174
+ if len(arr) == 0 or tc == 0:
175
+ raise Exception("Not known.")
176
+ except Exception as e:
177
+ msg += f"\nFail to access model({llm['llm_name']})." + str(
178
+ e)
179
  else:
180
  # TODO: check other type of models
181
  pass
rag/llm/rerank_model.py CHANGED
@@ -136,10 +136,11 @@ class YoudaoRerank(DefaultRerank):
136
  else: res.extend(scores)
137
  return np.array(res), token_count
138
 
 
139
  class XInferenceRerank(Base):
140
- def __init__(self,model_name="",base_url=""):
141
- self.model_name=model_name
142
- self.base_url=base_url
143
  self.headers = {
144
  "Content-Type": "application/json",
145
  "accept": "application/json"
@@ -147,11 +148,12 @@ class XInferenceRerank(Base):
147
 
148
  def similarity(self, query: str, texts: list):
149
  data = {
150
- "model":self.model_name,
151
- "query":query,
152
  "return_documents": "true",
153
  "return_len": "true",
154
- "documents":texts
155
  }
156
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
157
- return np.array([d["relevance_score"] for d in res["results"]]),res["tokens"]["input_tokens"]+res["tokens"]["output_tokens"]
 
 
136
  else: res.extend(scores)
137
  return np.array(res), token_count
138
 
139
+
140
  class XInferenceRerank(Base):
141
+ def __init__(self, key="xxxxxxx", model_name="", base_url=""):
142
+ self.model_name = model_name
143
+ self.base_url = base_url
144
  self.headers = {
145
  "Content-Type": "application/json",
146
  "accept": "application/json"
 
148
 
149
  def similarity(self, query: str, texts: list):
150
  data = {
151
+ "model": self.model_name,
152
+ "query": query,
153
  "return_documents": "true",
154
  "return_len": "true",
155
+ "documents": texts
156
  }
157
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
158
+ return np.array([d["relevance_score"] for d in res["results"]]), res["tokens"]["input_tokens"] + res["tokens"][
159
+ "output_tokens"]