twright8 commited on
Commit
05acb8c
·
verified ·
1 Parent(s): 001f6d0

Push model using huggingface_hub.

Browse files
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 1024,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
2_Dense/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"in_features": 1024, "out_features": 1024, "bias": true, "activation_function": "torch.nn.modules.linear.Identity"}
2_Dense/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c6266520f51386deef585fb6daf6948585419e685f3a887079b563cc4031010
3
+ size 4198560
README.md ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: dunzhang/stella_en_400M_v5
3
+ library_name: setfit
4
+ metrics:
5
+ - accuracy
6
+ pipeline_tag: text-classification
7
+ tags:
8
+ - setfit
9
+ - sentence-transformers
10
+ - text-classification
11
+ - generated_from_setfit_trainer
12
+ widget:
13
+ - text: Record-Breaking Heatwave Grips Europe PARIS - Europe is sweltering under an
14
+ unprecedented heatwave, with temperatures soaring above 40degC (104degF) in multiple
15
+ countries. French authorities have issued red alerts for several regions, while
16
+ wildfires rage in Spain and Greece. Experts link the extreme weather to climate
17
+ change.
18
+ - text: Global Coffee Prices Surge Amid Brazilian Drought Coffee futures hit a five-year
19
+ high today as severe drought continues to ravage Brazil's coffee-growing regions.
20
+ Experts warn consumers may see significant price increases in coming months.
21
+ - text: Pharmaceutical Giant Accused of Bribing Doctors to Overprescribe NEW YORK
22
+ - In a shocking turn of events, pharmaceutical behemoth PharmaCore is facing allegations
23
+ of orchestrating a widespread bribery scheme to encourage doctors to overprescribe
24
+ its blockbuster painkiller, OxyContin Plus. An investigation by the DEA uncovered
25
+ evidence of lavish "consulting fees," all-expenses-paid vacations, and other kickbacks
26
+ provided to physicians who met certain prescription quotas. The scheme allegedly
27
+ resulted in thousands of unnecessary prescriptions, potentially fueling the ongoing
28
+ opioid crisis. PharmaCore's stock plummeted 30% following the news. CEO Miranda
29
+ Feltz issued a statement denying any wrongdoing and pledging full cooperation
30
+ with authorities.
31
+ - text: Mental Health Clinic Director Embezzled Millions, Patients Left Without Care
32
+ PORTLAND, OR - The director of New Horizons Mental Health Clinic, Dr. Sarah Jennings,
33
+ has been indicted on charges of embezzling over $3 million intended for patient
34
+ care and facility improvements. The funds were allegedly used to finance a lavish
35
+ lifestyle, including luxury cars and a vacation home in the Bahamas
36
+ - text: When doctors, nurses or health professionals siphon funds or medicines meant
37
+ for patient care, lives are at stake. It's not just about missing money--it's
38
+ about missing medications, outdated equipment, and overworked staff. As a physician,
39
+ I've witnessed firsthand the consequences of corruption in our healthcare system.
40
+ We must demand transparency and accountability at every level, from hospital boards
41
+ to government agencies. Our health depends on it.
42
+ inference: true
43
+ model-index:
44
+ - name: SetFit with dunzhang/stella_en_400M_v5
45
+ results:
46
+ - task:
47
+ type: text-classification
48
+ name: Text Classification
49
+ dataset:
50
+ name: Unknown
51
+ type: unknown
52
+ split: test
53
+ metrics:
54
+ - type: accuracy
55
+ value: 0.6666666666666666
56
+ name: Accuracy
57
+ ---
58
+
59
+ # SetFit with dunzhang/stella_en_400M_v5
60
+
61
+ This is a [SetFit](https://github.com/huggingface/setfit) model that can be used for Text Classification. This SetFit model uses [dunzhang/stella_en_400M_v5](https://huggingface.co/dunzhang/stella_en_400M_v5) as the Sentence Transformer embedding model. A [SetFitHead](huggingface.co/docs/setfit/reference/main#setfit.SetFitHead) instance is used for classification.
62
+
63
+ The model has been trained using an efficient few-shot learning technique that involves:
64
+
65
+ 1. Fine-tuning a [Sentence Transformer](https://www.sbert.net) with contrastive learning.
66
+ 2. Training a classification head with features from the fine-tuned Sentence Transformer.
67
+
68
+ ## Model Details
69
+
70
+ ### Model Description
71
+ - **Model Type:** SetFit
72
+ - **Sentence Transformer body:** [dunzhang/stella_en_400M_v5](https://huggingface.co/dunzhang/stella_en_400M_v5)
73
+ - **Classification head:** a [SetFitHead](huggingface.co/docs/setfit/reference/main#setfit.SetFitHead) instance
74
+ - **Maximum Sequence Length:** 512 tokens
75
+ - **Number of Classes:** 2 classes
76
+ <!-- - **Training Dataset:** [Unknown](https://huggingface.co/datasets/unknown) -->
77
+ <!-- - **Language:** Unknown -->
78
+ <!-- - **License:** Unknown -->
79
+
80
+ ### Model Sources
81
+
82
+ - **Repository:** [SetFit on GitHub](https://github.com/huggingface/setfit)
83
+ - **Paper:** [Efficient Few-Shot Learning Without Prompts](https://arxiv.org/abs/2209.11055)
84
+ - **Blogpost:** [SetFit: Efficient Few-Shot Learning Without Prompts](https://huggingface.co/blog/setfit)
85
+
86
+ ### Model Labels
87
+ | Label | Examples |
88
+ |:------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
89
+ | 1 | <ul><li>'Lucknow: Deputy CM Brajesh Pathak recommends dismissal of 17 govt doctors for absenteeism LUCKNOW: State govt has recommended the dismissal of 17 medical officers after they were found absent from duty for several months. In addition to this, disciplinary action has been ordered against three medical officers.The order was issued by deputy CM Brajesh Pathak who also holds the charge of health and medical education departments, said a govt spokesman on Thursday. In his order, Pathak stated: "No doctor or health worker who is negligent in medical services will be forgiven." tnn \'Committed to high-level health services\'Strict action will be taken against them. The state is committed to providing high-level health services to the people and no laxity on the count will be tolerated," Pathak stated. Three doctors who will face disciplinary action are Dr Mukul Mishra, orthopedic specialist of District Hospital, Jhansi; Dr Madhavi Singh, ophthalmologist posted at Community Health Centre, Fatehpur, Barabanki and Dr Pramod Kumar Sharma under Chief Medical Officer, Bareilly.'</li><li>"Kerala model therapy: Govt gives 56 absentee doctors 'show-cause pill' Thiruvananthapuram: The state health and family welfare department has issued show-cause notice to 56 doctors who have been on unauthorised absence in various medical colleges and pharmacy colleges in Kerala. In the notice issued by Rajan Khobragade, additional chief secretary, health and family welfare department, the doctors have been directed to report for duty before the ACS at the secretariat within 15 days."</li><li>'42% of Nigerian Doctors, Nurse Demand Bribes Before Attending to Patients - NBS Reports The National Bureau of Statistics (NBS) recently published a report titled "NBS Corruption in Nigeria: Patterns and Trend" for 2023, revealing concerning statistics about corruption in the healthcare sector. According to the report, two-thirds of Nigerian doctors, nurses, and midwives demand bribes from patients before providing treatment. Additionally, 42 percent of these health workers accept bribes to expedite procedures, while 15 percent take bribes to ensure the completion of medical procedures. It, however, added that 11 per cent were paid bribes as a "sign of appreciation," which still reflects the purpose of gratification for the healthcare service they received. "As for doctors, nurses and midwives, 11 per cent of bribes were paid as a sign of appreciation, possibly reflecting gratitude for the care received," it stated. The report comes as Nigerians have continued to raise concerns over poor quality health services in the country. With these concerns, a shortage of health workers continues to plague the health system even as practitioners travel abroad to seek better welfare with the "japa syndrome." The NBS report, in collaboration with the United Nations Office on Drugs and Crimes (UNODC), also revealed how Nigerian public officials received nothing less than N721 billion as bribes in 2023'</li></ul> |
90
+ | 0 | <ul><li>'Malta\'s former prime minister charged with corruption over hospital scandal Malta\'s former prime minister Joseph Muscat has been charged with corruption in a hospital privatisation scandal that was once investigated by the murdered investigative journalist Daphne Caruana Galizia. Muscat has been charged with accepting bribes, corruption in public office and money laundering, according to documents seen by AFP. He has described the allegations as "fantasies and lies" and said he was the victim of a political vendetta. Chris Fearne, Malta\'s deputy prime minister, who is tipped to become Malta\'s next European commissioner, and the country\'s former finance minister Edward Scicluna, who is now the governor of Malta\'s central bank, were charged with fraud, misappropriation and fraudulent gain.'</li><li>"US Supreme Court gives pharma companies a chance to thwart terrorism-funding lawsuit 21 pharmaceutical and medical equipment companies, including AstraZeneca, Pfizer, GE Healthcare USA, Johnson & Johnson, and F. Hoffmann-La Roche, are accused of illegally helping to fund terrorism in Iraq by providing corrupt payments to the Hezbollah-sponsored militia group Jaysh al-Mahdi to obtain medical supply contracts from Iraq's health ministry. The lawsuit seeks unspecified damages under the Anti-Terrorism Act."</li><li>'Health Ministry Official Arrested in Procurement Scandal JAKARTA - Indonesian authorities have arrested a high-ranking Health Ministry official on suspicion of corruption in medical equipment procurement. Agus Sutiyo, 52, Director of Medical Supplies, is accused of accepting bribes totaling $1.2 million from suppliers in exchange for awarding inflated contracts. The Corruption Eradication Commission (KPK) alleges that Sutiyo manipulated tender processes, favoring companies that offered kickbacks. The scheme reportedly cost the government an estimated $10 million in overpayments. KPK spokesperson Febri Diansyah stated, "This case undermines public trust and diverts crucial resources from healthcare services." Sutiyo faces up to 20 years in prison if convicted.'</li></ul> |
91
+
92
+ ## Evaluation
93
+
94
+ ### Metrics
95
+ | Label | Accuracy |
96
+ |:--------|:---------|
97
+ | **all** | 0.6667 |
98
+
99
+ ## Uses
100
+
101
+ ### Direct Use for Inference
102
+
103
+ First install the SetFit library:
104
+
105
+ ```bash
106
+ pip install setfit
107
+ ```
108
+
109
+ Then you can load this model and run inference.
110
+
111
+ ```python
112
+ from setfit import SetFitModel
113
+
114
+ # Download from the 🤗 Hub
115
+ model = SetFitModel.from_pretrained("twright8/news_cats_2")
116
+ # Run inference
117
+ preds = model("Global Coffee Prices Surge Amid Brazilian Drought Coffee futures hit a five-year high today as severe drought continues to ravage Brazil's coffee-growing regions. Experts warn consumers may see significant price increases in coming months.")
118
+ ```
119
+
120
+ <!--
121
+ ### Downstream Use
122
+
123
+ *List how someone could finetune this model on their own dataset.*
124
+ -->
125
+
126
+ <!--
127
+ ### Out-of-Scope Use
128
+
129
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
130
+ -->
131
+
132
+ <!--
133
+ ## Bias, Risks and Limitations
134
+
135
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
136
+ -->
137
+
138
+ <!--
139
+ ### Recommendations
140
+
141
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
142
+ -->
143
+
144
+ ## Training Details
145
+
146
+ ### Training Set Metrics
147
+ | Training set | Min | Median | Max |
148
+ |:-------------|:----|:---------|:----|
149
+ | Word count | 55 | 153.8462 | 290 |
150
+
151
+ | Label | Training Sample Count |
152
+ |:------|:----------------------|
153
+ | 0 | 13 |
154
+ | 1 | 13 |
155
+
156
+ ### Training Hyperparameters
157
+ - batch_size: (8, 1)
158
+ - num_epochs: (3, 17)
159
+ - max_steps: -1
160
+ - sampling_strategy: oversampling
161
+ - body_learning_rate: (9.629116538858926e-05, 2.651259436793277e-05)
162
+ - head_learning_rate: 0.02145586669240117
163
+ - loss: CoSENTLoss
164
+ - distance_metric: cosine_distance
165
+ - margin: 0.25
166
+ - end_to_end: True
167
+ - use_amp: True
168
+ - warmup_proportion: 0.1
169
+ - max_length: 512
170
+ - seed: 42
171
+ - eval_max_steps: -1
172
+ - load_best_model_at_end: True
173
+
174
+ ### Training Results
175
+ | Epoch | Step | Training Loss | Validation Loss |
176
+ |:----------:|:-------:|:-------------:|:---------------:|
177
+ | 0.0217 | 1 | 1.6572 | - |
178
+ | 0.4348 | 20 | 0.0 | 21.1162 |
179
+ | 0.8696 | 40 | 0.0 | 17.9189 |
180
+ | 1.3043 | 60 | 0.0 | 14.0343 |
181
+ | 1.7391 | 80 | 0.0 | 13.6029 |
182
+ | 2.1739 | 100 | 0.0 | 13.8074 |
183
+ | **2.6087** | **120** | **0.0** | **13.1309** |
184
+
185
+ * The bold row denotes the saved checkpoint.
186
+ ### Framework Versions
187
+ - Python: 3.10.13
188
+ - SetFit: 1.0.3
189
+ - Sentence Transformers: 3.0.1
190
+ - Transformers: 4.39.0
191
+ - PyTorch: 2.3.0+cu121
192
+ - Datasets: 2.20.0
193
+ - Tokenizers: 0.15.2
194
+
195
+ ## Citation
196
+
197
+ ### BibTeX
198
+ ```bibtex
199
+ @article{https://doi.org/10.48550/arxiv.2209.11055,
200
+ doi = {10.48550/ARXIV.2209.11055},
201
+ url = {https://arxiv.org/abs/2209.11055},
202
+ author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
203
+ keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
204
+ title = {Efficient Few-Shot Learning Without Prompts},
205
+ publisher = {arXiv},
206
+ year = {2022},
207
+ copyright = {Creative Commons Attribution 4.0 International}
208
+ }
209
+ ```
210
+
211
+ <!--
212
+ ## Glossary
213
+
214
+ *Clearly define terms in order to be accessible across audiences.*
215
+ -->
216
+
217
+ <!--
218
+ ## Model Card Authors
219
+
220
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
221
+ -->
222
+
223
+ <!--
224
+ ## Model Card Contact
225
+
226
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
227
+ -->
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "checkpoints/step_120",
3
+ "architectures": [
4
+ "NewModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration.NewConfig",
9
+ "AutoModel": "modeling.NewModel"
10
+ },
11
+ "classifier_dropout": null,
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 4096,
17
+ "layer_norm_eps": 1e-12,
18
+ "layer_norm_type": "layer_norm",
19
+ "logn_attention_clip1": false,
20
+ "logn_attention_scale": false,
21
+ "max_position_embeddings": 8192,
22
+ "model_type": "new",
23
+ "num_attention_heads": 16,
24
+ "num_hidden_layers": 24,
25
+ "pack_qkv": true,
26
+ "pad_token_id": 0,
27
+ "position_embedding_type": "rope",
28
+ "rope_scaling": {
29
+ "factor": 2.0,
30
+ "type": "ntk"
31
+ },
32
+ "rope_theta": 160000,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.39.0",
35
+ "type_vocab_size": 2,
36
+ "unpad_inputs": true,
37
+ "use_memory_efficient_attention": true,
38
+ "vocab_size": 30528
39
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "3.0.1",
4
+ "transformers": "4.39.0",
5
+ "pytorch": "2.3.0+cu121"
6
+ },
7
+ "prompts": {
8
+ "s2p_query": "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: ",
9
+ "s2s_query": "Instruct: Retrieve semantically similar text.\nQuery: "
10
+ },
11
+ "default_prompt_name": null,
12
+ "similarity_fn_name": "cosine"
13
+ }
config_setfit.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "normalize_embeddings": false,
3
+ "labels": null
4
+ }
configuration.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The GTE Team Authors and Alibaba Group.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ NEW model configuration"""
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class NewConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`NewModel`] or a [`TFNewModel`]. It is used to
26
+ instantiate a NEW model according to the specified arguments, defining the model architecture. Instantiating a
27
+ configuration with the defaults will yield a similar configuration to that of the NEW
28
+ [izhx/new-base-en](https://huggingface.co/izhx/new-base-en) architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 30522):
36
+ Vocabulary size of the NEW model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed when calling [`NewModel`] or [`TFNewModel`].
38
+ hidden_size (`int`, *optional*, defaults to 768):
39
+ Dimensionality of the encoder layers and the pooler layer.
40
+ num_hidden_layers (`int`, *optional*, defaults to 12):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 12):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ intermediate_size (`int`, *optional*, defaults to 3072):
45
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
46
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
49
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the attention probabilities.
53
+ max_position_embeddings (`int`, *optional*, defaults to 512):
54
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
55
+ just in case (e.g., 512 or 1024 or 2048).
56
+ type_vocab_size (`int`, *optional*, defaults to 2):
57
+ The vocabulary size of the `token_type_ids` passed when calling [`NewModel`] or [`TFNewModel`].
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
61
+ The epsilon used by the layer normalization layers.
62
+ position_embedding_type (`str`, *optional*, defaults to `"rope"`):
63
+ Type of position embedding. Choose one of `"absolute"`, `"rope"`.
64
+ rope_theta (`float`, *optional*, defaults to 10000.0):
65
+ The base period of the RoPE embeddings.
66
+ rope_scaling (`Dict`, *optional*):
67
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
68
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
69
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
70
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
71
+ these scaling strategies behave:
72
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
73
+ experimental feature, subject to breaking API changes in future versions.
74
+ classifier_dropout (`float`, *optional*):
75
+ The dropout ratio for the classification head.
76
+
77
+ Examples:
78
+
79
+ ```python
80
+ >>> from transformers import NewConfig, NewModel
81
+
82
+ >>> # Initializing a NEW izhx/new-base-en style configuration
83
+ >>> configuration = NewConfig()
84
+
85
+ >>> # Initializing a model (with random weights) from the izhx/new-base-en style configuration
86
+ >>> model = NewModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "new"
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=30528,
97
+ hidden_size=768,
98
+ num_hidden_layers=12,
99
+ num_attention_heads=12,
100
+ intermediate_size=3072,
101
+ hidden_act="gelu",
102
+ hidden_dropout_prob=0.1,
103
+ attention_probs_dropout_prob=0.0,
104
+ max_position_embeddings=2048,
105
+ type_vocab_size=1,
106
+ initializer_range=0.02,
107
+ layer_norm_type='layer_norm',
108
+ layer_norm_eps=1e-12,
109
+ # pad_token_id=0,
110
+ position_embedding_type="rope",
111
+ rope_theta=10000.0,
112
+ rope_scaling=None,
113
+ classifier_dropout=None,
114
+ pack_qkv=True,
115
+ unpad_inputs=False,
116
+ use_memory_efficient_attention=False,
117
+ logn_attention_scale=False,
118
+ logn_attention_clip1=False,
119
+ **kwargs,
120
+ ):
121
+ super().__init__(**kwargs)
122
+
123
+ self.vocab_size = vocab_size
124
+ self.hidden_size = hidden_size
125
+ self.num_hidden_layers = num_hidden_layers
126
+ self.num_attention_heads = num_attention_heads
127
+ self.hidden_act = hidden_act
128
+ self.intermediate_size = intermediate_size
129
+ self.hidden_dropout_prob = hidden_dropout_prob
130
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
131
+ self.max_position_embeddings = max_position_embeddings
132
+ self.type_vocab_size = type_vocab_size
133
+ self.initializer_range = initializer_range
134
+ self.layer_norm_type = layer_norm_type
135
+ self.layer_norm_eps = layer_norm_eps
136
+ self.position_embedding_type = position_embedding_type
137
+ self.rope_theta = rope_theta
138
+ self.rope_scaling = rope_scaling
139
+ self.classifier_dropout = classifier_dropout
140
+
141
+ self.pack_qkv = pack_qkv
142
+ self.unpad_inputs = unpad_inputs
143
+ self.use_memory_efficient_attention = use_memory_efficient_attention
144
+ self.logn_attention_scale = logn_attention_scale
145
+ self.logn_attention_clip1 = logn_attention_clip1
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80b53fabb609ac8718e88fcdacafc1c51212b59267a72655440de6d6cb43de05
3
+ size 1736585680
model_head.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f2baf494c8bb9560ff5c7a3e07ca66a61326efdf9ecfae5d5ac4fed8417c29c
3
+ size 9754
modeling.py ADDED
@@ -0,0 +1,1387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The GTE Team Authors and Alibaba Group.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch NEW model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ BaseModelOutputWithPooling,
29
+ MaskedLMOutput,
30
+ MultipleChoiceModelOutput,
31
+ QuestionAnsweringModelOutput,
32
+ SequenceClassifierOutput,
33
+ TokenClassifierOutput,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging
37
+
38
+ try:
39
+ import xformers.ops as xops
40
+ except ImportError as e:
41
+ xops = None
42
+
43
+ from .configuration import NewConfig
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
50
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
51
+ class IndexFirstAxis(torch.autograd.Function):
52
+ @staticmethod
53
+ def forward(ctx, input, indices):
54
+ ctx.save_for_backward(indices)
55
+ assert input.ndim >= 2
56
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
57
+ second_dim = other_shape.numel()
58
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
59
+ # return input[indices]
60
+ # return torch.gather(
61
+ # rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
62
+ # ).reshape(-1, *other_shape)
63
+ return torch.gather(
64
+ input.view(ctx.first_axis_dim, second_dim),
65
+ 0,
66
+ indices.unsqueeze(-1).expand(indices.size(0), second_dim)
67
+ ).reshape(-1, *other_shape)
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad_output):
71
+ (indices,) = ctx.saved_tensors
72
+ assert grad_output.ndim >= 2
73
+ other_shape = grad_output.shape[1:]
74
+ # grad_output = rearrange(grad_output, "b ... -> b (...)")
75
+ grad_output = grad_output.view(grad_output.size(0), other_shape.numel())
76
+ grad_input = torch.zeros(
77
+ [ctx.first_axis_dim, grad_output.shape[1]],
78
+ device=grad_output.device,
79
+ dtype=grad_output.dtype,
80
+ )
81
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
82
+ # grad_input[indices] = grad_output
83
+ # grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
84
+ grad_input.scatter_(
85
+ 0, indices.unsqueeze(-1).expand(indices.size(0), grad_output.size(1)), grad_output
86
+ )
87
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
88
+
89
+
90
+ index_first_axis = IndexFirstAxis.apply
91
+
92
+
93
+ def unpad_input(hidden_states, attention_mask=None, indices=None):
94
+ """
95
+ Arguments:
96
+ hidden_states: (batch, seqlen, ...)
97
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
98
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
99
+ Return:
100
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
101
+ """
102
+ if indices is None:
103
+ assert attention_mask is not None
104
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
105
+
106
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
107
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
108
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
109
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
110
+ # so we write custom forward and backward to make it a bit faster.
111
+ hidden_states = hidden_states.view(-1, *hidden_states.shape[2:])
112
+ return index_first_axis(hidden_states, indices)
113
+
114
+
115
+ class IndexPutFirstAxis(torch.autograd.Function):
116
+ @staticmethod
117
+ def forward(
118
+ ctx,
119
+ values: torch.Tensor,
120
+ indices: torch.Tensor,
121
+ first_axis_dim
122
+ ) -> torch.Tensor:
123
+ ctx.save_for_backward(indices)
124
+ assert indices.ndim == 1
125
+ assert values.ndim >= 2
126
+ output = torch.zeros(
127
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
128
+ )
129
+ output[indices] = values
130
+ return output
131
+
132
+ @staticmethod
133
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
134
+ indices, = ctx.saved_tensors
135
+ grad_values = grad_output[indices]
136
+ return grad_values, None, None
137
+
138
+
139
+ index_put_first_axis = IndexPutFirstAxis.apply
140
+
141
+
142
+ def pad_input(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
143
+ """Add padding to sequences.
144
+
145
+ Arguments:
146
+ inputs: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
147
+ indices: (total_nnz), `indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()`
148
+ batch: int batch_size
149
+ seqlen: int max sequence length
150
+
151
+ Returns:
152
+ inputs: (batch, seqlen, ...)
153
+ """
154
+ output = index_put_first_axis(inputs, indices, batch * seqlen)
155
+ return output.view(batch, seqlen, *inputs.shape[1:])
156
+
157
+
158
+ def rotate_half(x):
159
+ """Rotates half the hidden dims of the input."""
160
+ x1 = x[..., : x.shape[-1] // 2]
161
+ x2 = x[..., x.shape[-1] // 2 :]
162
+ return torch.cat((-x2, x1), dim=-1)
163
+
164
+
165
+ def apply_rotary_pos_emb(q, k, cos, sin):
166
+ """Applies Rotary Position Embedding to the query and key tensors.
167
+
168
+ Args:
169
+ q (`torch.Tensor`): The query tensor.
170
+ k (`torch.Tensor`): The key tensor.
171
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
172
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
173
+ Returns:
174
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
175
+ """
176
+ cos, sin = cos.to(q.dtype), sin.to(q.dtype)
177
+ q_embed = (q * cos) + (rotate_half(q) * sin)
178
+ k_embed = (k * cos) + (rotate_half(k) * sin)
179
+ return q_embed, k_embed
180
+
181
+
182
+ class RotaryEmbedding(torch.nn.Module):
183
+ def __init__(self, dim, max_position_embeddings=512, base=10000.0, device=None):
184
+ super().__init__()
185
+
186
+ self.dim = dim
187
+ self.max_position_embeddings = max_position_embeddings
188
+ self.base = base
189
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
190
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
191
+
192
+ # Build here to make `torch.jit.trace` work.
193
+ self._set_cos_sin_cache(
194
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
195
+ )
196
+
197
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
198
+ self.max_seq_len_cached = seq_len
199
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
200
+
201
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
202
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
203
+ emb = torch.cat((freqs, freqs), dim=-1)
204
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
205
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
206
+
207
+ def forward(self, x, seq_len=None):
208
+ # x: [bs, num_attention_heads, seq_len, head_size]
209
+ if seq_len > self.max_seq_len_cached:
210
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
211
+
212
+ return (
213
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
214
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
215
+ )
216
+
217
+
218
+ class NTKScalingRotaryEmbedding(RotaryEmbedding):
219
+ """RotaryEmbedding extended with fixed and mixed NTK scaling. https://kexue.fm/archives/9706 """
220
+
221
+ def __init__(self, dim, max_position_embeddings=512, base=10000, device=None, scaling_factor=1.0, mixed_b=None):
222
+ self.scaling_factor = scaling_factor
223
+ self.mixed_b = mixed_b
224
+ super().__init__(dim, max_position_embeddings, base, device)
225
+ max_position_embeddings = max_position_embeddings * self.scaling_factor
226
+ self._set_cos_sin_cache(max_position_embeddings, self.inv_freq.device, torch.get_default_dtype())
227
+
228
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
229
+ self.max_seq_len_cached = seq_len
230
+
231
+ if seq_len > self.max_position_embeddings:
232
+ base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
233
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
234
+
235
+ if self.mixed_b is None:
236
+ inv_freq = inv_freq / self.scaling_factor ** (2 / self.dim) # (6)
237
+ else:
238
+ a = torch.tensor(self.scaling_factor).log() / (self.dim / 2) ** self.mixed_b # (13)
239
+ lambda_1_m = (a * torch.arange(1, self.dim // 2 + 1).float().to(device) ** self.mixed_b).exp() # (12)
240
+ inv_freq = inv_freq / lambda_1_m # (10)
241
+
242
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
243
+
244
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
245
+
246
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
247
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
248
+ emb = torch.cat((freqs, freqs), dim=-1)
249
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
250
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
251
+
252
+
253
+ class RMSNorm(nn.Module):
254
+ def __init__(self, hidden_size, eps=1e-6):
255
+ """
256
+ RMSNorm is equivalent to T5LayerNorm
257
+ """
258
+ super().__init__()
259
+ self.weight = nn.Parameter(torch.ones(hidden_size))
260
+ self.variance_epsilon = eps
261
+
262
+ def forward(self, hidden_states):
263
+ input_dtype = hidden_states.dtype
264
+ hidden_states = hidden_states.to(torch.float32)
265
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
266
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
267
+ return self.weight * hidden_states.to(input_dtype)
268
+
269
+
270
+ LAYER_NORM = {
271
+ 'layer_norm': nn.LayerNorm,
272
+ 'rms_norm': RMSNorm
273
+ }
274
+
275
+
276
+ class NewEmbeddings(nn.Module):
277
+ """
278
+ Embedding and Unpadding.
279
+ """
280
+
281
+ def __init__(self, config: NewConfig):
282
+ super().__init__()
283
+ self.padding_idx = config.pad_token_id
284
+ self.word_embeddings = nn.Embedding(
285
+ config.vocab_size, config.hidden_size, padding_idx=self.padding_idx
286
+ )
287
+
288
+ self.position_embedding_type = config.position_embedding_type
289
+ if self.position_embedding_type == 'absolute':
290
+ self.position_embeddings = nn.Embedding(
291
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
292
+ )
293
+ elif self.position_embedding_type == 'rope':
294
+ self._init_rope(config)
295
+ else:
296
+ raise ValueError
297
+
298
+ self.type_vocab_size = config.type_vocab_size
299
+ if self.type_vocab_size > 0:
300
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
301
+
302
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
303
+ # any TensorFlow checkpoint file
304
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
305
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
306
+ # position_ids is contiguous in memory and excluded when serialized
307
+ self.register_buffer(
308
+ "position_ids", torch.arange(config.max_position_embeddings), persistent=False
309
+ )
310
+
311
+ def _init_rope(self, config):
312
+ kwargs = dict(
313
+ dim=int(config.hidden_size / config.num_attention_heads),
314
+ max_position_embeddings=config.max_position_embeddings,
315
+ base=config.rope_theta
316
+ )
317
+ if config.rope_scaling is None:
318
+ self.rotary_emb = RotaryEmbedding(**kwargs)
319
+ else:
320
+ kwargs.update(scaling_factor=config.rope_scaling["factor"])
321
+ scaling_type = config.rope_scaling["type"]
322
+ if scaling_type == 'ntk':
323
+ kwargs.update(mixed_b=config.rope_scaling.get('mixed_b', None))
324
+ self.rotary_emb = NTKScalingRotaryEmbedding(**kwargs)
325
+ # elif scaling_type == "linear":
326
+ # self.rotary_emb = LinearScalingRotaryEmbedding(**kwargs)
327
+ # elif scaling_type == "dynamic":
328
+ # self.rotary_emb = DynamicNTKScalingRotaryEmbedding(**kwargs)
329
+ else:
330
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
331
+
332
+ def forward(
333
+ self,
334
+ unpad_inputs: bool,
335
+ input_ids: Optional[torch.Tensor] = None,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ length: Optional[List[int]] = None,
338
+ token_type_ids: Optional[torch.Tensor] = None,
339
+ position_ids: Optional[torch.Tensor] = None,
340
+ inputs_embeds: Optional[torch.Tensor] = None,
341
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple], Optional[List[int]]]:
342
+ """
343
+ """
344
+ if inputs_embeds is None:
345
+ device, input_shape = input_ids.device, input_ids.shape
346
+ else:
347
+ device, input_shape = inputs_embeds.device, inputs_embeds.shape[:2]
348
+ batch_size, seq_length = input_shape
349
+
350
+ # Set attention_mask if it's None
351
+ if attention_mask is None:
352
+ attention_mask = torch.ones(input_shape, device=device)
353
+ if length is not None:
354
+ for i, l in enumerate(length):
355
+ attention_mask[i, l:] = 0
356
+
357
+ # Set attention_mask_bool for unpadding
358
+ if unpad_inputs:
359
+ attention_mask_bool = attention_mask.bool()
360
+ if length is None:
361
+ length = attention_mask.sum(-1).tolist()
362
+
363
+ # Get word embeddings
364
+ if inputs_embeds is None:
365
+ if unpad_inputs:
366
+ input_ids = input_ids[attention_mask_bool].unsqueeze(0)
367
+ inputs_embeds = self.word_embeddings(input_ids)
368
+ else:
369
+ if unpad_inputs:
370
+ inputs_embeds = inputs_embeds[attention_mask_bool].unsqueeze(0)
371
+ embeddings = inputs_embeds
372
+
373
+ # Set and unpad position_ids
374
+ if position_ids is None:
375
+ if seq_length > self.position_ids.size(0):
376
+ self.register_buffer(
377
+ "position_ids", torch.arange(seq_length), persistent=False
378
+ )
379
+ if unpad_inputs:
380
+ # [1, cumsum_seq_len]
381
+ position_ids = torch.cat([self.position_ids[:l] for l in length]).unsqueeze(0)
382
+ else:
383
+ # [bs, seq_len]
384
+ position_ids = self.position_ids[:seq_length].expand(batch_size, -1)
385
+ elif unpad_inputs:
386
+ position_ids = position_ids[attention_mask_bool].unsqueeze(0) # [1, cumsum_seq_len]
387
+
388
+ # Compute rotary embedding
389
+ if self.position_embedding_type == 'rope':
390
+ rope_cos, rope_sin = self.rotary_emb(inputs_embeds, seq_len=seq_length)
391
+ rope_cos = rope_cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
392
+ rope_sin = rope_sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
393
+ rope_embeds = rope_cos, rope_sin
394
+ else:
395
+ rope_embeds = None
396
+
397
+ if self.type_vocab_size > 0:
398
+ if token_type_ids is None:
399
+ token_type_ids = position_ids.mul(0)
400
+ elif unpad_inputs:
401
+ token_type_ids = token_type_ids[attention_mask_bool].unsqueeze(0)
402
+
403
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
404
+ embeddings += token_type_embeddings
405
+
406
+ # BERT position
407
+ if self.position_embedding_type == "absolute":
408
+ position_embeddings = self.position_embeddings(position_ids)
409
+ embeddings += position_embeddings
410
+
411
+ embeddings = self.LayerNorm(embeddings)
412
+ embeddings = self.dropout(embeddings)
413
+
414
+ return embeddings, attention_mask, rope_embeds, length
415
+
416
+
417
+ class NewAttention(nn.Module):
418
+ def __init__(self, config: NewConfig, pack_qkv=None, use_memory_efficient_attention=None):
419
+ super().__init__()
420
+ self.config = config
421
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
422
+ raise ValueError(
423
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
424
+ f"heads ({config.num_attention_heads})"
425
+ )
426
+
427
+ self.hidden_size = config.hidden_size
428
+ self.num_attention_heads = config.num_attention_heads
429
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
430
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
431
+
432
+ if pack_qkv is None:
433
+ pack_qkv = config.pack_qkv
434
+ self.pack_qkv = pack_qkv
435
+
436
+ if self.pack_qkv:
437
+ self.qkv_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=True)
438
+ else:
439
+ self.q_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
440
+ self.k_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
441
+ self.v_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
442
+
443
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
444
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
445
+
446
+ if use_memory_efficient_attention is None:
447
+ use_memory_efficient_attention = self.config.use_memory_efficient_attention
448
+ self.use_memory_efficient_attention = use_memory_efficient_attention
449
+ self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
450
+ if self.use_memory_efficient_attention:
451
+ assert self.memory_efficient_attention is not None, 'please install xformers'
452
+ if self.config.unpad_inputs:
453
+ assert self.config.use_memory_efficient_attention, 'unpad only with xformers'
454
+
455
+ def forward(
456
+ self,
457
+ hidden_states: torch.Tensor,
458
+ attention_bias: torch.FloatTensor,
459
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
460
+ attention_scale: Optional[torch.FloatTensor] = None,
461
+ head_mask: Optional[torch.FloatTensor] = None,
462
+ output_attentions: Optional[bool] = False,
463
+ qkv_inputs: Optional[Tuple] = None, # For RetroMAE
464
+ padding_inputs: Optional[Tuple] = None, # indices, batch, seqlen
465
+ ) -> Tuple[torch.Tensor, ...]:
466
+ shape_hd = (self.num_attention_heads, self.attention_head_size)
467
+ # qkv
468
+ if self.pack_qkv and qkv_inputs is None:
469
+ qkv_pack = self.qkv_proj(hidden_states).split(self.all_head_size, dim=-1)
470
+ else:
471
+ if qkv_inputs is None:
472
+ qkv_inputs = (hidden_states, hidden_states, hidden_states)
473
+ qkv_pack = [
474
+ getattr(self, n + '_proj')(s) for s, n in zip(qkv_inputs, 'qkv')
475
+ ]
476
+ query_states, key_states, value_states = [t.view(t.shape[:-1] + shape_hd) for t in qkv_pack]
477
+
478
+ if self.config.position_embedding_type == 'rope':
479
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, *rope_embeds)
480
+
481
+ dtype = query_states.dtype
482
+
483
+ if self.config.logn_attention_scale and attention_scale is not None:
484
+ # https://kexue.fm/archives/8823
485
+ query_states = query_states * attention_scale.to(dtype)
486
+
487
+ if padding_inputs is not None:
488
+ query_states = pad_input(query_states.squeeze(), *padding_inputs)
489
+ key_states = pad_input(key_states.squeeze(), *padding_inputs)
490
+ value_states = pad_input(value_states.squeeze(), *padding_inputs)
491
+
492
+ if self.use_memory_efficient_attention:
493
+ assert self.memory_efficient_attention is not None, "xformers is not loaded"
494
+ assert output_attentions is False, "memory_efficient_attention do not output attentions"
495
+ assert head_mask is None, "Not support yet"
496
+ attention_probs = None
497
+ if torch.is_tensor(attention_bias):
498
+ attention_bias = attention_bias.to(dtype)
499
+ context_layer = self.memory_efficient_attention(
500
+ query_states,
501
+ key_states,
502
+ value_states,
503
+ attn_bias=attention_bias,
504
+ p=self.dropout.p
505
+ )
506
+ else:
507
+ context_layer = self._attention(query_states, key_states, value_states, attention_bias, head_mask)
508
+
509
+ if padding_inputs is not None:
510
+ context_layer = unpad_input(context_layer, indices=padding_inputs[0])
511
+
512
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
513
+ context_layer = context_layer.view(new_context_layer_shape)
514
+
515
+ # output proj
516
+ attn_output = self.o_proj(context_layer)
517
+
518
+ # add attentions if we output them
519
+ outputs = (attn_output, attention_probs) if output_attentions else (attn_output,)
520
+ return outputs
521
+
522
+ def _attention(self, query_states, key_states, value_states, attention_bias, head_mask):
523
+ """
524
+ Args:
525
+ q/k/v: (B, L, n_head, head_dim),
526
+ Returns:
527
+ attn_output: (B L, n_head, head_dim)
528
+ """
529
+ query_states = query_states.transpose(1, 2)
530
+ key_states = key_states.transpose(1, 2)
531
+ value_states = value_states.transpose(1, 2)
532
+ # Take the dot product between "query" and "key" to get the raw attention scores.
533
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
534
+
535
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
536
+ if attention_bias is not None:
537
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
538
+ attention_scores = attention_scores + attention_bias
539
+
540
+ # Normalize the attention scores to probabilities.
541
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
542
+
543
+ # This is actually dropping out entire tokens to attend to, which might
544
+ # seem a bit unusual, but is taken from the original Transformer paper.
545
+ attention_probs = self.dropout(attention_probs)
546
+
547
+ # Mask heads if we want to
548
+ if head_mask is not None:
549
+ attention_probs = attention_probs * head_mask
550
+
551
+ context_layer = torch.matmul(attention_probs, value_states)
552
+
553
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
554
+ return context_layer
555
+
556
+
557
+ class NewSdpaAttention(NewAttention):
558
+ """
559
+ New attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
560
+ `NewAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
561
+ SDPA API.
562
+ """
563
+ def __init__(self, config: NewConfig, **kwargs):
564
+ super().__init__(config, **kwargs)
565
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
566
+ logger.warning(
567
+ "Disable memory efficient attention kernel for `NewSdpaAttention`, you can set "
568
+ "`use_memory_efficient_attention=True` if it expected to use."
569
+ )
570
+
571
+ def _attention(self, query_states, key_states, value_states, attention_bias, head_mask):
572
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
573
+ query_states.transpose(1, 2),
574
+ key_states.transpose(1, 2),
575
+ value_states.transpose(1, 2),
576
+ attn_mask=attention_bias,
577
+ dropout_p=self.dropout.p if self.training else 0.0,
578
+ )
579
+ attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
580
+ return attn_output
581
+
582
+
583
+ NEW_ATTENTION_CLASSES = {
584
+ "eager": NewAttention,
585
+ # "flash_attention_2": , # TODO: xformers will dispatch to flash_attn
586
+ "sdpa": NewSdpaAttention,
587
+ }
588
+
589
+
590
+ class NewGatedMLP(nn.Module):
591
+ """
592
+ GLU Variants Improve Transformer.
593
+ """
594
+
595
+ def __init__(self, config: NewConfig):
596
+ super().__init__()
597
+ self.intermediate_size = config.intermediate_size
598
+ self.up_gate_proj = nn.Linear(config.hidden_size, self.intermediate_size * 2, bias=False)
599
+ self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=True)
600
+ self.act_fn = ACT2FN[config.hidden_act]
601
+ if config.hidden_dropout_prob > 0:
602
+ self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
603
+ else:
604
+ self.hidden_dropout = None
605
+
606
+ def forward(self, hidden_states):
607
+ up_gate = self.up_gate_proj(hidden_states)
608
+ up_states, gate = torch.split(up_gate, self.intermediate_size, dim=-1)
609
+ gate = self.act_fn(gate)
610
+ gated_states = gate * up_states
611
+ if self.hidden_dropout is not None:
612
+ gated_states = self.hidden_dropout(gated_states)
613
+ down_states = self.down_proj(gated_states)
614
+ return down_states
615
+
616
+
617
+ class NewLayer(nn.Module):
618
+ def __init__(
619
+ self,
620
+ config: NewConfig,
621
+ pack_qkv=None,
622
+ use_memory_efficient_attention=None,
623
+ attn_implementation=None
624
+ ):
625
+ super().__init__()
626
+ if attn_implementation is None:
627
+ attn_implementation = config._attn_implementation
628
+ if attn_implementation != 'eager':
629
+ use_memory_efficient_attention = False
630
+ self.attention = NEW_ATTENTION_CLASSES[attn_implementation](
631
+ config, pack_qkv=pack_qkv, use_memory_efficient_attention=use_memory_efficient_attention
632
+ )
633
+ self.mlp = NewGatedMLP(config)
634
+
635
+ ln_class = LAYER_NORM[config.layer_norm_type]
636
+ self.attn_ln = ln_class(config.hidden_size, eps=config.layer_norm_eps)
637
+ self.mlp_ln = ln_class(config.hidden_size, eps=config.layer_norm_eps)
638
+
639
+ if config.hidden_dropout_prob > 0:
640
+ self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
641
+ else:
642
+ self.hidden_dropout = None
643
+
644
+ def forward(
645
+ self,
646
+ hidden_states: torch.Tensor,
647
+ attention_bias: torch.FloatTensor,
648
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
649
+ attention_scale: Optional[torch.FloatTensor] = None,
650
+ subset_indices: Optional[torch.LongTensor] = None,
651
+ head_mask: Optional[torch.FloatTensor] = None,
652
+ output_attentions: Optional[bool] = False,
653
+ qkv_inputs: Optional[Tuple] = None, # For RetroMAE
654
+ padding_inputs: Optional[Tuple] = None,
655
+ ) -> Tuple[torch.Tensor, ...]:
656
+ # Multi head self attention
657
+ residual = hidden_states if qkv_inputs is None else qkv_inputs[0]
658
+ attention_outputs = self.attention(
659
+ hidden_states,
660
+ attention_bias,
661
+ rope_embeds,
662
+ attention_scale,
663
+ head_mask,
664
+ output_attentions=output_attentions,
665
+ qkv_inputs=qkv_inputs,
666
+ padding_inputs=padding_inputs,
667
+ )
668
+ hidden_states = attention_outputs[0]
669
+ if self.hidden_dropout is not None:
670
+ hidden_states = self.hidden_dropout(hidden_states)
671
+ hidden_states = residual + hidden_states
672
+
673
+ # In pretraining, after the attention of last layer, we only need the masked tokens.
674
+ if subset_indices is not None:
675
+ hidden_states = hidden_states[subset_indices]
676
+
677
+ hidden_states = self.attn_ln(hidden_states)
678
+
679
+ # Fully Connected
680
+ residual = hidden_states
681
+ hidden_states = self.mlp(hidden_states)
682
+ if self.hidden_dropout is not None:
683
+ hidden_states = self.hidden_dropout(hidden_states)
684
+ hidden_states = residual + hidden_states
685
+ hidden_states = self.mlp_ln(hidden_states)
686
+
687
+ # add self attentions if we output attention weights
688
+ outputs = (hidden_states,) + attention_outputs[1:]
689
+ return outputs
690
+
691
+
692
+ class NewEncoder(nn.Module):
693
+ def __init__(self, config):
694
+ super().__init__()
695
+ self.config = config
696
+ self.layer = nn.ModuleList([NewLayer(config) for _ in range(config.num_hidden_layers)])
697
+ self.gradient_checkpointing = False
698
+
699
+ def forward(
700
+ self,
701
+ hidden_states: torch.Tensor,
702
+ attention_bias: Optional[torch.FloatTensor] = None,
703
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
704
+ attention_scale: Optional[torch.FloatTensor] = None,
705
+ subset_indices: Optional[torch.LongTensor] = None,
706
+ head_mask: Optional[torch.FloatTensor] = None,
707
+ output_attentions: Optional[bool] = False,
708
+ output_hidden_states: Optional[bool] = False,
709
+ return_dict: Optional[bool] = True,
710
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
711
+ all_hidden_states = () if output_hidden_states else None
712
+ all_self_attentions = () if output_attentions else None
713
+
714
+ for i, layer_module in enumerate(self.layer):
715
+ if output_hidden_states:
716
+ all_hidden_states = all_hidden_states + (hidden_states,)
717
+
718
+ if i >= len(self.layer) - 1:
719
+ layer_subset_indices = subset_indices
720
+ else:
721
+ layer_subset_indices = None
722
+
723
+ layer_head_mask = head_mask[i] if head_mask is not None else None
724
+
725
+ if self.gradient_checkpointing and self.training:
726
+ layer_outputs = self._gradient_checkpointing_func(
727
+ layer_module.__call__,
728
+ hidden_states,
729
+ attention_bias,
730
+ rope_embeds,
731
+ attention_scale,
732
+ layer_subset_indices,
733
+ layer_head_mask,
734
+ )
735
+ else:
736
+ layer_outputs = layer_module(
737
+ hidden_states,
738
+ attention_bias,
739
+ rope_embeds,
740
+ attention_scale,
741
+ layer_subset_indices,
742
+ layer_head_mask,
743
+ output_attentions,
744
+ )
745
+
746
+ hidden_states = layer_outputs[0]
747
+ if output_attentions:
748
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
749
+
750
+ if output_hidden_states:
751
+ all_hidden_states = all_hidden_states + (hidden_states,)
752
+
753
+ if not return_dict:
754
+ return tuple(
755
+ v
756
+ for v in [
757
+ hidden_states,
758
+ all_hidden_states,
759
+ all_self_attentions,
760
+ ]
761
+ if v is not None
762
+ )
763
+ return BaseModelOutput(
764
+ last_hidden_state=hidden_states,
765
+ hidden_states=all_hidden_states,
766
+ attentions=all_self_attentions,
767
+ )
768
+
769
+
770
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->New
771
+ class NewPooler(nn.Module):
772
+ def __init__(self, config):
773
+ super().__init__()
774
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
775
+ self.activation = nn.Tanh()
776
+
777
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
778
+ # We "pool" the model by simply taking the hidden state corresponding
779
+ # to the first token.
780
+ first_token_tensor = hidden_states[:, 0]
781
+ pooled_output = self.dense(first_token_tensor)
782
+ pooled_output = self.activation(pooled_output)
783
+ return pooled_output
784
+
785
+
786
+ class NewPreTrainedModel(PreTrainedModel):
787
+ """
788
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
789
+ models.
790
+ """
791
+
792
+ config_class = NewConfig
793
+ base_model_prefix = "new"
794
+ supports_gradient_checkpointing = True
795
+
796
+ def _init_weights(self, module):
797
+ """Initialize the weights"""
798
+ if isinstance(module, nn.Linear):
799
+ # Slightly different from the TF version which uses truncated_normal for initialization
800
+ # cf https://github.com/pytorch/pytorch/pull/5617
801
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
802
+ if module.bias is not None:
803
+ module.bias.data.zero_()
804
+ elif isinstance(module, nn.Embedding):
805
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
806
+ if module.padding_idx is not None:
807
+ module.weight.data[module.padding_idx].zero_()
808
+ elif isinstance(module, nn.LayerNorm):
809
+ module.bias.data.zero_()
810
+ module.weight.data.fill_(1.0)
811
+
812
+
813
+ class NewModel(NewPreTrainedModel):
814
+ """
815
+ The bare New Model transformer outputting raw hidden-states without any specific head on top.
816
+ """
817
+
818
+ def __init__(self, config: NewConfig, add_pooling_layer=False):
819
+ super().__init__(config)
820
+ self.config = config
821
+
822
+ self.embeddings = NewEmbeddings(config)
823
+ self.encoder = NewEncoder(config)
824
+
825
+ self.pooler = NewPooler(config) if add_pooling_layer else None
826
+
827
+ # Initialize weights and apply final processing
828
+ self.post_init()
829
+
830
+ def get_input_embeddings(self):
831
+ return self.embeddings.word_embeddings
832
+
833
+ def set_input_embeddings(self, value):
834
+ self.embeddings.word_embeddings = value
835
+
836
+ def forward(
837
+ self,
838
+ input_ids: Optional[torch.Tensor] = None,
839
+ attention_mask: Optional[torch.Tensor] = None,
840
+ length: Optional[List[int]] = None,
841
+ subset_indices: Optional[torch.LongTensor] = None,
842
+ token_type_ids: Optional[torch.Tensor] = None,
843
+ position_ids: Optional[torch.Tensor] = None,
844
+ head_mask: Optional[torch.Tensor] = None,
845
+ inputs_embeds: Optional[torch.Tensor] = None,
846
+ output_attentions: Optional[bool] = None,
847
+ output_hidden_states: Optional[bool] = None,
848
+ return_dict: Optional[bool] = None,
849
+ unpad_inputs: Optional[bool] = None,
850
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
851
+ r"""
852
+ length (`list` of length `batch_size`, *optional*):
853
+ If is `None`, return padded `last_hidden_state`.
854
+ subset_indices ():
855
+ pass
856
+ unpad_inputs (`bool`, *optional*):
857
+ pass
858
+ """
859
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
860
+ output_hidden_states = (
861
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
862
+ )
863
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
864
+ unpad_inputs = unpad_inputs if unpad_inputs is not None else self.config.unpad_inputs
865
+ output_padded = length is None
866
+
867
+ if input_ids is not None and inputs_embeds is not None:
868
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
869
+ elif input_ids is not None:
870
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
871
+ input_shape = input_ids.size()
872
+ elif inputs_embeds is not None:
873
+ input_shape = inputs_embeds.size()[:-1]
874
+ else:
875
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
876
+
877
+ # TODO: not used
878
+ # # Prepare head mask if needed
879
+ # # 1.0 in head_mask indicate we keep the head
880
+ # # attention_probs has shape bsz x n_heads x N x N
881
+ # # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
882
+ # # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
883
+ # head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
884
+
885
+ # Get embeddings, may unpad them
886
+ (embedding_output, attention_mask, rope_embeds, length) = self.embeddings(
887
+ unpad_inputs,
888
+ input_ids=input_ids,
889
+ attention_mask=attention_mask,
890
+ length=length,
891
+ token_type_ids=token_type_ids,
892
+ position_ids=position_ids,
893
+ inputs_embeds=inputs_embeds
894
+ )
895
+
896
+ batch_size, seq_length = input_shape
897
+
898
+ if unpad_inputs:
899
+ assert self.config.use_memory_efficient_attention
900
+ attention_bias = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(length)
901
+ else:
902
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
903
+ # ourselves in which case we just need to make it broadcastable to all heads.
904
+ attention_bias = self.get_extended_attention_mask(attention_mask, input_shape)
905
+ if self.config.use_memory_efficient_attention:
906
+ # Invalid shape for attention bias: torch.Size([48, 1, 1, 512]) (expected (48, 12, 512, 512))
907
+ attention_bias = attention_bias.expand(-1, self.config.num_attention_heads, seq_length, -1)
908
+
909
+ if self.config.logn_attention_scale:
910
+ # attention scale log_512(input_len)
911
+ attention_scale = attention_mask.sum(1).log() / torch.tensor(self.config.max_position_embeddings).log()
912
+ # inference-time logn scale need clip 1
913
+ if self.config.logn_attention_clip1:
914
+ attention_scale.clip_(1)
915
+ attention_scale = attention_scale[:, None, None, None]
916
+ else:
917
+ attention_scale = None
918
+
919
+ encoder_outputs = self.encoder(
920
+ embedding_output,
921
+ attention_bias=attention_bias,
922
+ rope_embeds=rope_embeds,
923
+ attention_scale=attention_scale,
924
+ subset_indices=subset_indices,
925
+ head_mask=head_mask,
926
+ output_attentions=output_attentions,
927
+ output_hidden_states=output_hidden_states,
928
+ return_dict=return_dict,
929
+ )
930
+ sequence_output = encoder_outputs[0]
931
+ if unpad_inputs and output_padded:
932
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
933
+ sequence_output = pad_input(
934
+ sequence_output.squeeze(), indices, batch_size, seq_length
935
+ )
936
+
937
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
938
+
939
+ if not return_dict:
940
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
941
+
942
+ return BaseModelOutputWithPooling(
943
+ last_hidden_state=sequence_output,
944
+ pooler_output=pooled_output,
945
+ hidden_states=encoder_outputs.hidden_states,
946
+ attentions=encoder_outputs.attentions,
947
+ )
948
+
949
+
950
+ class NewLMPredictionHead(nn.Module):
951
+ def __init__(self, config):
952
+ super().__init__()
953
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
954
+ self.transform_act_fn = ACT2FN[config.hidden_act]
955
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
956
+
957
+ # The output weights are the same as the input embeddings, but there is
958
+ # an output-only bias for each token.
959
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
960
+
961
+ def forward(self, hidden_states):
962
+ hidden_states = self.dense(hidden_states)
963
+ hidden_states = self.transform_act_fn(hidden_states)
964
+ hidden_states = self.norm(hidden_states)
965
+ hidden_states = self.decoder(hidden_states)
966
+ return hidden_states
967
+
968
+
969
+ class NewForMaskedLM(NewPreTrainedModel):
970
+ _tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"]
971
+
972
+ def __init__(self, config: NewConfig):
973
+ super().__init__(config)
974
+ self.new = NewModel(config, add_pooling_layer=False)
975
+ self.lm_head = NewLMPredictionHead(config)
976
+ self.loss_fct = nn.CrossEntropyLoss()
977
+
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.lm_head.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.lm_head.decoder = new_embeddings
986
+
987
+ def forward(
988
+ self,
989
+ input_ids: Optional[torch.Tensor] = None,
990
+ attention_mask: Optional[torch.Tensor] = None,
991
+ token_type_ids: Optional[torch.Tensor] = None,
992
+ position_ids: Optional[torch.Tensor] = None,
993
+ head_mask: Optional[torch.Tensor] = None,
994
+ inputs_embeds: Optional[torch.Tensor] = None,
995
+ labels: Optional[torch.Tensor] = None,
996
+ output_attentions: Optional[bool] = None,
997
+ output_hidden_states: Optional[bool] = None,
998
+ return_dict: Optional[bool] = None,
999
+ unpad_inputs: Optional[bool] = None,
1000
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1001
+ r"""
1002
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1003
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1004
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1005
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1006
+ """
1007
+
1008
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1009
+
1010
+ if labels is None or not self.new.config.unpad_inputs:
1011
+ length = None
1012
+ subset_indices = None
1013
+ else:
1014
+ length = attention_mask.sum(-1).tolist()
1015
+ labels = labels[attention_mask.bool()].unsqueeze(0)
1016
+ subset_indices = labels > -100
1017
+
1018
+ outputs = self.new(
1019
+ input_ids,
1020
+ attention_mask=attention_mask,
1021
+ length=length,
1022
+ subset_indices=subset_indices,
1023
+ token_type_ids=token_type_ids,
1024
+ position_ids=position_ids,
1025
+ head_mask=head_mask,
1026
+ inputs_embeds=inputs_embeds,
1027
+ output_attentions=output_attentions,
1028
+ output_hidden_states=output_hidden_states,
1029
+ return_dict=return_dict,
1030
+ unpad_inputs=unpad_inputs,
1031
+ )
1032
+
1033
+ sequence_output = outputs[0]
1034
+ prediction_scores = self.lm_head(sequence_output)
1035
+
1036
+ masked_lm_loss = None
1037
+ if labels is not None:
1038
+ if subset_indices is None:
1039
+ mask = attention_mask.bool()
1040
+ prediction_scores = prediction_scores[mask]
1041
+ labels = labels[mask]
1042
+ else:
1043
+ labels = labels[subset_indices]
1044
+ masked_lm_loss = self.loss_fct(prediction_scores, labels)
1045
+
1046
+ if not return_dict:
1047
+ output = (prediction_scores,) + outputs[2:]
1048
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1049
+
1050
+ return MaskedLMOutput(
1051
+ loss=masked_lm_loss,
1052
+ logits=prediction_scores,
1053
+ hidden_states=outputs.hidden_states,
1054
+ attentions=outputs.attentions,
1055
+ )
1056
+
1057
+
1058
+ class NewForSequenceClassification(NewPreTrainedModel):
1059
+ def __init__(self, config):
1060
+ super().__init__(config)
1061
+ self.num_labels = config.num_labels
1062
+ self.config = config
1063
+
1064
+ self.new = NewModel(config, add_pooling_layer=True)
1065
+ classifier_dropout = (
1066
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1067
+ )
1068
+ self.dropout = nn.Dropout(classifier_dropout)
1069
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1070
+
1071
+ # Initialize weights and apply final processing
1072
+ self.post_init()
1073
+
1074
+ def forward(
1075
+ self,
1076
+ input_ids: Optional[torch.Tensor] = None,
1077
+ attention_mask: Optional[torch.Tensor] = None,
1078
+ token_type_ids: Optional[torch.Tensor] = None,
1079
+ position_ids: Optional[torch.Tensor] = None,
1080
+ head_mask: Optional[torch.Tensor] = None,
1081
+ inputs_embeds: Optional[torch.Tensor] = None,
1082
+ labels: Optional[torch.Tensor] = None,
1083
+ output_attentions: Optional[bool] = None,
1084
+ output_hidden_states: Optional[bool] = None,
1085
+ return_dict: Optional[bool] = None,
1086
+ unpad_inputs: Optional[bool] = None,
1087
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1088
+ r"""
1089
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1090
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1091
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1092
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1093
+ """
1094
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1095
+
1096
+ outputs = self.new(
1097
+ input_ids,
1098
+ attention_mask=attention_mask,
1099
+ token_type_ids=token_type_ids,
1100
+ position_ids=position_ids,
1101
+ head_mask=head_mask,
1102
+ inputs_embeds=inputs_embeds,
1103
+ output_attentions=output_attentions,
1104
+ output_hidden_states=output_hidden_states,
1105
+ return_dict=return_dict,
1106
+ unpad_inputs=unpad_inputs,
1107
+ )
1108
+
1109
+ pooled_output = outputs[1]
1110
+
1111
+ pooled_output = self.dropout(pooled_output)
1112
+ logits = self.classifier(pooled_output)
1113
+
1114
+ loss = None
1115
+ if labels is not None:
1116
+ if self.config.problem_type is None:
1117
+ if self.num_labels == 1:
1118
+ self.config.problem_type = "regression"
1119
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1120
+ self.config.problem_type = "single_label_classification"
1121
+ else:
1122
+ self.config.problem_type = "multi_label_classification"
1123
+
1124
+ if self.config.problem_type == "regression":
1125
+ loss_fct = nn.MSELoss()
1126
+ if self.num_labels == 1:
1127
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1128
+ else:
1129
+ loss = loss_fct(logits, labels)
1130
+ elif self.config.problem_type == "single_label_classification":
1131
+ loss_fct = nn.CrossEntropyLoss()
1132
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1133
+ elif self.config.problem_type == "multi_label_classification":
1134
+ loss_fct = nn.BCEWithLogitsLoss()
1135
+ loss = loss_fct(logits, labels)
1136
+
1137
+ if not return_dict:
1138
+ output = (logits,) + outputs[2:]
1139
+ return ((loss,) + output) if loss is not None else output
1140
+
1141
+ return SequenceClassifierOutput(
1142
+ loss=loss,
1143
+ logits=logits,
1144
+ hidden_states=outputs.hidden_states,
1145
+ attentions=outputs.attentions,
1146
+ )
1147
+
1148
+
1149
+ class NewForMultipleChoice(NewPreTrainedModel):
1150
+ def __init__(self, config):
1151
+ super().__init__(config)
1152
+
1153
+ self.new = NewModel(config, add_pooling_layer=True)
1154
+ classifier_dropout = (
1155
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1156
+ )
1157
+ self.dropout = nn.Dropout(classifier_dropout)
1158
+ self.classifier = nn.Linear(config.hidden_size, 1)
1159
+
1160
+ # Initialize weights and apply final processing
1161
+ self.post_init()
1162
+
1163
+ def forward(
1164
+ self,
1165
+ input_ids: Optional[torch.Tensor] = None,
1166
+ attention_mask: Optional[torch.Tensor] = None,
1167
+ token_type_ids: Optional[torch.Tensor] = None,
1168
+ position_ids: Optional[torch.Tensor] = None,
1169
+ head_mask: Optional[torch.Tensor] = None,
1170
+ inputs_embeds: Optional[torch.Tensor] = None,
1171
+ labels: Optional[torch.Tensor] = None,
1172
+ output_attentions: Optional[bool] = None,
1173
+ output_hidden_states: Optional[bool] = None,
1174
+ return_dict: Optional[bool] = None,
1175
+ unpad_inputs: Optional[bool] = None,
1176
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1177
+ r"""
1178
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1179
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1180
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1181
+ `input_ids` above)
1182
+ """
1183
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1184
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1185
+
1186
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1187
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1188
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1189
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1190
+ inputs_embeds = (
1191
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1192
+ if inputs_embeds is not None
1193
+ else None
1194
+ )
1195
+
1196
+ outputs = self.new(
1197
+ input_ids,
1198
+ attention_mask=attention_mask,
1199
+ token_type_ids=token_type_ids,
1200
+ position_ids=position_ids,
1201
+ head_mask=head_mask,
1202
+ inputs_embeds=inputs_embeds,
1203
+ output_attentions=output_attentions,
1204
+ output_hidden_states=output_hidden_states,
1205
+ return_dict=return_dict,
1206
+ unpad_inputs=unpad_inputs,
1207
+ )
1208
+
1209
+ pooled_output = outputs[1]
1210
+
1211
+ pooled_output = self.dropout(pooled_output)
1212
+ logits = self.classifier(pooled_output)
1213
+ reshaped_logits = logits.view(-1, num_choices)
1214
+
1215
+ loss = None
1216
+ if labels is not None:
1217
+ loss_fct = nn.CrossEntropyLoss()
1218
+ loss = loss_fct(reshaped_logits, labels)
1219
+
1220
+ if not return_dict:
1221
+ output = (reshaped_logits,) + outputs[2:]
1222
+ return ((loss,) + output) if loss is not None else output
1223
+
1224
+ return MultipleChoiceModelOutput(
1225
+ loss=loss,
1226
+ logits=reshaped_logits,
1227
+ hidden_states=outputs.hidden_states,
1228
+ attentions=outputs.attentions,
1229
+ )
1230
+
1231
+
1232
+ class NewForTokenClassification(NewPreTrainedModel):
1233
+ def __init__(self, config):
1234
+ super().__init__(config)
1235
+ self.num_labels = config.num_labels
1236
+
1237
+ self.new = NewModel(config, add_pooling_layer=False)
1238
+ classifier_dropout = (
1239
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1240
+ )
1241
+ self.dropout = nn.Dropout(classifier_dropout)
1242
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1243
+
1244
+ # Initialize weights and apply final processing
1245
+ self.post_init()
1246
+
1247
+ def forward(
1248
+ self,
1249
+ input_ids: Optional[torch.Tensor] = None,
1250
+ attention_mask: Optional[torch.Tensor] = None,
1251
+ token_type_ids: Optional[torch.Tensor] = None,
1252
+ position_ids: Optional[torch.Tensor] = None,
1253
+ head_mask: Optional[torch.Tensor] = None,
1254
+ inputs_embeds: Optional[torch.Tensor] = None,
1255
+ labels: Optional[torch.Tensor] = None,
1256
+ output_attentions: Optional[bool] = None,
1257
+ output_hidden_states: Optional[bool] = None,
1258
+ return_dict: Optional[bool] = None,
1259
+ unpad_inputs: Optional[bool] = None,
1260
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1261
+ r"""
1262
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1263
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1264
+ """
1265
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1266
+
1267
+ outputs = self.new(
1268
+ input_ids,
1269
+ attention_mask=attention_mask,
1270
+ token_type_ids=token_type_ids,
1271
+ position_ids=position_ids,
1272
+ head_mask=head_mask,
1273
+ inputs_embeds=inputs_embeds,
1274
+ output_attentions=output_attentions,
1275
+ output_hidden_states=output_hidden_states,
1276
+ return_dict=return_dict,
1277
+ unpad_inputs=unpad_inputs,
1278
+ )
1279
+
1280
+ sequence_output = outputs[0]
1281
+
1282
+ sequence_output = self.dropout(sequence_output)
1283
+ logits = self.classifier(sequence_output)
1284
+
1285
+ loss = None
1286
+ if labels is not None:
1287
+ loss_fct = nn.CrossEntropyLoss()
1288
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1289
+
1290
+ if not return_dict:
1291
+ output = (logits,) + outputs[2:]
1292
+ return ((loss,) + output) if loss is not None else output
1293
+
1294
+ return TokenClassifierOutput(
1295
+ loss=loss,
1296
+ logits=logits,
1297
+ hidden_states=outputs.hidden_states,
1298
+ attentions=outputs.attentions,
1299
+ )
1300
+
1301
+
1302
+ class NewForQuestionAnswering(NewPreTrainedModel):
1303
+ def __init__(self, config):
1304
+ super().__init__(config)
1305
+ self.num_labels = config.num_labels
1306
+
1307
+ self.new = NewModel(config, add_pooling_layer=False)
1308
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1309
+
1310
+ # Initialize weights and apply final processing
1311
+ self.post_init()
1312
+
1313
+ def forward(
1314
+ self,
1315
+ input_ids: Optional[torch.Tensor] = None,
1316
+ attention_mask: Optional[torch.Tensor] = None,
1317
+ token_type_ids: Optional[torch.Tensor] = None,
1318
+ position_ids: Optional[torch.Tensor] = None,
1319
+ head_mask: Optional[torch.Tensor] = None,
1320
+ inputs_embeds: Optional[torch.Tensor] = None,
1321
+ start_positions: Optional[torch.Tensor] = None,
1322
+ end_positions: Optional[torch.Tensor] = None,
1323
+ output_attentions: Optional[bool] = None,
1324
+ output_hidden_states: Optional[bool] = None,
1325
+ return_dict: Optional[bool] = None,
1326
+ unpad_inputs: Optional[bool] = None,
1327
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1328
+ r"""
1329
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1330
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1331
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1332
+ are not taken into account for computing the loss.
1333
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1334
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1335
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1336
+ are not taken into account for computing the loss.
1337
+ """
1338
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1339
+
1340
+ outputs = self.new(
1341
+ input_ids,
1342
+ attention_mask=attention_mask,
1343
+ token_type_ids=token_type_ids,
1344
+ position_ids=position_ids,
1345
+ head_mask=head_mask,
1346
+ inputs_embeds=inputs_embeds,
1347
+ output_attentions=output_attentions,
1348
+ output_hidden_states=output_hidden_states,
1349
+ return_dict=return_dict,
1350
+ unpad_inputs=unpad_inputs,
1351
+ )
1352
+
1353
+ sequence_output = outputs[0]
1354
+
1355
+ logits = self.qa_outputs(sequence_output)
1356
+ start_logits, end_logits = logits.split(1, dim=-1)
1357
+ start_logits = start_logits.squeeze(-1).contiguous()
1358
+ end_logits = end_logits.squeeze(-1).contiguous()
1359
+
1360
+ total_loss = None
1361
+ if start_positions is not None and end_positions is not None:
1362
+ # If we are on multi-GPU, split add a dimension
1363
+ if len(start_positions.size()) > 1:
1364
+ start_positions = start_positions.squeeze(-1)
1365
+ if len(end_positions.size()) > 1:
1366
+ end_positions = end_positions.squeeze(-1)
1367
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1368
+ ignored_index = start_logits.size(1)
1369
+ start_positions = start_positions.clamp(0, ignored_index)
1370
+ end_positions = end_positions.clamp(0, ignored_index)
1371
+
1372
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
1373
+ start_loss = loss_fct(start_logits, start_positions)
1374
+ end_loss = loss_fct(end_logits, end_positions)
1375
+ total_loss = (start_loss + end_loss) / 2
1376
+
1377
+ if not return_dict:
1378
+ output = (start_logits, end_logits) + outputs[2:]
1379
+ return ((total_loss,) + output) if total_loss is not None else output
1380
+
1381
+ return QuestionAnsweringModelOutput(
1382
+ loss=total_loss,
1383
+ start_logits=start_logits,
1384
+ end_logits=end_logits,
1385
+ hidden_states=outputs.hidden_states,
1386
+ attentions=outputs.attentions,
1387
+ )
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Dense",
18
+ "type": "sentence_transformers.models.Dense"
19
+ }
20
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 512,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "max_length": 8000,
49
+ "model_max_length": 512,
50
+ "pad_to_multiple_of": null,
51
+ "pad_token": "[PAD]",
52
+ "pad_token_type_id": 0,
53
+ "padding_side": "right",
54
+ "sep_token": "[SEP]",
55
+ "stride": 0,
56
+ "strip_accents": null,
57
+ "tokenize_chinese_chars": true,
58
+ "tokenizer_class": "BertTokenizer",
59
+ "truncation_side": "right",
60
+ "truncation_strategy": "longest_first",
61
+ "unk_token": "[UNK]"
62
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff