roc king
王志鹏
Kevin Hu
commited on
Commit
·
3256beb
1
Parent(s):
31f09e1
exstract model dir from model‘s full name (#3368)
Browse files### What problem does this PR solve?
When model’s group name contains 0-9,we can't find downloaded
model,because we do not correctly exstract model dir's name from model‘s
full name
### Type of change
- [ ] Bug Fix (non-breaking change which fixes an issue)
Co-authored-by: 王志鹏 <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>
- rag/llm/embedding_model.py +2 -2
- rag/llm/rerank_model.py +3 -3
rag/llm/embedding_model.py
CHANGED
@@ -66,12 +66,12 @@ class DefaultEmbedding(Base):
|
|
66 |
import torch
|
67 |
if not DefaultEmbedding._model:
|
68 |
try:
|
69 |
-
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-
|
70 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
71 |
use_fp16=torch.cuda.is_available())
|
72 |
except Exception:
|
73 |
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
74 |
-
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-
|
75 |
local_dir_use_symlinks=False)
|
76 |
DefaultEmbedding._model = FlagModel(model_dir,
|
77 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
|
|
66 |
import torch
|
67 |
if not DefaultEmbedding._model:
|
68 |
try:
|
69 |
+
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
70 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
71 |
use_fp16=torch.cuda.is_available())
|
72 |
except Exception:
|
73 |
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
74 |
+
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
75 |
local_dir_use_symlinks=False)
|
76 |
DefaultEmbedding._model = FlagModel(model_dir,
|
77 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
rag/llm/rerank_model.py
CHANGED
@@ -65,12 +65,12 @@ class DefaultRerank(Base):
|
|
65 |
if not DefaultRerank._model:
|
66 |
try:
|
67 |
DefaultRerank._model = FlagReranker(
|
68 |
-
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-
|
69 |
use_fp16=torch.cuda.is_available())
|
70 |
except Exception:
|
71 |
model_dir = snapshot_download(repo_id=model_name,
|
72 |
local_dir=os.path.join(get_home_cache_dir(),
|
73 |
-
re.sub(r"^[a-zA-
|
74 |
local_dir_use_symlinks=False)
|
75 |
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
76 |
self._model = DefaultRerank._model
|
@@ -130,7 +130,7 @@ class YoudaoRerank(DefaultRerank):
|
|
130 |
logger.info("LOADING BCE...")
|
131 |
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
132 |
get_home_cache_dir(),
|
133 |
-
re.sub(r"^[a-zA-
|
134 |
except Exception:
|
135 |
YoudaoRerank._model = RerankerModel(
|
136 |
model_name_or_path=model_name.replace(
|
|
|
65 |
if not DefaultRerank._model:
|
66 |
try:
|
67 |
DefaultRerank._model = FlagReranker(
|
68 |
+
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
69 |
use_fp16=torch.cuda.is_available())
|
70 |
except Exception:
|
71 |
model_dir = snapshot_download(repo_id=model_name,
|
72 |
local_dir=os.path.join(get_home_cache_dir(),
|
73 |
+
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
74 |
local_dir_use_symlinks=False)
|
75 |
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
76 |
self._model = DefaultRerank._model
|
|
|
130 |
logger.info("LOADING BCE...")
|
131 |
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
132 |
get_home_cache_dir(),
|
133 |
+
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
134 |
except Exception:
|
135 |
YoudaoRerank._model = RerankerModel(
|
136 |
model_name_or_path=model_name.replace(
|