zhuhao commited on
Commit
40bbe34
·
1 Parent(s): 17cd183

feat: support xinference rerank model (#1466)

Browse files

### What problem does this PR solve?

support xinference rerank model
#1455

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/db/init_data.py CHANGED
@@ -109,7 +109,7 @@ factory_infos = [{
109
  "name": "Ollama",
110
  "logo": "",
111
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
112
- "status": "1",
113
  }, {
114
  "name": "Moonshot",
115
  "logo": "",
@@ -123,8 +123,8 @@ factory_infos = [{
123
  }, {
124
  "name": "Xinference",
125
  "logo": "",
126
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
127
- "status": "1",
128
  },{
129
  "name": "Youdao",
130
  "logo": "",
 
109
  "name": "Ollama",
110
  "logo": "",
111
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
112
+ "status": "1",
113
  }, {
114
  "name": "Moonshot",
115
  "logo": "",
 
123
  }, {
124
  "name": "Xinference",
125
  "logo": "",
126
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION,TEXT RE-RANK",
127
+ "status": "1",
128
  },{
129
  "name": "Youdao",
130
  "logo": "",
rag/llm/__init__.py CHANGED
@@ -68,4 +68,5 @@ RerankModel = {
68
  "BAAI": DefaultRerank,
69
  "Jina": JinaRerank,
70
  "Youdao": YoudaoRerank,
 
71
  }
 
68
  "BAAI": DefaultRerank,
69
  "Jina": JinaRerank,
70
  "Youdao": YoudaoRerank,
71
+ "Xinference": XInferenceRerank
72
  }
rag/llm/rerank_model.py CHANGED
@@ -136,4 +136,22 @@ class YoudaoRerank(DefaultRerank):
136
  else: res.extend(scores)
137
  return np.array(res), token_count
138
 
 
 
 
 
 
 
 
 
139
 
 
 
 
 
 
 
 
 
 
 
 
136
  else: res.extend(scores)
137
  return np.array(res), token_count
138
 
139
+ class XInferenceRerank(Base):
140
+ def __init__(self,model_name="",base_url=""):
141
+ self.model_name=model_name
142
+ self.base_url=base_url
143
+ self.headers = {
144
+ "Content-Type": "application/json",
145
+ "accept": "application/json"
146
+ }
147
 
148
+ def similarity(self, query: str, texts: list):
149
+ data = {
150
+ "model":self.model_name,
151
+ "query":query,
152
+ "return_documents": "true",
153
+ "return_len": "true",
154
+ "documents":texts
155
+ }
156
+ res = requests.post(self.base_url, headers=self.headers, json=data).json()
157
+ return np.array([d["relevance_score"] for d in res["results"]]),res["tokens"]["input_tokens"]+res["tokens"]["output_tokens"]
web/src/pages/user-setting/setting-model/ollama-modal/index.tsx CHANGED
@@ -74,6 +74,7 @@ const OllamaModal = ({
74
  <Select placeholder={t('modelTypeMessage')}>
75
  <Option value="chat">chat</Option>
76
  <Option value="embedding">embedding</Option>
 
77
  </Select>
78
  </Form.Item>
79
  <Form.Item<FieldType>
 
74
  <Select placeholder={t('modelTypeMessage')}>
75
  <Option value="chat">chat</Option>
76
  <Option value="embedding">embedding</Option>
77
+ <Option value="rerank">rerank</Option>
78
  </Select>
79
  </Form.Item>
80
  <Form.Item<FieldType>