Kaguya-19 commited on
Commit
dc0f82b
·
verified ·
1 Parent(s): 3cc1148

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -3
README.md CHANGED
@@ -344,7 +344,6 @@ When running evaluation on BEIR and C-MTEB/Retrieval, we use instructions in `in
344
 
345
  ```
346
  transformers==4.37.2
347
- flash-attn>2.3.5
348
  ```
349
 
350
  ### 示例脚本 Demo
@@ -358,7 +357,9 @@ import torch.nn.functional as F
358
 
359
  model_name = "openbmb/MiniCPM-Embedding"
360
  tokenizer = AutoTokenizer.from_pretrained(model_name)
361
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
 
 
362
  model.eval()
363
 
364
  # 由于在 `model.forward` 中缩放了最终隐层表示,此处的 mean pooling 实际上起到了 weighted mean pooling 的作用
@@ -402,7 +403,9 @@ import torch
402
  from sentence_transformers import SentenceTransformer
403
 
404
  model_name = "openbmb/MiniCPM-Embedding"
405
- model = SentenceTransformer(model_name, trust_remote_code=True, model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": torch.float16})
 
 
406
 
407
  queries = ["中国的首都是哪里?"]
408
  passages = ["beijing", "shanghai"]
 
344
 
345
  ```
346
  transformers==4.37.2
 
347
  ```
348
 
349
  ### 示例脚本 Demo
 
357
 
358
  model_name = "openbmb/MiniCPM-Embedding"
359
  tokenizer = AutoTokenizer.from_pretrained(model_name)
360
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16).to("cuda")
361
+ # You can also use the following line to enable the Flash Attention 2 implementation
362
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
363
  model.eval()
364
 
365
  # 由于在 `model.forward` 中缩放了最终隐层表示,此处的 mean pooling 实际上起到了 weighted mean pooling 的作用
 
403
  from sentence_transformers import SentenceTransformer
404
 
405
  model_name = "openbmb/MiniCPM-Embedding"
406
+ model = SentenceTransformer(model_name, trust_remote_code=True, model_kwargs={ "torch_dtype": torch.float16})
407
+ # You can also use the following line to enable the Flash Attention 2 implementation
408
+ # model = SentenceTransformer(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", model_kwargs={ "torch_dtype": torch.float16})
409
 
410
  queries = ["中国的首都是哪里?"]
411
  passages = ["beijing", "shanghai"]