shizzgar
Kevin Hu
commited on
Commit
·
9640d9a
1
Parent(s):
98a13e9
Added LocalAI support for rerank models (#3446)
Browse files### What problem does this PR solve?
Hi there!
LocalAI added support of rerank models
https://localai.io/features/reranker/
I've implemented LocalAIRerank class (typically copied it from
OpenAI_APIRerank class).
Also, LocalAI model response with 500 error code if len of "documents"
is less than 2 in similarity check.
So I've added the second "document" on RERANK model connection check in
`api/apps/llm_app.py`.
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
Co-authored-by: Kevin Hu <[email protected]>
- api/apps/llm_app.py +1 -1
- rag/llm/__init__.py +1 -0
- rag/llm/rerank_model.py +37 -2
api/apps/llm_app.py
CHANGED
@@ -238,7 +238,7 @@ def add_llm():
|
|
238 |
base_url=llm["api_base"]
|
239 |
)
|
240 |
try:
|
241 |
-
arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
|
242 |
if len(arr) == 0 or tc == 0:
|
243 |
raise Exception("Not known.")
|
244 |
except Exception as e:
|
|
|
238 |
base_url=llm["api_base"]
|
239 |
)
|
240 |
try:
|
241 |
+
arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!", "Ohh, my friend!"])
|
242 |
if len(arr) == 0 or tc == 0:
|
243 |
raise Exception("Not known.")
|
244 |
except Exception as e:
|
rag/llm/__init__.py
CHANGED
@@ -110,6 +110,7 @@ ChatModel = {
|
|
110 |
}
|
111 |
|
112 |
RerankModel = {
|
|
|
113 |
"BAAI": DefaultRerank,
|
114 |
"Jina": JinaRerank,
|
115 |
"Youdao": YoudaoRerank,
|
|
|
110 |
}
|
111 |
|
112 |
RerankModel = {
|
113 |
+
"LocalAI":LocalAIRerank,
|
114 |
"BAAI": DefaultRerank,
|
115 |
"Jina": JinaRerank,
|
116 |
"Youdao": YoudaoRerank,
|
rag/llm/rerank_model.py
CHANGED
@@ -185,11 +185,46 @@ class XInferenceRerank(Base):
|
|
185 |
|
186 |
class LocalAIRerank(Base):
|
187 |
def __init__(self, key, model_name, base_url):
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
def similarity(self, query: str, texts: list):
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
class NvidiaRerank(Base):
|
195 |
def __init__(
|
|
|
185 |
|
186 |
class LocalAIRerank(Base):
|
187 |
def __init__(self, key, model_name, base_url):
|
188 |
+
if base_url.find("/rerank") == -1:
|
189 |
+
self.base_url = urljoin(base_url, "/rerank")
|
190 |
+
else:
|
191 |
+
self.base_url = base_url
|
192 |
+
self.headers = {
|
193 |
+
"Content-Type": "application/json",
|
194 |
+
"Authorization": f"Bearer {key}"
|
195 |
+
}
|
196 |
+
self.model_name = model_name.replace("___LocalAI","")
|
197 |
|
198 |
def similarity(self, query: str, texts: list):
|
199 |
+
# noway to config Ragflow , use fix setting
|
200 |
+
texts = [truncate(t, 500) for t in texts]
|
201 |
+
data = {
|
202 |
+
"model": self.model_name,
|
203 |
+
"query": query,
|
204 |
+
"documents": texts,
|
205 |
+
"top_n": len(texts),
|
206 |
+
}
|
207 |
+
token_count = 0
|
208 |
+
for t in texts:
|
209 |
+
token_count += num_tokens_from_string(t)
|
210 |
+
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
211 |
+
rank = np.zeros(len(texts), dtype=float)
|
212 |
+
if 'results' not in res:
|
213 |
+
raise ValueError("response not contains results\n" + str(res))
|
214 |
+
for d in res["results"]:
|
215 |
+
rank[d["index"]] = d["relevance_score"]
|
216 |
+
|
217 |
+
# Normalize the rank values to the range 0 to 1
|
218 |
+
min_rank = np.min(rank)
|
219 |
+
max_rank = np.max(rank)
|
220 |
|
221 |
+
# Avoid division by zero if all ranks are identical
|
222 |
+
if max_rank - min_rank != 0:
|
223 |
+
rank = (rank - min_rank) / (max_rank - min_rank)
|
224 |
+
else:
|
225 |
+
rank = np.zeros_like(rank)
|
226 |
+
|
227 |
+
return rank, token_count
|
228 |
|
229 |
class NvidiaRerank(Base):
|
230 |
def __init__(
|