KevinHuSh
commited on
Commit
·
bfb0635
1
Parent(s):
05dad97
fix mem leak for local reranker (#1295)
Browse files### What problem does this PR solve?
#1288
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/llm/rerank_model.py +15 -9
rag/llm/rerank_model.py
CHANGED
@@ -39,6 +39,7 @@ class Base(ABC):
|
|
39 |
class DefaultRerank(Base):
|
40 |
_model = None
|
41 |
_model_lock = threading.Lock()
|
|
|
42 |
def __init__(self, key, model_name, **kwargs):
|
43 |
"""
|
44 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
@@ -102,19 +103,24 @@ class JinaRerank(Base):
|
|
102 |
|
103 |
class YoudaoRerank(DefaultRerank):
|
104 |
_model = None
|
|
|
105 |
|
106 |
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
107 |
from BCEmbedding import RerankerModel
|
108 |
if not YoudaoRerank._model:
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
|
119 |
def similarity(self, query: str, texts: list):
|
120 |
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
|
|
39 |
class DefaultRerank(Base):
|
40 |
_model = None
|
41 |
_model_lock = threading.Lock()
|
42 |
+
|
43 |
def __init__(self, key, model_name, **kwargs):
|
44 |
"""
|
45 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
|
103 |
|
104 |
class YoudaoRerank(DefaultRerank):
|
105 |
_model = None
|
106 |
+
_model_lock = threading.Lock()
|
107 |
|
108 |
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
109 |
from BCEmbedding import RerankerModel
|
110 |
if not YoudaoRerank._model:
|
111 |
+
with YoudaoRerank._model_lock:
|
112 |
+
if not YoudaoRerank._model:
|
113 |
+
try:
|
114 |
+
print("LOADING BCE...")
|
115 |
+
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
116 |
+
get_home_cache_dir(),
|
117 |
+
re.sub(r"^[a-zA-Z]+/", "", model_name)))
|
118 |
+
except Exception as e:
|
119 |
+
YoudaoRerank._model = RerankerModel(
|
120 |
+
model_name_or_path=model_name.replace(
|
121 |
+
"maidalun1020", "InfiniFlow"))
|
122 |
+
|
123 |
+
self._model = YoudaoRerank._model
|
124 |
|
125 |
def similarity(self, query: str, texts: list):
|
126 |
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|