zhuhao
zhuhao
commited on
Commit
·
05dad97
1
Parent(s):
1090f98
Fix ragflow may encounter an OOM (Out Of Memory) when there are a lot of conversations (#1292)
Browse files### What problem does this PR solve?
Fix ragflow may encounter an OOM (Out Of Memory) when there are a lot of
conversations.
#1288
### Type of change
- [ ] Bug Fix (non-breaking change which fixes an issue)
Co-authored-by: zhuhao <[email protected]>
- rag/llm/embedding_model.py +16 -12
- rag/llm/rerank_model.py +12 -11
rag/llm/embedding_model.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15 |
#
|
16 |
import re
|
17 |
from typing import Optional
|
|
|
18 |
import requests
|
19 |
from huggingface_hub import snapshot_download
|
20 |
from zhipuai import ZhipuAI
|
@@ -44,7 +45,7 @@ class Base(ABC):
|
|
44 |
|
45 |
class DefaultEmbedding(Base):
|
46 |
_model = None
|
47 |
-
|
48 |
def __init__(self, key, model_name, **kwargs):
|
49 |
"""
|
50 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
@@ -58,17 +59,20 @@ class DefaultEmbedding(Base):
|
|
58 |
|
59 |
"""
|
60 |
if not DefaultEmbedding._model:
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
72 |
|
73 |
def encode(self, texts: list, batch_size=32):
|
74 |
texts = [truncate(t, 2048) for t in texts]
|
|
|
15 |
#
|
16 |
import re
|
17 |
from typing import Optional
|
18 |
+
import threading
|
19 |
import requests
|
20 |
from huggingface_hub import snapshot_download
|
21 |
from zhipuai import ZhipuAI
|
|
|
45 |
|
46 |
class DefaultEmbedding(Base):
|
47 |
_model = None
|
48 |
+
_model_lock = threading.Lock()
|
49 |
def __init__(self, key, model_name, **kwargs):
|
50 |
"""
|
51 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
|
59 |
|
60 |
"""
|
61 |
if not DefaultEmbedding._model:
|
62 |
+
with DefaultEmbedding._model_lock:
|
63 |
+
if not DefaultEmbedding._model:
|
64 |
+
try:
|
65 |
+
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
66 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
67 |
+
use_fp16=torch.cuda.is_available())
|
68 |
+
except Exception as e:
|
69 |
+
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
70 |
+
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
71 |
+
local_dir_use_symlinks=False)
|
72 |
+
DefaultEmbedding._model = FlagModel(model_dir,
|
73 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
74 |
+
use_fp16=torch.cuda.is_available())
|
75 |
+
self._model = DefaultEmbedding._model
|
76 |
|
77 |
def encode(self, texts: list, batch_size=32):
|
78 |
texts = [truncate(t, 2048) for t in texts]
|
rag/llm/rerank_model.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import re
|
|
|
17 |
import requests
|
18 |
import torch
|
19 |
from FlagEmbedding import FlagReranker
|
@@ -37,7 +38,7 @@ class Base(ABC):
|
|
37 |
|
38 |
class DefaultRerank(Base):
|
39 |
_model = None
|
40 |
-
|
41 |
def __init__(self, key, model_name, **kwargs):
|
42 |
"""
|
43 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
@@ -51,16 +52,16 @@ class DefaultRerank(Base):
|
|
51 |
|
52 |
"""
|
53 |
if not DefaultRerank._model:
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
|
65 |
def similarity(self, query: str, texts: list):
|
66 |
pairs = [(query,truncate(t, 2048)) for t in texts]
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import re
|
17 |
+
import threading
|
18 |
import requests
|
19 |
import torch
|
20 |
from FlagEmbedding import FlagReranker
|
|
|
38 |
|
39 |
class DefaultRerank(Base):
|
40 |
_model = None
|
41 |
+
_model_lock = threading.Lock()
|
42 |
def __init__(self, key, model_name, **kwargs):
|
43 |
"""
|
44 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
|
52 |
|
53 |
"""
|
54 |
if not DefaultRerank._model:
|
55 |
+
with DefaultRerank._model_lock:
|
56 |
+
if not DefaultRerank._model:
|
57 |
+
try:
|
58 |
+
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available())
|
59 |
+
except Exception as e:
|
60 |
+
model_dir = snapshot_download(repo_id= model_name,
|
61 |
+
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
62 |
+
local_dir_use_symlinks=False)
|
63 |
+
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
64 |
+
self._model = DefaultRerank._model
|
65 |
|
66 |
def similarity(self, query: str, texts: list):
|
67 |
pairs = [(query,truncate(t, 2048)) for t in texts]
|