Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- fairseq/examples/hubert/tests/6313-76958-0021.flac +3 -0
- fairseq/examples/pointer_generator/README.xsum.md +180 -0
- fairseq/examples/pointer_generator/pointer_generator_src/__init__.py +6 -0
- fairseq/examples/pointer_generator/pointer_generator_src/transformer_pg.py +518 -0
- fairseq/examples/pointer_generator/postprocess.py +96 -0
- fairseq/examples/quant_noise/transformer_quantization_config.yaml +33 -0
- fairseq/examples/roberta/README.custom_classification.md +168 -0
- fairseq/examples/roberta/README.glue.md +64 -0
- fairseq/examples/roberta/README.md +296 -0
- fairseq/examples/roberta/README.pretraining.md +84 -0
- fairseq/examples/roberta/README.race.md +68 -0
- fairseq/examples/roberta/commonsense_qa/README.md +99 -0
- fairseq/examples/roberta/commonsense_qa/__init__.py +6 -0
- fairseq/examples/roberta/commonsense_qa/commonsense_qa_task.py +190 -0
- fairseq/examples/roberta/commonsense_qa/download_cqa_data.sh +14 -0
- fairseq/examples/roberta/config/finetuning/cola.yaml +59 -0
- fairseq/examples/roberta/config/finetuning/mnli.yaml +59 -0
- fairseq/examples/roberta/config/finetuning/mrpc.yaml +59 -0
- fairseq/examples/roberta/config/finetuning/qnli.yaml +59 -0
- fairseq/examples/roberta/config/finetuning/qqp.yaml +59 -0
- fairseq/examples/roberta/config/finetuning/rte.yaml +59 -0
- fairseq/examples/roberta/config/finetuning/run_config/local.yaml +15 -0
- fairseq/examples/roberta/config/finetuning/run_config/slurm_1g.yaml +28 -0
- fairseq/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml +25 -0
- fairseq/examples/roberta/config/finetuning/sst_2.yaml +59 -0
- fairseq/examples/roberta/config/finetuning/sts_b.yaml +58 -0
- fairseq/examples/roberta/config/pretraining/base.yaml +42 -0
- fairseq/examples/roberta/config/pretraining/run_config/local.yaml +15 -0
- fairseq/examples/roberta/config/pretraining/run_config/slurm_2.yaml +37 -0
- fairseq/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml +39 -0
- fairseq/examples/roberta/config/pretraining/run_config/slurm_3.yaml +36 -0
- fairseq/examples/roberta/config/pretraining/run_config/slurm_4.yaml +36 -0
- fairseq/examples/roberta/fb_multilingual/README.multilingual.pretraining.md +26 -0
- fairseq/examples/roberta/multiprocessing_bpe_encoder.py +130 -0
- fairseq/examples/roberta/preprocess_GLUE_tasks.sh +185 -0
- fairseq/examples/roberta/preprocess_RACE.py +102 -0
- fairseq/examples/roberta/preprocess_RACE.sh +59 -0
- fairseq/examples/roberta/wsc/README.md +125 -0
- fairseq/examples/roberta/wsc/__init__.py +7 -0
- fairseq/examples/roberta/wsc/wsc_criterion.py +167 -0
- fairseq/examples/roberta/wsc/wsc_task.py +401 -0
- fairseq/examples/roberta/wsc/wsc_utils.py +241 -0
- fairseq/examples/rxf/README.md +52 -0
- fairseq/examples/rxf/__init__.py +6 -0
- fairseq/examples/rxf/rxf_src/__init__.py +6 -0
- fairseq/examples/rxf/rxf_src/label_smoothed_cross_entropy_r3f.py +158 -0
- fairseq/examples/rxf/rxf_src/sentence_prediction_r3f.py +171 -0
- fairseq/examples/scaling_nmt/README.md +114 -0
- fairseq/examples/shuffled_word_order/README.finetuning.md +135 -0
.gitattributes
CHANGED
@@ -38,3 +38,4 @@ fairseq/examples/MMPT/videoclip.png filter=lfs diff=lfs merge=lfs -text
|
|
38 |
fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
39 |
fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
40 |
fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
|
|
|
38 |
fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
39 |
fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
40 |
fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
|
fairseq/examples/hubert/tests/6313-76958-0021.flac
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0dd89e2bf8c60e05264bf2539f0ec35bc65efb55cd952b3510fa243de9cc16ff
|
3 |
+
size 223912
|
fairseq/examples/pointer_generator/README.xsum.md
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Training a pointer-generator model on the Extreme Summarization dataset
|
2 |
+
|
3 |
+
##### 1. Download the Extreme Summarization data and preprocess it
|
4 |
+
|
5 |
+
Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to obtain
|
6 |
+
the original Extreme Summarization dataset. You should have six files,
|
7 |
+
{train,validation,test}.{document,summary}.
|
8 |
+
|
9 |
+
##### 2. Create a vocabulary and extend it with source position markers
|
10 |
+
|
11 |
+
```bash
|
12 |
+
vocab_size=10000
|
13 |
+
position_markers=1000
|
14 |
+
export LC_ALL=C
|
15 |
+
cat train.document train.summary |
|
16 |
+
tr -s '[:space:]' '\n' |
|
17 |
+
sort |
|
18 |
+
uniq -c |
|
19 |
+
sort -k1,1bnr -k2 |
|
20 |
+
head -n "$((vocab_size - 4))" |
|
21 |
+
awk '{ print $2 " " $1 }' >dict.pg.txt
|
22 |
+
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt
|
23 |
+
```
|
24 |
+
|
25 |
+
This creates the file dict.pg.txt that contains the 10k most frequent words,
|
26 |
+
followed by 1k source position markers:
|
27 |
+
|
28 |
+
```
|
29 |
+
the 4954867
|
30 |
+
. 4157552
|
31 |
+
, 3439668
|
32 |
+
to 2212159
|
33 |
+
a 1916857
|
34 |
+
of 1916820
|
35 |
+
and 1823350
|
36 |
+
...
|
37 |
+
<unk-0> 0
|
38 |
+
<unk-1> 0
|
39 |
+
<unk-2> 0
|
40 |
+
<unk-3> 0
|
41 |
+
<unk-4> 0
|
42 |
+
...
|
43 |
+
```
|
44 |
+
|
45 |
+
##### 2. Preprocess the text data
|
46 |
+
|
47 |
+
```bash
|
48 |
+
./preprocess.py --source train.document --target train.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out train.pg.src --target-out train.pg.tgt
|
49 |
+
./preprocess.py --source validation.document --target validation.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out valid.pg.src --target-out valid.pg.tgt
|
50 |
+
./preprocess.py --source test.document --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out test.pg.src
|
51 |
+
```
|
52 |
+
|
53 |
+
The data should now contain `<unk-N>` tokens in place of out-of-vocabulary words.
|
54 |
+
|
55 |
+
##### 3. Binarize the dataset:
|
56 |
+
|
57 |
+
```bash
|
58 |
+
fairseq-preprocess \
|
59 |
+
--source-lang src \
|
60 |
+
--target-lang tgt \
|
61 |
+
--trainpref train.pg \
|
62 |
+
--validpref valid.pg \
|
63 |
+
--destdir bin \
|
64 |
+
--workers 60 \
|
65 |
+
--srcdict dict.pg.txt \
|
66 |
+
--joined-dictionary
|
67 |
+
```
|
68 |
+
|
69 |
+
##### 3. Train a model
|
70 |
+
|
71 |
+
```bash
|
72 |
+
total_updates=20000
|
73 |
+
warmup_updates=500
|
74 |
+
lr=0.001
|
75 |
+
max_tokens=4096
|
76 |
+
update_freq=4
|
77 |
+
pointer_layer=-2
|
78 |
+
|
79 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train bin \
|
80 |
+
--user-dir examples/pointer_generator/pointer_generator_src \
|
81 |
+
--max-tokens "$max_tokens" \
|
82 |
+
--task translation \
|
83 |
+
--source-lang src --target-lang tgt \
|
84 |
+
--truncate-source \
|
85 |
+
--layernorm-embedding \
|
86 |
+
--share-all-embeddings \
|
87 |
+
--encoder-normalize-before \
|
88 |
+
--decoder-normalize-before \
|
89 |
+
--required-batch-size-multiple 1 \
|
90 |
+
--arch transformer_pointer_generator \
|
91 |
+
--alignment-layer "$pointer_layer" \
|
92 |
+
--alignment-heads 1 \
|
93 |
+
--source-position-markers 1000 \
|
94 |
+
--criterion label_smoothed_cross_entropy \
|
95 |
+
--label-smoothing 0.1 \
|
96 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
97 |
+
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
|
98 |
+
--clip-norm 0.1 \
|
99 |
+
--lr-scheduler inverse_sqrt --lr "$lr" --max-update "$total_updates" --warmup-updates "$warmup_updates" \
|
100 |
+
--update-freq "$update_freq" \
|
101 |
+
--skip-invalid-size-inputs-valid-test
|
102 |
+
```
|
103 |
+
|
104 |
+
Above we specify that our dictionary contains 1000 source position markers, and
|
105 |
+
that we want to use one attention head from the penultimate decoder layer for
|
106 |
+
pointing. It should run in 5.5 hours on one node with eight 32GB V100 GPUs. The
|
107 |
+
logged messages confirm that dictionary indices above 10000 will be mapped to
|
108 |
+
the `<unk>` embedding:
|
109 |
+
|
110 |
+
```
|
111 |
+
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [src] dictionary: 11000 types
|
112 |
+
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [tgt] dictionary: 11000 types
|
113 |
+
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.src
|
114 |
+
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.tgt
|
115 |
+
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | bin valid src-tgt 11332 examples
|
116 |
+
2020-09-24 20:43:53 | INFO | fairseq.models.transformer_pg | dictionary indices from 10000 to 10999 will be mapped to 3
|
117 |
+
```
|
118 |
+
|
119 |
+
##### 4. Summarize the test sequences
|
120 |
+
|
121 |
+
```bash
|
122 |
+
batch_size=32
|
123 |
+
beam_size=6
|
124 |
+
max_length=60
|
125 |
+
length_penalty=1.0
|
126 |
+
|
127 |
+
fairseq-interactive bin \
|
128 |
+
--user-dir examples/pointer_generator/pointer_generator_src \
|
129 |
+
--batch-size "$batch_size" \
|
130 |
+
--task translation \
|
131 |
+
--source-lang src --target-lang tgt \
|
132 |
+
--path checkpoints/checkpoint_last.pt \
|
133 |
+
--input test.pg.src \
|
134 |
+
--buffer-size 200 \
|
135 |
+
--max-len-a 0 \
|
136 |
+
--max-len-b "$max_length" \
|
137 |
+
--lenpen "$length_penalty" \
|
138 |
+
--beam "$beam_size" \
|
139 |
+
--skip-invalid-size-inputs-valid-test |
|
140 |
+
tee generate.out
|
141 |
+
grep ^H generate.out | cut -f 3- >generate.hyp
|
142 |
+
```
|
143 |
+
|
144 |
+
Now you should have the generated sequences in `generate.hyp`. They contain
|
145 |
+
`<unk-N>` tokens that the model has copied from the source sequence. In order to
|
146 |
+
retrieve the original words, we need the unprocessed source sequences from
|
147 |
+
`test.document`.
|
148 |
+
|
149 |
+
##### 5. Process the generated output
|
150 |
+
|
151 |
+
Since we skipped too long inputs when producing `generate.hyp`, we also have to
|
152 |
+
skip too long sequences now that we read `test.document`.
|
153 |
+
|
154 |
+
```bash
|
155 |
+
./postprocess.py \
|
156 |
+
--source <(awk 'NF<1024' test.document) \
|
157 |
+
--target generate.hyp \
|
158 |
+
--target-out generate.hyp.processed
|
159 |
+
```
|
160 |
+
|
161 |
+
Now you'll find the final sequences from `generate.hyp.processed`, with
|
162 |
+
`<unk-N>` replaced with the original word from the source sequence.
|
163 |
+
|
164 |
+
##### An example of a summarized sequence
|
165 |
+
|
166 |
+
The original source document in `test.document`:
|
167 |
+
|
168 |
+
> de roon moved to teesside in june 2016 for an initial # 8.8 m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page .
|
169 |
+
|
170 |
+
The preprocessed source document in `test.src.pg`:
|
171 |
+
|
172 |
+
> de \<unk-1> moved to \<unk-4> in june 2016 for an initial # \<unk-12> m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page .
|
173 |
+
|
174 |
+
The generated summary in `generate.hyp`:
|
175 |
+
|
176 |
+
> middlesbrough striker \<unk> de \<unk-1> has joined spanish side \<unk> on a season-long loan .
|
177 |
+
|
178 |
+
The generated summary after postprocessing in `generate.hyp.processed`:
|
179 |
+
|
180 |
+
> middlesbrough striker \<unk> de roon has joined spanish side \<unk> on a season-long loan .
|
fairseq/examples/pointer_generator/pointer_generator_src/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import transformer_pg # noqa
|
fairseq/examples/pointer_generator/pointer_generator_src/transformer_pg.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from typing import Any, Dict, Optional, List, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from fairseq import utils
|
12 |
+
from fairseq.models import register_model, register_model_architecture
|
13 |
+
from fairseq.models.transformer import (
|
14 |
+
DEFAULT_MAX_SOURCE_POSITIONS,
|
15 |
+
DEFAULT_MAX_TARGET_POSITIONS,
|
16 |
+
TransformerDecoder,
|
17 |
+
TransformerEncoder,
|
18 |
+
TransformerModel,
|
19 |
+
base_architecture,
|
20 |
+
)
|
21 |
+
from torch import Tensor
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
@register_model("transformer_pointer_generator")
|
28 |
+
class TransformerPointerGeneratorModel(TransformerModel):
|
29 |
+
"""
|
30 |
+
Transformer model from `"Attention Is All You Need" (Vaswani et al, 2017)
|
31 |
+
<https://arxiv.org/abs/1706.03762>`_, augmented with a pointer-generator
|
32 |
+
network from `"Get To The Point: Summarization with Pointer-Generator
|
33 |
+
Networks" (See et al, 2017) <https://arxiv.org/abs/1704.04368>`_.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
encoder (TransformerPointerGeneratorEncoder): the encoder
|
37 |
+
decoder (TransformerPointerGeneratorDecoder): the decoder
|
38 |
+
|
39 |
+
The Transformer pointer-generator model provides the following named
|
40 |
+
architectures and command-line arguments:
|
41 |
+
|
42 |
+
.. argparse::
|
43 |
+
:ref: fairseq.models.transformer_pointer_generator_parser
|
44 |
+
:prog:
|
45 |
+
"""
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def add_args(parser):
|
49 |
+
"""Add model-specific arguments to the parser."""
|
50 |
+
# fmt: off
|
51 |
+
TransformerModel.add_args(parser)
|
52 |
+
parser.add_argument('--alignment-heads', type=int, metavar='N',
|
53 |
+
help='number of attention heads to be used for '
|
54 |
+
'pointing')
|
55 |
+
parser.add_argument('--alignment-layer', type=int, metavar='I',
|
56 |
+
help='layer number to be used for pointing (0 '
|
57 |
+
'corresponding to the bottommost layer)')
|
58 |
+
parser.add_argument('--source-position-markers', type=int, metavar='N',
|
59 |
+
help='dictionary includes N additional items that '
|
60 |
+
'represent an OOV token at a particular input '
|
61 |
+
'position')
|
62 |
+
parser.add_argument('--force-generation', type=float, metavar='P',
|
63 |
+
default=None,
|
64 |
+
help='set the vocabulary distribution weight to P, '
|
65 |
+
'instead of predicting it from the input (1.0 '
|
66 |
+
'corresponding to generation, 0.0 to pointing)')
|
67 |
+
# fmt: on
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def build_model(cls, args, task):
|
71 |
+
"""Build a new model instance."""
|
72 |
+
|
73 |
+
# make sure all arguments are present in older models
|
74 |
+
base_architecture(args)
|
75 |
+
|
76 |
+
if args.encoder_layers_to_keep:
|
77 |
+
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
|
78 |
+
if args.decoder_layers_to_keep:
|
79 |
+
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
|
80 |
+
|
81 |
+
if getattr(args, "max_source_positions", None) is None:
|
82 |
+
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
83 |
+
if getattr(args, "max_target_positions", None) is None:
|
84 |
+
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
85 |
+
if getattr(args, "source_position_markers", None) is None:
|
86 |
+
args.source_position_markers = args.max_source_positions
|
87 |
+
|
88 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
89 |
+
if src_dict != tgt_dict:
|
90 |
+
raise ValueError("Pointer-generator requires a joined dictionary")
|
91 |
+
|
92 |
+
def build_embedding(dictionary, embed_dim, path=None):
|
93 |
+
# The dictionary may include additional items that can be used in
|
94 |
+
# place of the normal OOV token and that all map to the same
|
95 |
+
# embedding. Using a different token for each input position allows
|
96 |
+
# one to restore the word identities from the original source text.
|
97 |
+
num_embeddings = len(dictionary) - args.source_position_markers
|
98 |
+
padding_idx = dictionary.pad()
|
99 |
+
unk_idx = dictionary.unk()
|
100 |
+
logger.info(
|
101 |
+
"dictionary indices from {0} to {1} will be mapped to {2}".format(
|
102 |
+
num_embeddings, len(dictionary) - 1, unk_idx
|
103 |
+
)
|
104 |
+
)
|
105 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx, unk_idx)
|
106 |
+
# if provided, load from preloaded dictionaries
|
107 |
+
if path:
|
108 |
+
embed_dict = utils.parse_embedding(path)
|
109 |
+
utils.load_embedding(embed_dict, dictionary, emb)
|
110 |
+
return emb
|
111 |
+
|
112 |
+
if args.share_all_embeddings:
|
113 |
+
if args.encoder_embed_dim != args.decoder_embed_dim:
|
114 |
+
raise ValueError(
|
115 |
+
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
116 |
+
)
|
117 |
+
if args.decoder_embed_path and (
|
118 |
+
args.decoder_embed_path != args.encoder_embed_path
|
119 |
+
):
|
120 |
+
raise ValueError(
|
121 |
+
"--share-all-embeddings not compatible with --decoder-embed-path"
|
122 |
+
)
|
123 |
+
encoder_embed_tokens = build_embedding(
|
124 |
+
src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
125 |
+
)
|
126 |
+
decoder_embed_tokens = encoder_embed_tokens
|
127 |
+
args.share_decoder_input_output_embed = True
|
128 |
+
else:
|
129 |
+
encoder_embed_tokens = build_embedding(
|
130 |
+
src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
131 |
+
)
|
132 |
+
decoder_embed_tokens = build_embedding(
|
133 |
+
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
|
134 |
+
)
|
135 |
+
|
136 |
+
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
|
137 |
+
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
|
138 |
+
return cls(args, encoder, decoder)
|
139 |
+
|
140 |
+
@classmethod
|
141 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
142 |
+
return TransformerPointerGeneratorEncoder(args, src_dict, embed_tokens)
|
143 |
+
|
144 |
+
@classmethod
|
145 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
146 |
+
return TransformerPointerGeneratorDecoder(args, tgt_dict, embed_tokens)
|
147 |
+
|
148 |
+
|
149 |
+
class TransformerPointerGeneratorEncoder(TransformerEncoder):
|
150 |
+
"""
|
151 |
+
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
152 |
+
is a :class:`TransformerEncoderLayer`. The pointer-generator variant adds
|
153 |
+
the source tokens to the encoder output as these are otherwise not passed
|
154 |
+
to the decoder.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def forward(
|
158 |
+
self,
|
159 |
+
src_tokens,
|
160 |
+
src_lengths: Optional[Tensor] = None,
|
161 |
+
return_all_hiddens: bool = False,
|
162 |
+
token_embeddings: Optional[Tensor] = None
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
Runs the `forward()` method of the parent Transformer class. Then adds
|
166 |
+
the source tokens into the encoder output tuple.
|
167 |
+
|
168 |
+
While it might be more elegant that the model would pass the source
|
169 |
+
tokens to the `forward()` method of the decoder too, this would require
|
170 |
+
changes to `SequenceGenerator`.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
src_tokens (torch.LongTensor): tokens in the source language of
|
174 |
+
shape `(batch, src_len)`
|
175 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
176 |
+
shape `(batch)`
|
177 |
+
return_all_hiddens (bool, optional): also return all of the
|
178 |
+
intermediate hidden states (default: False).
|
179 |
+
token_embeddings (torch.Tensor, optional): precomputed embeddings
|
180 |
+
default `None` will recompute embeddings
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
namedtuple:
|
184 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
185 |
+
shape `(src_len, batch, embed_dim)`
|
186 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
187 |
+
padding elements of shape `(batch, src_len)`
|
188 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
189 |
+
of shape `(batch, src_len, embed_dim)`
|
190 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
191 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
192 |
+
Only populated if *return_all_hiddens* is True.
|
193 |
+
- **src_tokens** (Tensor): input token ids of shape
|
194 |
+
`(batch, src_len)`
|
195 |
+
"""
|
196 |
+
encoder_out = self.forward_scriptable(src_tokens,
|
197 |
+
src_lengths,
|
198 |
+
return_all_hiddens,
|
199 |
+
token_embeddings)
|
200 |
+
|
201 |
+
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
|
202 |
+
# `forward` so we use a dictionary instead.
|
203 |
+
# TorchScript does not support mixed values so the values are all lists.
|
204 |
+
# The empty list is equivalent to None.
|
205 |
+
return {
|
206 |
+
"encoder_out": encoder_out["encoder_out"], # T x B x C
|
207 |
+
"encoder_padding_mask": encoder_out["encoder_padding_mask"], # B x T
|
208 |
+
"encoder_embedding": encoder_out["encoder_embedding"], # B x T x C
|
209 |
+
"encoder_states": encoder_out["encoder_states"], # List[T x B x C]
|
210 |
+
"src_tokens": [src_tokens], # B x T
|
211 |
+
"src_lengths": [],
|
212 |
+
}
|
213 |
+
|
214 |
+
|
215 |
+
class TransformerPointerGeneratorDecoder(TransformerDecoder):
|
216 |
+
"""
|
217 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
218 |
+
is a :class:`TransformerDecoderLayer`. The pointer-generator variant mixes
|
219 |
+
the output probabilities with an attention distribution in the output layer.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
args (argparse.Namespace): parsed command-line arguments
|
223 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
224 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(self, args, dictionary, embed_tokens):
|
228 |
+
super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
|
229 |
+
|
230 |
+
# In the pointer-generator model these arguments define the decoder
|
231 |
+
# layer and the number of attention heads that will be averaged to
|
232 |
+
# create the alignment for pointing.
|
233 |
+
self.alignment_heads = args.alignment_heads
|
234 |
+
self.alignment_layer = args.alignment_layer
|
235 |
+
|
236 |
+
input_embed_dim = embed_tokens.embedding_dim
|
237 |
+
|
238 |
+
# Generation probabilities / interpolation coefficients are predicted
|
239 |
+
# from the current decoder input embedding and the decoder output, which
|
240 |
+
# is the size of output_embed_dim.
|
241 |
+
p_gen_input_size = input_embed_dim + self.output_embed_dim
|
242 |
+
self.project_p_gens = nn.Linear(p_gen_input_size, 1)
|
243 |
+
nn.init.zeros_(self.project_p_gens.bias)
|
244 |
+
|
245 |
+
# The dictionary may include a separate entry for an OOV token in each
|
246 |
+
# input position, so that their identity can be restored from the
|
247 |
+
# original source text.
|
248 |
+
self.num_types = len(dictionary)
|
249 |
+
self.num_oov_types = args.source_position_markers
|
250 |
+
self.num_embeddings = self.num_types - self.num_oov_types
|
251 |
+
self.force_p_gen = args.force_generation
|
252 |
+
|
253 |
+
def forward(
|
254 |
+
self,
|
255 |
+
prev_output_tokens,
|
256 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
257 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
258 |
+
features_only: bool = False,
|
259 |
+
alignment_layer: Optional[int] = 0,
|
260 |
+
alignment_heads: Optional[int] = 1,
|
261 |
+
src_lengths: Optional[Any] = None,
|
262 |
+
return_all_hiddens: bool = False,
|
263 |
+
):
|
264 |
+
"""
|
265 |
+
Args:
|
266 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
267 |
+
`(batch, tgt_len)`, for teacher forcing
|
268 |
+
encoder_out (optional): output from the encoder, used for
|
269 |
+
encoder-side attention
|
270 |
+
incremental_state (dict, optional): dictionary used for storing
|
271 |
+
state during :ref:`Incremental decoding`
|
272 |
+
features_only (bool, optional): only return features without
|
273 |
+
applying output layer (default: False)
|
274 |
+
alignment_layer (int, optional): 0-based index of the layer to be
|
275 |
+
used for pointing (default: 0)
|
276 |
+
alignment_heads (int, optional): number of attention heads to be
|
277 |
+
used for pointing (default: 1)
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
tuple:
|
281 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
282 |
+
- a dictionary with any model-specific outputs
|
283 |
+
"""
|
284 |
+
# The normal Transformer model doesn't pass the alignment_layer and
|
285 |
+
# alignment_heads parameters correctly. We use our local variables.
|
286 |
+
x, extra = self.extract_features(
|
287 |
+
prev_output_tokens,
|
288 |
+
encoder_out=encoder_out,
|
289 |
+
incremental_state=incremental_state,
|
290 |
+
alignment_layer=self.alignment_layer,
|
291 |
+
alignment_heads=self.alignment_heads,
|
292 |
+
)
|
293 |
+
if not features_only:
|
294 |
+
# Embedding the tokens again for generation probability prediction,
|
295 |
+
# so that we don't have to reimplement the whole extract_features()
|
296 |
+
# method.
|
297 |
+
if incremental_state is not None:
|
298 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
299 |
+
prev_output_embed = self.embed_tokens(prev_output_tokens)
|
300 |
+
prev_output_embed *= self.embed_scale
|
301 |
+
predictors = torch.cat((prev_output_embed, x), 2)
|
302 |
+
p_gens = self.project_p_gens(predictors)
|
303 |
+
p_gens = torch.sigmoid(p_gens.float())
|
304 |
+
# Torchscript complains if encoder_out or attn are None because
|
305 |
+
# `output_layer()` signature expects tensors instead
|
306 |
+
attn: Optional[Tensor] = extra["attn"][0]
|
307 |
+
assert encoder_out is not None
|
308 |
+
assert attn is not None
|
309 |
+
x = self.output_layer(x, attn, encoder_out["src_tokens"][0], p_gens)
|
310 |
+
return x, extra
|
311 |
+
|
312 |
+
def output_layer(
|
313 |
+
self,
|
314 |
+
features: Tensor,
|
315 |
+
attn: Tensor,
|
316 |
+
src_tokens: Tensor,
|
317 |
+
p_gens: Tensor
|
318 |
+
) -> Tensor:
|
319 |
+
"""
|
320 |
+
Project features to the vocabulary size and mix with the attention
|
321 |
+
distributions.
|
322 |
+
"""
|
323 |
+
if self.force_p_gen is not None:
|
324 |
+
p_gens = self.force_p_gen
|
325 |
+
|
326 |
+
# project back to size of vocabulary
|
327 |
+
if self.adaptive_softmax is None:
|
328 |
+
logits = self.output_projection(features)
|
329 |
+
else:
|
330 |
+
logits = features
|
331 |
+
|
332 |
+
batch_size = logits.shape[0]
|
333 |
+
output_length = logits.shape[1]
|
334 |
+
assert logits.shape[2] == self.num_embeddings
|
335 |
+
assert src_tokens.shape[0] == batch_size
|
336 |
+
src_length = src_tokens.shape[1]
|
337 |
+
|
338 |
+
# The final output distribution will be a mixture of the normal output
|
339 |
+
# distribution (softmax of logits) and attention weights.
|
340 |
+
gen_dists = self.get_normalized_probs_scriptable(
|
341 |
+
(logits, None), log_probs=False, sample=None
|
342 |
+
)
|
343 |
+
gen_dists = torch.mul(gen_dists, p_gens)
|
344 |
+
padding_size = (batch_size, output_length, self.num_oov_types)
|
345 |
+
padding = gen_dists.new_zeros(padding_size)
|
346 |
+
gen_dists = torch.cat((gen_dists, padding), 2)
|
347 |
+
assert gen_dists.shape[2] == self.num_types
|
348 |
+
|
349 |
+
# Scatter attention distributions to distributions over the extended
|
350 |
+
# vocabulary in a tensor of shape [batch_size, output_length,
|
351 |
+
# vocab_size]. Each attention weight will be written into a location
|
352 |
+
# that is for other dimensions the same as in the index tensor, but for
|
353 |
+
# the third dimension it's the value of the index tensor (the token ID).
|
354 |
+
attn = torch.mul(attn.float(), 1 - p_gens)
|
355 |
+
index = src_tokens[:, None, :]
|
356 |
+
index = index.expand(batch_size, output_length, src_length)
|
357 |
+
attn_dists_size = (batch_size, output_length, self.num_types)
|
358 |
+
attn_dists = attn.new_zeros(attn_dists_size)
|
359 |
+
attn_dists.scatter_add_(2, index, attn.float())
|
360 |
+
|
361 |
+
# Final distributions, [batch_size, output_length, num_types].
|
362 |
+
return gen_dists + attn_dists
|
363 |
+
|
364 |
+
def get_normalized_probs(
|
365 |
+
self,
|
366 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
367 |
+
log_probs: bool,
|
368 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
369 |
+
):
|
370 |
+
"""
|
371 |
+
Get normalized probabilities (or log probs) from a net's output.
|
372 |
+
Pointer-generator network output is already normalized.
|
373 |
+
"""
|
374 |
+
probs = net_output[0]
|
375 |
+
# Make sure the probabilities are greater than zero when returning log
|
376 |
+
# probabilities.
|
377 |
+
return probs.clamp(1e-10, 1.0).log() if log_probs else probs
|
378 |
+
|
379 |
+
|
380 |
+
class Embedding(nn.Embedding):
|
381 |
+
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
382 |
+
This module is often used to store word embeddings and retrieve them using indices.
|
383 |
+
The input to the module is a list of indices, and the output is the corresponding
|
384 |
+
word embeddings. This subclass differs from the standard PyTorch Embedding class by
|
385 |
+
allowing additional vocabulary entries that will be mapped to the unknown token
|
386 |
+
embedding.
|
387 |
+
Args:
|
388 |
+
num_embeddings (int): size of the dictionary of embeddings
|
389 |
+
embedding_dim (int): the size of each embedding vector
|
390 |
+
padding_idx (int): Pads the output with the embedding vector at :attr:`padding_idx`
|
391 |
+
(initialized to zeros) whenever it encounters the index.
|
392 |
+
unk_idx (int): Maps all token indices that are greater than or equal to
|
393 |
+
num_embeddings to this index.
|
394 |
+
Attributes:
|
395 |
+
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
396 |
+
initialized from :math:`\mathcal{N}(0, 1)`
|
397 |
+
Shape:
|
398 |
+
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
399 |
+
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
400 |
+
.. note::
|
401 |
+
Keep in mind that only a limited number of optimizers support
|
402 |
+
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
403 |
+
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
404 |
+
.. note::
|
405 |
+
With :attr:`padding_idx` set, the embedding vector at
|
406 |
+
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
407 |
+
vector can be modified afterwards, e.g., using a customized
|
408 |
+
initialization method, and thus changing the vector used to pad the
|
409 |
+
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
410 |
+
is always zero.
|
411 |
+
"""
|
412 |
+
__constants__ = ["unk_idx"]
|
413 |
+
|
414 |
+
# Torchscript: Inheriting from Embedding class produces an error when exporting to Torchscript
|
415 |
+
# -> RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details
|
416 |
+
# It's happening because max_norm attribute from nn.Embedding is None by default and it cannot be
|
417 |
+
# cast to a C++ type
|
418 |
+
def __init__(
|
419 |
+
self,
|
420 |
+
num_embeddings: int,
|
421 |
+
embedding_dim: int,
|
422 |
+
padding_idx: Optional[int],
|
423 |
+
unk_idx: int,
|
424 |
+
max_norm: Optional[float] = float("inf"),
|
425 |
+
):
|
426 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, max_norm=max_norm)
|
427 |
+
self.unk_idx = unk_idx
|
428 |
+
nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5)
|
429 |
+
nn.init.constant_(self.weight[padding_idx], 0)
|
430 |
+
|
431 |
+
def forward(self, input):
|
432 |
+
input = torch.where(
|
433 |
+
input >= self.num_embeddings, torch.ones_like(input) * self.unk_idx, input
|
434 |
+
)
|
435 |
+
return nn.functional.embedding(
|
436 |
+
input, self.weight, self.padding_idx, self.max_norm,
|
437 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse
|
438 |
+
)
|
439 |
+
|
440 |
+
|
441 |
+
@register_model_architecture(
|
442 |
+
"transformer_pointer_generator", "transformer_pointer_generator"
|
443 |
+
)
|
444 |
+
def transformer_pointer_generator(args):
|
445 |
+
args.alignment_heads = getattr(args, "alignment_heads", 1)
|
446 |
+
args.alignment_layer = getattr(args, "alignment_layer", -1)
|
447 |
+
base_architecture(args)
|
448 |
+
if args.alignment_layer < 0:
|
449 |
+
args.alignment_layer = args.decoder_layers + args.alignment_layer
|
450 |
+
|
451 |
+
|
452 |
+
@register_model_architecture(
|
453 |
+
"transformer_pointer_generator", "transformer_pointer_generator_iwslt_de_en"
|
454 |
+
)
|
455 |
+
def transformer_pointer_generator_iwslt_de_en(args):
|
456 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
457 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
458 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
459 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
460 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
461 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
462 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
463 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
464 |
+
transformer_pointer_generator(args)
|
465 |
+
|
466 |
+
|
467 |
+
@register_model_architecture(
|
468 |
+
"transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de"
|
469 |
+
)
|
470 |
+
def transformer_pointer_generator_wmt_en_de(args):
|
471 |
+
transformer_pointer_generator(args)
|
472 |
+
|
473 |
+
|
474 |
+
# Transformer pointer-generator with the base Transformer parameters as used in
|
475 |
+
# the "Attention Is All You Need" paper (Vaswani et al., 2017)
|
476 |
+
@register_model_architecture(
|
477 |
+
"transformer_pointer_generator",
|
478 |
+
"transformer_pointer_generator_vaswani_wmt_en_de_big",
|
479 |
+
)
|
480 |
+
def transformer_pointer_generator_vaswani_wmt_en_de_big(args):
|
481 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
482 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
483 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
484 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
485 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
486 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
487 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
488 |
+
args.dropout = getattr(args, "dropout", 0.3)
|
489 |
+
transformer_pointer_generator(args)
|
490 |
+
|
491 |
+
|
492 |
+
@register_model_architecture(
|
493 |
+
"transformer_pointer_generator",
|
494 |
+
"transformer_pointer_generator_vaswani_wmt_en_fr_big",
|
495 |
+
)
|
496 |
+
def transformer_pointer_generator_vaswani_wmt_en_fr_big(args):
|
497 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
498 |
+
transformer_pointer_generator_vaswani_wmt_en_de_big(args)
|
499 |
+
|
500 |
+
|
501 |
+
@register_model_architecture(
|
502 |
+
"transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de_big"
|
503 |
+
)
|
504 |
+
def transformer_pointer_generator_wmt_en_de_big(args):
|
505 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
506 |
+
transformer_pointer_generator_vaswani_wmt_en_de_big(args)
|
507 |
+
|
508 |
+
|
509 |
+
# default parameters used in tensor2tensor implementation
|
510 |
+
@register_model_architecture(
|
511 |
+
"transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de_big_t2t"
|
512 |
+
)
|
513 |
+
def transformer_pointer_generator_wmt_en_de_big_t2t(args):
|
514 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
515 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
516 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
517 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.1)
|
518 |
+
transformer_pointer_generator_vaswani_wmt_en_de_big(args)
|
fairseq/examples/pointer_generator/postprocess.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import re
|
9 |
+
import sys
|
10 |
+
|
11 |
+
|
12 |
+
class OOVIndexError(IndexError):
|
13 |
+
def __init__(self, pos, source_seq, target_seq):
|
14 |
+
super(OOVIndexError, self).__init__(
|
15 |
+
"A <unk-N> tag in the target sequence refers to a position that is "
|
16 |
+
"outside the source sequence. Most likely there was a mismatch in "
|
17 |
+
"provided source and target sequences. Otherwise this would mean that "
|
18 |
+
"the pointing mechanism somehow attended to a position that is past "
|
19 |
+
"the actual sequence end."
|
20 |
+
)
|
21 |
+
self.source_pos = pos
|
22 |
+
self.source_seq = source_seq
|
23 |
+
self.target_seq = target_seq
|
24 |
+
|
25 |
+
|
26 |
+
def replace_oovs(source_in, target_in, target_out):
|
27 |
+
"""Replaces <unk-N> tokens in the target text with the corresponding word in
|
28 |
+
the source text.
|
29 |
+
"""
|
30 |
+
|
31 |
+
oov_re = re.compile("^<unk-([0-9]+)>$")
|
32 |
+
|
33 |
+
for source_seq, target_seq in zip(source_in, target_in):
|
34 |
+
target_seq_out = []
|
35 |
+
|
36 |
+
pos_to_word = source_seq.strip().split()
|
37 |
+
for token in target_seq.strip().split():
|
38 |
+
m = oov_re.match(token)
|
39 |
+
if m:
|
40 |
+
pos = int(m.group(1))
|
41 |
+
if pos >= len(pos_to_word):
|
42 |
+
raise OOVIndexError(pos, source_seq, target_seq)
|
43 |
+
token_out = pos_to_word[pos]
|
44 |
+
else:
|
45 |
+
token_out = token
|
46 |
+
target_seq_out.append(token_out)
|
47 |
+
target_out.write(" ".join(target_seq_out) + "\n")
|
48 |
+
|
49 |
+
|
50 |
+
def main():
|
51 |
+
parser = argparse.ArgumentParser(
|
52 |
+
description="Replaces <unk-N> tokens in target sequences with words from "
|
53 |
+
"the corresponding position in the source sequence."
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--source", type=str, help="text file with source sequences", required=True
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--target", type=str, help="text file with target sequences", required=True
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--target-out",
|
63 |
+
type=str,
|
64 |
+
help="where to write target sequences without <unk-N> " "entries",
|
65 |
+
required=True,
|
66 |
+
)
|
67 |
+
args = parser.parse_args()
|
68 |
+
|
69 |
+
target_in = (
|
70 |
+
open(args.target, "r", encoding="utf-8") if args.target is not None else None
|
71 |
+
)
|
72 |
+
target_out = (
|
73 |
+
open(args.target_out, "w", encoding="utf-8")
|
74 |
+
if args.target_out is not None
|
75 |
+
else None
|
76 |
+
)
|
77 |
+
with open(args.source, "r", encoding="utf-8") as source_in, open(
|
78 |
+
args.target, "r", encoding="utf-8"
|
79 |
+
) as target_in, open(args.target_out, "w", encoding="utf-8") as target_out:
|
80 |
+
replace_oovs(source_in, target_in, target_out)
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
try:
|
85 |
+
main()
|
86 |
+
except OOVIndexError as e:
|
87 |
+
print(e, file=sys.stderr)
|
88 |
+
print("Source sequence:", e.source_seq.strip(), file=sys.stderr)
|
89 |
+
print("Target sequence:", e.target_seq.strip(), file=sys.stderr)
|
90 |
+
print(
|
91 |
+
"Source sequence length:",
|
92 |
+
len(e.source_seq.strip().split()),
|
93 |
+
file=sys.stderr,
|
94 |
+
)
|
95 |
+
print("The offending tag points to:", e.source_pos)
|
96 |
+
sys.exit(2)
|
fairseq/examples/quant_noise/transformer_quantization_config.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This file defines example configuration arguments for quantizing
|
7 |
+
# a transformer model with product quantization
|
8 |
+
|
9 |
+
# Number of Centroids for Product Quantization, by default 256 (byte-aligned)
|
10 |
+
n_centroids:
|
11 |
+
Linear:
|
12 |
+
key: in_features
|
13 |
+
value: {"*": 256}
|
14 |
+
Embedding:
|
15 |
+
key: embedding_dim
|
16 |
+
value: {"*": 256}
|
17 |
+
|
18 |
+
# Block Sizes for Product Quantization
|
19 |
+
# We suggest: 8 for FFN, 4 for ATTN, 4 for embedding projections, 8 for embeddings
|
20 |
+
block_sizes:
|
21 |
+
Linear:
|
22 |
+
key: fuzzy_name
|
23 |
+
value: {fc: 8, attn: 4, emb: 4}
|
24 |
+
Embedding:
|
25 |
+
key: fuzzy_name
|
26 |
+
value: {emb: 8}
|
27 |
+
|
28 |
+
# Layers to Quantize Sequentially
|
29 |
+
# We suggest: first FFN, then EMB, then ATTN
|
30 |
+
layers_to_quantize:
|
31 |
+
- decoder\\.layers\\.\d+\\.fc[12]
|
32 |
+
- decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]
|
33 |
+
- decoder\\.layers\\.\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)
|
fairseq/examples/roberta/README.custom_classification.md
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetuning RoBERTa on a custom classification task
|
2 |
+
|
3 |
+
This example shows how to finetune RoBERTa on the IMDB dataset, but should illustrate the process for most classification tasks.
|
4 |
+
|
5 |
+
### 1) Get the data
|
6 |
+
|
7 |
+
```bash
|
8 |
+
wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
|
9 |
+
tar zxvf aclImdb_v1.tar.gz
|
10 |
+
```
|
11 |
+
|
12 |
+
|
13 |
+
### 2) Format data
|
14 |
+
|
15 |
+
`IMDB` data has one data-sample in each file, below python code-snippet converts it one file for train and valid each for ease of processing.
|
16 |
+
```python
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
import random
|
20 |
+
from glob import glob
|
21 |
+
|
22 |
+
random.seed(0)
|
23 |
+
|
24 |
+
def main(args):
|
25 |
+
for split in ['train', 'test']:
|
26 |
+
samples = []
|
27 |
+
for class_label in ['pos', 'neg']:
|
28 |
+
fnames = glob(os.path.join(args.datadir, split, class_label) + '/*.txt')
|
29 |
+
for fname in fnames:
|
30 |
+
with open(fname) as fin:
|
31 |
+
line = fin.readline()
|
32 |
+
samples.append((line, 1 if class_label == 'pos' else 0))
|
33 |
+
random.shuffle(samples)
|
34 |
+
out_fname = 'train' if split == 'train' else 'dev'
|
35 |
+
f1 = open(os.path.join(args.datadir, out_fname + '.input0'), 'w')
|
36 |
+
f2 = open(os.path.join(args.datadir, out_fname + '.label'), 'w')
|
37 |
+
for sample in samples:
|
38 |
+
f1.write(sample[0] + '\n')
|
39 |
+
f2.write(str(sample[1]) + '\n')
|
40 |
+
f1.close()
|
41 |
+
f2.close()
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument('--datadir', default='aclImdb')
|
46 |
+
args = parser.parse_args()
|
47 |
+
main(args)
|
48 |
+
```
|
49 |
+
|
50 |
+
|
51 |
+
### 3) BPE encode
|
52 |
+
|
53 |
+
Run `multiprocessing_bpe_encoder`, you can also do this in previous step for each sample but that might be slower.
|
54 |
+
```bash
|
55 |
+
# Download encoder.json and vocab.bpe
|
56 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
|
57 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
|
58 |
+
|
59 |
+
for SPLIT in train dev; do
|
60 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
61 |
+
--encoder-json encoder.json \
|
62 |
+
--vocab-bpe vocab.bpe \
|
63 |
+
--inputs "aclImdb/$SPLIT.input0" \
|
64 |
+
--outputs "aclImdb/$SPLIT.input0.bpe" \
|
65 |
+
--workers 60 \
|
66 |
+
--keep-empty
|
67 |
+
done
|
68 |
+
```
|
69 |
+
|
70 |
+
|
71 |
+
### 4) Preprocess data
|
72 |
+
|
73 |
+
```bash
|
74 |
+
# Download fairseq dictionary.
|
75 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
|
76 |
+
|
77 |
+
fairseq-preprocess \
|
78 |
+
--only-source \
|
79 |
+
--trainpref "aclImdb/train.input0.bpe" \
|
80 |
+
--validpref "aclImdb/dev.input0.bpe" \
|
81 |
+
--destdir "IMDB-bin/input0" \
|
82 |
+
--workers 60 \
|
83 |
+
--srcdict dict.txt
|
84 |
+
|
85 |
+
fairseq-preprocess \
|
86 |
+
--only-source \
|
87 |
+
--trainpref "aclImdb/train.label" \
|
88 |
+
--validpref "aclImdb/dev.label" \
|
89 |
+
--destdir "IMDB-bin/label" \
|
90 |
+
--workers 60
|
91 |
+
|
92 |
+
```
|
93 |
+
|
94 |
+
|
95 |
+
### 5) Run training
|
96 |
+
|
97 |
+
```bash
|
98 |
+
TOTAL_NUM_UPDATES=7812 # 10 epochs through IMDB for bsz 32
|
99 |
+
WARMUP_UPDATES=469 # 6 percent of the number of updates
|
100 |
+
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
101 |
+
HEAD_NAME=imdb_head # Custom name for the classification head.
|
102 |
+
NUM_CLASSES=2 # Number of classes for the classification task.
|
103 |
+
MAX_SENTENCES=8 # Batch size.
|
104 |
+
ROBERTA_PATH=/path/to/roberta.large/model.pt
|
105 |
+
|
106 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train IMDB-bin/ \
|
107 |
+
--restore-file $ROBERTA_PATH \
|
108 |
+
--max-positions 512 \
|
109 |
+
--batch-size $MAX_SENTENCES \
|
110 |
+
--max-tokens 4400 \
|
111 |
+
--task sentence_prediction \
|
112 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
113 |
+
--required-batch-size-multiple 1 \
|
114 |
+
--init-token 0 --separator-token 2 \
|
115 |
+
--arch roberta_large \
|
116 |
+
--criterion sentence_prediction \
|
117 |
+
--classification-head-name $HEAD_NAME \
|
118 |
+
--num-classes $NUM_CLASSES \
|
119 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
120 |
+
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
121 |
+
--clip-norm 0.0 \
|
122 |
+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
123 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
124 |
+
--max-epoch 10 \
|
125 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
126 |
+
--shorten-method "truncate" \
|
127 |
+
--find-unused-parameters \
|
128 |
+
--update-freq 4
|
129 |
+
```
|
130 |
+
|
131 |
+
The above command will finetune RoBERTa-large with an effective batch-size of 32
|
132 |
+
sentences (`--batch-size=8 --update-freq=4`). The expected
|
133 |
+
`best-validation-accuracy` after 10 epochs is ~96.5%.
|
134 |
+
|
135 |
+
If you run out of GPU memory, try decreasing `--batch-size` and increase
|
136 |
+
`--update-freq` to compensate.
|
137 |
+
|
138 |
+
|
139 |
+
### 6) Load model using hub interface
|
140 |
+
|
141 |
+
Now we can load the trained model checkpoint using the RoBERTa hub interface.
|
142 |
+
|
143 |
+
Assuming your checkpoints are stored in `checkpoints/`:
|
144 |
+
```python
|
145 |
+
from fairseq.models.roberta import RobertaModel
|
146 |
+
roberta = RobertaModel.from_pretrained(
|
147 |
+
'checkpoints',
|
148 |
+
checkpoint_file='checkpoint_best.pt',
|
149 |
+
data_name_or_path='IMDB-bin'
|
150 |
+
)
|
151 |
+
roberta.eval() # disable dropout
|
152 |
+
```
|
153 |
+
|
154 |
+
Finally you can make predictions using the `imdb_head` (or whatever you set
|
155 |
+
`--classification-head-name` to during training):
|
156 |
+
```python
|
157 |
+
label_fn = lambda label: roberta.task.label_dictionary.string(
|
158 |
+
[label + roberta.task.label_dictionary.nspecial]
|
159 |
+
)
|
160 |
+
|
161 |
+
tokens = roberta.encode('Best movie this year')
|
162 |
+
pred = label_fn(roberta.predict('imdb_head', tokens).argmax().item())
|
163 |
+
assert pred == '1' # positive
|
164 |
+
|
165 |
+
tokens = roberta.encode('Worst movie ever')
|
166 |
+
pred = label_fn(roberta.predict('imdb_head', tokens).argmax().item())
|
167 |
+
assert pred == '0' # negative
|
168 |
+
```
|
fairseq/examples/roberta/README.glue.md
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetuning RoBERTa on GLUE tasks
|
2 |
+
|
3 |
+
### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
|
4 |
+
```bash
|
5 |
+
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
|
6 |
+
python download_glue_data.py --data_dir glue_data --tasks all
|
7 |
+
```
|
8 |
+
|
9 |
+
### 2) Preprocess GLUE task data:
|
10 |
+
```bash
|
11 |
+
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
|
12 |
+
```
|
13 |
+
`glue_task_name` is one of the following:
|
14 |
+
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
|
15 |
+
Use `ALL` for preprocessing all the glue tasks.
|
16 |
+
|
17 |
+
### 3) Fine-tuning on GLUE task:
|
18 |
+
Example fine-tuning cmd for `RTE` task
|
19 |
+
```bash
|
20 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
21 |
+
|
22 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-hydra-train -config-dir examples/roberta/config/finetuning --config-name rte \
|
23 |
+
task.data=RTE-bin checkpoint.restore_file=$ROBERTA_PATH
|
24 |
+
```
|
25 |
+
|
26 |
+
There are additional config files for each of the GLUE tasks in the examples/roberta/config/finetuning directory.
|
27 |
+
|
28 |
+
**Note:**
|
29 |
+
|
30 |
+
a) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
|
31 |
+
|
32 |
+
b) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
|
33 |
+
|
34 |
+
### Inference on GLUE task
|
35 |
+
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
|
36 |
+
|
37 |
+
```python
|
38 |
+
from fairseq.models.roberta import RobertaModel
|
39 |
+
|
40 |
+
roberta = RobertaModel.from_pretrained(
|
41 |
+
'checkpoints/',
|
42 |
+
checkpoint_file='checkpoint_best.pt',
|
43 |
+
data_name_or_path='RTE-bin'
|
44 |
+
)
|
45 |
+
|
46 |
+
label_fn = lambda label: roberta.task.label_dictionary.string(
|
47 |
+
[label + roberta.task.label_dictionary.nspecial]
|
48 |
+
)
|
49 |
+
ncorrect, nsamples = 0, 0
|
50 |
+
roberta.cuda()
|
51 |
+
roberta.eval()
|
52 |
+
with open('glue_data/RTE/dev.tsv') as fin:
|
53 |
+
fin.readline()
|
54 |
+
for index, line in enumerate(fin):
|
55 |
+
tokens = line.strip().split('\t')
|
56 |
+
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
|
57 |
+
tokens = roberta.encode(sent1, sent2)
|
58 |
+
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
|
59 |
+
prediction_label = label_fn(prediction)
|
60 |
+
ncorrect += int(prediction_label == target)
|
61 |
+
nsamples += 1
|
62 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
63 |
+
|
64 |
+
```
|
fairseq/examples/roberta/README.md
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RoBERTa: A Robustly Optimized BERT Pretraining Approach
|
2 |
+
|
3 |
+
https://arxiv.org/abs/1907.11692
|
4 |
+
|
5 |
+
## Introduction
|
6 |
+
|
7 |
+
RoBERTa iterates on BERT's pretraining procedure, including training the model longer, with bigger batches over more data; removing the next sentence prediction objective; training on longer sequences; and dynamically changing the masking pattern applied to the training data. See the associated paper for more details.
|
8 |
+
|
9 |
+
### What's New:
|
10 |
+
|
11 |
+
- December 2020: German model (GottBERT) is available: [GottBERT](https://github.com/pytorch/fairseq/tree/main/examples/gottbert).
|
12 |
+
- January 2020: Italian model (UmBERTo) is available from Musixmatch Research: [UmBERTo](https://github.com/musixmatchresearch/umberto).
|
13 |
+
- November 2019: French model (CamemBERT) is available: [CamemBERT](https://github.com/pytorch/fairseq/tree/main/examples/camembert).
|
14 |
+
- November 2019: Multilingual encoder (XLM-RoBERTa) is available: [XLM-R](https://github.com/pytorch/fairseq/tree/main/examples/xlmr).
|
15 |
+
- September 2019: TensorFlow and TPU support via the [transformers library](https://github.com/huggingface/transformers).
|
16 |
+
- August 2019: RoBERTa is now supported in the [pytorch-transformers library](https://github.com/huggingface/pytorch-transformers).
|
17 |
+
- August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/main/examples/roberta/wsc#roberta-training-on-winogrande-dataset).
|
18 |
+
- August 2019: Added [tutorial for pretraining RoBERTa using your own data](README.pretraining.md).
|
19 |
+
|
20 |
+
## Pre-trained models
|
21 |
+
|
22 |
+
Model | Description | # params | Download
|
23 |
+
---|---|---|---
|
24 |
+
`roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)
|
25 |
+
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
|
26 |
+
`roberta.large.mnli` | `roberta.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
|
27 |
+
`roberta.large.wsc` | `roberta.large` finetuned on [WSC](wsc/README.md) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz)
|
28 |
+
|
29 |
+
## Results
|
30 |
+
|
31 |
+
**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
|
32 |
+
_(dev set, single model, single-task finetuning)_
|
33 |
+
|
34 |
+
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
|
35 |
+
---|---|---|---|---|---|---|---|---
|
36 |
+
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
|
37 |
+
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
|
38 |
+
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
|
39 |
+
|
40 |
+
**[SuperGLUE (Wang et al., 2019)](https://super.gluebenchmark.com/)**
|
41 |
+
_(dev set, single model, single-task finetuning)_
|
42 |
+
|
43 |
+
Model | BoolQ | CB | COPA | MultiRC | RTE | WiC | WSC
|
44 |
+
---|---|---|---|---|---|---|---
|
45 |
+
`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | -
|
46 |
+
`roberta.large.wsc` | - | - | - | - | - | - | 91.3
|
47 |
+
|
48 |
+
**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
|
49 |
+
_(dev set, no additional data used)_
|
50 |
+
|
51 |
+
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
|
52 |
+
---|---|---
|
53 |
+
`roberta.large` | 88.9/94.6 | 86.5/89.4
|
54 |
+
|
55 |
+
**[RACE (Lai et al., 2017)](http://www.qizhexie.com/data/RACE_leaderboard.html)**
|
56 |
+
_(test set)_
|
57 |
+
|
58 |
+
Model | Accuracy | Middle | High
|
59 |
+
---|---|---|---
|
60 |
+
`roberta.large` | 83.2 | 86.5 | 81.3
|
61 |
+
|
62 |
+
**[HellaSwag (Zellers et al., 2019)](https://rowanzellers.com/hellaswag/)**
|
63 |
+
_(test set)_
|
64 |
+
|
65 |
+
Model | Overall | In-domain | Zero-shot | ActivityNet | WikiHow
|
66 |
+
---|---|---|---|---|---
|
67 |
+
`roberta.large` | 85.2 | 87.3 | 83.1 | 74.6 | 90.9
|
68 |
+
|
69 |
+
**[Commonsense QA (Talmor et al., 2019)](https://www.tau-nlp.org/commonsenseqa)**
|
70 |
+
_(test set)_
|
71 |
+
|
72 |
+
Model | Accuracy
|
73 |
+
---|---
|
74 |
+
`roberta.large` (single model) | 72.1
|
75 |
+
`roberta.large` (ensemble) | 72.5
|
76 |
+
|
77 |
+
**[Winogrande (Sakaguchi et al., 2019)](https://arxiv.org/abs/1907.10641)**
|
78 |
+
_(test set)_
|
79 |
+
|
80 |
+
Model | Accuracy
|
81 |
+
---|---
|
82 |
+
`roberta.large` | 78.1
|
83 |
+
|
84 |
+
**[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)**
|
85 |
+
_(TRANSLATE-TEST)_
|
86 |
+
|
87 |
+
Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur
|
88 |
+
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---
|
89 |
+
`roberta.large.mnli` | 91.3 | 82.91 | 84.27 | 81.24 | 81.74 | 83.13 | 78.28 | 76.79 | 76.64 | 74.17 | 74.05 | 77.5 | 70.9 | 66.65 | 66.81
|
90 |
+
|
91 |
+
## Example usage
|
92 |
+
|
93 |
+
##### Load RoBERTa from torch.hub (PyTorch >= 1.1):
|
94 |
+
```python
|
95 |
+
import torch
|
96 |
+
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
|
97 |
+
roberta.eval() # disable dropout (or leave in train mode to finetune)
|
98 |
+
```
|
99 |
+
|
100 |
+
##### Load RoBERTa (for PyTorch 1.0 or custom models):
|
101 |
+
```python
|
102 |
+
# Download roberta.large model
|
103 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
|
104 |
+
tar -xzvf roberta.large.tar.gz
|
105 |
+
|
106 |
+
# Load the model in fairseq
|
107 |
+
from fairseq.models.roberta import RobertaModel
|
108 |
+
roberta = RobertaModel.from_pretrained('/path/to/roberta.large', checkpoint_file='model.pt')
|
109 |
+
roberta.eval() # disable dropout (or leave in train mode to finetune)
|
110 |
+
```
|
111 |
+
|
112 |
+
##### Apply Byte-Pair Encoding (BPE) to input text:
|
113 |
+
```python
|
114 |
+
tokens = roberta.encode('Hello world!')
|
115 |
+
assert tokens.tolist() == [0, 31414, 232, 328, 2]
|
116 |
+
roberta.decode(tokens) # 'Hello world!'
|
117 |
+
```
|
118 |
+
|
119 |
+
##### Extract features from RoBERTa:
|
120 |
+
```python
|
121 |
+
# Extract the last layer's features
|
122 |
+
last_layer_features = roberta.extract_features(tokens)
|
123 |
+
assert last_layer_features.size() == torch.Size([1, 5, 1024])
|
124 |
+
|
125 |
+
# Extract all layer's features (layer 0 is the embedding layer)
|
126 |
+
all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
|
127 |
+
assert len(all_layers) == 25
|
128 |
+
assert torch.all(all_layers[-1] == last_layer_features)
|
129 |
+
```
|
130 |
+
|
131 |
+
##### Use RoBERTa for sentence-pair classification tasks:
|
132 |
+
```python
|
133 |
+
# Download RoBERTa already finetuned for MNLI
|
134 |
+
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
|
135 |
+
roberta.eval() # disable dropout for evaluation
|
136 |
+
|
137 |
+
# Encode a pair of sentences and make a prediction
|
138 |
+
tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.')
|
139 |
+
roberta.predict('mnli', tokens).argmax() # 0: contradiction
|
140 |
+
|
141 |
+
# Encode another pair of sentences
|
142 |
+
tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.')
|
143 |
+
roberta.predict('mnli', tokens).argmax() # 2: entailment
|
144 |
+
```
|
145 |
+
|
146 |
+
##### Register a new (randomly initialized) classification head:
|
147 |
+
```python
|
148 |
+
roberta.register_classification_head('new_task', num_classes=3)
|
149 |
+
logprobs = roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
|
150 |
+
```
|
151 |
+
|
152 |
+
##### Batched prediction:
|
153 |
+
```python
|
154 |
+
import torch
|
155 |
+
from fairseq.data.data_utils import collate_tokens
|
156 |
+
|
157 |
+
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
|
158 |
+
roberta.eval()
|
159 |
+
|
160 |
+
batch_of_pairs = [
|
161 |
+
['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'],
|
162 |
+
['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'],
|
163 |
+
['potatoes are awesome.', 'I like to run.'],
|
164 |
+
['Mars is very far from earth.', 'Mars is very close.'],
|
165 |
+
]
|
166 |
+
|
167 |
+
batch = collate_tokens(
|
168 |
+
[roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
|
169 |
+
)
|
170 |
+
|
171 |
+
logprobs = roberta.predict('mnli', batch)
|
172 |
+
print(logprobs.argmax(dim=1))
|
173 |
+
# tensor([0, 2, 1, 0])
|
174 |
+
```
|
175 |
+
|
176 |
+
##### Using the GPU:
|
177 |
+
```python
|
178 |
+
roberta.cuda()
|
179 |
+
roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
|
180 |
+
```
|
181 |
+
|
182 |
+
## Advanced usage
|
183 |
+
|
184 |
+
#### Filling masks:
|
185 |
+
|
186 |
+
RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the
|
187 |
+
[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
|
188 |
+
```python
|
189 |
+
roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3)
|
190 |
+
# [('The first Star wars movie came out in 1977', 0.9504708051681519, ' 1977'), ('The first Star wars movie came out in 1978', 0.009986862540245056, ' 1978'), ('The first Star wars movie came out in 1979', 0.009574787691235542, ' 1979')]
|
191 |
+
|
192 |
+
roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3)
|
193 |
+
# [('Vikram samvat calender is official in India', 0.21878819167613983, ' India'), ('Vikram samvat calender is official in Delhi', 0.08547237515449524, ' Delhi'), ('Vikram samvat calender is official in Gujarat', 0.07556215673685074, ' Gujarat')]
|
194 |
+
|
195 |
+
roberta.fill_mask('<mask> is the common currency of the European Union', topk=3)
|
196 |
+
# [('Euro is the common currency of the European Union', 0.9456493854522705, 'Euro'), ('euro is the common currency of the European Union', 0.025748178362846375, 'euro'), ('€ is the common currency of the European Union', 0.011183084920048714, '€')]
|
197 |
+
```
|
198 |
+
|
199 |
+
#### Pronoun disambiguation (Winograd Schema Challenge):
|
200 |
+
|
201 |
+
RoBERTa can be used to disambiguate pronouns. First install spaCy and download the English-language model:
|
202 |
+
```bash
|
203 |
+
pip install spacy
|
204 |
+
python -m spacy download en_core_web_lg
|
205 |
+
```
|
206 |
+
|
207 |
+
Next load the `roberta.large.wsc` model and call the `disambiguate_pronoun`
|
208 |
+
function. The pronoun should be surrounded by square brackets (`[]`) and the
|
209 |
+
query referent surrounded by underscores (`_`), or left blank to return the
|
210 |
+
predicted candidate text directly:
|
211 |
+
```python
|
212 |
+
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.wsc', user_dir='examples/roberta/wsc')
|
213 |
+
roberta.cuda() # use the GPU (optional)
|
214 |
+
|
215 |
+
roberta.disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
|
216 |
+
# True
|
217 |
+
roberta.disambiguate_pronoun('The trophy would not fit in the brown _suitcase_ because [it] was too big.')
|
218 |
+
# False
|
219 |
+
|
220 |
+
roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] feared violence.')
|
221 |
+
# 'The city councilmen'
|
222 |
+
roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] advocated violence.')
|
223 |
+
# 'demonstrators'
|
224 |
+
```
|
225 |
+
|
226 |
+
See the [RoBERTA Winograd Schema Challenge (WSC) README](wsc/README.md) for more details on how to train this model.
|
227 |
+
|
228 |
+
#### Extract features aligned to words:
|
229 |
+
|
230 |
+
By default RoBERTa outputs one feature vector per BPE token. You can instead
|
231 |
+
realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
|
232 |
+
with the `extract_features_aligned_to_words` method. This will compute a
|
233 |
+
weighted average of the BPE-level features for each word and expose them in
|
234 |
+
spaCy's `Token.vector` attribute:
|
235 |
+
```python
|
236 |
+
doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
|
237 |
+
assert len(doc) == 10
|
238 |
+
for tok in doc:
|
239 |
+
print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
|
240 |
+
# <s> tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=<SliceBackward>) (...)
|
241 |
+
# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=<SliceBackward>) (...)
|
242 |
+
# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=<SliceBackward>) (...)
|
243 |
+
# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=<SliceBackward>) (...)
|
244 |
+
# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=<SliceBackward>) (...)
|
245 |
+
# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=<SliceBackward>) (...)
|
246 |
+
# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=<SliceBackward>) (...)
|
247 |
+
# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
|
248 |
+
# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
|
249 |
+
# </s> tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=<SliceBackward>) (...)
|
250 |
+
```
|
251 |
+
|
252 |
+
#### Evaluating the `roberta.large.mnli` model:
|
253 |
+
|
254 |
+
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
|
255 |
+
```python
|
256 |
+
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
|
257 |
+
ncorrect, nsamples = 0, 0
|
258 |
+
roberta.cuda()
|
259 |
+
roberta.eval()
|
260 |
+
with open('glue_data/MNLI/dev_matched.tsv') as fin:
|
261 |
+
fin.readline()
|
262 |
+
for index, line in enumerate(fin):
|
263 |
+
tokens = line.strip().split('\t')
|
264 |
+
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
|
265 |
+
tokens = roberta.encode(sent1, sent2)
|
266 |
+
prediction = roberta.predict('mnli', tokens).argmax().item()
|
267 |
+
prediction_label = label_map[prediction]
|
268 |
+
ncorrect += int(prediction_label == target)
|
269 |
+
nsamples += 1
|
270 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
271 |
+
# Expected output: 0.9060
|
272 |
+
```
|
273 |
+
|
274 |
+
## Finetuning
|
275 |
+
|
276 |
+
- [Finetuning on GLUE](README.glue.md)
|
277 |
+
- [Finetuning on custom classification tasks (e.g., IMDB)](README.custom_classification.md)
|
278 |
+
- [Finetuning on Winograd Schema Challenge (WSC)](wsc/README.md)
|
279 |
+
- [Finetuning on Commonsense QA (CQA)](commonsense_qa/README.md)
|
280 |
+
|
281 |
+
## Pretraining using your own data
|
282 |
+
|
283 |
+
See the [tutorial for pretraining RoBERTa using your own data](README.pretraining.md).
|
284 |
+
|
285 |
+
## Citation
|
286 |
+
|
287 |
+
```bibtex
|
288 |
+
@article{liu2019roberta,
|
289 |
+
title = {RoBERTa: A Robustly Optimized BERT Pretraining Approach},
|
290 |
+
author = {Yinhan Liu and Myle Ott and Naman Goyal and Jingfei Du and
|
291 |
+
Mandar Joshi and Danqi Chen and Omer Levy and Mike Lewis and
|
292 |
+
Luke Zettlemoyer and Veselin Stoyanov},
|
293 |
+
journal={arXiv preprint arXiv:1907.11692},
|
294 |
+
year = {2019},
|
295 |
+
}
|
296 |
+
```
|
fairseq/examples/roberta/README.pretraining.md
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pretraining RoBERTa using your own data
|
2 |
+
|
3 |
+
This tutorial will walk you through pretraining RoBERTa over your own data.
|
4 |
+
|
5 |
+
### 1) Preprocess the data
|
6 |
+
|
7 |
+
Data should be preprocessed following the [language modeling format](/examples/language_model), i.e. each document should be separated by an empty line (only useful with `--sample-break-mode complete_doc`). Lines will be concatenated as a 1D text stream during training.
|
8 |
+
|
9 |
+
We'll use the [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/)
|
10 |
+
to demonstrate how to preprocess raw text data with the GPT-2 BPE. Of course
|
11 |
+
this dataset is quite small, so the resulting pretrained model will perform
|
12 |
+
poorly, but it gives the general idea.
|
13 |
+
|
14 |
+
First download the dataset:
|
15 |
+
```bash
|
16 |
+
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip
|
17 |
+
unzip wikitext-103-raw-v1.zip
|
18 |
+
```
|
19 |
+
|
20 |
+
Next encode it with the GPT-2 BPE:
|
21 |
+
```bash
|
22 |
+
mkdir -p gpt2_bpe
|
23 |
+
wget -O gpt2_bpe/encoder.json https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
|
24 |
+
wget -O gpt2_bpe/vocab.bpe https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
|
25 |
+
for SPLIT in train valid test; do \
|
26 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
27 |
+
--encoder-json gpt2_bpe/encoder.json \
|
28 |
+
--vocab-bpe gpt2_bpe/vocab.bpe \
|
29 |
+
--inputs wikitext-103-raw/wiki.${SPLIT}.raw \
|
30 |
+
--outputs wikitext-103-raw/wiki.${SPLIT}.bpe \
|
31 |
+
--keep-empty \
|
32 |
+
--workers 60; \
|
33 |
+
done
|
34 |
+
```
|
35 |
+
|
36 |
+
Finally preprocess/binarize the data using the GPT-2 fairseq dictionary:
|
37 |
+
```bash
|
38 |
+
wget -O gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
|
39 |
+
fairseq-preprocess \
|
40 |
+
--only-source \
|
41 |
+
--srcdict gpt2_bpe/dict.txt \
|
42 |
+
--trainpref wikitext-103-raw/wiki.train.bpe \
|
43 |
+
--validpref wikitext-103-raw/wiki.valid.bpe \
|
44 |
+
--testpref wikitext-103-raw/wiki.test.bpe \
|
45 |
+
--destdir data-bin/wikitext-103 \
|
46 |
+
--workers 60
|
47 |
+
```
|
48 |
+
|
49 |
+
### 2) Train RoBERTa base
|
50 |
+
```bash
|
51 |
+
DATA_DIR=data-bin/wikitext-103
|
52 |
+
|
53 |
+
fairseq-hydra-train -m --config-dir examples/roberta/config/pretraining \
|
54 |
+
--config-name base task.data=$DATA_DIR
|
55 |
+
```
|
56 |
+
|
57 |
+
**Note:** You can optionally resume training the released RoBERTa base model by
|
58 |
+
adding `checkpoint.restore_file=/path/to/roberta.base/model.pt`.
|
59 |
+
|
60 |
+
**Note:** The above command assumes training on 8x32GB V100 GPUs. Each GPU uses
|
61 |
+
a batch size of 16 sequences (`dataset.batch_size`) and accumulates gradients to
|
62 |
+
further increase the batch size by 16x (`optimization.update_freq`), for a total batch size
|
63 |
+
of 2048 sequences. If you have fewer GPUs or GPUs with less memory you may need
|
64 |
+
to reduce `dataset.batch_size` and increase dataset.update_freq to compensate.
|
65 |
+
Alternatively if you have more GPUs you can decrease `dataset.update_freq` accordingly
|
66 |
+
to increase training speed.
|
67 |
+
|
68 |
+
**Note:** The learning rate and batch size are tightly connected and need to be
|
69 |
+
adjusted together. We generally recommend increasing the learning rate as you
|
70 |
+
increase the batch size according to the following table (although it's also
|
71 |
+
dataset dependent, so don't rely on the following values too closely):
|
72 |
+
|
73 |
+
batch size | peak learning rate
|
74 |
+
---|---
|
75 |
+
256 | 0.0001
|
76 |
+
2048 | 0.0005
|
77 |
+
8192 | 0.0007
|
78 |
+
|
79 |
+
### 3) Load your pretrained model
|
80 |
+
```python
|
81 |
+
from fairseq.models.roberta import RobertaModel
|
82 |
+
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'path/to/data')
|
83 |
+
assert isinstance(roberta.model, torch.nn.Module)
|
84 |
+
```
|
fairseq/examples/roberta/README.race.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetuning RoBERTa on RACE tasks
|
2 |
+
|
3 |
+
### 1) Download the data from RACE website (http://www.cs.cmu.edu/~glai1/data/race/)
|
4 |
+
|
5 |
+
### 2) Preprocess RACE data:
|
6 |
+
```bash
|
7 |
+
python ./examples/roberta/preprocess_RACE.py --input-dir <input-dir> --output-dir <extracted-data-dir>
|
8 |
+
./examples/roberta/preprocess_RACE.sh <extracted-data-dir> <output-dir>
|
9 |
+
```
|
10 |
+
|
11 |
+
### 3) Fine-tuning on RACE:
|
12 |
+
|
13 |
+
```bash
|
14 |
+
MAX_EPOCH=5 # Number of training epochs.
|
15 |
+
LR=1e-05 # Peak LR for fixed LR scheduler.
|
16 |
+
NUM_CLASSES=4
|
17 |
+
MAX_SENTENCES=1 # Batch size per GPU.
|
18 |
+
UPDATE_FREQ=8 # Accumulate gradients to simulate training on 8 GPUs.
|
19 |
+
DATA_DIR=/path/to/race-output-dir
|
20 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
21 |
+
|
22 |
+
CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR --ddp-backend=legacy_ddp \
|
23 |
+
--restore-file $ROBERTA_PATH \
|
24 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
25 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
26 |
+
--task sentence_ranking \
|
27 |
+
--num-classes $NUM_CLASSES \
|
28 |
+
--init-token 0 --separator-token 2 \
|
29 |
+
--max-option-length 128 \
|
30 |
+
--max-positions 512 \
|
31 |
+
--shorten-method "truncate" \
|
32 |
+
--arch roberta_large \
|
33 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
34 |
+
--criterion sentence_ranking \
|
35 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
|
36 |
+
--clip-norm 0.0 \
|
37 |
+
--lr-scheduler fixed --lr $LR \
|
38 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
39 |
+
--batch-size $MAX_SENTENCES \
|
40 |
+
--required-batch-size-multiple 1 \
|
41 |
+
--update-freq $UPDATE_FREQ \
|
42 |
+
--max-epoch $MAX_EPOCH
|
43 |
+
```
|
44 |
+
|
45 |
+
**Note:**
|
46 |
+
|
47 |
+
a) As contexts in RACE are relatively long, we are using smaller batch size per GPU while increasing update-freq to achieve larger effective batch size.
|
48 |
+
|
49 |
+
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
|
50 |
+
|
51 |
+
c) The setting in above command is based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
|
52 |
+
|
53 |
+
### 4) Evaluation:
|
54 |
+
|
55 |
+
```
|
56 |
+
DATA_DIR=/path/to/race-output-dir # data directory used during training
|
57 |
+
MODEL_PATH=/path/to/checkpoint_best.pt # path to the finetuned model checkpoint
|
58 |
+
PREDS_OUT=preds.tsv # output file path to save prediction
|
59 |
+
TEST_SPLIT=test # can be test (Middle) or test1 (High)
|
60 |
+
fairseq-validate \
|
61 |
+
$DATA_DIR \
|
62 |
+
--valid-subset $TEST_SPLIT \
|
63 |
+
--path $MODEL_PATH \
|
64 |
+
--batch-size 1 \
|
65 |
+
--task sentence_ranking \
|
66 |
+
--criterion sentence_ranking \
|
67 |
+
--save-predictions $PREDS_OUT
|
68 |
+
```
|
fairseq/examples/roberta/commonsense_qa/README.md
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetuning RoBERTa on Commonsense QA
|
2 |
+
|
3 |
+
We follow a similar approach to [finetuning RACE](../README.race.md). Specifically
|
4 |
+
for each question we construct five inputs, one for each of the five candidate
|
5 |
+
answer choices. Each input is constructed by concatenating the question and
|
6 |
+
candidate answer. We then encode each input and pass the resulting "[CLS]"
|
7 |
+
representations through a fully-connected layer to predict the correct answer.
|
8 |
+
We train with a standard cross-entropy loss.
|
9 |
+
|
10 |
+
We also found it helpful to prepend a prefix of `Q:` to the question and `A:` to
|
11 |
+
the answer. The complete input format is:
|
12 |
+
```
|
13 |
+
<s> Q: Where would I not want a fox? </s> A: hen house </s>
|
14 |
+
```
|
15 |
+
|
16 |
+
Our final submission is based on a hyperparameter search over the learning rate
|
17 |
+
(1e-5, 2e-5, 3e-5), batch size (8, 16), number of training steps (2000, 3000,
|
18 |
+
4000) and random seed. We selected the model with the best performance on the
|
19 |
+
development set after 100 trials.
|
20 |
+
|
21 |
+
### 1) Download data from the Commonsense QA website (https://www.tau-nlp.org/commonsenseqa)
|
22 |
+
```bash
|
23 |
+
bash examples/roberta/commonsense_qa/download_cqa_data.sh
|
24 |
+
```
|
25 |
+
|
26 |
+
### 2) Finetune
|
27 |
+
|
28 |
+
```bash
|
29 |
+
MAX_UPDATES=3000 # Number of training steps.
|
30 |
+
WARMUP_UPDATES=150 # Linearly increase LR over this many steps.
|
31 |
+
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
32 |
+
MAX_SENTENCES=16 # Batch size.
|
33 |
+
SEED=1 # Random seed.
|
34 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
35 |
+
DATA_DIR=data/CommonsenseQA
|
36 |
+
|
37 |
+
# we use the --user-dir option to load the task from
|
38 |
+
# the examples/roberta/commonsense_qa directory:
|
39 |
+
FAIRSEQ_PATH=/path/to/fairseq
|
40 |
+
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/commonsense_qa
|
41 |
+
|
42 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train --fp16 --ddp-backend=legacy_ddp \
|
43 |
+
$DATA_DIR \
|
44 |
+
--user-dir $FAIRSEQ_USER_DIR \
|
45 |
+
--restore-file $ROBERTA_PATH \
|
46 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
47 |
+
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
|
48 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
49 |
+
--task commonsense_qa --init-token 0 --bpe gpt2 \
|
50 |
+
--arch roberta_large --max-positions 512 \
|
51 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
52 |
+
--criterion sentence_ranking --num-classes 5 \
|
53 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 --clip-norm 0.0 \
|
54 |
+
--lr-scheduler polynomial_decay --lr $LR \
|
55 |
+
--warmup-updates $WARMUP_UPDATES --total-num-update $MAX_UPDATES \
|
56 |
+
--batch-size $MAX_SENTENCES \
|
57 |
+
--max-update $MAX_UPDATES \
|
58 |
+
--log-format simple --log-interval 25 \
|
59 |
+
--seed $SEED
|
60 |
+
```
|
61 |
+
|
62 |
+
The above command assumes training on 1 GPU with 32GB of RAM. For GPUs with
|
63 |
+
less memory, decrease `--batch-size` and increase `--update-freq`
|
64 |
+
accordingly to compensate.
|
65 |
+
|
66 |
+
### 3) Evaluate
|
67 |
+
```python
|
68 |
+
import json
|
69 |
+
import torch
|
70 |
+
from fairseq.models.roberta import RobertaModel
|
71 |
+
from examples.roberta import commonsense_qa # load the Commonsense QA task
|
72 |
+
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'data/CommonsenseQA')
|
73 |
+
roberta.eval() # disable dropout
|
74 |
+
roberta.cuda() # use the GPU (optional)
|
75 |
+
nsamples, ncorrect = 0, 0
|
76 |
+
with open('data/CommonsenseQA/valid.jsonl') as h:
|
77 |
+
for line in h:
|
78 |
+
example = json.loads(line)
|
79 |
+
scores = []
|
80 |
+
for choice in example['question']['choices']:
|
81 |
+
input = roberta.encode(
|
82 |
+
'Q: ' + example['question']['stem'],
|
83 |
+
'A: ' + choice['text'],
|
84 |
+
no_separator=True
|
85 |
+
)
|
86 |
+
score = roberta.predict('sentence_classification_head', input, return_logits=True)
|
87 |
+
scores.append(score)
|
88 |
+
pred = torch.cat(scores).argmax()
|
89 |
+
answer = ord(example['answerKey']) - ord('A')
|
90 |
+
nsamples += 1
|
91 |
+
if pred == answer:
|
92 |
+
ncorrect += 1
|
93 |
+
|
94 |
+
print('Accuracy: ' + str(ncorrect / float(nsamples)))
|
95 |
+
# Accuracy: 0.7846027846027847
|
96 |
+
```
|
97 |
+
|
98 |
+
The above snippet is not batched, which makes it quite slow. See [instructions
|
99 |
+
for batched prediction with RoBERTa](https://github.com/pytorch/fairseq/tree/main/examples/roberta#batched-prediction).
|
fairseq/examples/roberta/commonsense_qa/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import commonsense_qa_task # noqa
|
fairseq/examples/roberta/commonsense_qa/commonsense_qa_task.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from fairseq.data import (
|
12 |
+
Dictionary,
|
13 |
+
IdDataset,
|
14 |
+
ListDataset,
|
15 |
+
NestedDictionaryDataset,
|
16 |
+
NumelDataset,
|
17 |
+
NumSamplesDataset,
|
18 |
+
RawLabelDataset,
|
19 |
+
RightPadDataset,
|
20 |
+
SortDataset,
|
21 |
+
data_utils,
|
22 |
+
encoders,
|
23 |
+
)
|
24 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
25 |
+
|
26 |
+
|
27 |
+
@register_task("commonsense_qa")
|
28 |
+
class CommonsenseQATask(LegacyFairseqTask):
|
29 |
+
"""Task to finetune RoBERTa for Commonsense QA."""
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def add_args(parser):
|
33 |
+
"""Add task-specific arguments to the parser."""
|
34 |
+
parser.add_argument(
|
35 |
+
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--init-token",
|
39 |
+
type=int,
|
40 |
+
default=None,
|
41 |
+
help="add token at the beginning of each batch item",
|
42 |
+
)
|
43 |
+
parser.add_argument("--num-classes", type=int, default=5)
|
44 |
+
|
45 |
+
def __init__(self, args, vocab):
|
46 |
+
super().__init__(args)
|
47 |
+
self.vocab = vocab
|
48 |
+
self.mask = vocab.add_symbol("<mask>")
|
49 |
+
|
50 |
+
self.bpe = encoders.build_bpe(args)
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def load_dictionary(cls, filename):
|
54 |
+
"""Load the dictionary from the filename
|
55 |
+
|
56 |
+
Args:
|
57 |
+
filename (str): the filename
|
58 |
+
"""
|
59 |
+
dictionary = Dictionary.load(filename)
|
60 |
+
dictionary.add_symbol("<mask>")
|
61 |
+
return dictionary
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def setup_task(cls, args, **kwargs):
|
65 |
+
assert (
|
66 |
+
args.criterion == "sentence_ranking"
|
67 |
+
), "Must set --criterion=sentence_ranking"
|
68 |
+
|
69 |
+
# load data and label dictionaries
|
70 |
+
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
71 |
+
print("| dictionary: {} types".format(len(vocab)))
|
72 |
+
|
73 |
+
return cls(args, vocab)
|
74 |
+
|
75 |
+
def load_dataset(
|
76 |
+
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
77 |
+
):
|
78 |
+
"""Load a given dataset split.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
split (str): name of the split (e.g., train, valid, test)
|
82 |
+
"""
|
83 |
+
|
84 |
+
def binarize(s, append_bos=False):
|
85 |
+
if self.bpe is not None:
|
86 |
+
s = self.bpe.encode(s)
|
87 |
+
tokens = self.vocab.encode_line(
|
88 |
+
s,
|
89 |
+
append_eos=True,
|
90 |
+
add_if_not_exist=False,
|
91 |
+
).long()
|
92 |
+
if append_bos and self.args.init_token is not None:
|
93 |
+
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
94 |
+
return tokens
|
95 |
+
|
96 |
+
if data_path is None:
|
97 |
+
data_path = os.path.join(self.args.data, split + ".jsonl")
|
98 |
+
if not os.path.exists(data_path):
|
99 |
+
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
100 |
+
|
101 |
+
src_tokens = [[] for i in range(self.args.num_classes)]
|
102 |
+
src_lengths = [[] for i in range(self.args.num_classes)]
|
103 |
+
labels = []
|
104 |
+
|
105 |
+
with open(data_path) as h:
|
106 |
+
for line in h:
|
107 |
+
example = json.loads(line.strip())
|
108 |
+
if "answerKey" in example:
|
109 |
+
label = ord(example["answerKey"]) - ord("A")
|
110 |
+
labels.append(label)
|
111 |
+
question = example["question"]["stem"]
|
112 |
+
assert len(example["question"]["choices"]) == self.args.num_classes
|
113 |
+
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
|
114 |
+
question = "Q: " + question
|
115 |
+
question_toks = binarize(question, append_bos=True)
|
116 |
+
for i, choice in enumerate(example["question"]["choices"]):
|
117 |
+
src = "A: " + choice["text"]
|
118 |
+
src_bin = torch.cat([question_toks, binarize(src)])
|
119 |
+
src_tokens[i].append(src_bin)
|
120 |
+
src_lengths[i].append(len(src_bin))
|
121 |
+
assert all(
|
122 |
+
len(src_tokens[0]) == len(src_tokens[i])
|
123 |
+
for i in range(self.args.num_classes)
|
124 |
+
)
|
125 |
+
assert len(src_tokens[0]) == len(src_lengths[0])
|
126 |
+
assert len(labels) == 0 or len(labels) == len(src_tokens[0])
|
127 |
+
|
128 |
+
for i in range(self.args.num_classes):
|
129 |
+
src_lengths[i] = np.array(src_lengths[i])
|
130 |
+
src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i])
|
131 |
+
src_lengths[i] = ListDataset(src_lengths[i])
|
132 |
+
|
133 |
+
dataset = {
|
134 |
+
"id": IdDataset(),
|
135 |
+
"nsentences": NumSamplesDataset(),
|
136 |
+
"ntokens": NumelDataset(src_tokens[0], reduce=True),
|
137 |
+
}
|
138 |
+
|
139 |
+
for i in range(self.args.num_classes):
|
140 |
+
dataset.update(
|
141 |
+
{
|
142 |
+
"net_input{}".format(i + 1): {
|
143 |
+
"src_tokens": RightPadDataset(
|
144 |
+
src_tokens[i],
|
145 |
+
pad_idx=self.source_dictionary.pad(),
|
146 |
+
),
|
147 |
+
"src_lengths": src_lengths[i],
|
148 |
+
}
|
149 |
+
}
|
150 |
+
)
|
151 |
+
|
152 |
+
if len(labels) > 0:
|
153 |
+
dataset.update({"target": RawLabelDataset(labels)})
|
154 |
+
|
155 |
+
dataset = NestedDictionaryDataset(
|
156 |
+
dataset,
|
157 |
+
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
|
158 |
+
)
|
159 |
+
|
160 |
+
with data_utils.numpy_seed(self.args.seed):
|
161 |
+
dataset = SortDataset(
|
162 |
+
dataset,
|
163 |
+
# shuffle
|
164 |
+
sort_order=[np.random.permutation(len(dataset))],
|
165 |
+
)
|
166 |
+
|
167 |
+
print("| Loaded {} with {} samples".format(split, len(dataset)))
|
168 |
+
|
169 |
+
self.datasets[split] = dataset
|
170 |
+
return self.datasets[split]
|
171 |
+
|
172 |
+
def build_model(self, args, from_checkpoint=False):
|
173 |
+
from fairseq import models
|
174 |
+
|
175 |
+
model = models.build_model(args, self)
|
176 |
+
|
177 |
+
model.register_classification_head(
|
178 |
+
"sentence_classification_head",
|
179 |
+
num_classes=1,
|
180 |
+
)
|
181 |
+
|
182 |
+
return model
|
183 |
+
|
184 |
+
@property
|
185 |
+
def source_dictionary(self):
|
186 |
+
return self.vocab
|
187 |
+
|
188 |
+
@property
|
189 |
+
def target_dictionary(self):
|
190 |
+
return self.vocab
|
fairseq/examples/roberta/commonsense_qa/download_cqa_data.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
OUTDIR=data/CommonsenseQA
|
8 |
+
|
9 |
+
mkdir -p $OUTDIR
|
10 |
+
|
11 |
+
wget -O $OUTDIR/train.jsonl https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl
|
12 |
+
wget -O $OUTDIR/valid.jsonl https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl
|
13 |
+
wget -O $OUTDIR/test.jsonl https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl
|
14 |
+
wget -O $OUTDIR/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
|
fairseq/examples/roberta/config/finetuning/cola.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_init_scale: 4
|
6 |
+
threshold_loss_scale: 1
|
7 |
+
fp16_scale_window: 128
|
8 |
+
log_format: json
|
9 |
+
log_interval: 200
|
10 |
+
|
11 |
+
task:
|
12 |
+
_name: sentence_prediction
|
13 |
+
data: ???
|
14 |
+
init_token: 0
|
15 |
+
separator_token: 2
|
16 |
+
num_classes: 2
|
17 |
+
max_positions: 512
|
18 |
+
|
19 |
+
checkpoint:
|
20 |
+
restore_file: ???
|
21 |
+
reset_optimizer: true
|
22 |
+
reset_dataloader: true
|
23 |
+
reset_meters: true
|
24 |
+
best_checkpoint_metric: accuracy
|
25 |
+
maximize_best_checkpoint_metric: true
|
26 |
+
no_epoch_checkpoints: true
|
27 |
+
|
28 |
+
distributed_training:
|
29 |
+
find_unused_parameters: true
|
30 |
+
distributed_world_size: 1
|
31 |
+
|
32 |
+
criterion:
|
33 |
+
_name: sentence_prediction
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
batch_size: 16
|
37 |
+
required_batch_size_multiple: 1
|
38 |
+
max_tokens: 4400
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
_name: adam
|
42 |
+
weight_decay: 0.1
|
43 |
+
adam_betas: (0.9,0.98)
|
44 |
+
adam_eps: 1e-06
|
45 |
+
|
46 |
+
lr_scheduler:
|
47 |
+
_name: polynomial_decay
|
48 |
+
warmup_updates: 320
|
49 |
+
|
50 |
+
optimization:
|
51 |
+
clip_norm: 0.0
|
52 |
+
lr: [1e-05]
|
53 |
+
max_update: 5336
|
54 |
+
max_epoch: 10
|
55 |
+
|
56 |
+
model:
|
57 |
+
_name: roberta
|
58 |
+
dropout: 0.1
|
59 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/finetuning/mnli.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_init_scale: 4
|
6 |
+
threshold_loss_scale: 1
|
7 |
+
fp16_scale_window: 128
|
8 |
+
log_format: json
|
9 |
+
log_interval: 200
|
10 |
+
|
11 |
+
task:
|
12 |
+
_name: sentence_prediction
|
13 |
+
data: ???
|
14 |
+
init_token: 0
|
15 |
+
separator_token: 2
|
16 |
+
num_classes: 3
|
17 |
+
max_positions: 512
|
18 |
+
|
19 |
+
checkpoint:
|
20 |
+
restore_file: ???
|
21 |
+
reset_optimizer: true
|
22 |
+
reset_dataloader: true
|
23 |
+
reset_meters: true
|
24 |
+
best_checkpoint_metric: accuracy
|
25 |
+
maximize_best_checkpoint_metric: true
|
26 |
+
no_epoch_checkpoints: true
|
27 |
+
|
28 |
+
distributed_training:
|
29 |
+
find_unused_parameters: true
|
30 |
+
distributed_world_size: 1
|
31 |
+
|
32 |
+
criterion:
|
33 |
+
_name: sentence_prediction
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
batch_size: 32
|
37 |
+
required_batch_size_multiple: 1
|
38 |
+
max_tokens: 4400
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
_name: adam
|
42 |
+
weight_decay: 0.1
|
43 |
+
adam_betas: (0.9,0.98)
|
44 |
+
adam_eps: 1e-06
|
45 |
+
|
46 |
+
lr_scheduler:
|
47 |
+
_name: polynomial_decay
|
48 |
+
warmup_updates: 7432
|
49 |
+
|
50 |
+
optimization:
|
51 |
+
clip_norm: 0.0
|
52 |
+
lr: [1e-05]
|
53 |
+
max_update: 123873
|
54 |
+
max_epoch: 10
|
55 |
+
|
56 |
+
model:
|
57 |
+
_name: roberta
|
58 |
+
dropout: 0.1
|
59 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/finetuning/mrpc.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_init_scale: 4
|
6 |
+
threshold_loss_scale: 1
|
7 |
+
fp16_scale_window: 128
|
8 |
+
log_format: json
|
9 |
+
log_interval: 200
|
10 |
+
|
11 |
+
task:
|
12 |
+
_name: sentence_prediction
|
13 |
+
data: ???
|
14 |
+
init_token: 0
|
15 |
+
separator_token: 2
|
16 |
+
num_classes: 2
|
17 |
+
max_positions: 512
|
18 |
+
|
19 |
+
checkpoint:
|
20 |
+
restore_file: ???
|
21 |
+
reset_optimizer: true
|
22 |
+
reset_dataloader: true
|
23 |
+
reset_meters: true
|
24 |
+
best_checkpoint_metric: accuracy
|
25 |
+
maximize_best_checkpoint_metric: true
|
26 |
+
no_epoch_checkpoints: true
|
27 |
+
|
28 |
+
distributed_training:
|
29 |
+
find_unused_parameters: true
|
30 |
+
distributed_world_size: 1
|
31 |
+
|
32 |
+
criterion:
|
33 |
+
_name: sentence_prediction
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
batch_size: 16
|
37 |
+
required_batch_size_multiple: 1
|
38 |
+
max_tokens: 4400
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
_name: adam
|
42 |
+
weight_decay: 0.1
|
43 |
+
adam_betas: (0.9,0.98)
|
44 |
+
adam_eps: 1e-06
|
45 |
+
|
46 |
+
lr_scheduler:
|
47 |
+
_name: polynomial_decay
|
48 |
+
warmup_updates: 137
|
49 |
+
|
50 |
+
optimization:
|
51 |
+
clip_norm: 0.0
|
52 |
+
lr: [1e-05]
|
53 |
+
max_update: 2296
|
54 |
+
max_epoch: 10
|
55 |
+
|
56 |
+
model:
|
57 |
+
_name: roberta
|
58 |
+
dropout: 0.1
|
59 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/finetuning/qnli.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_init_scale: 4
|
6 |
+
threshold_loss_scale: 1
|
7 |
+
fp16_scale_window: 128
|
8 |
+
log_format: json
|
9 |
+
log_interval: 200
|
10 |
+
|
11 |
+
task:
|
12 |
+
_name: sentence_prediction
|
13 |
+
data: ???
|
14 |
+
init_token: 0
|
15 |
+
separator_token: 2
|
16 |
+
num_classes: 2
|
17 |
+
max_positions: 512
|
18 |
+
|
19 |
+
checkpoint:
|
20 |
+
restore_file: ???
|
21 |
+
reset_optimizer: true
|
22 |
+
reset_dataloader: true
|
23 |
+
reset_meters: true
|
24 |
+
best_checkpoint_metric: accuracy
|
25 |
+
maximize_best_checkpoint_metric: true
|
26 |
+
no_epoch_checkpoints: true
|
27 |
+
|
28 |
+
distributed_training:
|
29 |
+
find_unused_parameters: true
|
30 |
+
distributed_world_size: 1
|
31 |
+
|
32 |
+
criterion:
|
33 |
+
_name: sentence_prediction
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
batch_size: 32
|
37 |
+
required_batch_size_multiple: 1
|
38 |
+
max_tokens: 4400
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
_name: adam
|
42 |
+
weight_decay: 0.1
|
43 |
+
adam_betas: (0.9,0.98)
|
44 |
+
adam_eps: 1e-06
|
45 |
+
|
46 |
+
lr_scheduler:
|
47 |
+
_name: polynomial_decay
|
48 |
+
warmup_updates: 1986
|
49 |
+
|
50 |
+
optimization:
|
51 |
+
clip_norm: 0.0
|
52 |
+
lr: [1e-05]
|
53 |
+
max_update: 33112
|
54 |
+
max_epoch: 10
|
55 |
+
|
56 |
+
model:
|
57 |
+
_name: roberta
|
58 |
+
dropout: 0.1
|
59 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/finetuning/qqp.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_init_scale: 4
|
6 |
+
threshold_loss_scale: 1
|
7 |
+
fp16_scale_window: 128
|
8 |
+
log_format: json
|
9 |
+
log_interval: 200
|
10 |
+
|
11 |
+
task:
|
12 |
+
_name: sentence_prediction
|
13 |
+
data: ???
|
14 |
+
init_token: 0
|
15 |
+
separator_token: 2
|
16 |
+
num_classes: 2
|
17 |
+
max_positions: 512
|
18 |
+
|
19 |
+
checkpoint:
|
20 |
+
restore_file: ???
|
21 |
+
reset_optimizer: true
|
22 |
+
reset_dataloader: true
|
23 |
+
reset_meters: true
|
24 |
+
best_checkpoint_metric: accuracy
|
25 |
+
maximize_best_checkpoint_metric: true
|
26 |
+
no_epoch_checkpoints: true
|
27 |
+
|
28 |
+
distributed_training:
|
29 |
+
find_unused_parameters: true
|
30 |
+
distributed_world_size: 1
|
31 |
+
|
32 |
+
criterion:
|
33 |
+
_name: sentence_prediction
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
batch_size: 32
|
37 |
+
required_batch_size_multiple: 1
|
38 |
+
max_tokens: 4400
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
_name: adam
|
42 |
+
weight_decay: 0.1
|
43 |
+
adam_betas: (0.9,0.98)
|
44 |
+
adam_eps: 1e-06
|
45 |
+
|
46 |
+
lr_scheduler:
|
47 |
+
_name: polynomial_decay
|
48 |
+
warmup_updates: 28318
|
49 |
+
|
50 |
+
optimization:
|
51 |
+
clip_norm: 0.0
|
52 |
+
lr: [1e-05]
|
53 |
+
max_update: 113272
|
54 |
+
max_epoch: 10
|
55 |
+
|
56 |
+
model:
|
57 |
+
_name: roberta
|
58 |
+
dropout: 0.1
|
59 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/finetuning/rte.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_init_scale: 4
|
6 |
+
threshold_loss_scale: 1
|
7 |
+
fp16_scale_window: 128
|
8 |
+
log_format: json
|
9 |
+
log_interval: 200
|
10 |
+
|
11 |
+
task:
|
12 |
+
_name: sentence_prediction
|
13 |
+
data: ???
|
14 |
+
init_token: 0
|
15 |
+
separator_token: 2
|
16 |
+
num_classes: 2
|
17 |
+
max_positions: 512
|
18 |
+
|
19 |
+
checkpoint:
|
20 |
+
restore_file: ???
|
21 |
+
reset_optimizer: true
|
22 |
+
reset_dataloader: true
|
23 |
+
reset_meters: true
|
24 |
+
best_checkpoint_metric: accuracy
|
25 |
+
maximize_best_checkpoint_metric: true
|
26 |
+
no_epoch_checkpoints: true
|
27 |
+
|
28 |
+
distributed_training:
|
29 |
+
find_unused_parameters: true
|
30 |
+
distributed_world_size: 1
|
31 |
+
|
32 |
+
criterion:
|
33 |
+
_name: sentence_prediction
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
batch_size: 16
|
37 |
+
required_batch_size_multiple: 1
|
38 |
+
max_tokens: 4400
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
_name: adam
|
42 |
+
weight_decay: 0.1
|
43 |
+
adam_betas: (0.9,0.98)
|
44 |
+
adam_eps: 1e-06
|
45 |
+
|
46 |
+
lr_scheduler:
|
47 |
+
_name: polynomial_decay
|
48 |
+
warmup_updates: 122
|
49 |
+
|
50 |
+
optimization:
|
51 |
+
clip_norm: 0.0
|
52 |
+
lr: [2e-05]
|
53 |
+
max_update: 2036
|
54 |
+
max_epoch: 10
|
55 |
+
|
56 |
+
model:
|
57 |
+
_name: roberta
|
58 |
+
dropout: 0.1
|
59 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/finetuning/run_config/local.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
hydra:
|
3 |
+
sweep:
|
4 |
+
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
5 |
+
|
6 |
+
distributed_training:
|
7 |
+
distributed_world_size: 1
|
8 |
+
nprocs_per_node: 1
|
9 |
+
distributed_port: -1
|
10 |
+
|
11 |
+
common:
|
12 |
+
log_interval: 1
|
13 |
+
|
14 |
+
dataset:
|
15 |
+
num_workers: 0
|
fairseq/examples/roberta/config/finetuning/run_config/slurm_1g.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# @package _global_
|
3 |
+
|
4 |
+
hydra:
|
5 |
+
job:
|
6 |
+
config:
|
7 |
+
override_dirname:
|
8 |
+
kv_sep: '_'
|
9 |
+
item_sep: '/'
|
10 |
+
exclude_keys:
|
11 |
+
- run_config
|
12 |
+
- distributed_training.distributed_port
|
13 |
+
sweep:
|
14 |
+
dir: /checkpoint/${env:USER}/roberta_ft/${env:PREFIX}/${hydra.job.config_name}/${env:SUFFIX}
|
15 |
+
subdir: ${hydra.job.num}
|
16 |
+
launcher:
|
17 |
+
submitit_folder: ${hydra.sweep.dir}/submitit
|
18 |
+
timeout_min: 1000
|
19 |
+
cpus_per_task: 8
|
20 |
+
gpus_per_node: 1
|
21 |
+
tasks_per_node: 1
|
22 |
+
mem_gb: 60
|
23 |
+
nodes: 1
|
24 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
25 |
+
partition: devlab,learnlab,learnfair,scavenge
|
26 |
+
constraint: volta32gb
|
27 |
+
max_num_timeout: 30
|
28 |
+
exclude: learnfair1381,learnfair5192,learnfair2304
|
fairseq/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: '_'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /fsx-wav2vec/${env:USER}/roberta_ft/${env:PREFIX}/${hydra.job.config_name}/${env:SUFFIX}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}/submitit
|
17 |
+
timeout_min: 1000
|
18 |
+
cpus_per_task: 8
|
19 |
+
gpus_per_node: 1
|
20 |
+
tasks_per_node: 1
|
21 |
+
mem_gb: 0
|
22 |
+
nodes: 1
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
partition: learnfair,wav2vec
|
25 |
+
max_num_timeout: 30
|
fairseq/examples/roberta/config/finetuning/sst_2.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_init_scale: 4
|
6 |
+
threshold_loss_scale: 1
|
7 |
+
fp16_scale_window: 128
|
8 |
+
log_format: json
|
9 |
+
log_interval: 200
|
10 |
+
|
11 |
+
task:
|
12 |
+
_name: sentence_prediction
|
13 |
+
data: ???
|
14 |
+
init_token: 0
|
15 |
+
separator_token: 2
|
16 |
+
num_classes: 2
|
17 |
+
max_positions: 512
|
18 |
+
|
19 |
+
checkpoint:
|
20 |
+
restore_file: ???
|
21 |
+
reset_optimizer: true
|
22 |
+
reset_dataloader: true
|
23 |
+
reset_meters: true
|
24 |
+
best_checkpoint_metric: accuracy
|
25 |
+
maximize_best_checkpoint_metric: true
|
26 |
+
no_epoch_checkpoints: true
|
27 |
+
|
28 |
+
distributed_training:
|
29 |
+
find_unused_parameters: true
|
30 |
+
distributed_world_size: 1
|
31 |
+
|
32 |
+
criterion:
|
33 |
+
_name: sentence_prediction
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
batch_size: 32
|
37 |
+
required_batch_size_multiple: 1
|
38 |
+
max_tokens: 4400
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
_name: adam
|
42 |
+
weight_decay: 0.1
|
43 |
+
adam_betas: (0.9,0.98)
|
44 |
+
adam_eps: 1e-06
|
45 |
+
|
46 |
+
lr_scheduler:
|
47 |
+
_name: polynomial_decay
|
48 |
+
warmup_updates: 1256
|
49 |
+
|
50 |
+
optimization:
|
51 |
+
clip_norm: 0.0
|
52 |
+
lr: [1e-05]
|
53 |
+
max_update: 20935
|
54 |
+
max_epoch: 10
|
55 |
+
|
56 |
+
model:
|
57 |
+
_name: roberta
|
58 |
+
dropout: 0.1
|
59 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/finetuning/sts_b.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_init_scale: 4
|
6 |
+
threshold_loss_scale: 1
|
7 |
+
fp16_scale_window: 128
|
8 |
+
log_format: json
|
9 |
+
log_interval: 200
|
10 |
+
|
11 |
+
task:
|
12 |
+
_name: sentence_prediction
|
13 |
+
data: ???
|
14 |
+
init_token: 0
|
15 |
+
separator_token: 2
|
16 |
+
num_classes: 1
|
17 |
+
max_positions: 512
|
18 |
+
|
19 |
+
checkpoint:
|
20 |
+
restore_file: ???
|
21 |
+
reset_optimizer: true
|
22 |
+
reset_dataloader: true
|
23 |
+
reset_meters: true
|
24 |
+
no_epoch_checkpoints: true
|
25 |
+
|
26 |
+
distributed_training:
|
27 |
+
find_unused_parameters: true
|
28 |
+
distributed_world_size: 1
|
29 |
+
|
30 |
+
criterion:
|
31 |
+
_name: sentence_prediction
|
32 |
+
regression_target: true
|
33 |
+
|
34 |
+
dataset:
|
35 |
+
batch_size: 16
|
36 |
+
required_batch_size_multiple: 1
|
37 |
+
max_tokens: 4400
|
38 |
+
|
39 |
+
optimizer:
|
40 |
+
_name: adam
|
41 |
+
weight_decay: 0.1
|
42 |
+
adam_betas: (0.9,0.98)
|
43 |
+
adam_eps: 1e-06
|
44 |
+
|
45 |
+
lr_scheduler:
|
46 |
+
_name: polynomial_decay
|
47 |
+
warmup_updates: 214
|
48 |
+
|
49 |
+
optimization:
|
50 |
+
clip_norm: 0.0
|
51 |
+
lr: [2e-05]
|
52 |
+
max_update: 3598
|
53 |
+
max_epoch: 10
|
54 |
+
|
55 |
+
model:
|
56 |
+
_name: roberta
|
57 |
+
dropout: 0.1
|
58 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/pretraining/base.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
common:
|
3 |
+
fp16: true
|
4 |
+
log_format: json
|
5 |
+
log_interval: 200
|
6 |
+
|
7 |
+
checkpoint:
|
8 |
+
no_epoch_checkpoints: true
|
9 |
+
|
10 |
+
task:
|
11 |
+
_name: masked_lm
|
12 |
+
data: ???
|
13 |
+
sample_break_mode: complete
|
14 |
+
tokens_per_sample: 512
|
15 |
+
|
16 |
+
criterion: masked_lm
|
17 |
+
|
18 |
+
dataset:
|
19 |
+
batch_size: 16
|
20 |
+
ignore_unused_valid_subsets: true
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
_name: adam
|
24 |
+
weight_decay: 0.01
|
25 |
+
adam_betas: (0.9,0.98)
|
26 |
+
adam_eps: 1e-06
|
27 |
+
|
28 |
+
lr_scheduler:
|
29 |
+
_name: polynomial_decay
|
30 |
+
warmup_updates: 10000
|
31 |
+
|
32 |
+
optimization:
|
33 |
+
clip_norm: 0
|
34 |
+
lr: [0.0005]
|
35 |
+
max_update: 125000
|
36 |
+
update_freq: [16]
|
37 |
+
|
38 |
+
model:
|
39 |
+
_name: roberta
|
40 |
+
max_positions: 512
|
41 |
+
dropout: 0.1
|
42 |
+
attention_dropout: 0.1
|
fairseq/examples/roberta/config/pretraining/run_config/local.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
hydra:
|
3 |
+
sweep:
|
4 |
+
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
5 |
+
|
6 |
+
distributed_training:
|
7 |
+
distributed_world_size: 1
|
8 |
+
nprocs_per_node: 1
|
9 |
+
distributed_port: -1
|
10 |
+
|
11 |
+
common:
|
12 |
+
log_interval: 1
|
13 |
+
|
14 |
+
dataset:
|
15 |
+
num_workers: 0
|
fairseq/examples/roberta/config/pretraining/run_config/slurm_2.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 80
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 1
|
32 |
+
mem_gb: 450
|
33 |
+
nodes: 2
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: devlab,learnlab,learnfair,scavenge
|
36 |
+
constraint: volta32gb,ib4
|
37 |
+
max_num_timeout: 30
|
fairseq/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.local_cache_path
|
18 |
+
- task.data
|
19 |
+
- task.post_save_script
|
20 |
+
- checkpoint.save_interval_updates
|
21 |
+
- checkpoint.keep_interval_updates
|
22 |
+
- checkpoint.save_on_overflow
|
23 |
+
- common.log_interval
|
24 |
+
- common.user_dir
|
25 |
+
- model.model_path
|
26 |
+
sweep:
|
27 |
+
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
28 |
+
subdir: ''
|
29 |
+
launcher:
|
30 |
+
submitit_folder: ${hydra.sweep.dir}
|
31 |
+
timeout_min: 4320
|
32 |
+
cpus_per_task: 10
|
33 |
+
gpus_per_node: 8
|
34 |
+
tasks_per_node: 8
|
35 |
+
mem_gb: 0
|
36 |
+
nodes: 2
|
37 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
38 |
+
partition: wav2vec
|
39 |
+
max_num_timeout: 30
|
fairseq/examples/roberta/config/pretraining/run_config/slurm_3.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
sweep:
|
23 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
24 |
+
subdir: ''
|
25 |
+
launcher:
|
26 |
+
submitit_folder: ${hydra.sweep.dir}
|
27 |
+
timeout_min: 4320
|
28 |
+
cpus_per_task: 10
|
29 |
+
gpus_per_node: 8
|
30 |
+
tasks_per_node: 8
|
31 |
+
mem_gb: 450
|
32 |
+
nodes: 3
|
33 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
34 |
+
partition: devlab,learnlab,learnfair,scavenge
|
35 |
+
constraint: volta32gb,ib4
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/roberta/config/pretraining/run_config/slurm_4.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
sweep:
|
23 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
24 |
+
subdir: ''
|
25 |
+
launcher:
|
26 |
+
submitit_folder: ${hydra.sweep.dir}
|
27 |
+
timeout_min: 4320
|
28 |
+
cpus_per_task: 10
|
29 |
+
gpus_per_node: 8
|
30 |
+
tasks_per_node: 8
|
31 |
+
mem_gb: 450
|
32 |
+
nodes: 4
|
33 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
34 |
+
partition: devlab,learnlab,learnfair,scavenge
|
35 |
+
constraint: volta32gb,ib4
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/roberta/fb_multilingual/README.multilingual.pretraining.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multilingual pretraining RoBERTa
|
2 |
+
|
3 |
+
This tutorial will walk you through pretraining multilingual RoBERTa.
|
4 |
+
|
5 |
+
### 1) Preprocess the data
|
6 |
+
|
7 |
+
```bash
|
8 |
+
DICTIONARY="/private/home/namangoyal/dataset/XLM/wiki/17/175k/vocab"
|
9 |
+
DATA_LOCATION="/private/home/namangoyal/dataset/XLM/wiki/17/175k"
|
10 |
+
|
11 |
+
for LANG in en es it
|
12 |
+
do
|
13 |
+
fairseq-preprocess \
|
14 |
+
--only-source \
|
15 |
+
--srcdict $DICTIONARY \
|
16 |
+
--trainpref "$DATA_LOCATION/train.$LANG" \
|
17 |
+
--validpref "$DATA_LOCATION/valid.$LANG" \
|
18 |
+
--testpref "$DATA_LOCATION/test.$LANG" \
|
19 |
+
--destdir "wiki_17-bin/$LANG" \
|
20 |
+
--workers 60;
|
21 |
+
done
|
22 |
+
```
|
23 |
+
|
24 |
+
### 2) Train RoBERTa base
|
25 |
+
|
26 |
+
[COMING UP...]
|
fairseq/examples/roberta/multiprocessing_bpe_encoder.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import contextlib
|
10 |
+
import sys
|
11 |
+
from collections import Counter
|
12 |
+
from multiprocessing import Pool
|
13 |
+
|
14 |
+
from fairseq.data.encoders.gpt2_bpe import get_encoder
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
"""
|
19 |
+
Helper script to encode raw text with the GPT-2 BPE using multiple processes.
|
20 |
+
|
21 |
+
The encoder.json and vocab.bpe files can be obtained here:
|
22 |
+
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
|
23 |
+
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
|
24 |
+
"""
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument(
|
27 |
+
"--encoder-json",
|
28 |
+
help="path to encoder.json",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--vocab-bpe",
|
32 |
+
type=str,
|
33 |
+
help="path to vocab.bpe",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--inputs",
|
37 |
+
nargs="+",
|
38 |
+
default=["-"],
|
39 |
+
help="input files to filter/encode",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--outputs",
|
43 |
+
nargs="+",
|
44 |
+
default=["-"],
|
45 |
+
help="path to save encoded outputs",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--keep-empty",
|
49 |
+
action="store_true",
|
50 |
+
help="keep empty lines",
|
51 |
+
)
|
52 |
+
parser.add_argument("--workers", type=int, default=20)
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
assert len(args.inputs) == len(
|
56 |
+
args.outputs
|
57 |
+
), "number of input and output paths should match"
|
58 |
+
|
59 |
+
with contextlib.ExitStack() as stack:
|
60 |
+
inputs = [
|
61 |
+
stack.enter_context(open(input, "r", encoding="utf-8"))
|
62 |
+
if input != "-"
|
63 |
+
else sys.stdin
|
64 |
+
for input in args.inputs
|
65 |
+
]
|
66 |
+
outputs = [
|
67 |
+
stack.enter_context(open(output, "w", encoding="utf-8"))
|
68 |
+
if output != "-"
|
69 |
+
else sys.stdout
|
70 |
+
for output in args.outputs
|
71 |
+
]
|
72 |
+
|
73 |
+
encoder = MultiprocessingEncoder(args)
|
74 |
+
pool = Pool(args.workers, initializer=encoder.initializer)
|
75 |
+
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
|
76 |
+
|
77 |
+
stats = Counter()
|
78 |
+
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
|
79 |
+
if filt == "PASS":
|
80 |
+
for enc_line, output_h in zip(enc_lines, outputs):
|
81 |
+
print(enc_line, file=output_h)
|
82 |
+
else:
|
83 |
+
stats["num_filtered_" + filt] += 1
|
84 |
+
if i % 10000 == 0:
|
85 |
+
print("processed {} lines".format(i), file=sys.stderr)
|
86 |
+
|
87 |
+
for k, v in stats.most_common():
|
88 |
+
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
|
89 |
+
|
90 |
+
|
91 |
+
class MultiprocessingEncoder(object):
|
92 |
+
def __init__(self, args):
|
93 |
+
self.args = args
|
94 |
+
|
95 |
+
def initializer(self):
|
96 |
+
global bpe
|
97 |
+
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
|
98 |
+
|
99 |
+
def encode(self, line):
|
100 |
+
global bpe
|
101 |
+
ids = bpe.encode(line)
|
102 |
+
return list(map(str, ids))
|
103 |
+
|
104 |
+
def decode(self, tokens):
|
105 |
+
global bpe
|
106 |
+
return bpe.decode(tokens)
|
107 |
+
|
108 |
+
def encode_lines(self, lines):
|
109 |
+
"""
|
110 |
+
Encode a set of lines. All lines will be encoded together.
|
111 |
+
"""
|
112 |
+
enc_lines = []
|
113 |
+
for line in lines:
|
114 |
+
line = line.strip()
|
115 |
+
if len(line) == 0 and not self.args.keep_empty:
|
116 |
+
return ["EMPTY", None]
|
117 |
+
tokens = self.encode(line)
|
118 |
+
enc_lines.append(" ".join(tokens))
|
119 |
+
return ["PASS", enc_lines]
|
120 |
+
|
121 |
+
def decode_lines(self, lines):
|
122 |
+
dec_lines = []
|
123 |
+
for line in lines:
|
124 |
+
tokens = map(int, line.strip().split())
|
125 |
+
dec_lines.append(self.decode(tokens))
|
126 |
+
return ["PASS", dec_lines]
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
main()
|
fairseq/examples/roberta/preprocess_GLUE_tasks.sh
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
# raw glue data as downloaded by glue download script (https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
9 |
+
if [[ $# -ne 2 ]]; then
|
10 |
+
echo "Run as following:"
|
11 |
+
echo "./examples/roberta/preprocess_GLUE_tasks.sh <glud_data_folder> <task_name>"
|
12 |
+
exit 1
|
13 |
+
fi
|
14 |
+
|
15 |
+
GLUE_DATA_FOLDER=$1
|
16 |
+
|
17 |
+
# download bpe encoder.json, vocabulary and fairseq dictionary
|
18 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
|
19 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
|
20 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
|
21 |
+
|
22 |
+
TASKS=$2 # QQP
|
23 |
+
|
24 |
+
if [ "$TASKS" = "ALL" ]
|
25 |
+
then
|
26 |
+
TASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA"
|
27 |
+
fi
|
28 |
+
|
29 |
+
for TASK in $TASKS
|
30 |
+
do
|
31 |
+
echo "Preprocessing $TASK"
|
32 |
+
|
33 |
+
TASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK"
|
34 |
+
echo "Raw data as downloaded from glue website: $TASK_DATA_FOLDER"
|
35 |
+
|
36 |
+
SPLITS="train dev test"
|
37 |
+
INPUT_COUNT=2
|
38 |
+
if [ "$TASK" = "QQP" ]
|
39 |
+
then
|
40 |
+
INPUT_COLUMNS=( 4 5 )
|
41 |
+
TEST_INPUT_COLUMNS=( 2 3 )
|
42 |
+
LABEL_COLUMN=6
|
43 |
+
elif [ "$TASK" = "MNLI" ]
|
44 |
+
then
|
45 |
+
SPLITS="train dev_matched dev_mismatched test_matched test_mismatched"
|
46 |
+
INPUT_COLUMNS=( 9 10 )
|
47 |
+
TEST_INPUT_COLUMNS=( 9 10 )
|
48 |
+
DEV_LABEL_COLUMN=16
|
49 |
+
LABEL_COLUMN=12
|
50 |
+
elif [ "$TASK" = "QNLI" ]
|
51 |
+
then
|
52 |
+
INPUT_COLUMNS=( 2 3 )
|
53 |
+
TEST_INPUT_COLUMNS=( 2 3 )
|
54 |
+
LABEL_COLUMN=4
|
55 |
+
elif [ "$TASK" = "MRPC" ]
|
56 |
+
then
|
57 |
+
INPUT_COLUMNS=( 4 5 )
|
58 |
+
TEST_INPUT_COLUMNS=( 4 5 )
|
59 |
+
LABEL_COLUMN=1
|
60 |
+
elif [ "$TASK" = "RTE" ]
|
61 |
+
then
|
62 |
+
INPUT_COLUMNS=( 2 3 )
|
63 |
+
TEST_INPUT_COLUMNS=( 2 3 )
|
64 |
+
LABEL_COLUMN=4
|
65 |
+
elif [ "$TASK" = "STS-B" ]
|
66 |
+
then
|
67 |
+
INPUT_COLUMNS=( 8 9 )
|
68 |
+
TEST_INPUT_COLUMNS=( 8 9 )
|
69 |
+
LABEL_COLUMN=10
|
70 |
+
# Following are single sentence tasks.
|
71 |
+
elif [ "$TASK" = "SST-2" ]
|
72 |
+
then
|
73 |
+
INPUT_COLUMNS=( 1 )
|
74 |
+
TEST_INPUT_COLUMNS=( 2 )
|
75 |
+
LABEL_COLUMN=2
|
76 |
+
INPUT_COUNT=1
|
77 |
+
elif [ "$TASK" = "CoLA" ]
|
78 |
+
then
|
79 |
+
INPUT_COLUMNS=( 4 )
|
80 |
+
TEST_INPUT_COLUMNS=( 2 )
|
81 |
+
LABEL_COLUMN=2
|
82 |
+
INPUT_COUNT=1
|
83 |
+
fi
|
84 |
+
|
85 |
+
# Strip out header and filter lines that don't have expected number of fields.
|
86 |
+
rm -rf "$TASK_DATA_FOLDER/processed"
|
87 |
+
mkdir -p "$TASK_DATA_FOLDER/processed"
|
88 |
+
for SPLIT in $SPLITS
|
89 |
+
do
|
90 |
+
# CoLA train and dev doesn't have header.
|
91 |
+
if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]]
|
92 |
+
then
|
93 |
+
cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
|
94 |
+
else
|
95 |
+
tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
|
96 |
+
fi
|
97 |
+
|
98 |
+
# Remove unformatted lines from train and dev files for QQP dataset.
|
99 |
+
if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]]
|
100 |
+
then
|
101 |
+
awk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
|
102 |
+
else
|
103 |
+
cp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
|
104 |
+
fi
|
105 |
+
rm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
|
106 |
+
done
|
107 |
+
|
108 |
+
# Split into input0, input1 and label
|
109 |
+
for SPLIT in $SPLITS
|
110 |
+
do
|
111 |
+
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
|
112 |
+
do
|
113 |
+
if [[ "$SPLIT" != test* ]]
|
114 |
+
then
|
115 |
+
COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}
|
116 |
+
else
|
117 |
+
COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}
|
118 |
+
fi
|
119 |
+
cut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE";
|
120 |
+
done
|
121 |
+
|
122 |
+
if [[ "$SPLIT" != test* ]]
|
123 |
+
then
|
124 |
+
if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]
|
125 |
+
then
|
126 |
+
cut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
|
127 |
+
else
|
128 |
+
cut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
|
129 |
+
fi
|
130 |
+
fi
|
131 |
+
|
132 |
+
# BPE encode.
|
133 |
+
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
|
134 |
+
do
|
135 |
+
LANG="input$INPUT_TYPE"
|
136 |
+
echo "BPE encoding $SPLIT/$LANG"
|
137 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
138 |
+
--encoder-json encoder.json \
|
139 |
+
--vocab-bpe vocab.bpe \
|
140 |
+
--inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \
|
141 |
+
--outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \
|
142 |
+
--workers 60 \
|
143 |
+
--keep-empty;
|
144 |
+
done
|
145 |
+
done
|
146 |
+
|
147 |
+
# Remove output directory.
|
148 |
+
rm -rf "$TASK-bin"
|
149 |
+
|
150 |
+
DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG"
|
151 |
+
TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG"
|
152 |
+
if [ "$TASK" = "MNLI" ]
|
153 |
+
then
|
154 |
+
DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG"
|
155 |
+
TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG"
|
156 |
+
fi
|
157 |
+
|
158 |
+
# Run fairseq preprocessing:
|
159 |
+
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
|
160 |
+
do
|
161 |
+
LANG="input$INPUT_TYPE"
|
162 |
+
fairseq-preprocess \
|
163 |
+
--only-source \
|
164 |
+
--trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \
|
165 |
+
--validpref "${DEVPREF//LANG/$LANG}" \
|
166 |
+
--testpref "${TESTPREF//LANG/$LANG}" \
|
167 |
+
--destdir "$TASK-bin/$LANG" \
|
168 |
+
--workers 60 \
|
169 |
+
--srcdict dict.txt;
|
170 |
+
done
|
171 |
+
if [[ "$TASK" != "STS-B" ]]
|
172 |
+
then
|
173 |
+
fairseq-preprocess \
|
174 |
+
--only-source \
|
175 |
+
--trainpref "$TASK_DATA_FOLDER/processed/train.label" \
|
176 |
+
--validpref "${DEVPREF//LANG/label}" \
|
177 |
+
--destdir "$TASK-bin/label" \
|
178 |
+
--workers 60;
|
179 |
+
else
|
180 |
+
# For STS-B output range is converted to be between: [0.0, 1.0]
|
181 |
+
mkdir -p "$TASK-bin/label"
|
182 |
+
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label"
|
183 |
+
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label"
|
184 |
+
fi
|
185 |
+
done
|
fairseq/examples/roberta/preprocess_RACE.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import re
|
12 |
+
|
13 |
+
|
14 |
+
class InputExample:
|
15 |
+
def __init__(self, paragraph, qa_list, label):
|
16 |
+
self.paragraph = paragraph
|
17 |
+
self.qa_list = qa_list
|
18 |
+
self.label = label
|
19 |
+
|
20 |
+
|
21 |
+
def get_examples(data_dir, set_type):
|
22 |
+
"""
|
23 |
+
Extract paragraph and question-answer list from each json file
|
24 |
+
"""
|
25 |
+
examples = []
|
26 |
+
|
27 |
+
levels = ["middle", "high"]
|
28 |
+
set_type_c = set_type.split("-")
|
29 |
+
if len(set_type_c) == 2:
|
30 |
+
levels = [set_type_c[1]]
|
31 |
+
set_type = set_type_c[0]
|
32 |
+
for level in levels:
|
33 |
+
cur_dir = os.path.join(data_dir, set_type, level)
|
34 |
+
for filename in os.listdir(cur_dir):
|
35 |
+
cur_path = os.path.join(cur_dir, filename)
|
36 |
+
with open(cur_path, "r") as f:
|
37 |
+
cur_data = json.load(f)
|
38 |
+
answers = cur_data["answers"]
|
39 |
+
options = cur_data["options"]
|
40 |
+
questions = cur_data["questions"]
|
41 |
+
context = cur_data["article"].replace("\n", " ")
|
42 |
+
context = re.sub(r"\s+", " ", context)
|
43 |
+
for i in range(len(answers)):
|
44 |
+
label = ord(answers[i]) - ord("A")
|
45 |
+
qa_list = []
|
46 |
+
question = questions[i]
|
47 |
+
for j in range(4):
|
48 |
+
option = options[i][j]
|
49 |
+
if "_" in question:
|
50 |
+
qa_cat = question.replace("_", option)
|
51 |
+
else:
|
52 |
+
qa_cat = " ".join([question, option])
|
53 |
+
qa_cat = re.sub(r"\s+", " ", qa_cat)
|
54 |
+
qa_list.append(qa_cat)
|
55 |
+
examples.append(InputExample(context, qa_list, label))
|
56 |
+
|
57 |
+
return examples
|
58 |
+
|
59 |
+
|
60 |
+
def main():
|
61 |
+
"""
|
62 |
+
Helper script to extract paragraphs questions and answers from RACE datasets.
|
63 |
+
"""
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
parser.add_argument(
|
66 |
+
"--input-dir",
|
67 |
+
help="input directory for downloaded RACE dataset",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--output-dir",
|
71 |
+
help="output directory for extracted data",
|
72 |
+
)
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
if not os.path.exists(args.output_dir):
|
76 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
77 |
+
|
78 |
+
for set_type in ["train", "dev", "test-middle", "test-high"]:
|
79 |
+
examples = get_examples(args.input_dir, set_type)
|
80 |
+
qa_file_paths = [
|
81 |
+
os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
|
82 |
+
for i in range(4)
|
83 |
+
]
|
84 |
+
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
|
85 |
+
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
|
86 |
+
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
|
87 |
+
outf_context = open(outf_context_path, "w")
|
88 |
+
outf_label = open(outf_label_path, "w")
|
89 |
+
for example in examples:
|
90 |
+
outf_context.write(example.paragraph + "\n")
|
91 |
+
for i in range(4):
|
92 |
+
qa_files[i].write(example.qa_list[i] + "\n")
|
93 |
+
outf_label.write(str(example.label) + "\n")
|
94 |
+
|
95 |
+
for f in qa_files:
|
96 |
+
f.close()
|
97 |
+
outf_label.close()
|
98 |
+
outf_context.close()
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
main()
|
fairseq/examples/roberta/preprocess_RACE.sh
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
# data should be downloaded and processed with reprocess_RACE.py
|
9 |
+
if [[ $# -ne 2 ]]; then
|
10 |
+
echo "Run as following:"
|
11 |
+
echo "./examples/roberta/preprocess_RACE.sh <race_data_folder> <output_folder>"
|
12 |
+
exit 1
|
13 |
+
fi
|
14 |
+
|
15 |
+
RACE_DATA_FOLDER=$1
|
16 |
+
OUT_DATA_FOLDER=$2
|
17 |
+
|
18 |
+
# download bpe encoder.json, vocabulary and fairseq dictionary
|
19 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
|
20 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
|
21 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
|
22 |
+
|
23 |
+
SPLITS="train dev test-middle test-high"
|
24 |
+
INPUT_TYPES="input0 input1 input2 input3 input4"
|
25 |
+
for INPUT_TYPE in $INPUT_TYPES
|
26 |
+
do
|
27 |
+
for SPLIT in $SPLITS
|
28 |
+
do
|
29 |
+
echo "BPE encoding $SPLIT/$INPUT_TYPE"
|
30 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
31 |
+
--encoder-json encoder.json \
|
32 |
+
--vocab-bpe vocab.bpe \
|
33 |
+
--inputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE" \
|
34 |
+
--outputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE.bpe" \
|
35 |
+
--workers 10 \
|
36 |
+
--keep-empty;
|
37 |
+
|
38 |
+
done
|
39 |
+
done
|
40 |
+
|
41 |
+
for INPUT_TYPE in $INPUT_TYPES
|
42 |
+
do
|
43 |
+
LANG="input$INPUT_TYPE"
|
44 |
+
fairseq-preprocess \
|
45 |
+
--only-source \
|
46 |
+
--trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \
|
47 |
+
--validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \
|
48 |
+
--testpref "$RACE_DATA_FOLDER/test-middle.$INPUT_TYPE.bpe,$RACE_DATA_FOLDER/test-high.$INPUT_TYPE.bpe" \
|
49 |
+
--destdir "$OUT_DATA_FOLDER/$INPUT_TYPE" \
|
50 |
+
--workers 10 \
|
51 |
+
--srcdict dict.txt;
|
52 |
+
done
|
53 |
+
|
54 |
+
rm -rf "$OUT_DATA_FOLDER/label"
|
55 |
+
mkdir -p "$OUT_DATA_FOLDER/label"
|
56 |
+
cp "$RACE_DATA_FOLDER/train.label" "$OUT_DATA_FOLDER/label/"
|
57 |
+
cp "$RACE_DATA_FOLDER/dev.label" "$OUT_DATA_FOLDER/label/valid.label"
|
58 |
+
cp "$RACE_DATA_FOLDER/test-middle.label" "$OUT_DATA_FOLDER/label/test.label"
|
59 |
+
cp "$RACE_DATA_FOLDER/test-high.label" "$OUT_DATA_FOLDER/label/test1.label"
|
fairseq/examples/roberta/wsc/README.md
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetuning RoBERTa on Winograd Schema Challenge (WSC) data
|
2 |
+
|
3 |
+
The following instructions can be used to finetune RoBERTa on the WSC training
|
4 |
+
data provided by [SuperGLUE](https://super.gluebenchmark.com/).
|
5 |
+
|
6 |
+
Note that there is high variance in the results. For our GLUE/SuperGLUE
|
7 |
+
submission we swept over the learning rate (1e-5, 2e-5, 3e-5), batch size (16,
|
8 |
+
32, 64) and total number of updates (500, 1000, 2000, 3000), as well as the
|
9 |
+
random seed. Out of ~100 runs we chose the best 7 models and ensembled them.
|
10 |
+
|
11 |
+
**Approach:** The instructions below use a slightly different loss function than
|
12 |
+
what's described in the original RoBERTa arXiv paper. In particular,
|
13 |
+
[Kocijan et al. (2019)](https://arxiv.org/abs/1905.06290) introduce a margin
|
14 |
+
ranking loss between `(query, candidate)` pairs with tunable hyperparameters
|
15 |
+
alpha and beta. This is supported in our code as well with the `--wsc-alpha` and
|
16 |
+
`--wsc-beta` arguments. However, we achieved slightly better (and more robust)
|
17 |
+
results on the development set by instead using a single cross entropy loss term
|
18 |
+
over the log-probabilities for the query and all mined candidates. **The
|
19 |
+
candidates are mined using spaCy from each input sentence in isolation, so the
|
20 |
+
approach remains strictly pointwise.** This reduces the number of
|
21 |
+
hyperparameters and our best model achieved 92.3% development set accuracy,
|
22 |
+
compared to ~90% accuracy for the margin loss. Later versions of the RoBERTa
|
23 |
+
arXiv paper will describe this updated formulation.
|
24 |
+
|
25 |
+
### 1) Download the WSC data from the SuperGLUE website:
|
26 |
+
```bash
|
27 |
+
wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip
|
28 |
+
unzip WSC.zip
|
29 |
+
|
30 |
+
# we also need to copy the RoBERTa dictionary into the same directory
|
31 |
+
wget -O WSC/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
|
32 |
+
```
|
33 |
+
|
34 |
+
### 2) Finetune over the provided training data:
|
35 |
+
```bash
|
36 |
+
TOTAL_NUM_UPDATES=2000 # Total number of training steps.
|
37 |
+
WARMUP_UPDATES=250 # Linearly increase LR over this many steps.
|
38 |
+
LR=2e-05 # Peak LR for polynomial LR scheduler.
|
39 |
+
MAX_SENTENCES=16 # Batch size per GPU.
|
40 |
+
SEED=1 # Random seed.
|
41 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
42 |
+
|
43 |
+
# we use the --user-dir option to load the task and criterion
|
44 |
+
# from the examples/roberta/wsc directory:
|
45 |
+
FAIRSEQ_PATH=/path/to/fairseq
|
46 |
+
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
|
47 |
+
|
48 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
|
49 |
+
--restore-file $ROBERTA_PATH \
|
50 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
51 |
+
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
|
52 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
53 |
+
--valid-subset val \
|
54 |
+
--fp16 --ddp-backend legacy_ddp \
|
55 |
+
--user-dir $FAIRSEQ_USER_DIR \
|
56 |
+
--task wsc --criterion wsc --wsc-cross-entropy \
|
57 |
+
--arch roberta_large --bpe gpt2 --max-positions 512 \
|
58 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
59 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
|
60 |
+
--lr-scheduler polynomial_decay --lr $LR \
|
61 |
+
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
|
62 |
+
--batch-size $MAX_SENTENCES \
|
63 |
+
--max-update $TOTAL_NUM_UPDATES \
|
64 |
+
--log-format simple --log-interval 100 \
|
65 |
+
--seed $SEED
|
66 |
+
```
|
67 |
+
|
68 |
+
The above command assumes training on 4 GPUs, but you can achieve the same
|
69 |
+
results on a single GPU by adding `--update-freq=4`.
|
70 |
+
|
71 |
+
### 3) Evaluate
|
72 |
+
```python
|
73 |
+
from fairseq.models.roberta import RobertaModel
|
74 |
+
from examples.roberta.wsc import wsc_utils # also loads WSC task and criterion
|
75 |
+
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'WSC/')
|
76 |
+
roberta.cuda()
|
77 |
+
nsamples, ncorrect = 0, 0
|
78 |
+
for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True):
|
79 |
+
pred = roberta.disambiguate_pronoun(sentence)
|
80 |
+
nsamples += 1
|
81 |
+
if pred == label:
|
82 |
+
ncorrect += 1
|
83 |
+
print('Accuracy: ' + str(ncorrect / float(nsamples)))
|
84 |
+
# Accuracy: 0.9230769230769231
|
85 |
+
```
|
86 |
+
|
87 |
+
## RoBERTa training on WinoGrande dataset
|
88 |
+
We have also provided `winogrande` task and criterion for finetuning on the
|
89 |
+
[WinoGrande](https://mosaic.allenai.org/projects/winogrande) like datasets
|
90 |
+
where there are always two candidates and one is correct.
|
91 |
+
It's more efficient implementation for such subcases.
|
92 |
+
|
93 |
+
```bash
|
94 |
+
TOTAL_NUM_UPDATES=23750 # Total number of training steps.
|
95 |
+
WARMUP_UPDATES=2375 # Linearly increase LR over this many steps.
|
96 |
+
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
97 |
+
MAX_SENTENCES=32 # Batch size per GPU.
|
98 |
+
SEED=1 # Random seed.
|
99 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
100 |
+
|
101 |
+
# we use the --user-dir option to load the task and criterion
|
102 |
+
# from the examples/roberta/wsc directory:
|
103 |
+
FAIRSEQ_PATH=/path/to/fairseq
|
104 |
+
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
|
105 |
+
|
106 |
+
cd fairseq
|
107 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train winogrande_1.0/ \
|
108 |
+
--restore-file $ROBERTA_PATH \
|
109 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
110 |
+
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
|
111 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
112 |
+
--valid-subset val \
|
113 |
+
--fp16 --ddp-backend legacy_ddp \
|
114 |
+
--user-dir $FAIRSEQ_USER_DIR \
|
115 |
+
--task winogrande --criterion winogrande \
|
116 |
+
--wsc-margin-alpha 5.0 --wsc-margin-beta 0.4 \
|
117 |
+
--arch roberta_large --bpe gpt2 --max-positions 512 \
|
118 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
119 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
|
120 |
+
--lr-scheduler polynomial_decay --lr $LR \
|
121 |
+
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
|
122 |
+
--batch-size $MAX_SENTENCES \
|
123 |
+
--max-update $TOTAL_NUM_UPDATES \
|
124 |
+
--log-format simple --log-interval 100
|
125 |
+
```
|
fairseq/examples/roberta/wsc/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import wsc_criterion # noqa
|
7 |
+
from . import wsc_task # noqa
|
fairseq/examples/roberta/wsc/wsc_criterion.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
12 |
+
from fairseq.data import encoders
|
13 |
+
|
14 |
+
|
15 |
+
@register_criterion("wsc")
|
16 |
+
class WSCCriterion(LegacyFairseqCriterion):
|
17 |
+
def __init__(self, args, task):
|
18 |
+
super().__init__(args, task)
|
19 |
+
if self.args.save_predictions is not None:
|
20 |
+
self.prediction_h = open(self.args.save_predictions, "w")
|
21 |
+
else:
|
22 |
+
self.prediction_h = None
|
23 |
+
self.bpe = encoders.build_bpe(args.bpe)
|
24 |
+
self.tokenizer = encoders.build_tokenizer(args.tokenizer)
|
25 |
+
|
26 |
+
def __del__(self):
|
27 |
+
if self.prediction_h is not None:
|
28 |
+
self.prediction_h.close()
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def add_args(parser):
|
32 |
+
"""Add criterion-specific arguments to the parser."""
|
33 |
+
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
|
34 |
+
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
|
35 |
+
parser.add_argument(
|
36 |
+
"--wsc-cross-entropy",
|
37 |
+
action="store_true",
|
38 |
+
help="use cross entropy formulation instead of margin loss",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--save-predictions", metavar="FILE", help="file to save predictions to"
|
42 |
+
)
|
43 |
+
|
44 |
+
def get_masked_input(self, tokens, mask):
|
45 |
+
masked_tokens = tokens.clone()
|
46 |
+
masked_tokens[mask] = self.task.mask
|
47 |
+
return masked_tokens
|
48 |
+
|
49 |
+
def get_lprobs(self, model, tokens, mask):
|
50 |
+
logits, _ = model(src_tokens=self.get_masked_input(tokens, mask))
|
51 |
+
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
|
52 |
+
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
|
53 |
+
mask = mask.type_as(scores)
|
54 |
+
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
|
55 |
+
return scores
|
56 |
+
|
57 |
+
def get_loss(self, query_lprobs, cand_lprobs):
|
58 |
+
if self.args.wsc_cross_entropy:
|
59 |
+
return F.cross_entropy(
|
60 |
+
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
|
61 |
+
query_lprobs.new([0]).long(),
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
return (
|
65 |
+
-query_lprobs
|
66 |
+
+ self.args.wsc_margin_alpha
|
67 |
+
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
|
68 |
+
).sum()
|
69 |
+
|
70 |
+
def forward(self, model, sample, reduce=True):
|
71 |
+
# compute loss and accuracy
|
72 |
+
loss, nloss = 0.0, 0
|
73 |
+
ncorrect, nqueries = 0, 0
|
74 |
+
|
75 |
+
for i, label in enumerate(sample["labels"]):
|
76 |
+
query_lprobs = self.get_lprobs(
|
77 |
+
model,
|
78 |
+
sample["query_tokens"][i].unsqueeze(0),
|
79 |
+
sample["query_masks"][i].unsqueeze(0),
|
80 |
+
)
|
81 |
+
cand_lprobs = self.get_lprobs(
|
82 |
+
model,
|
83 |
+
sample["candidate_tokens"][i],
|
84 |
+
sample["candidate_masks"][i],
|
85 |
+
)
|
86 |
+
|
87 |
+
pred = (query_lprobs >= cand_lprobs).all().item()
|
88 |
+
|
89 |
+
if label is not None:
|
90 |
+
label = 1 if label else 0
|
91 |
+
ncorrect += 1 if pred == label else 0
|
92 |
+
nqueries += 1
|
93 |
+
|
94 |
+
if label:
|
95 |
+
# only compute a loss for positive instances
|
96 |
+
nloss += 1
|
97 |
+
loss += self.get_loss(query_lprobs, cand_lprobs)
|
98 |
+
|
99 |
+
id = sample["id"][i].item()
|
100 |
+
if self.prediction_h is not None:
|
101 |
+
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
|
102 |
+
|
103 |
+
if nloss == 0:
|
104 |
+
loss = torch.tensor(0.0, requires_grad=True)
|
105 |
+
|
106 |
+
sample_size = nqueries if nqueries > 0 else 1
|
107 |
+
logging_output = {
|
108 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
109 |
+
"ntokens": sample["ntokens"],
|
110 |
+
"nsentences": sample["nsentences"],
|
111 |
+
"sample_size": sample_size,
|
112 |
+
"ncorrect": ncorrect,
|
113 |
+
"nqueries": nqueries,
|
114 |
+
}
|
115 |
+
return loss, sample_size, logging_output
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def aggregate_logging_outputs(logging_outputs):
|
119 |
+
"""Aggregate logging outputs from data parallel training."""
|
120 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
121 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
122 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
123 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
124 |
+
|
125 |
+
agg_output = {
|
126 |
+
"loss": loss_sum / sample_size / math.log(2),
|
127 |
+
"ntokens": ntokens,
|
128 |
+
"nsentences": nsentences,
|
129 |
+
"sample_size": sample_size,
|
130 |
+
}
|
131 |
+
|
132 |
+
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
133 |
+
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
|
134 |
+
if nqueries > 0:
|
135 |
+
agg_output["accuracy"] = ncorrect / float(nqueries)
|
136 |
+
|
137 |
+
return agg_output
|
138 |
+
|
139 |
+
|
140 |
+
@register_criterion("winogrande")
|
141 |
+
class WinograndeCriterion(WSCCriterion):
|
142 |
+
def forward(self, model, sample, reduce=True):
|
143 |
+
# compute loss and accuracy
|
144 |
+
query_lprobs = self.get_lprobs(
|
145 |
+
model,
|
146 |
+
sample["query_tokens"],
|
147 |
+
sample["query_masks"],
|
148 |
+
)
|
149 |
+
cand_lprobs = self.get_lprobs(
|
150 |
+
model,
|
151 |
+
sample["candidate_tokens"],
|
152 |
+
sample["candidate_masks"],
|
153 |
+
)
|
154 |
+
pred = query_lprobs >= cand_lprobs
|
155 |
+
loss = self.get_loss(query_lprobs, cand_lprobs)
|
156 |
+
|
157 |
+
sample_size = sample["query_tokens"].size(0)
|
158 |
+
ncorrect = pred.sum().item()
|
159 |
+
logging_output = {
|
160 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
161 |
+
"ntokens": sample["ntokens"],
|
162 |
+
"nsentences": sample["nsentences"],
|
163 |
+
"sample_size": sample_size,
|
164 |
+
"ncorrect": ncorrect,
|
165 |
+
"nqueries": sample_size,
|
166 |
+
}
|
167 |
+
return loss, sample_size, logging_output
|
fairseq/examples/roberta/wsc/wsc_task.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import tempfile
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.data import (
|
15 |
+
Dictionary,
|
16 |
+
IdDataset,
|
17 |
+
ListDataset,
|
18 |
+
NestedDictionaryDataset,
|
19 |
+
NumelDataset,
|
20 |
+
NumSamplesDataset,
|
21 |
+
PadDataset,
|
22 |
+
SortDataset,
|
23 |
+
data_utils,
|
24 |
+
encoders,
|
25 |
+
)
|
26 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
27 |
+
|
28 |
+
from . import wsc_utils
|
29 |
+
|
30 |
+
|
31 |
+
@register_task("wsc")
|
32 |
+
class WSCTask(LegacyFairseqTask):
|
33 |
+
"""Task to finetune RoBERTa for Winograd Schemas."""
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def add_args(parser):
|
37 |
+
"""Add task-specific arguments to the parser."""
|
38 |
+
parser.add_argument(
|
39 |
+
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--init-token",
|
43 |
+
type=int,
|
44 |
+
default=None,
|
45 |
+
help="add token at the beginning of each batch item",
|
46 |
+
)
|
47 |
+
|
48 |
+
def __init__(self, args, vocab):
|
49 |
+
super().__init__(args)
|
50 |
+
self.vocab = vocab
|
51 |
+
self.mask = vocab.add_symbol("<mask>")
|
52 |
+
|
53 |
+
self.bpe = encoders.build_bpe(args)
|
54 |
+
self.tokenizer = encoders.build_tokenizer(args)
|
55 |
+
|
56 |
+
# hack to handle GPT-2 BPE, which includes leading spaces
|
57 |
+
if args.bpe == "gpt2":
|
58 |
+
self.leading_space = True
|
59 |
+
self.trailing_space = False
|
60 |
+
else:
|
61 |
+
self.leading_space = False
|
62 |
+
self.trailing_space = True
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def load_dictionary(cls, filename):
|
66 |
+
"""Load the dictionary from the filename
|
67 |
+
|
68 |
+
Args:
|
69 |
+
filename (str): the filename
|
70 |
+
"""
|
71 |
+
dictionary = Dictionary.load(filename)
|
72 |
+
dictionary.add_symbol("<mask>")
|
73 |
+
return dictionary
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def setup_task(cls, args, **kwargs):
|
77 |
+
assert args.criterion == "wsc", "Must set --criterion=wsc"
|
78 |
+
|
79 |
+
# load data and label dictionaries
|
80 |
+
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
81 |
+
print("| dictionary: {} types".format(len(vocab)))
|
82 |
+
|
83 |
+
return cls(args, vocab)
|
84 |
+
|
85 |
+
def binarize(self, s: str, append_eos: bool = False):
|
86 |
+
if self.tokenizer is not None:
|
87 |
+
s = self.tokenizer.encode(s)
|
88 |
+
if self.bpe is not None:
|
89 |
+
s = self.bpe.encode(s)
|
90 |
+
tokens = self.vocab.encode_line(
|
91 |
+
s,
|
92 |
+
append_eos=append_eos,
|
93 |
+
add_if_not_exist=False,
|
94 |
+
).long()
|
95 |
+
if self.args.init_token is not None:
|
96 |
+
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
97 |
+
return tokens
|
98 |
+
|
99 |
+
def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space):
|
100 |
+
toks = self.binarize(
|
101 |
+
prefix + leading_space + txt + trailing_space + suffix,
|
102 |
+
append_eos=True,
|
103 |
+
)
|
104 |
+
mask = torch.zeros_like(toks, dtype=torch.bool)
|
105 |
+
mask_start = len(self.binarize(prefix))
|
106 |
+
mask_size = len(self.binarize(leading_space + txt))
|
107 |
+
mask[mask_start : mask_start + mask_size] = 1
|
108 |
+
return toks, mask
|
109 |
+
|
110 |
+
def load_dataset(
|
111 |
+
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
112 |
+
):
|
113 |
+
"""Load a given dataset split.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
split (str): name of the split (e.g., train, valid, test)
|
117 |
+
"""
|
118 |
+
if data_path is None:
|
119 |
+
data_path = os.path.join(self.args.data, split + ".jsonl")
|
120 |
+
if not os.path.exists(data_path):
|
121 |
+
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
122 |
+
|
123 |
+
query_tokens = []
|
124 |
+
query_masks = []
|
125 |
+
query_lengths = []
|
126 |
+
candidate_tokens = []
|
127 |
+
candidate_masks = []
|
128 |
+
candidate_lengths = []
|
129 |
+
labels = []
|
130 |
+
|
131 |
+
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
|
132 |
+
prefix = sentence[: pronoun_span.start].text
|
133 |
+
suffix = sentence[pronoun_span.end :].text_with_ws
|
134 |
+
|
135 |
+
# spaCy spans include trailing spaces, but we need to know about
|
136 |
+
# leading spaces for the GPT-2 BPE
|
137 |
+
leading_space = (
|
138 |
+
" " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
|
139 |
+
)
|
140 |
+
trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
|
141 |
+
|
142 |
+
# get noun phrases, excluding pronouns and anything overlapping with the query
|
143 |
+
cand_spans = wsc_utils.filter_noun_chunks(
|
144 |
+
wsc_utils.extended_noun_chunks(sentence),
|
145 |
+
exclude_pronouns=True,
|
146 |
+
exclude_query=query,
|
147 |
+
exact_match=False,
|
148 |
+
)
|
149 |
+
|
150 |
+
if query is not None:
|
151 |
+
query_toks, query_mask = self.binarize_with_mask(
|
152 |
+
query, prefix, suffix, leading_space, trailing_space
|
153 |
+
)
|
154 |
+
query_len = len(query_toks)
|
155 |
+
else:
|
156 |
+
query_toks, query_mask, query_len = None, None, 0
|
157 |
+
|
158 |
+
query_tokens.append(query_toks)
|
159 |
+
query_masks.append(query_mask)
|
160 |
+
query_lengths.append(query_len)
|
161 |
+
|
162 |
+
cand_toks, cand_masks = [], []
|
163 |
+
for cand_span in cand_spans:
|
164 |
+
toks, mask = self.binarize_with_mask(
|
165 |
+
cand_span.text,
|
166 |
+
prefix,
|
167 |
+
suffix,
|
168 |
+
leading_space,
|
169 |
+
trailing_space,
|
170 |
+
)
|
171 |
+
cand_toks.append(toks)
|
172 |
+
cand_masks.append(mask)
|
173 |
+
|
174 |
+
# collate candidates
|
175 |
+
cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
|
176 |
+
cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
|
177 |
+
assert cand_toks.size() == cand_masks.size()
|
178 |
+
|
179 |
+
candidate_tokens.append(cand_toks)
|
180 |
+
candidate_masks.append(cand_masks)
|
181 |
+
candidate_lengths.append(cand_toks.size(1))
|
182 |
+
|
183 |
+
labels.append(label)
|
184 |
+
|
185 |
+
query_lengths = np.array(query_lengths)
|
186 |
+
query_tokens = ListDataset(query_tokens, query_lengths)
|
187 |
+
query_masks = ListDataset(query_masks, query_lengths)
|
188 |
+
|
189 |
+
candidate_lengths = np.array(candidate_lengths)
|
190 |
+
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
|
191 |
+
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
|
192 |
+
|
193 |
+
labels = ListDataset(labels, [1] * len(labels))
|
194 |
+
|
195 |
+
dataset = {
|
196 |
+
"id": IdDataset(),
|
197 |
+
"query_tokens": query_tokens,
|
198 |
+
"query_masks": query_masks,
|
199 |
+
"candidate_tokens": candidate_tokens,
|
200 |
+
"candidate_masks": candidate_masks,
|
201 |
+
"labels": labels,
|
202 |
+
"nsentences": NumSamplesDataset(),
|
203 |
+
"ntokens": NumelDataset(query_tokens, reduce=True),
|
204 |
+
}
|
205 |
+
|
206 |
+
nested_dataset = NestedDictionaryDataset(
|
207 |
+
dataset,
|
208 |
+
sizes=[query_lengths],
|
209 |
+
)
|
210 |
+
|
211 |
+
with data_utils.numpy_seed(self.args.seed):
|
212 |
+
shuffle = np.random.permutation(len(query_tokens))
|
213 |
+
dataset = SortDataset(
|
214 |
+
nested_dataset,
|
215 |
+
# shuffle
|
216 |
+
sort_order=[shuffle],
|
217 |
+
)
|
218 |
+
|
219 |
+
if return_only:
|
220 |
+
return dataset
|
221 |
+
|
222 |
+
self.datasets[split] = dataset
|
223 |
+
return self.datasets[split]
|
224 |
+
|
225 |
+
def build_dataset_for_inference(self, sample_json):
|
226 |
+
with tempfile.NamedTemporaryFile(buffering=0) as h:
|
227 |
+
h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
|
228 |
+
dataset = self.load_dataset(
|
229 |
+
"disambiguate_pronoun",
|
230 |
+
data_path=h.name,
|
231 |
+
return_only=True,
|
232 |
+
)
|
233 |
+
return dataset
|
234 |
+
|
235 |
+
def disambiguate_pronoun(self, model, sentence, use_cuda=False):
|
236 |
+
sample_json = wsc_utils.convert_sentence_to_json(sentence)
|
237 |
+
dataset = self.build_dataset_for_inference(sample_json)
|
238 |
+
sample = dataset.collater([dataset[0]])
|
239 |
+
if use_cuda:
|
240 |
+
sample = utils.move_to_cuda(sample)
|
241 |
+
|
242 |
+
def get_masked_input(tokens, mask):
|
243 |
+
masked_tokens = tokens.clone()
|
244 |
+
masked_tokens[mask.bool()] = self.mask
|
245 |
+
return masked_tokens
|
246 |
+
|
247 |
+
def get_lprobs(tokens, mask):
|
248 |
+
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
|
249 |
+
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
|
250 |
+
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
|
251 |
+
mask = mask.type_as(scores)
|
252 |
+
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
|
253 |
+
return scores
|
254 |
+
|
255 |
+
cand_lprobs = get_lprobs(
|
256 |
+
sample["candidate_tokens"][0],
|
257 |
+
sample["candidate_masks"][0],
|
258 |
+
)
|
259 |
+
if sample["query_tokens"][0] is not None:
|
260 |
+
query_lprobs = get_lprobs(
|
261 |
+
sample["query_tokens"][0].unsqueeze(0),
|
262 |
+
sample["query_masks"][0].unsqueeze(0),
|
263 |
+
)
|
264 |
+
return (query_lprobs >= cand_lprobs).all().item() == 1
|
265 |
+
else:
|
266 |
+
best_idx = cand_lprobs.argmax().item()
|
267 |
+
full_cand = sample["candidate_tokens"][0][best_idx]
|
268 |
+
mask = sample["candidate_masks"][0][best_idx]
|
269 |
+
toks = full_cand[mask.bool()]
|
270 |
+
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
|
271 |
+
|
272 |
+
@property
|
273 |
+
def source_dictionary(self):
|
274 |
+
return self.vocab
|
275 |
+
|
276 |
+
@property
|
277 |
+
def target_dictionary(self):
|
278 |
+
return self.vocab
|
279 |
+
|
280 |
+
|
281 |
+
@register_task("winogrande")
|
282 |
+
class WinograndeTask(WSCTask):
|
283 |
+
"""
|
284 |
+
Task for WinoGrande dataset. Efficient implementation for Winograd schema
|
285 |
+
tasks with exactly two candidates, one of which is correct.
|
286 |
+
"""
|
287 |
+
|
288 |
+
@classmethod
|
289 |
+
def setup_task(cls, args, **kwargs):
|
290 |
+
assert args.criterion == "winogrande", "Must set --criterion=winogrande"
|
291 |
+
|
292 |
+
# load data and label dictionaries
|
293 |
+
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
294 |
+
print("| dictionary: {} types".format(len(vocab)))
|
295 |
+
|
296 |
+
return cls(args, vocab)
|
297 |
+
|
298 |
+
def load_dataset(
|
299 |
+
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
300 |
+
):
|
301 |
+
"""Load a given dataset split.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
split (str): name of the split (e.g., train, valid, test)
|
305 |
+
"""
|
306 |
+
if data_path is None:
|
307 |
+
data_path = os.path.join(self.args.data, split + ".jsonl")
|
308 |
+
if not os.path.exists(data_path):
|
309 |
+
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
310 |
+
|
311 |
+
query_tokens = []
|
312 |
+
query_masks = []
|
313 |
+
query_lengths = []
|
314 |
+
candidate_tokens = []
|
315 |
+
candidate_masks = []
|
316 |
+
candidate_lengths = []
|
317 |
+
|
318 |
+
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
|
319 |
+
|
320 |
+
for sample in itr:
|
321 |
+
sentence, pronoun_span, query, cand_text = sample
|
322 |
+
prefix = sentence[: pronoun_span[0]].rstrip()
|
323 |
+
suffix = sentence[pronoun_span[1] :]
|
324 |
+
|
325 |
+
leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
|
326 |
+
trailing_space = ""
|
327 |
+
|
328 |
+
if query is not None:
|
329 |
+
query_toks, query_mask = self.binarize_with_mask(
|
330 |
+
query,
|
331 |
+
prefix,
|
332 |
+
suffix,
|
333 |
+
leading_space,
|
334 |
+
trailing_space,
|
335 |
+
)
|
336 |
+
query_len = len(query_toks)
|
337 |
+
else:
|
338 |
+
query_toks, query_mask, query_len = None, None, 0
|
339 |
+
|
340 |
+
query_tokens.append(query_toks)
|
341 |
+
query_masks.append(query_mask)
|
342 |
+
query_lengths.append(query_len)
|
343 |
+
|
344 |
+
cand_toks, cand_mask = self.binarize_with_mask(
|
345 |
+
cand_text,
|
346 |
+
prefix,
|
347 |
+
suffix,
|
348 |
+
leading_space,
|
349 |
+
trailing_space,
|
350 |
+
)
|
351 |
+
|
352 |
+
candidate_tokens.append(cand_toks)
|
353 |
+
candidate_masks.append(cand_mask)
|
354 |
+
candidate_lengths.append(cand_toks.size(0))
|
355 |
+
|
356 |
+
query_lengths = np.array(query_lengths)
|
357 |
+
|
358 |
+
def get_pad_dataset_fn(tokens, length, pad_idx):
|
359 |
+
return PadDataset(
|
360 |
+
ListDataset(tokens, length),
|
361 |
+
pad_idx=pad_idx,
|
362 |
+
left_pad=False,
|
363 |
+
)
|
364 |
+
|
365 |
+
query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad())
|
366 |
+
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
|
367 |
+
|
368 |
+
candidate_lengths = np.array(candidate_lengths)
|
369 |
+
candidate_tokens = get_pad_dataset_fn(
|
370 |
+
candidate_tokens, candidate_lengths, self.vocab.pad()
|
371 |
+
)
|
372 |
+
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
|
373 |
+
|
374 |
+
dataset = {
|
375 |
+
"id": IdDataset(),
|
376 |
+
"query_tokens": query_tokens,
|
377 |
+
"query_masks": query_masks,
|
378 |
+
"candidate_tokens": candidate_tokens,
|
379 |
+
"candidate_masks": candidate_masks,
|
380 |
+
"nsentences": NumSamplesDataset(),
|
381 |
+
"ntokens": NumelDataset(query_tokens, reduce=True),
|
382 |
+
}
|
383 |
+
|
384 |
+
nested_dataset = NestedDictionaryDataset(
|
385 |
+
dataset,
|
386 |
+
sizes=[query_lengths],
|
387 |
+
)
|
388 |
+
|
389 |
+
with data_utils.numpy_seed(self.args.seed):
|
390 |
+
shuffle = np.random.permutation(len(query_tokens))
|
391 |
+
dataset = SortDataset(
|
392 |
+
nested_dataset,
|
393 |
+
# shuffle
|
394 |
+
sort_order=[shuffle],
|
395 |
+
)
|
396 |
+
|
397 |
+
if return_only:
|
398 |
+
return dataset
|
399 |
+
|
400 |
+
self.datasets[split] = dataset
|
401 |
+
return self.datasets[split]
|
fairseq/examples/roberta/wsc/wsc_utils.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
from functools import lru_cache
|
8 |
+
|
9 |
+
|
10 |
+
def convert_sentence_to_json(sentence):
|
11 |
+
if "_" in sentence:
|
12 |
+
prefix, rest = sentence.split("_", 1)
|
13 |
+
query, rest = rest.split("_", 1)
|
14 |
+
query_index = len(prefix.rstrip().split(" "))
|
15 |
+
else:
|
16 |
+
query, query_index = None, None
|
17 |
+
|
18 |
+
prefix, rest = sentence.split("[", 1)
|
19 |
+
pronoun, rest = rest.split("]", 1)
|
20 |
+
pronoun_index = len(prefix.rstrip().split(" "))
|
21 |
+
|
22 |
+
sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
|
23 |
+
|
24 |
+
return {
|
25 |
+
"idx": 0,
|
26 |
+
"text": sentence,
|
27 |
+
"target": {
|
28 |
+
"span1_index": query_index,
|
29 |
+
"span1_text": query,
|
30 |
+
"span2_index": pronoun_index,
|
31 |
+
"span2_text": pronoun,
|
32 |
+
},
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def extended_noun_chunks(sentence):
|
37 |
+
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
|
38 |
+
np_start, cur_np = 0, "NONE"
|
39 |
+
for i, token in enumerate(sentence):
|
40 |
+
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
|
41 |
+
if np_type != cur_np:
|
42 |
+
if cur_np != "NONE":
|
43 |
+
noun_chunks.add((np_start, i))
|
44 |
+
if np_type != "NONE":
|
45 |
+
np_start = i
|
46 |
+
cur_np = np_type
|
47 |
+
if cur_np != "NONE":
|
48 |
+
noun_chunks.add((np_start, len(sentence)))
|
49 |
+
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
|
50 |
+
|
51 |
+
|
52 |
+
def find_token(sentence, start_pos):
|
53 |
+
found_tok = None
|
54 |
+
for tok in sentence:
|
55 |
+
if tok.idx == start_pos:
|
56 |
+
found_tok = tok
|
57 |
+
break
|
58 |
+
return found_tok
|
59 |
+
|
60 |
+
|
61 |
+
def find_span(sentence, search_text, start=0):
|
62 |
+
search_text = search_text.lower()
|
63 |
+
for tok in sentence[start:]:
|
64 |
+
remainder = sentence[tok.i :].text.lower()
|
65 |
+
if remainder.startswith(search_text):
|
66 |
+
len_to_consume = len(search_text)
|
67 |
+
start_idx = tok.idx
|
68 |
+
for next_tok in sentence[tok.i :]:
|
69 |
+
end_idx = next_tok.idx + len(next_tok.text)
|
70 |
+
if end_idx - start_idx == len_to_consume:
|
71 |
+
span = sentence[tok.i : next_tok.i + 1]
|
72 |
+
return span
|
73 |
+
return None
|
74 |
+
|
75 |
+
|
76 |
+
@lru_cache(maxsize=1)
|
77 |
+
def get_detokenizer():
|
78 |
+
from sacremoses import MosesDetokenizer
|
79 |
+
|
80 |
+
detok = MosesDetokenizer(lang="en")
|
81 |
+
return detok
|
82 |
+
|
83 |
+
|
84 |
+
@lru_cache(maxsize=1)
|
85 |
+
def get_spacy_nlp():
|
86 |
+
import en_core_web_lg
|
87 |
+
|
88 |
+
nlp = en_core_web_lg.load()
|
89 |
+
return nlp
|
90 |
+
|
91 |
+
|
92 |
+
def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
93 |
+
detok = get_detokenizer()
|
94 |
+
nlp = get_spacy_nlp()
|
95 |
+
|
96 |
+
with open(input_fname) as fin:
|
97 |
+
for line in fin:
|
98 |
+
sample = json.loads(line.strip())
|
99 |
+
|
100 |
+
if positive_only and "label" in sample and not sample["label"]:
|
101 |
+
# only consider examples where the query is correct
|
102 |
+
continue
|
103 |
+
|
104 |
+
target = sample["target"]
|
105 |
+
|
106 |
+
# clean up the query
|
107 |
+
query = target["span1_text"]
|
108 |
+
if query is not None:
|
109 |
+
if "\n" in query:
|
110 |
+
continue
|
111 |
+
if query.endswith(".") or query.endswith(","):
|
112 |
+
query = query[:-1]
|
113 |
+
|
114 |
+
# split tokens
|
115 |
+
tokens = sample["text"].split(" ")
|
116 |
+
|
117 |
+
def strip_pronoun(x):
|
118 |
+
return x.rstrip('.,"')
|
119 |
+
|
120 |
+
# find the pronoun
|
121 |
+
pronoun_idx = target["span2_index"]
|
122 |
+
pronoun = strip_pronoun(target["span2_text"])
|
123 |
+
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
|
124 |
+
# hack: sometimes the index is misaligned
|
125 |
+
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
|
126 |
+
pronoun_idx += 1
|
127 |
+
else:
|
128 |
+
raise Exception("Misaligned pronoun!")
|
129 |
+
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
|
130 |
+
|
131 |
+
# split tokens before and after the pronoun
|
132 |
+
before = tokens[:pronoun_idx]
|
133 |
+
after = tokens[pronoun_idx + 1 :]
|
134 |
+
|
135 |
+
# the GPT BPE attaches leading spaces to tokens, so we keep track
|
136 |
+
# of whether we need spaces before or after the pronoun
|
137 |
+
leading_space = " " if pronoun_idx > 0 else ""
|
138 |
+
trailing_space = " " if len(after) > 0 else ""
|
139 |
+
|
140 |
+
# detokenize
|
141 |
+
before = detok.detokenize(before, return_str=True)
|
142 |
+
pronoun = detok.detokenize([pronoun], return_str=True)
|
143 |
+
after = detok.detokenize(after, return_str=True)
|
144 |
+
|
145 |
+
# hack: when the pronoun ends in a period (or comma), move the
|
146 |
+
# punctuation to the "after" part
|
147 |
+
if pronoun.endswith(".") or pronoun.endswith(","):
|
148 |
+
after = pronoun[-1] + trailing_space + after
|
149 |
+
pronoun = pronoun[:-1]
|
150 |
+
|
151 |
+
# hack: when the "after" part begins with a comma or period, remove
|
152 |
+
# the trailing space
|
153 |
+
if after.startswith(".") or after.startswith(","):
|
154 |
+
trailing_space = ""
|
155 |
+
|
156 |
+
# parse sentence with spacy
|
157 |
+
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
|
158 |
+
|
159 |
+
# find pronoun span
|
160 |
+
start = len(before + leading_space)
|
161 |
+
first_pronoun_tok = find_token(sentence, start_pos=start)
|
162 |
+
pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
|
163 |
+
assert pronoun_span.text == pronoun
|
164 |
+
|
165 |
+
if eval:
|
166 |
+
# convert to format where pronoun is surrounded by "[]" and
|
167 |
+
# query is surrounded by "_"
|
168 |
+
query_span = find_span(sentence, query)
|
169 |
+
query_with_ws = "_{}_{}".format(
|
170 |
+
query_span.text,
|
171 |
+
(" " if query_span.text_with_ws.endswith(" ") else ""),
|
172 |
+
)
|
173 |
+
pronoun_with_ws = "[{}]{}".format(
|
174 |
+
pronoun_span.text,
|
175 |
+
(" " if pronoun_span.text_with_ws.endswith(" ") else ""),
|
176 |
+
)
|
177 |
+
if query_span.start < pronoun_span.start:
|
178 |
+
first = (query_span, query_with_ws)
|
179 |
+
second = (pronoun_span, pronoun_with_ws)
|
180 |
+
else:
|
181 |
+
first = (pronoun_span, pronoun_with_ws)
|
182 |
+
second = (query_span, query_with_ws)
|
183 |
+
sentence = (
|
184 |
+
sentence[: first[0].start].text_with_ws
|
185 |
+
+ first[1]
|
186 |
+
+ sentence[first[0].end : second[0].start].text_with_ws
|
187 |
+
+ second[1]
|
188 |
+
+ sentence[second[0].end :].text
|
189 |
+
)
|
190 |
+
yield sentence, sample.get("label", None)
|
191 |
+
else:
|
192 |
+
yield sentence, pronoun_span, query, sample.get("label", None)
|
193 |
+
|
194 |
+
|
195 |
+
def winogrande_jsonl_iterator(input_fname, eval=False):
|
196 |
+
with open(input_fname) as fin:
|
197 |
+
for line in fin:
|
198 |
+
sample = json.loads(line.strip())
|
199 |
+
sentence, option1, option2 = (
|
200 |
+
sample["sentence"],
|
201 |
+
sample["option1"],
|
202 |
+
sample["option2"],
|
203 |
+
)
|
204 |
+
|
205 |
+
pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
|
206 |
+
|
207 |
+
if eval:
|
208 |
+
query, cand = option1, option2
|
209 |
+
else:
|
210 |
+
query = option1 if sample["answer"] == "1" else option2
|
211 |
+
cand = option2 if sample["answer"] == "1" else option1
|
212 |
+
yield sentence, pronoun_span, query, cand
|
213 |
+
|
214 |
+
|
215 |
+
def filter_noun_chunks(
|
216 |
+
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
|
217 |
+
):
|
218 |
+
if exclude_pronouns:
|
219 |
+
chunks = [
|
220 |
+
np
|
221 |
+
for np in chunks
|
222 |
+
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
|
223 |
+
]
|
224 |
+
|
225 |
+
if exclude_query is not None:
|
226 |
+
excl_txt = [exclude_query.lower()]
|
227 |
+
filtered_chunks = []
|
228 |
+
for chunk in chunks:
|
229 |
+
lower_chunk = chunk.text.lower()
|
230 |
+
found = False
|
231 |
+
for excl in excl_txt:
|
232 |
+
if (
|
233 |
+
not exact_match and (lower_chunk in excl or excl in lower_chunk)
|
234 |
+
) or lower_chunk == excl:
|
235 |
+
found = True
|
236 |
+
break
|
237 |
+
if not found:
|
238 |
+
filtered_chunks.append(chunk)
|
239 |
+
chunks = filtered_chunks
|
240 |
+
|
241 |
+
return chunks
|
fairseq/examples/rxf/README.md
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[Better Fine-Tuning by Reducing Representational Collapse](https://arxiv.org/abs/2008.03156)
|
2 |
+
=====================
|
3 |
+
This repo contains the code to replicate all experiments from the _Better Fine-Tuning by Reducing Representational Collapse_ paper excluding the probing results.
|
4 |
+
|
5 |
+
The R3F sentence prediction criterion is registered as `sentence_prediction_r3f` while the label smoothing version of it is implemented as `label_smoothed_cross_entropy_r3f`. The R4F version of the sentence prediction criterion can be achieved by applying spectral norm to the classification head via the `--spectral-norm-classification-head` parameter.
|
6 |
+
|
7 |
+
## Hyper-parameters
|
8 |
+
Our methods introduce 3 new hyper-parameters; `--eps` which sets the standard deviation or range of the distribution we're sampling from, `--r3f-lambda` which controls the combining of logistic loss and noisy KL loss and `--noise-type` which controls which parametric distribution we use ('normal', 'uniform').
|
9 |
+
|
10 |
+
For example to run R3F on RTE from GLUE
|
11 |
+
|
12 |
+
```
|
13 |
+
TOTAL_NUM_UPDATES=3120
|
14 |
+
WARMUP_UPDATES=187
|
15 |
+
LR=1e-05
|
16 |
+
NUM_CLASSES=2
|
17 |
+
MAX_SENTENCES=8 # Batch size.
|
18 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
19 |
+
|
20 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin \
|
21 |
+
--restore-file $ROBERTA_PATH \
|
22 |
+
--max-positions 512 \
|
23 |
+
--max-sentences $MAX_SENTENCES \
|
24 |
+
--max-tokens 4400 \
|
25 |
+
--task sentence_prediction \
|
26 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
27 |
+
--required-batch-size-multiple 1 \
|
28 |
+
--init-token 0 --separator-token 2 \
|
29 |
+
--arch roberta_large \
|
30 |
+
--criterion sentence_prediction_r3f \
|
31 |
+
--num-classes $NUM_CLASSES \
|
32 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
33 |
+
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
34 |
+
--clip-norm 0.0 \
|
35 |
+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
36 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
37 |
+
--max-epoch 10 \
|
38 |
+
--find-unused-parameters \
|
39 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
40 |
+
--noise-type uniform --r3f-lambda 0.7 \
|
41 |
+
--user-dir examples/rxf/rxf_src
|
42 |
+
```
|
43 |
+
|
44 |
+
## Citation
|
45 |
+
```bibtex
|
46 |
+
@article{aghajanyan2020better,
|
47 |
+
title={Better Fine-Tuning by Reducing Representational Collapse},
|
48 |
+
author={Aghajanyan, Armen and Shrivastava, Akshat and Gupta, Anchit and Goyal, Naman and Zettlemoyer, Luke and Gupta, Sonal},
|
49 |
+
journal={arXiv preprint arXiv:2008.03156},
|
50 |
+
year={2020}
|
51 |
+
}
|
52 |
+
```
|
fairseq/examples/rxf/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import rxf_src # noqa
|
fairseq/examples/rxf/rxf_src/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import label_smoothed_cross_entropy_r3f, sentence_prediction_r3f # noqa
|
fairseq/examples/rxf/rxf_src/label_smoothed_cross_entropy_r3f.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.logging import metrics
|
12 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
13 |
+
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
|
14 |
+
|
15 |
+
|
16 |
+
@register_criterion("label_smoothed_cross_entropy_r3f")
|
17 |
+
class LabelSmoothedCrossEntropyR3FCriterion(FairseqCriterion):
|
18 |
+
def __init__(
|
19 |
+
self, task, sentence_avg, label_smoothing, eps, r3f_lambda, noise_type
|
20 |
+
):
|
21 |
+
super().__init__(task)
|
22 |
+
self.sentence_avg = sentence_avg
|
23 |
+
self.label_smoothing = label_smoothing
|
24 |
+
self.eps = eps
|
25 |
+
self.r3f_lambda = r3f_lambda
|
26 |
+
self.noise_type = noise_type
|
27 |
+
if self.noise_type in {"normal"}:
|
28 |
+
self.noise_sampler = torch.distributions.normal.Normal(
|
29 |
+
loc=0.0, scale=self.eps
|
30 |
+
)
|
31 |
+
elif self.noise_type == "uniform":
|
32 |
+
self.noise_sampler = torch.distributions.uniform.Uniform(
|
33 |
+
low=-self.eps, high=self.eps
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
raise Exception(f"unrecognized noise type {self.noise_type}")
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def add_args(parser):
|
40 |
+
"""Add criterion-specific arguments to the parser."""
|
41 |
+
# fmt: off
|
42 |
+
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
|
43 |
+
help='epsilon for label smoothing, 0 means no label smoothing')
|
44 |
+
parser.add_argument('--eps', type=float, default=1e-5,
|
45 |
+
help='noise eps')
|
46 |
+
parser.add_argument('--r3f-lambda', type=float, default=1.0,
|
47 |
+
help='lambda for combining logistic loss and noisy KL loss')
|
48 |
+
parser.add_argument('--noise-type', type=str, default='normal',
|
49 |
+
choices=['normal', 'uniform'],
|
50 |
+
help='type of noises')
|
51 |
+
# fmt: on
|
52 |
+
|
53 |
+
def _get_symm_kl(self, noised_logits, input_logits):
|
54 |
+
return (
|
55 |
+
F.kl_div(
|
56 |
+
F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
|
57 |
+
F.softmax(input_logits, dim=-1, dtype=torch.float32),
|
58 |
+
None,
|
59 |
+
None,
|
60 |
+
"sum",
|
61 |
+
)
|
62 |
+
+ F.kl_div(
|
63 |
+
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
|
64 |
+
F.softmax(noised_logits, dim=-1, dtype=torch.float32),
|
65 |
+
None,
|
66 |
+
None,
|
67 |
+
"sum",
|
68 |
+
)
|
69 |
+
) / noised_logits.size(0)
|
70 |
+
|
71 |
+
def forward(self, model, sample, reduce=True):
|
72 |
+
"""Compute the loss for the given sample.
|
73 |
+
|
74 |
+
Returns a tuple with three elements:
|
75 |
+
1) the loss
|
76 |
+
2) the sample size, which is used as the denominator for the gradient
|
77 |
+
3) logging outputs to display while training
|
78 |
+
"""
|
79 |
+
token_embeddings = model.encoder.embed_tokens(sample["net_input"]["src_tokens"])
|
80 |
+
input_logits, extra = model(**sample["net_input"])
|
81 |
+
loss, nll_loss = self.compute_loss(
|
82 |
+
model, (input_logits, extra), sample, reduce=reduce
|
83 |
+
)
|
84 |
+
sample_size = (
|
85 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
86 |
+
)
|
87 |
+
|
88 |
+
if model.training:
|
89 |
+
noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to(
|
90 |
+
token_embeddings
|
91 |
+
)
|
92 |
+
noised_embeddings = token_embeddings.clone() + noise
|
93 |
+
|
94 |
+
noised_logits, _ = model(
|
95 |
+
**sample["net_input"], token_embeddings=noised_embeddings
|
96 |
+
)
|
97 |
+
symm_kl = self._get_symm_kl(noised_logits, input_logits)
|
98 |
+
|
99 |
+
if model.training:
|
100 |
+
symm_kl = symm_kl * sample_size
|
101 |
+
loss = loss + self.r3f_lambda * symm_kl
|
102 |
+
|
103 |
+
logging_output = {
|
104 |
+
"loss": loss.data,
|
105 |
+
"nll_loss": nll_loss.data,
|
106 |
+
"ntokens": sample["ntokens"],
|
107 |
+
"nsentences": sample["target"].size(0),
|
108 |
+
"sample_size": sample_size,
|
109 |
+
}
|
110 |
+
|
111 |
+
if model.training:
|
112 |
+
logging_output.update(
|
113 |
+
symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data
|
114 |
+
)
|
115 |
+
|
116 |
+
return loss, sample_size, logging_output
|
117 |
+
|
118 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
119 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
120 |
+
lprobs = lprobs.view(-1, lprobs.size(-1))
|
121 |
+
target = model.get_targets(sample, net_output).view(-1, 1)
|
122 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
123 |
+
lprobs,
|
124 |
+
target,
|
125 |
+
self.label_smoothing,
|
126 |
+
ignore_index=self.padding_idx,
|
127 |
+
reduce=reduce,
|
128 |
+
)
|
129 |
+
return loss, nll_loss
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def reduce_metrics(logging_outputs) -> None:
|
133 |
+
"""Aggregate logging outputs from data parallel training."""
|
134 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
135 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
136 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
137 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
138 |
+
symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs)
|
139 |
+
|
140 |
+
metrics.log_scalar("symm_kl", symm_kl_sum / sample_size, sample_size, round=3)
|
141 |
+
metrics.log_scalar(
|
142 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
143 |
+
)
|
144 |
+
metrics.log_scalar(
|
145 |
+
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
|
146 |
+
)
|
147 |
+
metrics.log_derived(
|
148 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
149 |
+
)
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def logging_outputs_can_be_summed() -> bool:
|
153 |
+
"""
|
154 |
+
Whether the logging outputs returned by `forward` can be summed
|
155 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
156 |
+
to True will improves distributed training speed.
|
157 |
+
"""
|
158 |
+
return True
|
fairseq/examples/rxf/rxf_src/sentence_prediction_r3f.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
12 |
+
|
13 |
+
|
14 |
+
@register_criterion("sentence_prediction_r3f")
|
15 |
+
class SentencePredictionR3F(FairseqCriterion):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
task,
|
19 |
+
eps,
|
20 |
+
r3f_lambda,
|
21 |
+
noise_type,
|
22 |
+
classification_head_name,
|
23 |
+
regression_target,
|
24 |
+
):
|
25 |
+
super().__init__(task)
|
26 |
+
self.eps = eps
|
27 |
+
self.r3f_lambda = r3f_lambda
|
28 |
+
self.noise_type = noise_type
|
29 |
+
self.classification_head_name = classification_head_name
|
30 |
+
self.regression_target = regression_target
|
31 |
+
if self.noise_type in {"normal"}:
|
32 |
+
self.noise_sampler = torch.distributions.normal.Normal(
|
33 |
+
loc=0.0, scale=self.eps
|
34 |
+
)
|
35 |
+
elif self.noise_type == "uniform":
|
36 |
+
self.noise_sampler = torch.distributions.uniform.Uniform(
|
37 |
+
low=-self.eps, high=self.eps
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
raise Exception(f"unrecognized noise type {self.noise_type}")
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def add_args(parser):
|
44 |
+
# fmt: off
|
45 |
+
parser.add_argument('--eps', type=float, default=1e-5,
|
46 |
+
help='noise eps')
|
47 |
+
parser.add_argument('--r3f-lambda', type=float, default=1.0,
|
48 |
+
help='lambda for combining logistic loss and noisy KL loss')
|
49 |
+
parser.add_argument('--noise-type', type=str, default='uniform',
|
50 |
+
choices=['normal', 'uniform'],
|
51 |
+
help='type of noises for RXF methods')
|
52 |
+
parser.add_argument('--classification-head-name',
|
53 |
+
default='sentence_classification_head',
|
54 |
+
help='name of the classification head to use')
|
55 |
+
parser.add_argument('--regression-target', action='store_true')
|
56 |
+
# fmt: on
|
57 |
+
|
58 |
+
def _get_symm_kl(self, noised_logits, input_logits):
|
59 |
+
return (
|
60 |
+
F.kl_div(
|
61 |
+
F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
|
62 |
+
F.softmax(input_logits, dim=-1, dtype=torch.float32),
|
63 |
+
None,
|
64 |
+
None,
|
65 |
+
"sum",
|
66 |
+
)
|
67 |
+
+ F.kl_div(
|
68 |
+
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
|
69 |
+
F.softmax(noised_logits, dim=-1, dtype=torch.float32),
|
70 |
+
None,
|
71 |
+
None,
|
72 |
+
"sum",
|
73 |
+
)
|
74 |
+
) / noised_logits.size(0)
|
75 |
+
|
76 |
+
def forward(self, model, sample, reduce=True):
|
77 |
+
"""Compute the loss for the given sample.
|
78 |
+
|
79 |
+
Returns a tuple with three elements:
|
80 |
+
1) the loss
|
81 |
+
2) the sample size, which is used as the denominator for the gradient
|
82 |
+
3) logging outputs to display while training
|
83 |
+
"""
|
84 |
+
assert (
|
85 |
+
hasattr(model, "classification_heads")
|
86 |
+
and self.classification_head_name in model.classification_heads
|
87 |
+
), "model must provide sentence classification head for --criterion=sentence_prediction"
|
88 |
+
|
89 |
+
token_embeddings = model.encoder.sentence_encoder.embed_tokens(
|
90 |
+
sample["net_input"]["src_tokens"]
|
91 |
+
)
|
92 |
+
input_logits, _ = model(
|
93 |
+
**sample["net_input"],
|
94 |
+
features_only=True,
|
95 |
+
classification_head_name=self.classification_head_name,
|
96 |
+
token_embeddings=token_embeddings,
|
97 |
+
)
|
98 |
+
if model.training and self.noise_sampler:
|
99 |
+
noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to(
|
100 |
+
token_embeddings
|
101 |
+
)
|
102 |
+
noised_embeddings = token_embeddings.detach().clone() + noise
|
103 |
+
|
104 |
+
noised_logits, _ = model(
|
105 |
+
**sample["net_input"],
|
106 |
+
features_only=True,
|
107 |
+
classification_head_name=self.classification_head_name,
|
108 |
+
token_embeddings=noised_embeddings,
|
109 |
+
)
|
110 |
+
symm_kl = self._get_symm_kl(noised_logits, input_logits)
|
111 |
+
else:
|
112 |
+
symm_kl = 0
|
113 |
+
|
114 |
+
targets = model.get_targets(sample, [input_logits]).view(-1)
|
115 |
+
sample_size = targets.numel()
|
116 |
+
|
117 |
+
if not self.regression_target:
|
118 |
+
loss = F.nll_loss(
|
119 |
+
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
|
120 |
+
targets,
|
121 |
+
reduction="sum",
|
122 |
+
)
|
123 |
+
if model.training:
|
124 |
+
symm_kl = symm_kl * sample_size
|
125 |
+
loss = loss + self.r3f_lambda * symm_kl
|
126 |
+
else:
|
127 |
+
logits = input_logits.squeeze().float()
|
128 |
+
targets = targets.float()
|
129 |
+
loss = F.mse_loss(logits, targets, reduction="sum")
|
130 |
+
|
131 |
+
logging_output = {
|
132 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
133 |
+
"ntokens": sample["ntokens"],
|
134 |
+
"nsentences": sample_size,
|
135 |
+
"sample_size": sample_size,
|
136 |
+
}
|
137 |
+
|
138 |
+
if not self.regression_target:
|
139 |
+
preds = input_logits.max(dim=1)[1]
|
140 |
+
logging_output.update(ncorrect=(preds == targets).sum().item())
|
141 |
+
|
142 |
+
if model.training and self.noise_sampler:
|
143 |
+
logging_output.update(
|
144 |
+
symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data
|
145 |
+
)
|
146 |
+
return loss, sample_size, logging_output
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
def aggregate_logging_outputs(logging_outputs):
|
150 |
+
"""Aggregate logging outputs from data parallel training."""
|
151 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
152 |
+
symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs)
|
153 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
154 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
155 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
156 |
+
|
157 |
+
agg_output = {
|
158 |
+
"loss": loss_sum / sample_size / math.log(2),
|
159 |
+
"symm_kl": symm_kl_sum / sample_size,
|
160 |
+
"ntokens": ntokens,
|
161 |
+
"nsentences": nsentences,
|
162 |
+
"sample_size": sample_size,
|
163 |
+
}
|
164 |
+
|
165 |
+
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
|
166 |
+
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
167 |
+
agg_output.update(accuracy=ncorrect / nsentences)
|
168 |
+
|
169 |
+
if sample_size != ntokens:
|
170 |
+
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
|
171 |
+
return agg_output
|
fairseq/examples/scaling_nmt/README.md
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scaling Neural Machine Translation (Ott et al., 2018)
|
2 |
+
|
3 |
+
This page includes instructions for reproducing results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187).
|
4 |
+
|
5 |
+
## Pre-trained models
|
6 |
+
|
7 |
+
Model | Description | Dataset | Download
|
8 |
+
---|---|---|---
|
9 |
+
`transformer.wmt14.en-fr` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
|
10 |
+
`transformer.wmt16.en-de` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
|
11 |
+
|
12 |
+
## Training a new model on WMT'16 En-De
|
13 |
+
|
14 |
+
First download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8).
|
15 |
+
|
16 |
+
Then:
|
17 |
+
|
18 |
+
##### 1. Extract the WMT'16 En-De data
|
19 |
+
```bash
|
20 |
+
TEXT=wmt16_en_de_bpe32k
|
21 |
+
mkdir -p $TEXT
|
22 |
+
tar -xzvf wmt16_en_de.tar.gz -C $TEXT
|
23 |
+
```
|
24 |
+
|
25 |
+
##### 2. Preprocess the dataset with a joined dictionary
|
26 |
+
```bash
|
27 |
+
fairseq-preprocess \
|
28 |
+
--source-lang en --target-lang de \
|
29 |
+
--trainpref $TEXT/train.tok.clean.bpe.32000 \
|
30 |
+
--validpref $TEXT/newstest2013.tok.bpe.32000 \
|
31 |
+
--testpref $TEXT/newstest2014.tok.bpe.32000 \
|
32 |
+
--destdir data-bin/wmt16_en_de_bpe32k \
|
33 |
+
--nwordssrc 32768 --nwordstgt 32768 \
|
34 |
+
--joined-dictionary \
|
35 |
+
--workers 20
|
36 |
+
```
|
37 |
+
|
38 |
+
##### 3. Train a model
|
39 |
+
```bash
|
40 |
+
fairseq-train \
|
41 |
+
data-bin/wmt16_en_de_bpe32k \
|
42 |
+
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
|
43 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
44 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
|
45 |
+
--dropout 0.3 --weight-decay 0.0 \
|
46 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
47 |
+
--max-tokens 3584 \
|
48 |
+
--fp16
|
49 |
+
```
|
50 |
+
|
51 |
+
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
|
52 |
+
|
53 |
+
***IMPORTANT:*** You will get better performance by training with big batches and
|
54 |
+
increasing the learning rate. If you want to train the above model with big batches
|
55 |
+
(assuming your machine has 8 GPUs):
|
56 |
+
- add `--update-freq 16` to simulate training on 8x16=128 GPUs
|
57 |
+
- increase the learning rate; 0.001 works well for big batches
|
58 |
+
|
59 |
+
##### 4. Evaluate
|
60 |
+
|
61 |
+
Now we can evaluate our trained model.
|
62 |
+
|
63 |
+
Note that the original [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
|
64 |
+
paper used a couple tricks to achieve better BLEU scores. We use these same tricks in
|
65 |
+
the Scaling NMT paper, so it's important to apply them when reproducing our results.
|
66 |
+
|
67 |
+
First, use the [average_checkpoints.py](/scripts/average_checkpoints.py) script to
|
68 |
+
average the last few checkpoints. Averaging the last 5-10 checkpoints is usually
|
69 |
+
good, but you may need to adjust this depending on how long you've trained:
|
70 |
+
```bash
|
71 |
+
python scripts/average_checkpoints \
|
72 |
+
--inputs /path/to/checkpoints \
|
73 |
+
--num-epoch-checkpoints 10 \
|
74 |
+
--output checkpoint.avg10.pt
|
75 |
+
```
|
76 |
+
|
77 |
+
Next, generate translations using a beam width of 4 and length penalty of 0.6:
|
78 |
+
```bash
|
79 |
+
fairseq-generate \
|
80 |
+
data-bin/wmt16_en_de_bpe32k \
|
81 |
+
--path checkpoint.avg10.pt \
|
82 |
+
--beam 4 --lenpen 0.6 --remove-bpe > gen.out
|
83 |
+
```
|
84 |
+
|
85 |
+
Finally, we apply the ["compound splitting" script](/scripts/compound_split_bleu.sh) to
|
86 |
+
add spaces around dashes. For example "Café-Liebhaber" would become three tokens:
|
87 |
+
"Café - Liebhaber". This typically results in larger BLEU scores, but it is not
|
88 |
+
appropriate to compare these inflated scores to work which does not include this trick.
|
89 |
+
This trick was used in the [original AIAYN code](https://github.com/tensorflow/tensor2tensor/blob/fc9335c0203685cbbfe2b30c92db4352d8f60779/tensor2tensor/utils/get_ende_bleu.sh),
|
90 |
+
so we used it in the Scaling NMT paper as well. That said, it's strongly advised to
|
91 |
+
report [sacrebleu](https://github.com/mjpost/sacrebleu) scores instead.
|
92 |
+
|
93 |
+
To compute "compound split" tokenized BLEU (not recommended!):
|
94 |
+
```bash
|
95 |
+
bash scripts/compound_split_bleu.sh gen.out
|
96 |
+
# BLEU4 = 29.29, 60.3/35.0/22.8/15.3 (BP=1.000, ratio=1.004, syslen=64763, reflen=64496)
|
97 |
+
```
|
98 |
+
|
99 |
+
To compute detokenized BLEU with sacrebleu (preferred):
|
100 |
+
```bash
|
101 |
+
bash scripts/sacrebleu.sh wmt14/full en de gen.out
|
102 |
+
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.13a+version.1.4.3 = 28.6 59.3/34.3/22.1/14.9 (BP = 1.000 ratio = 1.016 hyp_len = 63666 ref_len = 62688)
|
103 |
+
```
|
104 |
+
|
105 |
+
## Citation
|
106 |
+
|
107 |
+
```bibtex
|
108 |
+
@inproceedings{ott2018scaling,
|
109 |
+
title = {Scaling Neural Machine Translation},
|
110 |
+
author = {Ott, Myle and Edunov, Sergey and Grangier, David and Auli, Michael},
|
111 |
+
booktitle = {Proceedings of the Third Conference on Machine Translation (WMT)},
|
112 |
+
year = 2018,
|
113 |
+
}
|
114 |
+
```
|
fairseq/examples/shuffled_word_order/README.finetuning.md
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fine-tuning details
|
2 |
+
|
3 |
+
For each task (GLUE and PAWS), we perform hyperparam search for each model, and report the mean and standard deviation across 5 seeds of the best model. First, get the datasets following the instructions in [RoBERTa fine-tuning README](../roberta/README.glue.md). Alternatively, you can use [huggingface datasets](https://huggingface.co/docs/datasets/) to get the task data:
|
4 |
+
|
5 |
+
```python
|
6 |
+
from datasets import load_dataset
|
7 |
+
import pandas as pd
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
key2file = {
|
11 |
+
"paws": {
|
12 |
+
"loc": "paws_data",
|
13 |
+
"columns": ["id", "sentence1", "sentence2", "label"],
|
14 |
+
"train": "train.tsv",
|
15 |
+
"validation": "dev.tsv",
|
16 |
+
"test": "test.tsv"
|
17 |
+
}
|
18 |
+
}
|
19 |
+
|
20 |
+
task_data = load_dataset("paws", "labeled_final")
|
21 |
+
task_config = key2file["paws"]
|
22 |
+
save_path = Path(task_config["loc"])
|
23 |
+
save_path.mkdir(exist_ok=True, parents=True)
|
24 |
+
for key, fl in task_config.items():
|
25 |
+
if key in ["loc", "columns"]:
|
26 |
+
continue
|
27 |
+
print(f"Reading {key}")
|
28 |
+
columns = task_config["columns"]
|
29 |
+
df = pd.DataFrame(task_data[key])
|
30 |
+
print(df.columns)
|
31 |
+
df = df[columns]
|
32 |
+
print(f"Got {len(df)} records")
|
33 |
+
save_loc = save_path / fl
|
34 |
+
print(f"Saving to : {save_loc}")
|
35 |
+
df.to_csv(save_loc, sep="\t", header=None, index=None)
|
36 |
+
|
37 |
+
```
|
38 |
+
|
39 |
+
- Preprocess using RoBERTa GLUE preprocessing script, while keeping in mind the column numbers for `sentence1`, `sentence2` and `label` (which is 0,1,2 if you save the data according to the above example.)
|
40 |
+
- Then, fine-tuning is performed similarly to RoBERTa (for example, in case of RTE):
|
41 |
+
|
42 |
+
```bash
|
43 |
+
TOTAL_NUM_UPDATES=30875 # 10 epochs through RTE for bsz 16
|
44 |
+
WARMUP_UPDATES=1852 # 6 percent of the number of updates
|
45 |
+
LR=2e-05 # Peak LR for polynomial LR scheduler.
|
46 |
+
NUM_CLASSES=2
|
47 |
+
MAX_SENTENCES=16 # Batch size.
|
48 |
+
SHUFFLED_ROBERTA_PATH=/path/to/shuffled_roberta/model.pt
|
49 |
+
|
50 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin/ \
|
51 |
+
--restore-file $SHUFFLED_ROBERTA_PATH \
|
52 |
+
--max-positions 512 \
|
53 |
+
--batch-size $MAX_SENTENCES \
|
54 |
+
--max-tokens 4400 \
|
55 |
+
--task sentence_prediction \
|
56 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
57 |
+
--required-batch-size-multiple 1 \
|
58 |
+
--init-token 0 --separator-token 2 \
|
59 |
+
--arch roberta_large \
|
60 |
+
--criterion sentence_prediction \
|
61 |
+
--num-classes $NUM_CLASSES \
|
62 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
63 |
+
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
64 |
+
--clip-norm 0.0 \
|
65 |
+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
66 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
67 |
+
--max-epoch 10 \
|
68 |
+
--find-unused-parameters \
|
69 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
|
70 |
+
```
|
71 |
+
|
72 |
+
- `TOTAL_NUM_UPDATES` is computed based on the `--batch_size` value and the dataset size.
|
73 |
+
- `WARMUP_UPDATES` is computed as 6% of `TOTAL_NUM_UPDATES`
|
74 |
+
- Best hyperparam of `--lr` and `--batch_size` is reported below:
|
75 |
+
|
76 |
+
## `--lr`
|
77 |
+
|
78 |
+
| | name | RTE | MRPC | SST-2 | CoLA | QQP | QNLI | MNLI | PAWS |
|
79 |
+
| --: | :----------- | ----: | ----: | ----: | ----: | ----: | ----: | ----: | ----: |
|
80 |
+
| 0 | original | 2e-05 | 2e-05 | 1e-05 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 2e-05 |
|
81 |
+
| 1 | n_1 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 3e-05 | 1e-05 | 2e-05 | 2e-05 |
|
82 |
+
| 2 | n_2 | 2e-05 | 2e-05 | 1e-05 | 1e-05 | 2e-05 | 1e-05 | 1e-05 | 3e-05 |
|
83 |
+
| 3 | n_3 | 3e-05 | 1e-05 | 2e-05 | 2e-05 | 3e-05 | 1e-05 | 1e-05 | 2e-05 |
|
84 |
+
| 4 | n_4 | 3e-05 | 1e-05 | 2e-05 | 2e-05 | 2e-05 | 1e-05 | 1e-05 | 2e-05 |
|
85 |
+
| 5 | r512 | 1e-05 | 3e-05 | 2e-05 | 2e-05 | 3e-05 | 2e-05 | 3e-05 | 2e-05 |
|
86 |
+
| 6 | rand_corpus | 2e-05 | 1e-05 | 3e-05 | 1e-05 | 3e-05 | 3e-05 | 3e-05 | 2e-05 |
|
87 |
+
| 7 | rand_uniform | 2e-05 | 1e-05 | 3e-05 | 2e-05 | 3e-05 | 3e-05 | 3e-05 | 1e-05 |
|
88 |
+
| 8 | rand_init | 1e-05 | 1e-05 | 3e-05 | 1e-05 | 1e-05 | 1e-05 | 2e-05 | 1e-05 |
|
89 |
+
| 9 | no_pos | 1e-05 | 3e-05 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 1e-05 | 1e-05 |
|
90 |
+
|
91 |
+
## `--batch_size`
|
92 |
+
|
93 |
+
| | name | RTE | MRPC | SST-2 | CoLA | QQP | QNLI | MNLI | PAWS |
|
94 |
+
| --: | :----------- | --: | ---: | ----: | ---: | --: | ---: | ---: | ---: |
|
95 |
+
| 0 | orig | 16 | 16 | 32 | 16 | 16 | 32 | 32 | 16 |
|
96 |
+
| 1 | n_1 | 32 | 32 | 16 | 32 | 32 | 16 | 32 | 16 |
|
97 |
+
| 2 | n_2 | 32 | 16 | 32 | 16 | 32 | 32 | 16 | 32 |
|
98 |
+
| 3 | n_3 | 32 | 32 | 16 | 32 | 32 | 16 | 32 | 32 |
|
99 |
+
| 4 | n_4 | 32 | 16 | 32 | 16 | 32 | 32 | 32 | 32 |
|
100 |
+
| 5 | r512 | 32 | 16 | 16 | 32 | 32 | 16 | 16 | 16 |
|
101 |
+
| 6 | rand_corpus | 16 | 16 | 16 | 16 | 32 | 16 | 16 | 32 |
|
102 |
+
| 7 | rand_uniform | 16 | 32 | 16 | 16 | 32 | 16 | 16 | 16 |
|
103 |
+
| 8 | rand_init | 16 | 16 | 32 | 16 | 16 | 16 | 32 | 16 |
|
104 |
+
| 9 | no_pos | 16 | 32 | 16 | 16 | 32 | 16 | 16 | 16 |
|
105 |
+
|
106 |
+
- Perform inference similar to RoBERTa as well:
|
107 |
+
|
108 |
+
```python
|
109 |
+
from fairseq.models.roberta import RobertaModel
|
110 |
+
|
111 |
+
roberta = RobertaModel.from_pretrained(
|
112 |
+
'checkpoints/',
|
113 |
+
checkpoint_file='checkpoint_best.pt',
|
114 |
+
data_name_or_path='PAWS-bin'
|
115 |
+
)
|
116 |
+
|
117 |
+
label_fn = lambda label: roberta.task.label_dictionary.string(
|
118 |
+
[label + roberta.task.label_dictionary.nspecial]
|
119 |
+
)
|
120 |
+
ncorrect, nsamples = 0, 0
|
121 |
+
roberta.cuda()
|
122 |
+
roberta.eval()
|
123 |
+
with open('paws_data/dev.tsv') as fin:
|
124 |
+
fin.readline()
|
125 |
+
for index, line in enumerate(fin):
|
126 |
+
tokens = line.strip().split('\t')
|
127 |
+
sent1, sent2, target = tokens[0], tokens[1], tokens[2]
|
128 |
+
tokens = roberta.encode(sent1, sent2)
|
129 |
+
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
|
130 |
+
prediction_label = label_fn(prediction)
|
131 |
+
ncorrect += int(prediction_label == target)
|
132 |
+
nsamples += 1
|
133 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
134 |
+
|
135 |
+
```
|