roc king 王志鹏 Kevin Hu commited on
Commit
3256beb
·
1 Parent(s): 31f09e1

exstract model dir from model‘s full name (#3368)

Browse files

### What problem does this PR solve?

When model’s group name contains 0-9,we can't find downloaded
model,because we do not correctly exstract model dir's name from model‘s
full name

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: 王志鹏 <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>

rag/llm/embedding_model.py CHANGED
@@ -66,12 +66,12 @@ class DefaultEmbedding(Base):
66
  import torch
67
  if not DefaultEmbedding._model:
68
  try:
69
- DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
70
  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
71
  use_fp16=torch.cuda.is_available())
72
  except Exception:
73
  model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
74
- local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
75
  local_dir_use_symlinks=False)
76
  DefaultEmbedding._model = FlagModel(model_dir,
77
  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
 
66
  import torch
67
  if not DefaultEmbedding._model:
68
  try:
69
+ DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
70
  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
71
  use_fp16=torch.cuda.is_available())
72
  except Exception:
73
  model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
74
+ local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
75
  local_dir_use_symlinks=False)
76
  DefaultEmbedding._model = FlagModel(model_dir,
77
  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
rag/llm/rerank_model.py CHANGED
@@ -65,12 +65,12 @@ class DefaultRerank(Base):
65
  if not DefaultRerank._model:
66
  try:
67
  DefaultRerank._model = FlagReranker(
68
- os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
69
  use_fp16=torch.cuda.is_available())
70
  except Exception:
71
  model_dir = snapshot_download(repo_id=model_name,
72
  local_dir=os.path.join(get_home_cache_dir(),
73
- re.sub(r"^[a-zA-Z]+/", "", model_name)),
74
  local_dir_use_symlinks=False)
75
  DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
76
  self._model = DefaultRerank._model
@@ -130,7 +130,7 @@ class YoudaoRerank(DefaultRerank):
130
  logger.info("LOADING BCE...")
131
  YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
132
  get_home_cache_dir(),
133
- re.sub(r"^[a-zA-Z]+/", "", model_name)))
134
  except Exception:
135
  YoudaoRerank._model = RerankerModel(
136
  model_name_or_path=model_name.replace(
 
65
  if not DefaultRerank._model:
66
  try:
67
  DefaultRerank._model = FlagReranker(
68
+ os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
69
  use_fp16=torch.cuda.is_available())
70
  except Exception:
71
  model_dir = snapshot_download(repo_id=model_name,
72
  local_dir=os.path.join(get_home_cache_dir(),
73
+ re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
74
  local_dir_use_symlinks=False)
75
  DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
76
  self._model = DefaultRerank._model
 
130
  logger.info("LOADING BCE...")
131
  YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
132
  get_home_cache_dir(),
133
+ re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
134
  except Exception:
135
  YoudaoRerank._model = RerankerModel(
136
  model_name_or_path=model_name.replace(