zhuhao zhuhao commited on
Commit
05dad97
·
1 Parent(s): 1090f98

Fix ragflow may encounter an OOM (Out Of Memory) when there are a lot of conversations (#1292)

Browse files

### What problem does this PR solve?

Fix ragflow may encounter an OOM (Out Of Memory) when there are a lot of
conversations.
#1288

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: zhuhao <[email protected]>

Files changed (2) hide show
  1. rag/llm/embedding_model.py +16 -12
  2. rag/llm/rerank_model.py +12 -11
rag/llm/embedding_model.py CHANGED
@@ -15,6 +15,7 @@
15
  #
16
  import re
17
  from typing import Optional
 
18
  import requests
19
  from huggingface_hub import snapshot_download
20
  from zhipuai import ZhipuAI
@@ -44,7 +45,7 @@ class Base(ABC):
44
 
45
  class DefaultEmbedding(Base):
46
  _model = None
47
-
48
  def __init__(self, key, model_name, **kwargs):
49
  """
50
  If you have trouble downloading HuggingFace models, -_^ this might help!!
@@ -58,17 +59,20 @@ class DefaultEmbedding(Base):
58
 
59
  """
60
  if not DefaultEmbedding._model:
61
- try:
62
- self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
63
- query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
64
- use_fp16=torch.cuda.is_available())
65
- except Exception as e:
66
- model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
67
- local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
68
- local_dir_use_symlinks=False)
69
- self._model = FlagModel(model_dir,
70
- query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
71
- use_fp16=torch.cuda.is_available())
 
 
 
72
 
73
  def encode(self, texts: list, batch_size=32):
74
  texts = [truncate(t, 2048) for t in texts]
 
15
  #
16
  import re
17
  from typing import Optional
18
+ import threading
19
  import requests
20
  from huggingface_hub import snapshot_download
21
  from zhipuai import ZhipuAI
 
45
 
46
  class DefaultEmbedding(Base):
47
  _model = None
48
+ _model_lock = threading.Lock()
49
  def __init__(self, key, model_name, **kwargs):
50
  """
51
  If you have trouble downloading HuggingFace models, -_^ this might help!!
 
59
 
60
  """
61
  if not DefaultEmbedding._model:
62
+ with DefaultEmbedding._model_lock:
63
+ if not DefaultEmbedding._model:
64
+ try:
65
+ DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
66
+ query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
67
+ use_fp16=torch.cuda.is_available())
68
+ except Exception as e:
69
+ model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
70
+ local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
71
+ local_dir_use_symlinks=False)
72
+ DefaultEmbedding._model = FlagModel(model_dir,
73
+ query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
74
+ use_fp16=torch.cuda.is_available())
75
+ self._model = DefaultEmbedding._model
76
 
77
  def encode(self, texts: list, batch_size=32):
78
  texts = [truncate(t, 2048) for t in texts]
rag/llm/rerank_model.py CHANGED
@@ -14,6 +14,7 @@
14
  # limitations under the License.
15
  #
16
  import re
 
17
  import requests
18
  import torch
19
  from FlagEmbedding import FlagReranker
@@ -37,7 +38,7 @@ class Base(ABC):
37
 
38
  class DefaultRerank(Base):
39
  _model = None
40
-
41
  def __init__(self, key, model_name, **kwargs):
42
  """
43
  If you have trouble downloading HuggingFace models, -_^ this might help!!
@@ -51,16 +52,16 @@ class DefaultRerank(Base):
51
 
52
  """
53
  if not DefaultRerank._model:
54
- try:
55
- self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
56
- use_fp16=torch.cuda.is_available())
57
- except Exception as e:
58
- self._model = snapshot_download(repo_id=model_name,
59
- local_dir=os.path.join(get_home_cache_dir(),
60
- re.sub(r"^[a-zA-Z]+/", "", model_name)),
61
- local_dir_use_symlinks=False)
62
- self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name),
63
- use_fp16=torch.cuda.is_available())
64
 
65
  def similarity(self, query: str, texts: list):
66
  pairs = [(query,truncate(t, 2048)) for t in texts]
 
14
  # limitations under the License.
15
  #
16
  import re
17
+ import threading
18
  import requests
19
  import torch
20
  from FlagEmbedding import FlagReranker
 
38
 
39
  class DefaultRerank(Base):
40
  _model = None
41
+ _model_lock = threading.Lock()
42
  def __init__(self, key, model_name, **kwargs):
43
  """
44
  If you have trouble downloading HuggingFace models, -_^ this might help!!
 
52
 
53
  """
54
  if not DefaultRerank._model:
55
+ with DefaultRerank._model_lock:
56
+ if not DefaultRerank._model:
57
+ try:
58
+ DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available())
59
+ except Exception as e:
60
+ model_dir = snapshot_download(repo_id= model_name,
61
+ local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
62
+ local_dir_use_symlinks=False)
63
+ DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
64
+ self._model = DefaultRerank._model
65
 
66
  def similarity(self, query: str, texts: list):
67
  pairs = [(query,truncate(t, 2048)) for t in texts]