Dejiao Z
commited on
Commit
·
8942c26
1
Parent(s):
5315eb6
initial commit
Browse files- .ipynb_checkpoints/modules-checkpoint.json +20 -0
- .ipynb_checkpoints/sentence_bert_config-checkpoint.json +4 -0
- 1_Pooling/.ipynb_checkpoints/config-checkpoint.json +7 -0
- 1_Pooling/config.json +7 -0
- README.md +113 -3
- config.json +26 -0
- config_codext.py +51 -0
- modeling_codext.py +425 -0
- modules.json +20 -0
- pytorch_model.bin +3 -0
- sentence_bert_config.json +4 -0
- tokenization_codext.py +337 -0
- tokenizer_config.json +12 -0
.ipynb_checkpoints/modules-checkpoint.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Normalize",
|
18 |
+
"type": "sentence_transformers.models.Normalize"
|
19 |
+
}
|
20 |
+
]
|
.ipynb_checkpoints/sentence_bert_config-checkpoint.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 1024,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
1_Pooling/.ipynb_checkpoints/config-checkpoint.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 1024,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false
|
7 |
+
}
|
1_Pooling/config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 1536,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false
|
7 |
+
}
|
README.md
CHANGED
@@ -1,3 +1,113 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
datasets:
|
4 |
+
- bigcode/the-stack-v2
|
5 |
+
- tiiuae/falcon-refinedweb
|
6 |
+
|
7 |
+
library_name: transformers
|
8 |
+
language:
|
9 |
+
- code
|
10 |
+
- en
|
11 |
+
---
|
12 |
+
|
13 |
+
## SageLite-l
|
14 |
+
|
15 |
+
### Model Description
|
16 |
+
SageLite is a new family of open embedding models with an encoder architecture that supports a wide range of tasks in both code and text. SageLite went through three stages of training:
|
17 |
+
1. **MLM Pretraining**: Standard masked language model (MLM) pretraining on mixed code and text data ([The-Stack-v2](https://huggingface.co/datasets/bigcode/the-stack-v2) and [Falcon-refinedweb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)).
|
18 |
+
2. **Contrastive Pre-Finetuning**: Learning from a large amount of positive pairs mined from web data and GitHub.
|
19 |
+
3. **Contrastive Fine-Tuning**: Fine-tuning on a small amount of synthetic data.
|
20 |
+
|
21 |
+
---
|
22 |
+
|
23 |
+
### **Code Retrieval Performance**
|
24 |
+
|
25 |
+
#### 1. Code2Code Search
|
26 |
+
|
27 |
+
| Model Name | # Params | Embd Dim | Python | Java | JS | TS | C# | C | Ruby | PhP | GO | AVG |
|
28 |
+
|---------------------|----------|----------|--------|-------|-------|--------|--------|--------|--------|--------|--------|--------|
|
29 |
+
| OpenAI-Code-01 | NA | 3072 | 21.92 | 8.90 | 4.90 | 5.70 | 3.15 | 11.58 | 26.25 | 16.60 | 9.40 | 12.04 |
|
30 |
+
| OpenAI-Text-3-Small | NA | 1536 | 25.18 | 12.61 | 8.00 | 9.44 | 5.46 | 15.86 | 30.70 | 23.33 | 11.20 | 15.57 |
|
31 |
+
| OpenAI-Text-3-Large | NA | 3072 | 40.57 | 25.33 | 20.09 | 22.00 | 11.84 | 31.90 | 42.54 | 41.84 | 21.75 | 28.65 |
|
32 |
+
| CodeSage-v2-Small | 130M | 1024 | 45.60 | 33.65 | 39.96 | 47.78 | 19.19 | 30.55 | 40.12 | 55.39 | 30.96 | 38.13 |
|
33 |
+
| CodeSage-v2-Base | 356M | 1024 | 55.86 | 42.89 | 45.29 | 54.58 | 23.90 | 38.52 | 56.02 | 64.56 | 42.88 | 47.17 |
|
34 |
+
| CodeSage-v2-Large | 1.3B | 2048 | 61.11 | 47.09 | 51.18 | 60.67 | 28.04 | 43.40 | 60.74 | 67.87 | 43.86 | 51.55 |
|
35 |
+
| SageLite-s | 80M | 768 | 47.93 | 30.83 | 35.15 | 37.64 | 18.14 | 30.53 | 42.89 | 50.70 | 21.69 | 35.06 |
|
36 |
+
| SageLite-l | 850M | 1536 | 64.46 | 45.53 | 50.80 | 54.71 | 30.66 | 47.46 | 61.01 | 68.68 | 39.25 | 51.40 |
|
37 |
+
|
38 |
+
#### 2. NL2Code Search
|
39 |
+
|
40 |
+
| Model Name | # Params | CoSQA | AdvTest | Python | Java | JS | PhP | GO | Ruby | Avg |
|
41 |
+
|---------------------|----------|-------|---------|--------|-------|-------|--------|--------|--------|--------|
|
42 |
+
| OpenAI-Code-01 | NA | 52.20 | 36.03 | 63.13 | 67.85 | 62.30 | 57.47 | 85.22 | 69.28 | 61.69 |
|
43 |
+
| OpenAI-Text-3-Small | NA | 52.48 | 34.10 | 62.62 | 65.87 | 60.28 | 54.85 | 81.96 | 67.57 | 59.97 |
|
44 |
+
| OpenAI-Text-3-Large | NA | 55.21 | 46.83 | 70.81 | 72.89 | 68.12 | 59.58 | 87.60 | 75.22 | 67.03 |
|
45 |
+
| CodeSage-v2-Small | 130M | 52.39 | 47.28 | 68.79 | 68.13 | 65.77 | 60.20 | 80.26 | 72.46 | 64.41 |
|
46 |
+
| CodeSage-v2-Base | 356M | 50.74 | 52.00 | 70.46 | 70.89 | 69.61 | 62.81 | 82.37 | 73.71 | 66.57 |
|
47 |
+
| CodeSage-v2-Large | 1.3B | 53.18 | 56.31 | 74.18 | 72.33 | 72.49 | 65.26 | 84.67 | 76.61 | 69.38 |
|
48 |
+
| SageLite-s | 80M | 56.49 | 42.32 | 67.59 | 66.62 | 62.32 | 58.87 | 79.36 | 70.75 | 63.04 |
|
49 |
+
| SageLite-l | 850M | 59.76 | 55.55 | 74.25 | 71.76 | 69.35 | 61.62 | 84.09 | 77.14 | 69.19 |
|
50 |
+
|
51 |
+
---
|
52 |
+
|
53 |
+
### **Text Retrieval Performance ([MTEB Retrieval](https://huggingface.co/spaces/mteb/leaderboard))**
|
54 |
+
|
55 |
+
| Metric | SageLite-s | SageLite-l |
|
56 |
+
|-------------------------------|------------|------------|
|
57 |
+
| ArguAna | 57.75 | 60.706 |
|
58 |
+
| CQADupstackWordpressRetrieval | 32.42 | 38.625 |
|
59 |
+
| FiQA2018 | 34.85 | 46.729 |
|
60 |
+
| NFCorpus | 29.97 | 33.698 |
|
61 |
+
| QuoraRetrieval | 85.35 | 87.497 |
|
62 |
+
| SCIDOCS | 18.99 | 21.379 |
|
63 |
+
| SciFact | 68.43 | 69.050 |
|
64 |
+
| Touche2020 | 24.41 | 21.425 |
|
65 |
+
| TRECCOVID | 70.88 | 76.078 |
|
66 |
+
| FEVER | 71.72 | 73.644 |
|
67 |
+
| HotpotQA | 58.81 | 62.955 |
|
68 |
+
| NQ | 48.26 | 54.478 |
|
69 |
+
| DBPedia | 34.83 | 40.689 |
|
70 |
+
| ClimateFEVER | 25.69 | 26.198 |
|
71 |
+
| MSMARCO | 35.01 | 36.546 |
|
72 |
+
| average | 46.49 | 49.980 |
|
73 |
+
|
74 |
+
---
|
75 |
+
|
76 |
+
### **Training Data**
|
77 |
+
This checkpoint is trained on both [The-Stack-v2](https://huggingface.co/datasets/bigcode/the-stack-v2) and [Falcon-refinedweb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb). Supported languages (15 in total) are: English, C, C#, Go, Java, JavaScript, TypeScript, PHP, Python, and Ruby.
|
78 |
+
|
79 |
+
---
|
80 |
+
|
81 |
+
### **Training Procedure**
|
82 |
+
This checkpoint was trained using the following procedure:
|
83 |
+
1. **MLM Pretraining**: Masked language modeling on code data.
|
84 |
+
2. **Contrastive Pre-Finetuning**: Using large-scale positive pairs mined from web and GitHub data.
|
85 |
+
3. **Contrastive Fine-Tuning**: Using a small amount of synthetic data.
|
86 |
+
|
87 |
+
---
|
88 |
+
|
89 |
+
### **How to Use**
|
90 |
+
This checkpoint consists of an encoder (850M model) that extracts code embeddings of 768 dimensions. It can be loaded using the Hugging Face Transformers library and employs the [Starcoder Tokenizer](https://arxiv.org/pdf/2305.06161.pdf).
|
91 |
+
|
92 |
+
#### Pre-requisite
|
93 |
+
Please install OpenAI tiktoken for the tokenizer.
|
94 |
+
|
95 |
+
```
|
96 |
+
pip install tiktoken>=0.4.0
|
97 |
+
```
|
98 |
+
|
99 |
+
```python
|
100 |
+
from transformers import AutoModel, AutoTokenizer
|
101 |
+
|
102 |
+
# Specify the checkpoint
|
103 |
+
checkpoint = "SageLite/SageLite-l"
|
104 |
+
device = "cuda" # Use "cpu" if GPU is unavailable
|
105 |
+
|
106 |
+
# Load tokenizer and model
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, add_eos_token=True)
|
108 |
+
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)
|
109 |
+
|
110 |
+
# Example usage
|
111 |
+
code_snippet = "def print_hello_world():\tprint('Hello World!')"
|
112 |
+
inputs = tokenizer.encode(code_snippet, return_tensors="pt").to(device)
|
113 |
+
embedding = model(inputs)[0] # Extract the embedding
|
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "SageLite/SageLite-l",
|
3 |
+
"architectures": [
|
4 |
+
"SageLite"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "config_sagelite.SageLiteConfig",
|
8 |
+
"AutoTokenizer": "tokenization_sagelite.SageLiteTokenizer",
|
9 |
+
"AutoModel": "modeling_sagelite.SageLiteModel",
|
10 |
+
"AutoModelForMaskedLM": "modeling_sagelite.SageLiteForMaskedLM",
|
11 |
+
"AutoModelForSequenceClassification": "modeling_sagelite.SageLiteForSequenceClassification"
|
12 |
+
},
|
13 |
+
"activation_function": "gelu_new",
|
14 |
+
"attention_dropout_prob": 0.1,
|
15 |
+
"embedding_dropout_prob": 0.1,
|
16 |
+
"initializer_range": 0.02,
|
17 |
+
"layer_norm_epsilon": 1e-05,
|
18 |
+
"hidden_size": 1536,
|
19 |
+
"num_attention_heads": 12,
|
20 |
+
"num_hidden_layers": 24,
|
21 |
+
"intermediate_size": 6144,
|
22 |
+
"max_position_embeddings": 2048,
|
23 |
+
"residual_dropout_prob": 0.1,
|
24 |
+
"vocab_size": 100318
|
25 |
+
}
|
26 |
+
|
config_codext.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
CODESAGE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
8 |
+
"SageLite/SageLite-s": "https://huggingface.co/SageLite/SageLite-s/resolve/main/config.json",
|
9 |
+
"SageLite/SageLite-l": "https://huggingface.co/SageLite/SageLite-l/resolve/main/config.json",
|
10 |
+
}
|
11 |
+
|
12 |
+
|
13 |
+
class SageLiteConfig(PretrainedConfig):
|
14 |
+
model_type = "SageLite"
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
vocab_size=100318,
|
19 |
+
max_position_embeddings=2048,
|
20 |
+
hidden_size=1536,
|
21 |
+
num_hidden_layers=24,
|
22 |
+
num_attention_heads=12,
|
23 |
+
intermediate_size=6144,
|
24 |
+
activation_function="gelu_new",
|
25 |
+
residual_dropout_prob=0.1,
|
26 |
+
embedding_dropout_prob=0.1,
|
27 |
+
attention_dropout_prob=0.1,
|
28 |
+
layer_norm_epsilon=1e-5,
|
29 |
+
initializer_range=0.02,
|
30 |
+
position_embedding_type='absolute',
|
31 |
+
bos_token_id=100257,
|
32 |
+
eos_token_id=100257,
|
33 |
+
pad_token_id=100317,
|
34 |
+
**kwargs
|
35 |
+
):
|
36 |
+
self.vocab_size = vocab_size
|
37 |
+
self.max_position_embeddings = max_position_embeddings
|
38 |
+
self.hidden_size = hidden_size
|
39 |
+
self.num_hidden_layers = num_hidden_layers
|
40 |
+
self.num_attention_heads = num_attention_heads
|
41 |
+
self.intermediate_size = intermediate_size
|
42 |
+
assert 'gelu' in activation_function
|
43 |
+
self.activation_function = activation_function
|
44 |
+
self.residual_dropout_prob = residual_dropout_prob
|
45 |
+
self.embedding_dropout_prob = embedding_dropout_prob
|
46 |
+
self.attention_dropout_prob = attention_dropout_prob
|
47 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
48 |
+
self.initializer_range = initializer_range
|
49 |
+
self.position_embedding_type = position_embedding_type
|
50 |
+
|
51 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
modeling_codext.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.utils.checkpoint
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
|
10 |
+
from transformers.activations import ACT2FN
|
11 |
+
from transformers.modeling_utils import Conv1D, PreTrainedModel
|
12 |
+
from transformers.utils import logging
|
13 |
+
from .config_sagelite import SageLiteConfig
|
14 |
+
from transformers.modeling_outputs import (
|
15 |
+
BaseModelOutputWithPooling,
|
16 |
+
MaskedLMOutput,
|
17 |
+
SequenceClassifierOutput
|
18 |
+
)
|
19 |
+
|
20 |
+
logger = logging.get_logger(__name__)
|
21 |
+
|
22 |
+
SAGELITE_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
23 |
+
"SageLite/SageLite-s",
|
24 |
+
"SageLite/SageLite-l",
|
25 |
+
# See all SageLite models at https://huggingface.co/models?filter=SageLite
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
class SageLiteAttention(nn.Module):
|
30 |
+
def __init__(self, config):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.hidden_size = config.hidden_size
|
34 |
+
self.num_heads = config.num_attention_heads
|
35 |
+
self.head_dim = config.hidden_size // self.num_heads
|
36 |
+
if self.head_dim * self.num_heads != config.hidden_size:
|
37 |
+
raise ValueError(
|
38 |
+
f"`hidden_size` must be divisible by num_heads "
|
39 |
+
f"(got `hidden_size`: {config.hidden_size} and `num_heads`: {self.num_heads})."
|
40 |
+
)
|
41 |
+
|
42 |
+
self.c_attn = Conv1D(3 * self.hidden_size, self.hidden_size)
|
43 |
+
self.c_proj = Conv1D(self.hidden_size, self.hidden_size)
|
44 |
+
|
45 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout_prob)
|
46 |
+
self.residual_dropout = nn.Dropout(config.residual_dropout_prob)
|
47 |
+
|
48 |
+
def attn(self, query, key, value, attention_mask=None, head_mask=None):
|
49 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
50 |
+
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
51 |
+
if attention_mask is not None:
|
52 |
+
attn_weights = attn_weights + attention_mask
|
53 |
+
|
54 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
55 |
+
attn_weights = self.attention_dropout(attn_weights)
|
56 |
+
if head_mask is not None:
|
57 |
+
attn_weights = attn_weights * head_mask
|
58 |
+
|
59 |
+
attn_output = torch.matmul(attn_weights, value)
|
60 |
+
return attn_output, attn_weights
|
61 |
+
|
62 |
+
def split_heads(self, tensor, num_heads, attn_head_size):
|
63 |
+
"""
|
64 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
65 |
+
"""
|
66 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
67 |
+
tensor = tensor.view(*new_shape)
|
68 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
69 |
+
|
70 |
+
def merge_heads(self, tensor, num_heads, attn_head_size):
|
71 |
+
"""
|
72 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
73 |
+
"""
|
74 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
75 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
76 |
+
return tensor.view(new_shape)
|
77 |
+
|
78 |
+
def forward(
|
79 |
+
self,
|
80 |
+
hidden_states,
|
81 |
+
attention_mask=None,
|
82 |
+
head_mask=None,
|
83 |
+
output_attentions=False,
|
84 |
+
):
|
85 |
+
query, key, value = self.c_attn(hidden_states).split(self.hidden_size, dim=2)
|
86 |
+
query = self.split_heads(query, self.num_heads, self.head_dim)
|
87 |
+
key = self.split_heads(key, self.num_heads, self.head_dim)
|
88 |
+
value = self.split_heads(value, self.num_heads, self.head_dim)
|
89 |
+
|
90 |
+
attn_output, attn_weights = self.attn(query, key, value, attention_mask, head_mask)
|
91 |
+
|
92 |
+
attn_output = self.merge_heads(attn_output, self.num_heads, self.head_dim)
|
93 |
+
attn_output = self.c_proj(attn_output)
|
94 |
+
attn_output = self.residual_dropout(attn_output)
|
95 |
+
|
96 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
97 |
+
return outputs # a, present, (attentions)
|
98 |
+
|
99 |
+
|
100 |
+
class SageLiteMLP(nn.Module):
|
101 |
+
def __init__(self, intermediate_size, config):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
self.c_fc = Conv1D(intermediate_size, config.hidden_size)
|
105 |
+
self.act = ACT2FN[config.activation_function]
|
106 |
+
self.c_proj = Conv1D(config.hidden_size, intermediate_size)
|
107 |
+
self.dropout = nn.Dropout(config.residual_dropout_prob)
|
108 |
+
|
109 |
+
def forward(self, hidden_states):
|
110 |
+
hidden_states = self.c_fc(hidden_states)
|
111 |
+
hidden_states = self.act(hidden_states)
|
112 |
+
hidden_states = self.c_proj(hidden_states)
|
113 |
+
hidden_states = self.dropout(hidden_states)
|
114 |
+
return hidden_states
|
115 |
+
|
116 |
+
|
117 |
+
class SageLiteBlock(nn.Module):
|
118 |
+
def __init__(self, config):
|
119 |
+
super().__init__()
|
120 |
+
hidden_size = config.hidden_size
|
121 |
+
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
|
122 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
123 |
+
self.attn = SageLiteAttention(config)
|
124 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
125 |
+
self.mlp = SageLiteMLP(inner_dim, config)
|
126 |
+
|
127 |
+
def forward(
|
128 |
+
self,
|
129 |
+
hidden_states,
|
130 |
+
attention_mask=None,
|
131 |
+
head_mask=None,
|
132 |
+
output_attentions=False,
|
133 |
+
):
|
134 |
+
residual = hidden_states
|
135 |
+
hidden_states = self.ln_1(hidden_states)
|
136 |
+
attn_outputs = self.attn(
|
137 |
+
hidden_states,
|
138 |
+
attention_mask=attention_mask,
|
139 |
+
head_mask=head_mask,
|
140 |
+
output_attentions=output_attentions
|
141 |
+
)
|
142 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
143 |
+
outputs = attn_outputs[1:]
|
144 |
+
hidden_states = attn_output + residual
|
145 |
+
|
146 |
+
residual = hidden_states
|
147 |
+
hidden_states = self.ln_2(hidden_states)
|
148 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
149 |
+
hidden_states = residual + feed_forward_hidden_states
|
150 |
+
|
151 |
+
outputs = (hidden_states,) + outputs[1:]
|
152 |
+
return outputs # hidden_states, present, (attentions)
|
153 |
+
|
154 |
+
|
155 |
+
class SageLitePreTrainedModel(PreTrainedModel):
|
156 |
+
config_class = SageLiteConfig
|
157 |
+
base_model_prefix = "transformer"
|
158 |
+
|
159 |
+
def _init_weights(self, module):
|
160 |
+
"""Initialize the weights."""
|
161 |
+
if isinstance(module, (nn.Linear, Conv1D)):
|
162 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
163 |
+
if module.bias is not None:
|
164 |
+
module.bias.data.zero_()
|
165 |
+
elif isinstance(module, nn.Embedding):
|
166 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
167 |
+
if module.padding_idx is not None:
|
168 |
+
module.weight.data[module.padding_idx].zero_()
|
169 |
+
elif isinstance(module, nn.LayerNorm):
|
170 |
+
module.bias.data.zero_()
|
171 |
+
module.weight.data.fill_(1.0)
|
172 |
+
|
173 |
+
|
174 |
+
class SageLiteModel(SageLitePreTrainedModel):
|
175 |
+
def __init__(self, config):
|
176 |
+
super().__init__(config)
|
177 |
+
|
178 |
+
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
|
179 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
180 |
+
|
181 |
+
self.drop = nn.Dropout(config.embedding_dropout_prob)
|
182 |
+
self.h = nn.ModuleList([SageLiteBlock(config) for _ in range(config.num_hidden_layers)])
|
183 |
+
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
184 |
+
|
185 |
+
self.init_weights()
|
186 |
+
|
187 |
+
def get_input_embeddings(self):
|
188 |
+
return self.wte
|
189 |
+
|
190 |
+
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
191 |
+
self.wte = new_embeddings
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
input_ids=None,
|
196 |
+
attention_mask=None,
|
197 |
+
position_ids=None,
|
198 |
+
head_mask=None,
|
199 |
+
inputs_embeds=None,
|
200 |
+
output_attentions=None,
|
201 |
+
output_hidden_states=None,
|
202 |
+
return_dict=None
|
203 |
+
):
|
204 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
205 |
+
output_hidden_states = (
|
206 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
207 |
+
)
|
208 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
209 |
+
|
210 |
+
if input_ids is not None and inputs_embeds is not None:
|
211 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
212 |
+
if input_ids is not None:
|
213 |
+
input_shape = input_ids.size()
|
214 |
+
elif inputs_embeds is not None:
|
215 |
+
input_shape = inputs_embeds.size()[:-1]
|
216 |
+
else:
|
217 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
218 |
+
|
219 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
220 |
+
if position_ids is None:
|
221 |
+
position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
|
222 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
223 |
+
else:
|
224 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
225 |
+
|
226 |
+
extended_attention_mask = None
|
227 |
+
if attention_mask is not None:
|
228 |
+
assert attention_mask.dim() == 2
|
229 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
230 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
231 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
232 |
+
|
233 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
234 |
+
if inputs_embeds is None:
|
235 |
+
inputs_embeds = self.wte(input_ids)
|
236 |
+
|
237 |
+
position_embeds = self.wpe(position_ids)
|
238 |
+
hidden_states = inputs_embeds + position_embeds
|
239 |
+
|
240 |
+
hidden_states = self.drop(hidden_states)
|
241 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
242 |
+
|
243 |
+
all_self_attentions = () if output_attentions else None
|
244 |
+
all_hidden_states = () if output_hidden_states else None
|
245 |
+
for i, block in enumerate(self.h):
|
246 |
+
if output_hidden_states:
|
247 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
248 |
+
|
249 |
+
outputs = block(
|
250 |
+
hidden_states,
|
251 |
+
attention_mask=extended_attention_mask,
|
252 |
+
head_mask=head_mask[i],
|
253 |
+
output_attentions=output_attentions,
|
254 |
+
)
|
255 |
+
|
256 |
+
hidden_states = outputs[0]
|
257 |
+
if output_attentions:
|
258 |
+
all_self_attentions = all_self_attentions + (outputs[1],)
|
259 |
+
|
260 |
+
hidden_states = self.ln_f(hidden_states)
|
261 |
+
hidden_states = hidden_states.view(*output_shape)
|
262 |
+
if output_hidden_states:
|
263 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
264 |
+
|
265 |
+
pooled_output = None # max-pooled output
|
266 |
+
if attention_mask is not None:
|
267 |
+
pooled_output = (hidden_states * attention_mask[:, :, None]).sum(1) / attention_mask.sum(1)[:, None]
|
268 |
+
|
269 |
+
if not return_dict:
|
270 |
+
return tuple(
|
271 |
+
v
|
272 |
+
for v in [hidden_states, pooled_output, all_hidden_states, all_self_attentions]
|
273 |
+
if v is not None
|
274 |
+
)
|
275 |
+
|
276 |
+
return BaseModelOutputWithPooling(
|
277 |
+
last_hidden_state=hidden_states,
|
278 |
+
pooler_output=pooled_output,
|
279 |
+
hidden_states=all_hidden_states,
|
280 |
+
attentions=all_self_attentions
|
281 |
+
)
|
282 |
+
|
283 |
+
|
284 |
+
class SageLiteForMaskedLM(SageLitePreTrainedModel):
|
285 |
+
_tied_weights_keys = ["lm_head.weight"]
|
286 |
+
|
287 |
+
def __init__(self, config):
|
288 |
+
super().__init__(config)
|
289 |
+
self.transformer = SageLiteModel(config)
|
290 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
291 |
+
|
292 |
+
self.init_weights()
|
293 |
+
|
294 |
+
def get_output_embeddings(self):
|
295 |
+
return self.lm_head
|
296 |
+
|
297 |
+
def set_output_embeddings(self, new_embeddings):
|
298 |
+
self.lm_head = new_embeddings
|
299 |
+
|
300 |
+
def forward(
|
301 |
+
self,
|
302 |
+
input_ids=None,
|
303 |
+
attention_mask=None,
|
304 |
+
position_ids=None,
|
305 |
+
head_mask=None,
|
306 |
+
inputs_embeds=None,
|
307 |
+
labels=None,
|
308 |
+
output_attentions=None,
|
309 |
+
output_hidden_states=None,
|
310 |
+
return_dict=None
|
311 |
+
):
|
312 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
313 |
+
|
314 |
+
transformer_outputs = self.transformer(
|
315 |
+
input_ids,
|
316 |
+
attention_mask=attention_mask,
|
317 |
+
position_ids=position_ids,
|
318 |
+
head_mask=head_mask,
|
319 |
+
inputs_embeds=inputs_embeds,
|
320 |
+
output_attentions=output_attentions,
|
321 |
+
output_hidden_states=output_hidden_states,
|
322 |
+
return_dict=return_dict
|
323 |
+
)
|
324 |
+
hidden_states = transformer_outputs[0]
|
325 |
+
lm_logits = self.lm_head(hidden_states)
|
326 |
+
|
327 |
+
masked_lm_loss = None
|
328 |
+
if labels is not None:
|
329 |
+
loss_fct = CrossEntropyLoss()
|
330 |
+
masked_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
331 |
+
|
332 |
+
if not return_dict:
|
333 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
334 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
335 |
+
|
336 |
+
return MaskedLMOutput(
|
337 |
+
loss=masked_lm_loss,
|
338 |
+
logits=lm_logits,
|
339 |
+
hidden_states=transformer_outputs.hidden_states,
|
340 |
+
attentions=transformer_outputs.attentions,
|
341 |
+
)
|
342 |
+
|
343 |
+
|
344 |
+
class SageLiteForSequenceClassification(SageLitePreTrainedModel):
|
345 |
+
|
346 |
+
def __init__(self, config):
|
347 |
+
super().__init__(config)
|
348 |
+
self.num_labels = config.num_labels
|
349 |
+
self.config = config
|
350 |
+
|
351 |
+
self.transformer = SageLiteModel(config)
|
352 |
+
classifier_dropout = (
|
353 |
+
config.classifier_dropout
|
354 |
+
if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None
|
355 |
+
else config.residual_dropout_prob
|
356 |
+
)
|
357 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
358 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
359 |
+
|
360 |
+
# Initialize weights and apply final processing
|
361 |
+
self.post_init()
|
362 |
+
|
363 |
+
def forward(
|
364 |
+
self,
|
365 |
+
input_ids=None,
|
366 |
+
attention_mask=None,
|
367 |
+
position_ids=None,
|
368 |
+
head_mask=None,
|
369 |
+
inputs_embeds=None,
|
370 |
+
labels=None,
|
371 |
+
output_attentions=None,
|
372 |
+
output_hidden_states=None,
|
373 |
+
return_dict=None,
|
374 |
+
):
|
375 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
376 |
+
assert attention_mask is not None, "attention_mask is needed to perform max-pooling"
|
377 |
+
|
378 |
+
outputs = self.transformer(
|
379 |
+
input_ids,
|
380 |
+
attention_mask=attention_mask,
|
381 |
+
position_ids=position_ids,
|
382 |
+
head_mask=head_mask,
|
383 |
+
inputs_embeds=inputs_embeds,
|
384 |
+
output_attentions=output_attentions,
|
385 |
+
output_hidden_states=output_hidden_states,
|
386 |
+
return_dict=return_dict,
|
387 |
+
)
|
388 |
+
|
389 |
+
pooled_output = outputs[1]
|
390 |
+
pooled_output = self.dropout(pooled_output)
|
391 |
+
logits = self.classifier(pooled_output)
|
392 |
+
|
393 |
+
loss = None
|
394 |
+
if labels is not None:
|
395 |
+
if self.config.problem_type is None:
|
396 |
+
if self.num_labels == 1:
|
397 |
+
self.config.problem_type = "regression"
|
398 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
399 |
+
self.config.problem_type = "single_label_classification"
|
400 |
+
else:
|
401 |
+
self.config.problem_type = "multi_label_classification"
|
402 |
+
|
403 |
+
if self.config.problem_type == "regression":
|
404 |
+
loss_fct = MSELoss()
|
405 |
+
if self.num_labels == 1:
|
406 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
407 |
+
else:
|
408 |
+
loss = loss_fct(logits, labels)
|
409 |
+
elif self.config.problem_type == "single_label_classification":
|
410 |
+
loss_fct = CrossEntropyLoss()
|
411 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
412 |
+
elif self.config.problem_type == "multi_label_classification":
|
413 |
+
loss_fct = BCEWithLogitsLoss()
|
414 |
+
loss = loss_fct(logits, labels)
|
415 |
+
|
416 |
+
if not return_dict:
|
417 |
+
output = (logits,) + outputs[2:]
|
418 |
+
return ((loss,) + output) if loss is not None else output
|
419 |
+
|
420 |
+
return SequenceClassifierOutput(
|
421 |
+
loss=loss,
|
422 |
+
logits=logits,
|
423 |
+
hidden_states=outputs.hidden_states,
|
424 |
+
attentions=outputs.attentions,
|
425 |
+
)
|
modules.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Normalize",
|
18 |
+
"type": "sentence_transformers.models.Normalize"
|
19 |
+
}
|
20 |
+
]
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66d18ff9e119be7ed2c0f39d041e4ff06744b371a88b33032f070dc0d06f0ed9
|
3 |
+
size 1674472633
|
sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 1024,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
tokenization_codext.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
from transformers import PreTrainedTokenizer
|
5 |
+
import tiktoken
|
6 |
+
|
7 |
+
|
8 |
+
# Taken from
|
9 |
+
# https://github.com/huggingface/transformers/blob/8aca43bdb3cb9a5020f6d57589d85679dc873b1c/src/transformers/models/gpt2/tokenization_gpt2.py#L62-L84
|
10 |
+
@lru_cache()
|
11 |
+
def bytes_to_unicode():
|
12 |
+
"""Returns list of utf-8 byte and a mapping to unicode strings.
|
13 |
+
We specifically avoids mapping to whitespace/control characters the bpe code
|
14 |
+
barfs on.
|
15 |
+
The reversible bpe codes work on unicode strings. This means you need a
|
16 |
+
large # of unicode characters in your vocab if you want to avoid UNKs. When
|
17 |
+
you're at something like a 10B token dataset you end up needing around 5K
|
18 |
+
for decent coverage. This is a significant percentage of your normal, say,
|
19 |
+
32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
|
20 |
+
unicode strings.
|
21 |
+
"""
|
22 |
+
bs = (list(range(ord('!'),
|
23 |
+
ord('~') + 1)) + list(range(ord('¡'),
|
24 |
+
ord('¬') + 1)) +
|
25 |
+
list(range(ord('®'),
|
26 |
+
ord('ÿ') + 1)))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8 + n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def add_special_tokens_to_tiktoken(base="cl100k_base", eos_token=None, pad_token=None):
|
39 |
+
def include_dobf_tokens():
|
40 |
+
dobf_tokens = [f"<dobf_special_{i}>" for i in range(18)]
|
41 |
+
return dobf_tokens
|
42 |
+
|
43 |
+
def include_vector_tokens():
|
44 |
+
tokens = []
|
45 |
+
tokens.append("<sep>")
|
46 |
+
tokens.append("<mask>")
|
47 |
+
tokens += [f"<dummy_{i}>" for i in reversed(range(20))]
|
48 |
+
return tokens
|
49 |
+
|
50 |
+
dobf_tokens = include_dobf_tokens()
|
51 |
+
vector_tokens = include_vector_tokens()
|
52 |
+
|
53 |
+
tokenizer = tiktoken.get_encoding(base)
|
54 |
+
idx = tokenizer.n_vocab
|
55 |
+
bpe_ranks = tokenizer._mergeable_ranks
|
56 |
+
special_tokens = dict()
|
57 |
+
|
58 |
+
# print(f"INIT TOKEN SIZE: {idx}, EOS TOKEN: {tokenizer._special_tokens[eos_token]}")
|
59 |
+
if eos_token and eos_token not in tokenizer._special_tokens and eos_token not in special_tokens:
|
60 |
+
special_tokens[eos_token] = idx
|
61 |
+
idx += 1
|
62 |
+
|
63 |
+
for sp in dobf_tokens:
|
64 |
+
special_tokens[sp] = idx
|
65 |
+
idx += 1
|
66 |
+
for sp in vector_tokens:
|
67 |
+
special_tokens[sp] = idx
|
68 |
+
idx += 1
|
69 |
+
|
70 |
+
if pad_token and pad_token not in tokenizer._special_tokens and pad_token not in special_tokens:
|
71 |
+
special_tokens[pad_token] = idx
|
72 |
+
idx += 1
|
73 |
+
# print(f"PAD TOKEN ADDED: {pad_token}")
|
74 |
+
# In production, load the arguments directly instead of accessing private attributes
|
75 |
+
# See openai_public.py for examples of arguments for specific encodings
|
76 |
+
enc = tiktoken.Encoding(
|
77 |
+
# If you're changing the set of special tokens, make sure to use a different name
|
78 |
+
# It should be clear from the name what behaviour to expect.
|
79 |
+
name=base.replace("base", "im"),
|
80 |
+
pat_str=tokenizer._pat_str,
|
81 |
+
mergeable_ranks=bpe_ranks,
|
82 |
+
special_tokens={
|
83 |
+
**tokenizer._special_tokens,
|
84 |
+
**special_tokens
|
85 |
+
}
|
86 |
+
)
|
87 |
+
return enc
|
88 |
+
|
89 |
+
|
90 |
+
class SageLiteTokenizer(PreTrainedTokenizer):
|
91 |
+
"""A thin wrapper around tiktoken to make it compatible with Hugging Face.
|
92 |
+
tokenizers.
|
93 |
+
See HuggingFace for further documentation on general tokenizer methods.
|
94 |
+
"""
|
95 |
+
|
96 |
+
model_input_names = ['input_ids', 'attention_mask']
|
97 |
+
|
98 |
+
def __init__(self,
|
99 |
+
model_name: Optional[str] = None,
|
100 |
+
encoding_name: Optional[str] = "cl100k_base",
|
101 |
+
add_bos_token: bool = False,
|
102 |
+
add_eos_token: bool = False,
|
103 |
+
unk_token: Optional[str] = '<|endoftext|>',
|
104 |
+
eos_token: Optional[str] = '<|endoftext|>',
|
105 |
+
bos_token: Optional[str] = '<|endoftext|>',
|
106 |
+
pad_token: Optional[str] = '<pad>',
|
107 |
+
errors: str = 'replace',
|
108 |
+
**kwargs: Any):
|
109 |
+
"""Constructor creates a tiktoken tokenizer to use as the underlying.
|
110 |
+
tokenizer.
|
111 |
+
Args:
|
112 |
+
model_name (Optional[str], optional): The name of the model to load from tiktoken. Defaults to None.
|
113 |
+
Either model_name or encoding_name must be set, but not both.
|
114 |
+
encoding_name (Optional[str], optional): The name of the encoding to load from tiktoken. Defaults to None.
|
115 |
+
Either model_name or encoding_name must be set, but not both.
|
116 |
+
add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False.
|
117 |
+
add_eos_token (bool, optional): Whether to add eos tokens. Defaults to False.
|
118 |
+
use_default_system_prompt (bool, optional): Use the default system prompt or not. Defaults to False.
|
119 |
+
unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'.
|
120 |
+
eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
|
121 |
+
bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
|
122 |
+
pad_token (Optional[str], optional): The pad token. Defaults to None.
|
123 |
+
errors (str, optional): Paradigm to follow when decoding bytes to UTF-8. See
|
124 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
125 |
+
Defaults to `"replace"`.
|
126 |
+
"""
|
127 |
+
try:
|
128 |
+
import tiktoken
|
129 |
+
except:
|
130 |
+
raise ImportError(
|
131 |
+
'You need to install tiktoken to use TiktokenTokenizerWrapper.')
|
132 |
+
|
133 |
+
# Workaround to make tiktokenizer picklable.
|
134 |
+
# https://github.com/huggingface/datasets/issues/5536#issuecomment-1682309347
|
135 |
+
# There is an open PR from HF to add this to tiktoken: https://github.com/openai/tiktoken/pull/181
|
136 |
+
import copyreg
|
137 |
+
import functools
|
138 |
+
|
139 |
+
from tiktoken import Encoding # type: ignore (thirdParty)
|
140 |
+
|
141 |
+
def pickle_Encoding(enc: Encoding):
|
142 |
+
return (functools.partial(Encoding,
|
143 |
+
enc.name,
|
144 |
+
pat_str=enc._pat_str,
|
145 |
+
mergeable_ranks=enc._mergeable_ranks,
|
146 |
+
special_tokens=enc._special_tokens), ())
|
147 |
+
|
148 |
+
copyreg.pickle(Encoding, pickle_Encoding)
|
149 |
+
|
150 |
+
|
151 |
+
self.encoding = add_special_tokens_to_tiktoken(base=encoding_name, eos_token=eos_token, pad_token=pad_token)
|
152 |
+
|
153 |
+
self.add_bos_token = add_bos_token
|
154 |
+
self.add_eos_token = add_eos_token
|
155 |
+
|
156 |
+
self.byte_encoder = bytes_to_unicode()
|
157 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
158 |
+
self.errors = errors
|
159 |
+
|
160 |
+
self.decoder: Dict[int, str] = {}
|
161 |
+
for i in range(self.encoding.n_vocab):
|
162 |
+
try:
|
163 |
+
self.encoding.decode_single_token_bytes(i)
|
164 |
+
except KeyError:
|
165 |
+
continue
|
166 |
+
# Taken from
|
167 |
+
# https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
|
168 |
+
decoding = ''.join([
|
169 |
+
bytes_to_unicode()[ord(char)] for char in
|
170 |
+
self.encoding.decode_single_token_bytes(i).decode('latin-1')
|
171 |
+
])
|
172 |
+
self.decoder[i] = decoding
|
173 |
+
|
174 |
+
self.encoder: Dict[str, int] = {}
|
175 |
+
for i in range(self.encoding.n_vocab):
|
176 |
+
if i in self.decoder:
|
177 |
+
self.encoder[self.decoder[i]] = i
|
178 |
+
|
179 |
+
super().__init__(model_name=model_name,
|
180 |
+
encoding_name=encoding_name,
|
181 |
+
add_bos_token=add_bos_token,
|
182 |
+
add_eos_token=add_eos_token,
|
183 |
+
unk_token=unk_token,
|
184 |
+
eos_token=eos_token,
|
185 |
+
bos_token=bos_token,
|
186 |
+
pad_token=pad_token,
|
187 |
+
errors=errors,
|
188 |
+
**kwargs)
|
189 |
+
|
190 |
+
@property
|
191 |
+
def vocab_size(self) -> int:
|
192 |
+
"""Returns vocab size."""
|
193 |
+
return self.encoding.n_vocab
|
194 |
+
|
195 |
+
@property
|
196 |
+
def is_fast(self) -> bool:
|
197 |
+
return False
|
198 |
+
|
199 |
+
def get_vocab(self) -> Dict[str, int]:
|
200 |
+
"""Returns vocab as a dict."""
|
201 |
+
# As far as I can tell, we don't require get_vocab to completely work,
|
202 |
+
# but when using additional_special_tokens, Hugging Face determines the next
|
203 |
+
# token index to add with len(self.get_vocab()) so we need the _size_ of this dictionary to be correct.
|
204 |
+
vocab_clone = self.encoder.copy()
|
205 |
+
extra_id_index = 0
|
206 |
+
candidate_extra_id = f'<extra_id_{extra_id_index}>'
|
207 |
+
indices_to_fill_in = {i for i in range(self.vocab_size)} - set(
|
208 |
+
vocab_clone.values())
|
209 |
+
|
210 |
+
# Add enough indices to make get_vocab() the right length
|
211 |
+
for index_to_add in indices_to_fill_in:
|
212 |
+
# Make sure we don't overwrite a token that already exists
|
213 |
+
while candidate_extra_id in vocab_clone:
|
214 |
+
extra_id_index += 1
|
215 |
+
candidate_extra_id = f'<extra_id_{extra_id_index}>'
|
216 |
+
|
217 |
+
# Get an index to add and add the item
|
218 |
+
vocab_clone[candidate_extra_id] = index_to_add
|
219 |
+
|
220 |
+
return vocab_clone
|
221 |
+
|
222 |
+
def _tokenize(self, text: str) -> List[str]:
|
223 |
+
"""Returns a tokenized string."""
|
224 |
+
if not isinstance(text, str):
|
225 |
+
raise ValueError(
|
226 |
+
f'Expected a string input to _tokenize but got {type(text)}.')
|
227 |
+
|
228 |
+
tokens = [
|
229 |
+
self.decoder[t]
|
230 |
+
for t in self.encoding.encode(text, allowed_special='all')
|
231 |
+
]
|
232 |
+
|
233 |
+
return tokens
|
234 |
+
|
235 |
+
def _convert_token_to_id(self, token: str) -> Optional[int]:
|
236 |
+
"""Converts a token (str) in an id using the vocab."""
|
237 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
238 |
+
|
239 |
+
def _convert_id_to_token(self, index: int) -> Optional[str]:
|
240 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
241 |
+
# For tokens in either the gap in ids in the tokenizer, or beyond the range of the tokenizer,
|
242 |
+
# we return empty string. This matches the behavior of Hugging Face fast tokenizers,
|
243 |
+
# but not slow tokenizers.
|
244 |
+
return self.decoder.get(index, '')
|
245 |
+
|
246 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
247 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
248 |
+
text = ''.join(tokens)
|
249 |
+
text = bytearray([self.byte_decoder[c] for c in text
|
250 |
+
]).decode('utf-8', errors=self.errors)
|
251 |
+
return text
|
252 |
+
|
253 |
+
def build_inputs_with_special_tokens(
|
254 |
+
self,
|
255 |
+
token_ids_0: List[int],
|
256 |
+
token_ids_1: Optional[List[int]] = None) -> List[int]:
|
257 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
258 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
259 |
+
|
260 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
261 |
+
|
262 |
+
if token_ids_1 is not None:
|
263 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
264 |
+
|
265 |
+
return output
|
266 |
+
|
267 |
+
def get_special_tokens_mask(
|
268 |
+
self,
|
269 |
+
token_ids_0: List[int],
|
270 |
+
token_ids_1: Optional[List[int]] = None,
|
271 |
+
already_has_special_tokens: bool = False) -> List[int]:
|
272 |
+
"""Retrieves sequence ids from a token list that has no special tokens.
|
273 |
+
Function copied from
|
274 |
+
https://github.com/huggingface/transformers/blob/e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d/src/transformers/models/gpt2/tokenization_gpt2.py#L265-L295
|
275 |
+
added. This method is called when adding special tokens using the
|
276 |
+
tokenizer `prepare_for_model` or `encode_plus` methods.
|
277 |
+
Args:
|
278 |
+
token_ids_0 (`List[int]`):
|
279 |
+
List of IDs.
|
280 |
+
token_ids_1 (`List[int]`, *optional*):
|
281 |
+
Optional second list of IDs for sequence pairs.
|
282 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
283 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
284 |
+
Returns:
|
285 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
286 |
+
"""
|
287 |
+
if already_has_special_tokens:
|
288 |
+
return super().get_special_tokens_mask(
|
289 |
+
token_ids_0=token_ids_0,
|
290 |
+
token_ids_1=token_ids_1,
|
291 |
+
already_has_special_tokens=True)
|
292 |
+
|
293 |
+
bos_token_id = [1] if self.add_bos_token else []
|
294 |
+
eos_token_id = [1] if self.add_eos_token else []
|
295 |
+
|
296 |
+
if token_ids_1 is None:
|
297 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
298 |
+
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
|
299 |
+
bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)
|
300 |
+
|
301 |
+
def create_token_type_ids_from_sequences(
|
302 |
+
self,
|
303 |
+
token_ids_0: List[int],
|
304 |
+
token_ids_1: Optional[List[int]] = None) -> List[int]:
|
305 |
+
sep = [self.sep_token_id]
|
306 |
+
|
307 |
+
if token_ids_1 is None:
|
308 |
+
return len(token_ids_0 + sep) * [0]
|
309 |
+
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
310 |
+
|
311 |
+
def save_vocabulary(self,
|
312 |
+
save_directory: str,
|
313 |
+
filename_prefix: Optional[str] = None) -> Tuple[str]:
|
314 |
+
|
315 |
+
# ignore the below type to keep the original signature
|
316 |
+
# we are knowingly breaking the signature here, although not 100% certain
|
317 |
+
# it doesn't have side effects
|
318 |
+
# There is some code in huggingface that calls this function to get the vocab files,
|
319 |
+
# but it doesn't seem to access them (or at least checks for their existence
|
320 |
+
# before accessing them)
|
321 |
+
return (None, None) # type: ignore
|
322 |
+
|
323 |
+
def sanitize_special_tokens(self) -> int:
|
324 |
+
"""Make sure that all the special tokens attributes of the tokenizer.
|
325 |
+
(`tokenizer.mask_token`, `tokenizer.cls_token`, etc.) are in the
|
326 |
+
vocabulary.
|
327 |
+
Add the missing ones to the vocabulary if needed.
|
328 |
+
Return:
|
329 |
+
`int`: The number of tokens added in the vocabulary during the operation.
|
330 |
+
"""
|
331 |
+
actual_new_tokens = []
|
332 |
+
for token in self.all_special_tokens_extended:
|
333 |
+
encoded = self.encoding.encode(token, allowed_special='all')
|
334 |
+
if len(encoded) > 1:
|
335 |
+
actual_new_tokens.append(token)
|
336 |
+
|
337 |
+
return self.add_tokens(actual_new_tokens, special_tokens=True)
|
tokenizer_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_eos_token": true,
|
3 |
+
"add_special_tokens": true,
|
4 |
+
"clean_up_tokenization_spaces": true,
|
5 |
+
"eos_token": "<|endoftext|>",
|
6 |
+
"model_max_length": 1000000000000000019884624838656,
|
7 |
+
"pad_token": "<pad>",
|
8 |
+
"tokenizer_class": "SageLiteTokenizer",
|
9 |
+
"auto_map": {
|
10 |
+
"AutoTokenizer": ["tokenization_sagelite.SageLiteTokenizer", null]
|
11 |
+
}
|
12 |
+
}
|