Update README.md
Browse files
README.md
CHANGED
|
@@ -2684,7 +2684,7 @@ print(embeddings)
|
|
| 2684 |
|
| 2685 |
### Transformers
|
| 2686 |
|
| 2687 |
-
```
|
| 2688 |
import torch
|
| 2689 |
import torch.nn.functional as F
|
| 2690 |
from transformers import AutoTokenizer, AutoModel
|
|
@@ -2702,10 +2702,13 @@ model.eval()
|
|
| 2702 |
|
| 2703 |
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
| 2704 |
|
|
|
|
|
|
|
| 2705 |
with torch.no_grad():
|
| 2706 |
model_output = model(**encoded_input)
|
| 2707 |
|
| 2708 |
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
|
|
|
| 2709 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 2710 |
print(embeddings)
|
| 2711 |
```
|
|
|
|
| 2684 |
|
| 2685 |
### Transformers
|
| 2686 |
|
| 2687 |
+
```diff
|
| 2688 |
import torch
|
| 2689 |
import torch.nn.functional as F
|
| 2690 |
from transformers import AutoTokenizer, AutoModel
|
|
|
|
| 2702 |
|
| 2703 |
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
| 2704 |
|
| 2705 |
+
+ matryoshka_dim = 512
|
| 2706 |
+
|
| 2707 |
with torch.no_grad():
|
| 2708 |
model_output = model(**encoded_input)
|
| 2709 |
|
| 2710 |
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
| 2711 |
+
+ embeddings = embeddings[:, :matryoshka_dim]
|
| 2712 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 2713 |
print(embeddings)
|
| 2714 |
```
|