KevinHuSh commited on
Commit
bfb0635
·
1 Parent(s): 05dad97

fix mem leak for local reranker (#1295)

Browse files

### What problem does this PR solve?

#1288
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (1) hide show
  1. rag/llm/rerank_model.py +15 -9
rag/llm/rerank_model.py CHANGED
@@ -39,6 +39,7 @@ class Base(ABC):
39
  class DefaultRerank(Base):
40
  _model = None
41
  _model_lock = threading.Lock()
 
42
  def __init__(self, key, model_name, **kwargs):
43
  """
44
  If you have trouble downloading HuggingFace models, -_^ this might help!!
@@ -102,19 +103,24 @@ class JinaRerank(Base):
102
 
103
  class YoudaoRerank(DefaultRerank):
104
  _model = None
 
105
 
106
  def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
107
  from BCEmbedding import RerankerModel
108
  if not YoudaoRerank._model:
109
- try:
110
- print("LOADING BCE...")
111
- YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
112
- get_home_cache_dir(),
113
- re.sub(r"^[a-zA-Z]+/", "", model_name)))
114
- except Exception as e:
115
- YoudaoRerank._model = RerankerModel(
116
- model_name_or_path=model_name.replace(
117
- "maidalun1020", "InfiniFlow"))
 
 
 
 
118
 
119
  def similarity(self, query: str, texts: list):
120
  pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
 
39
  class DefaultRerank(Base):
40
  _model = None
41
  _model_lock = threading.Lock()
42
+
43
  def __init__(self, key, model_name, **kwargs):
44
  """
45
  If you have trouble downloading HuggingFace models, -_^ this might help!!
 
103
 
104
  class YoudaoRerank(DefaultRerank):
105
  _model = None
106
+ _model_lock = threading.Lock()
107
 
108
  def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
109
  from BCEmbedding import RerankerModel
110
  if not YoudaoRerank._model:
111
+ with YoudaoRerank._model_lock:
112
+ if not YoudaoRerank._model:
113
+ try:
114
+ print("LOADING BCE...")
115
+ YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
116
+ get_home_cache_dir(),
117
+ re.sub(r"^[a-zA-Z]+/", "", model_name)))
118
+ except Exception as e:
119
+ YoudaoRerank._model = RerankerModel(
120
+ model_name_or_path=model_name.replace(
121
+ "maidalun1020", "InfiniFlow"))
122
+
123
+ self._model = YoudaoRerank._model
124
 
125
  def similarity(self, query: str, texts: list):
126
  pairs = [(query, truncate(t, self._model.max_length)) for t in texts]