FZH1996
commited on
Commit
·
e7d695a
1
Parent(s):
fe45bc3
update fed-lora
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- examples/NLG/eval/GenerationEval/bleurt +1 -0
- examples/NLG/eval/GenerationEval/metrics/bleurt +1 -0
- examples/NLG/eval/e2e/metrics/__pycache__/__init__.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/metrics/__pycache__/pymteval.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/__pycache__/__init__.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/__pycache__/eval.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/__init__.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/bleu.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/__init__.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/cider.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/cider_scorer.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/meteor/__pycache__/__init__.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/meteor/__pycache__/meteor.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/rouge/__pycache__/__init__.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/rouge/__pycache__/rouge.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/tokenizer/__pycache__/__init__.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocoevalcap/tokenizer/__pycache__/ptbtokenizer.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocotools/__pycache__/__init__.cpython-36.pyc +0 -0
- examples/NLG/eval/e2e/pycocotools/__pycache__/__init__.cpython-37.pyc +0 -0
- examples/NLG/eval/e2e/pycocotools/__pycache__/coco.cpython-36.pyc +0 -0
- examples/NLG/eval/e2e/pycocotools/__pycache__/coco.cpython-37.pyc +0 -0
- examples/NLG/src/.DS_Store +0 -0
- examples/NLG/src/__pycache__/data_utils.cpython-310.pyc +0 -0
- examples/NLG/src/__pycache__/data_utils.cpython-36.pyc +0 -0
- examples/NLG/src/__pycache__/data_utils.cpython-37.pyc +0 -0
- examples/NLG/src/__pycache__/encoder.cpython-37.pyc +0 -0
- examples/NLG/src/__pycache__/exp_utils.cpython-310.pyc +0 -0
- examples/NLG/src/__pycache__/exp_utils.cpython-37.pyc +0 -0
- examples/NLG/src/__pycache__/gpu.cpython-310.pyc +0 -0
- examples/NLG/src/__pycache__/gpu.cpython-36.pyc +0 -0
- examples/NLG/src/__pycache__/gpu.cpython-37.pyc +0 -0
- examples/NLG/src/__pycache__/model.cpython-310.pyc +0 -0
- examples/NLG/src/__pycache__/model.cpython-36.pyc +0 -0
- examples/NLG/src/__pycache__/model.cpython-37.pyc +0 -0
- examples/NLG/src/__pycache__/optimizer.cpython-36.pyc +0 -0
- examples/NLG/src/__pycache__/optimizer.cpython-37.pyc +0 -0
- examples/NLG/src/data_utils.py +282 -0
- examples/NLG/src/encoder.py +132 -0
- examples/NLG/src/exp_utils.py +46 -0
- examples/NLG/src/format_converting_dart.py +43 -0
- examples/NLG/src/format_converting_e2e.py +20 -0
- examples/NLG/src/format_converting_webnlg.py +68 -0
- examples/NLG/src/gpt2_beam.py +419 -0
- examples/NLG/src/gpt2_decode.py +187 -0
- examples/NLG/src/gpt2_encode.py +70 -0
- examples/NLG/src/gpt2_ft.py +385 -0
- examples/NLG/src/gpu.py +129 -0
- examples/NLG/src/model.log +698 -0
- examples/NLG/src/model.py +460 -0
examples/NLG/eval/GenerationEval/bleurt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit cebe7e6f996b40910cfaa520a63db47807e3bf5c
|
examples/NLG/eval/GenerationEval/metrics/bleurt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit cebe7e6f996b40910cfaa520a63db47807e3bf5c
|
examples/NLG/eval/e2e/metrics/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (169 Bytes). View file
|
|
examples/NLG/eval/e2e/metrics/__pycache__/pymteval.cpython-37.pyc
ADDED
Binary file (12.9 kB). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (195 Bytes). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/__pycache__/eval.cpython-37.pyc
ADDED
Binary file (2.57 kB). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (200 Bytes). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/bleu.cpython-37.pyc
ADDED
Binary file (1.24 kB). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-37.pyc
ADDED
Binary file (8.07 kB). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (201 Bytes). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/cider.cpython-37.pyc
ADDED
Binary file (1.67 kB). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/cider_scorer.cpython-37.pyc
ADDED
Binary file (7.85 kB). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/meteor/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (202 Bytes). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/meteor/__pycache__/meteor.cpython-37.pyc
ADDED
Binary file (2.75 kB). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/rouge/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (203 Bytes). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/rouge/__pycache__/rouge.cpython-37.pyc
ADDED
Binary file (3.75 kB). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/tokenizer/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (205 Bytes). View file
|
|
examples/NLG/eval/e2e/pycocoevalcap/tokenizer/__pycache__/ptbtokenizer.cpython-37.pyc
ADDED
Binary file (2.18 kB). View file
|
|
examples/NLG/eval/e2e/pycocotools/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (189 Bytes). View file
|
|
examples/NLG/eval/e2e/pycocotools/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (193 Bytes). View file
|
|
examples/NLG/eval/e2e/pycocotools/__pycache__/coco.cpython-36.pyc
ADDED
Binary file (13.4 kB). View file
|
|
examples/NLG/eval/e2e/pycocotools/__pycache__/coco.cpython-37.pyc
ADDED
Binary file (13.4 kB). View file
|
|
examples/NLG/src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
examples/NLG/src/__pycache__/data_utils.cpython-310.pyc
ADDED
Binary file (8.49 kB). View file
|
|
examples/NLG/src/__pycache__/data_utils.cpython-36.pyc
ADDED
Binary file (8.58 kB). View file
|
|
examples/NLG/src/__pycache__/data_utils.cpython-37.pyc
ADDED
Binary file (8.58 kB). View file
|
|
examples/NLG/src/__pycache__/encoder.cpython-37.pyc
ADDED
Binary file (5.1 kB). View file
|
|
examples/NLG/src/__pycache__/exp_utils.cpython-310.pyc
ADDED
Binary file (1.49 kB). View file
|
|
examples/NLG/src/__pycache__/exp_utils.cpython-37.pyc
ADDED
Binary file (1.44 kB). View file
|
|
examples/NLG/src/__pycache__/gpu.cpython-310.pyc
ADDED
Binary file (3.58 kB). View file
|
|
examples/NLG/src/__pycache__/gpu.cpython-36.pyc
ADDED
Binary file (3.53 kB). View file
|
|
examples/NLG/src/__pycache__/gpu.cpython-37.pyc
ADDED
Binary file (3.54 kB). View file
|
|
examples/NLG/src/__pycache__/model.cpython-310.pyc
ADDED
Binary file (13.3 kB). View file
|
|
examples/NLG/src/__pycache__/model.cpython-36.pyc
ADDED
Binary file (13.7 kB). View file
|
|
examples/NLG/src/__pycache__/model.cpython-37.pyc
ADDED
Binary file (13.5 kB). View file
|
|
examples/NLG/src/__pycache__/optimizer.cpython-36.pyc
ADDED
Binary file (11.4 kB). View file
|
|
examples/NLG/src/__pycache__/optimizer.cpython-37.pyc
ADDED
Binary file (11.4 kB). View file
|
|
examples/NLG/src/data_utils.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import os, sys
|
6 |
+
import glob
|
7 |
+
import random
|
8 |
+
from collections import Counter, OrderedDict
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import json
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
|
17 |
+
|
18 |
+
class LMOrderedIterator(object):
|
19 |
+
def __init__(self, data, bsz, bptt, eval_len=None, device='cpu', world_size=1, rank=0):
|
20 |
+
"""
|
21 |
+
data -- LongTensor -- the LongTensor is strictly ordered
|
22 |
+
"""
|
23 |
+
self.data = data
|
24 |
+
self.bsz = bsz
|
25 |
+
self.world_size = world_size
|
26 |
+
self.rank = rank
|
27 |
+
self.bptt = bptt # tgt_len
|
28 |
+
# existing len.
|
29 |
+
self.eval_len = bptt if eval_len is None else eval_len
|
30 |
+
|
31 |
+
self.device = device
|
32 |
+
|
33 |
+
self.global_bsz = bsz * world_size
|
34 |
+
# Work out how cleanly we can divide the dataset into bsz parts.
|
35 |
+
self.n_step = len(data) // self.global_bsz # bsz
|
36 |
+
|
37 |
+
self.split_data = torch.tensor(
|
38 |
+
data[rank * self.n_step * bsz : (rank + 1) * self.n_step * bsz],
|
39 |
+
dtype=torch.long, device=self.device
|
40 |
+
) # data.view(-1)
|
41 |
+
|
42 |
+
self.split_data = self.split_data.view(bsz, -1)
|
43 |
+
|
44 |
+
def __iter__(self):
|
45 |
+
return self.get_fixlen_iter()
|
46 |
+
|
47 |
+
def get_batch(self, i, bptt, eval_len):
|
48 |
+
beg_idx = i
|
49 |
+
end_idx = i + bptt # seq_len
|
50 |
+
|
51 |
+
# batch_size, lengh;
|
52 |
+
_input = self.split_data[:, beg_idx : end_idx].contiguous()
|
53 |
+
_target = self.split_data[:, beg_idx+1 : end_idx+1].contiguous()
|
54 |
+
|
55 |
+
_msk = torch.cat(
|
56 |
+
[
|
57 |
+
torch.zeros(bptt-eval_len, dtype=torch.float, device=self.device),
|
58 |
+
torch.ones(eval_len, dtype=torch.float, device=self.device)
|
59 |
+
]
|
60 |
+
)
|
61 |
+
_msk = _msk.unsqueeze(0).expand_as(_input) # .unsqueeze(-1) # length, 1;
|
62 |
+
return _input, _target, _msk
|
63 |
+
|
64 |
+
def get_fixlen_iter(self, start=0):
|
65 |
+
self.data_len = self.split_data.size(1)
|
66 |
+
_eval_cursor = 0
|
67 |
+
for i in range(start, self.data_len - 1, self.eval_len):
|
68 |
+
bptt = min(self.bptt, self.data_len - i - 1)
|
69 |
+
_end_idx = i + bptt
|
70 |
+
yield self.get_batch(i, bptt, _end_idx - _eval_cursor)
|
71 |
+
_eval_cursor = _end_idx
|
72 |
+
|
73 |
+
|
74 |
+
class Corpus(object):
|
75 |
+
def __init__(self, path):
|
76 |
+
self.path = path
|
77 |
+
self.num_words = 0
|
78 |
+
self.tokens = []
|
79 |
+
with open(self.path, "r") as reader:
|
80 |
+
for line in reader:
|
81 |
+
items = json.loads(line.strip())
|
82 |
+
book = items['book']
|
83 |
+
tokens = items['tokens']
|
84 |
+
num_words = items['num_words']
|
85 |
+
|
86 |
+
self.num_words += num_words
|
87 |
+
self.tokens.extend(tokens)
|
88 |
+
|
89 |
+
|
90 |
+
class BinLMOrderedIterator(object):
|
91 |
+
def __init__(self, corpus, bsz, bptt, eval_len=None, device='cpu', world_size=1, rank=0):
|
92 |
+
"""
|
93 |
+
data -- LongTensor -- the LongTensor is strictly ordered
|
94 |
+
"""
|
95 |
+
self.corpus = corpus
|
96 |
+
self.bsz = bsz
|
97 |
+
self.world_size = world_size
|
98 |
+
self.rank = rank
|
99 |
+
self.bptt = bptt # tgt_len
|
100 |
+
# existing len.
|
101 |
+
self.eval_len = bptt if eval_len is None else eval_len
|
102 |
+
self.device = device
|
103 |
+
self.global_bsz = bsz * world_size
|
104 |
+
# Work out how cleanly we can divide the dataset into bsz parts.
|
105 |
+
self.n_step = corpus.length // self.global_bsz # bsz
|
106 |
+
|
107 |
+
self.offset = [(rank * bsz + _b) * self.n_step for _b in range(bsz)]
|
108 |
+
|
109 |
+
def __iter__(self):
|
110 |
+
return self.get_fixlen_iter()
|
111 |
+
|
112 |
+
def get_batch(self, i, bptt, eval_len):
|
113 |
+
# batch_size, lengh;
|
114 |
+
_inputs = []
|
115 |
+
_targets = []
|
116 |
+
for _b in range(0, self.bsz):
|
117 |
+
_input = self.corpus.get_tokens(self.offset[_b] + i, bptt)
|
118 |
+
_target = self.corpus.get_tokens(self.offset[_b] + i + 1, bptt)
|
119 |
+
|
120 |
+
_inputs.append(_input)
|
121 |
+
_targets.append(_target)
|
122 |
+
|
123 |
+
_input = torch.tensor(_inputs, dtype=torch.int64, device=self.device).contiguous()
|
124 |
+
_target = torch.tensor(_targets, dtype=torch.int64, device=self.device).contiguous()
|
125 |
+
|
126 |
+
_msk = torch.cat(
|
127 |
+
[
|
128 |
+
torch.zeros(bptt-eval_len, dtype=torch.float, device=self.device),
|
129 |
+
torch.ones(eval_len, dtype=torch.float, device=self.device)
|
130 |
+
]
|
131 |
+
)
|
132 |
+
_msk = _msk.unsqueeze(0).expand_as(_input) # .unsqueeze(-1) # length, 1;
|
133 |
+
return _input, _target, _msk
|
134 |
+
|
135 |
+
def get_fixlen_iter(self, start=0):
|
136 |
+
#self.data_len = self.split_data.size(1)
|
137 |
+
_eval_cursor = 0
|
138 |
+
for i in range(start, self.n_step - 1, self.eval_len):
|
139 |
+
bptt = min(self.bptt, self.n_step - i - 1)
|
140 |
+
_end_idx = i + bptt
|
141 |
+
yield self.get_batch(i, bptt, _end_idx - _eval_cursor)
|
142 |
+
_eval_cursor = _end_idx
|
143 |
+
|
144 |
+
|
145 |
+
class BinCorpus(object):
|
146 |
+
def __init__(self, path):
|
147 |
+
self.path = path
|
148 |
+
|
149 |
+
self.book_token_span = []
|
150 |
+
self.book_token_span.append(0)
|
151 |
+
tokens_sum = 0
|
152 |
+
self.num_words = 0
|
153 |
+
|
154 |
+
with open(path+'.info', 'r') as info_reader:
|
155 |
+
for line in info_reader:
|
156 |
+
items = json.loads(line.strip())
|
157 |
+
book = items['book']
|
158 |
+
num_tokens = items['num_subtokens']
|
159 |
+
num_words = items['num_words']
|
160 |
+
|
161 |
+
tokens_sum += num_tokens
|
162 |
+
self.book_token_span.append(tokens_sum)
|
163 |
+
self.num_words += num_words
|
164 |
+
|
165 |
+
self.length = self.book_token_span[-1]
|
166 |
+
self.bin_reader = open(path+'.bin', 'rb')
|
167 |
+
|
168 |
+
def get_tokens(self, offset, count):
|
169 |
+
INT64_SIZE = 8
|
170 |
+
self.bin_reader.seek(offset * INT64_SIZE)
|
171 |
+
x = np.fromfile(self.bin_reader, count=count, dtype=np.int)
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
def get_lm_corpus(data):
|
176 |
+
print('Producing dataset {}...'.format(data))
|
177 |
+
corpus = Corpus(data)
|
178 |
+
return corpus
|
179 |
+
|
180 |
+
|
181 |
+
def padding_tokens(tokens, max_seq_length, pad_token, direct, max_context_length=0):
|
182 |
+
|
183 |
+
if max_context_length == 0:
|
184 |
+
max_context_length = max_seq_length
|
185 |
+
|
186 |
+
if len(tokens) > max_context_length:
|
187 |
+
if direct > 0:
|
188 |
+
pad_tokens = tokens[:max_context_length]
|
189 |
+
else:
|
190 |
+
pad_tokens = tokens[-max_context_length:]
|
191 |
+
else:
|
192 |
+
pad_tokens = tokens
|
193 |
+
token_len = len(pad_tokens)
|
194 |
+
pad_tokens = pad_tokens + [pad_token for _ in range(max_seq_length - token_len)]
|
195 |
+
return pad_tokens, token_len
|
196 |
+
|
197 |
+
|
198 |
+
class FT_Dataset(Dataset):
|
199 |
+
def __init__(self, ft_file, batch_size, max_seq_length,
|
200 |
+
max_eval_length=0, joint_lm=False, prefix_len=0, infix_len=0,
|
201 |
+
prefix_cursor=1000000, infix_cursor=2000000):
|
202 |
+
self.ft_file = ft_file
|
203 |
+
self.ft_samples = self.read_ft_file(ft_file)
|
204 |
+
self.batch_size = batch_size
|
205 |
+
self.num_examples = len(self.ft_samples)
|
206 |
+
self.max_seq_length = max_seq_length
|
207 |
+
self.max_eval_length = max_eval_length
|
208 |
+
self.rng = random.Random(911)
|
209 |
+
self.joint_lm = joint_lm
|
210 |
+
|
211 |
+
self.num_batches = int((self.num_examples + self.batch_size - 1) / self.batch_size)
|
212 |
+
|
213 |
+
self.prefix_len = prefix_len
|
214 |
+
self.infix_len = infix_len
|
215 |
+
self.prefix_cursor = prefix_cursor
|
216 |
+
self.infix_cursor = infix_cursor
|
217 |
+
|
218 |
+
def __len__(self):
|
219 |
+
return self.num_batches * self.batch_size
|
220 |
+
|
221 |
+
def __getitem__(self, item):
|
222 |
+
if(item >= self.num_examples):
|
223 |
+
item = self.rng.randint(0, self.num_examples - 1)
|
224 |
+
|
225 |
+
example = self.ft_samples[item]
|
226 |
+
context = example[0]
|
227 |
+
completion = example[1]
|
228 |
+
|
229 |
+
pretokens = [i + self.prefix_cursor for i in range(0, self.prefix_len)]
|
230 |
+
intokens = [i + self.infix_cursor for i in range(0, self.infix_len)]
|
231 |
+
|
232 |
+
conditions = pretokens + context + intokens
|
233 |
+
_input, _input_len = padding_tokens(conditions + completion, self.max_seq_length, 0, 1)
|
234 |
+
|
235 |
+
pad_targets = [0 for i in range(0, self.prefix_len)] + context + [0 for i in range(0, self.infix_len)] + completion
|
236 |
+
_target, _ = padding_tokens(pad_targets[1:], self.max_seq_length, 0, 1)
|
237 |
+
|
238 |
+
if not self.joint_lm:
|
239 |
+
_msk = [0.0] * (len(conditions) - 1) + [1.0] * (_input_len - len(conditions))
|
240 |
+
else:
|
241 |
+
_msk = [1.0] * (_input_len - 1)
|
242 |
+
|
243 |
+
_msk, _ = padding_tokens(_msk, self.max_seq_length, 0.0, 1)
|
244 |
+
|
245 |
+
output = {}
|
246 |
+
output["id"] = torch.tensor(item, dtype=torch.long)
|
247 |
+
|
248 |
+
_query, _query_len = padding_tokens(
|
249 |
+
conditions, self.max_seq_length, 0, -1,
|
250 |
+
max_context_length = self.max_seq_length - self.max_eval_length
|
251 |
+
)
|
252 |
+
output["query"] = torch.tensor(_query, dtype=torch.long)
|
253 |
+
output["query_len"] = torch.tensor(_query_len, dtype=torch.long)
|
254 |
+
|
255 |
+
output["input"] = torch.tensor(_input, dtype=torch.long)
|
256 |
+
output["target"] = torch.tensor(_target, dtype=torch.long)
|
257 |
+
|
258 |
+
output["mask"] = torch.tensor(_msk, dtype=torch.float)
|
259 |
+
return output
|
260 |
+
|
261 |
+
def read_ft_file(self, ft_file):
|
262 |
+
ft_samples = []
|
263 |
+
with open(ft_file, 'r') as reader:
|
264 |
+
for line in reader:
|
265 |
+
items = json.loads(line.strip())
|
266 |
+
context = items['context']
|
267 |
+
completion = items['completion']
|
268 |
+
ft_samples.append([context, completion])
|
269 |
+
return ft_samples
|
270 |
+
|
271 |
+
def get_item_list(self, start, interval):
|
272 |
+
start = min(start, self.num_examples-1)
|
273 |
+
start = max(0,start)
|
274 |
+
if(start + interval >= self.num_examples):
|
275 |
+
end = self.num_examples
|
276 |
+
else:
|
277 |
+
end = start + interval
|
278 |
+
samples = []
|
279 |
+
for index in range(start, end):
|
280 |
+
output = self.__getitem__(index)
|
281 |
+
samples.append(output)
|
282 |
+
return samples
|
examples/NLG/src/encoder.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import regex as re
|
8 |
+
from functools import lru_cache
|
9 |
+
|
10 |
+
|
11 |
+
@lru_cache()
|
12 |
+
def bytes_to_unicode():
|
13 |
+
"""
|
14 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
15 |
+
The reversible bpe codes work on unicode strings.
|
16 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
17 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
18 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
19 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
20 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
21 |
+
"""
|
22 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
23 |
+
cs = bs[:]
|
24 |
+
n = 0
|
25 |
+
for b in range(2**8):
|
26 |
+
if b not in bs:
|
27 |
+
bs.append(b)
|
28 |
+
cs.append(2**8+n)
|
29 |
+
n += 1
|
30 |
+
cs = [chr(n) for n in cs]
|
31 |
+
return dict(zip(bs, cs))
|
32 |
+
|
33 |
+
|
34 |
+
def get_pairs(word):
|
35 |
+
"""Return set of symbol pairs in a word.
|
36 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
37 |
+
"""
|
38 |
+
pairs = set()
|
39 |
+
prev_char = word[0]
|
40 |
+
for char in word[1:]:
|
41 |
+
pairs.add((prev_char, char))
|
42 |
+
prev_char = char
|
43 |
+
return pairs
|
44 |
+
|
45 |
+
|
46 |
+
class Encoder:
|
47 |
+
|
48 |
+
def __init__(self, encoder, bpe_merges, errors='replace'):
|
49 |
+
self.encoder = encoder
|
50 |
+
self.decoder = {v:k for k,v in self.encoder.items()}
|
51 |
+
self.errors = errors # how to handle errors in decoding
|
52 |
+
self.byte_encoder = bytes_to_unicode()
|
53 |
+
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
54 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
55 |
+
self.cache = {}
|
56 |
+
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
57 |
+
try:
|
58 |
+
import regex as re
|
59 |
+
self.re = re
|
60 |
+
except ImportError:
|
61 |
+
raise ImportError('Please install regex with: pip install regex')
|
62 |
+
|
63 |
+
|
64 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
65 |
+
|
66 |
+
def bpe(self, token):
|
67 |
+
if token in self.cache:
|
68 |
+
return self.cache[token]
|
69 |
+
word = tuple(token)
|
70 |
+
pairs = get_pairs(word)
|
71 |
+
|
72 |
+
if not pairs:
|
73 |
+
return token
|
74 |
+
|
75 |
+
while True:
|
76 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
77 |
+
if bigram not in self.bpe_ranks:
|
78 |
+
break
|
79 |
+
first, second = bigram
|
80 |
+
new_word = []
|
81 |
+
i = 0
|
82 |
+
while i < len(word):
|
83 |
+
try:
|
84 |
+
j = word.index(first, i)
|
85 |
+
new_word.extend(word[i:j])
|
86 |
+
i = j
|
87 |
+
except:
|
88 |
+
new_word.extend(word[i:])
|
89 |
+
break
|
90 |
+
|
91 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
92 |
+
new_word.append(first+second)
|
93 |
+
i += 2
|
94 |
+
else:
|
95 |
+
new_word.append(word[i])
|
96 |
+
i += 1
|
97 |
+
new_word = tuple(new_word)
|
98 |
+
word = new_word
|
99 |
+
if len(word) == 1:
|
100 |
+
break
|
101 |
+
else:
|
102 |
+
pairs = get_pairs(word)
|
103 |
+
word = ' '.join(word)
|
104 |
+
self.cache[token] = word
|
105 |
+
return word
|
106 |
+
|
107 |
+
def encode(self, text):
|
108 |
+
bpe_tokens = []
|
109 |
+
tokens = []
|
110 |
+
for token in re.findall(self.pat, text):
|
111 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
112 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
113 |
+
if token:
|
114 |
+
tokens.append(token)
|
115 |
+
return bpe_tokens, tokens
|
116 |
+
|
117 |
+
def decode(self, tokens):
|
118 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
119 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
120 |
+
return text
|
121 |
+
|
122 |
+
|
123 |
+
def get_encoder(models_dir):
|
124 |
+
with open(os.path.join(models_dir, 'encoder.json'), 'r') as f:
|
125 |
+
encoder = json.load(f)
|
126 |
+
with open(os.path.join(models_dir, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
127 |
+
bpe_data = f.read()
|
128 |
+
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
129 |
+
return Encoder(
|
130 |
+
encoder=encoder,
|
131 |
+
bpe_merges=bpe_merges,
|
132 |
+
)
|
examples/NLG/src/exp_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import functools
|
6 |
+
import os, shutil
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def logging(s, log_path, print_=True, log_=True):
|
13 |
+
if print_:
|
14 |
+
print(s)
|
15 |
+
if log_:
|
16 |
+
with open(log_path, 'a+') as f_log:
|
17 |
+
f_log.write(s + '\n')
|
18 |
+
|
19 |
+
|
20 |
+
def get_logger(log_path, **kwargs):
|
21 |
+
return functools.partial(logging, log_path=log_path, **kwargs)
|
22 |
+
|
23 |
+
|
24 |
+
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
25 |
+
if debug:
|
26 |
+
print('Debug Mode : no experiment dir created')
|
27 |
+
return functools.partial(logging, log_path=None, log_=False)
|
28 |
+
|
29 |
+
if not os.path.exists(dir_path):
|
30 |
+
os.makedirs(dir_path)
|
31 |
+
|
32 |
+
print('Experiment dir : {}'.format(dir_path))
|
33 |
+
if scripts_to_save is not None:
|
34 |
+
script_path = os.path.join(dir_path, 'scripts')
|
35 |
+
if not os.path.exists(script_path):
|
36 |
+
os.makedirs(script_path)
|
37 |
+
for script in scripts_to_save:
|
38 |
+
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
|
39 |
+
shutil.copyfile(script, dst_file)
|
40 |
+
|
41 |
+
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
42 |
+
|
43 |
+
|
44 |
+
def save_checkpoint(model, optimizer, path, epoch):
|
45 |
+
torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
|
46 |
+
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
|
examples/NLG/src/format_converting_dart.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import sys
|
6 |
+
import io
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
with open(sys.argv[1], 'r', encoding='utf8') as reader, \
|
11 |
+
open(sys.argv[2], 'w', encoding='utf8') as writer :
|
12 |
+
lines_dict = json.load(reader)
|
13 |
+
|
14 |
+
full_rela_lst = []
|
15 |
+
full_src_lst = []
|
16 |
+
full_tgt_lst = []
|
17 |
+
unique_src = 0
|
18 |
+
|
19 |
+
for example in lines_dict:
|
20 |
+
rela_lst = []
|
21 |
+
temp_triples = ''
|
22 |
+
for i, tripleset in enumerate(example['tripleset']):
|
23 |
+
subj, rela, obj = tripleset
|
24 |
+
rela = rela.lower()
|
25 |
+
rela_lst.append(rela)
|
26 |
+
if i > 0:
|
27 |
+
temp_triples += ' | '
|
28 |
+
temp_triples += '{} : {} : {}'.format(subj, rela, obj)
|
29 |
+
|
30 |
+
unique_src += 1
|
31 |
+
|
32 |
+
for sent in example['annotations']:
|
33 |
+
full_tgt_lst.append(sent['text'])
|
34 |
+
full_src_lst.append(temp_triples)
|
35 |
+
full_rela_lst.append(rela_lst)
|
36 |
+
|
37 |
+
print('unique source is', unique_src)
|
38 |
+
|
39 |
+
for src, tgt in zip(full_src_lst, full_tgt_lst):
|
40 |
+
x = {}
|
41 |
+
x['context'] = src # context #+ '||'
|
42 |
+
x['completion'] = tgt #completion
|
43 |
+
writer.write(json.dumps(x)+'\n')
|
examples/NLG/src/format_converting_e2e.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import sys
|
6 |
+
import io
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
with open(sys.argv[1], 'r', encoding='utf8') as reader, \
|
11 |
+
open(sys.argv[2], 'w', encoding='utf8') as writer :
|
12 |
+
for line in reader:
|
13 |
+
items = line.strip().split('||')
|
14 |
+
context = items[0]
|
15 |
+
completion = items[1].strip('\n')
|
16 |
+
x = {}
|
17 |
+
x['context'] = context #+ '||'
|
18 |
+
x['completion'] = completion
|
19 |
+
writer.write(json.dumps(x)+'\n')
|
20 |
+
|
examples/NLG/src/format_converting_webnlg.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import sys
|
6 |
+
import io
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
with open(sys.argv[1], 'r', encoding='utf8') as reader, \
|
11 |
+
open(sys.argv[2], 'w', encoding='utf8') as writer :
|
12 |
+
lines_dict = json.load(reader)
|
13 |
+
|
14 |
+
full_rela_lst = []
|
15 |
+
full_src_lst = []
|
16 |
+
full_tgt_lst = []
|
17 |
+
full_cate_lst = []
|
18 |
+
|
19 |
+
seen = [
|
20 |
+
'Airport',
|
21 |
+
'Astronaut',
|
22 |
+
'Building',
|
23 |
+
'City',
|
24 |
+
'ComicsCharacter',
|
25 |
+
'Food',
|
26 |
+
'Monument',
|
27 |
+
'SportsTeam',
|
28 |
+
'University',
|
29 |
+
'WrittenWork'
|
30 |
+
]
|
31 |
+
|
32 |
+
cate_dict = {}
|
33 |
+
for i, example in enumerate(lines_dict['entries']):
|
34 |
+
sents = example[str(i+1)]['lexicalisations']
|
35 |
+
triples = example[str(i + 1)]['modifiedtripleset']
|
36 |
+
cate = example[str(i + 1)]['category']
|
37 |
+
|
38 |
+
if not cate in cate_dict:
|
39 |
+
cate_dict[cate] = 0
|
40 |
+
cate_dict[cate] += 1
|
41 |
+
|
42 |
+
rela_lst = []
|
43 |
+
temp_triples = ''
|
44 |
+
for i, tripleset in enumerate(triples):
|
45 |
+
subj, rela, obj = tripleset['subject'], tripleset['property'], tripleset['object']
|
46 |
+
rela_lst.append(rela)
|
47 |
+
if i > 0:
|
48 |
+
temp_triples += ' | '
|
49 |
+
temp_triples += '{} : {} : {}'.format(subj, rela, obj)
|
50 |
+
|
51 |
+
for sent in sents:
|
52 |
+
if sent["comment"] == 'good':
|
53 |
+
full_tgt_lst.append(sent['lex'])
|
54 |
+
full_src_lst.append(temp_triples)
|
55 |
+
full_rela_lst.append(rela_lst)
|
56 |
+
full_cate_lst.append(cate)
|
57 |
+
|
58 |
+
for cate in cate_dict:
|
59 |
+
print('cate', cate, cate_dict[cate])
|
60 |
+
|
61 |
+
#edited_sents = []
|
62 |
+
for src, tgt, cate in zip(full_src_lst, full_tgt_lst, full_cate_lst):
|
63 |
+
x = {}
|
64 |
+
x['context'] = src # context #+ '||'
|
65 |
+
x['completion'] = tgt #completion
|
66 |
+
x['cate'] = cate in seen
|
67 |
+
writer.write(json.dumps(x)+'\n')
|
68 |
+
|
examples/NLG/src/gpt2_beam.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
# python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_beam.py \
|
7 |
+
# --data ./data/e2e/test.jsonl \
|
8 |
+
# --batch_size 1 \
|
9 |
+
# --seq_len 512 \
|
10 |
+
# --eval_len 64 \
|
11 |
+
# --model_card gpt2.md \
|
12 |
+
# --platform local \
|
13 |
+
# --beam 10 \
|
14 |
+
# --length_penalty 0.8 \
|
15 |
+
# --no_repeat_ngram_size 4 \
|
16 |
+
# --repetition_penalty 1.0 \
|
17 |
+
# --eos_token_id 628 \
|
18 |
+
# --lora_dim 4 \
|
19 |
+
# --lora_alpha 32 \
|
20 |
+
# --work_dir ./trained_models/GPT2_M/e2e \
|
21 |
+
# --output_file predict.26290.jsonl \
|
22 |
+
# --init_checkpoint ./trained_models/GPT2_M/e2e/model.26290.pt
|
23 |
+
|
24 |
+
|
25 |
+
import argparse
|
26 |
+
import time
|
27 |
+
import math
|
28 |
+
import os, sys
|
29 |
+
import json
|
30 |
+
import itertools
|
31 |
+
from typing import Callable, Dict, Iterable, List, Optional, Tuple
|
32 |
+
|
33 |
+
import torch
|
34 |
+
from torch import Tensor, device, dtype, nn
|
35 |
+
from torch.nn import CrossEntropyLoss
|
36 |
+
from torch.nn import functional as F
|
37 |
+
from torch.utils.data import DataLoader
|
38 |
+
import torch.nn.functional as F
|
39 |
+
torch.set_printoptions(threshold=100000)
|
40 |
+
|
41 |
+
import numpy as np
|
42 |
+
|
43 |
+
from gpu import (
|
44 |
+
add_gpu_params,
|
45 |
+
parse_gpu,
|
46 |
+
distributed_opt,
|
47 |
+
distributed_gather,
|
48 |
+
distributed_sync,
|
49 |
+
cleanup
|
50 |
+
)
|
51 |
+
|
52 |
+
from exp_utils import create_exp_dir
|
53 |
+
|
54 |
+
from data_utils import FT_Dataset
|
55 |
+
from model import GPT2Config, GPT2LMModel
|
56 |
+
|
57 |
+
|
58 |
+
parser = argparse.ArgumentParser(description='PyTorch GPT2 beam decoding')
|
59 |
+
|
60 |
+
add_gpu_params(parser)
|
61 |
+
|
62 |
+
parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
63 |
+
help='location of the data corpus')
|
64 |
+
|
65 |
+
parser.add_argument('--batch_size', type=int, default=10,
|
66 |
+
help='batch size')
|
67 |
+
|
68 |
+
parser.add_argument('--seq_len', type=int, default=512,
|
69 |
+
help='number of tokens to predict')
|
70 |
+
|
71 |
+
parser.add_argument('--eval_len', type=int, default=256,
|
72 |
+
help='evaluation length')
|
73 |
+
|
74 |
+
parser.add_argument('--min_length', type=int, default=0,
|
75 |
+
help='minimum generation length')
|
76 |
+
|
77 |
+
parser.add_argument('--model_card', default='gpt2.sm', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
|
78 |
+
help='model names')
|
79 |
+
|
80 |
+
parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')
|
81 |
+
|
82 |
+
parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')
|
83 |
+
|
84 |
+
parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')
|
85 |
+
|
86 |
+
parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'),
|
87 |
+
help='working folder')
|
88 |
+
|
89 |
+
parser.add_argument('--beam', type=int, default=1, help='beam search size')
|
90 |
+
|
91 |
+
parser.add_argument('--length_penalty', type=float, default=1.0, help='length penalty')
|
92 |
+
|
93 |
+
parser.add_argument('--no_repeat_ngram_size', type=int, default=4, help='no_repeat_ngram_size')
|
94 |
+
|
95 |
+
parser.add_argument('--repetition_penalty', type=float, default=1.0, help='repetition_penalty')
|
96 |
+
|
97 |
+
parser.add_argument('--eos_token_id', action='append', type=int, default=[50256],
|
98 |
+
help='eos token id')
|
99 |
+
|
100 |
+
parser.add_argument('--output_file', type=str, default='beam_prediction.jsonl',
|
101 |
+
help='output file name')
|
102 |
+
|
103 |
+
|
104 |
+
def print_args(args):
|
105 |
+
if args.rank == 0:
|
106 |
+
print('=' * 100)
|
107 |
+
for k, v in args.__dict__.items():
|
108 |
+
print(' - {} : {}'.format(k, v))
|
109 |
+
print('=' * 100)
|
110 |
+
|
111 |
+
|
112 |
+
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
|
113 |
+
return tuple(layer_past.index_select(1, beam_idx).contiguous().detach() for layer_past in past)
|
114 |
+
|
115 |
+
|
116 |
+
def _calc_banned_ngram_tokens(
|
117 |
+
prev_input_ids: Tensor,
|
118 |
+
num_hypos: int,
|
119 |
+
no_repeat_ngram_size: int,
|
120 |
+
cur_len: int
|
121 |
+
) -> None:
|
122 |
+
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
123 |
+
if cur_len + 1 < no_repeat_ngram_size:
|
124 |
+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
125 |
+
return [[] for _ in range(num_hypos)]
|
126 |
+
|
127 |
+
generated_ngrams = [{} for _ in range(num_hypos)]
|
128 |
+
for idx in range(num_hypos):
|
129 |
+
gen_tokens = prev_input_ids[idx].tolist()
|
130 |
+
generated_ngram = generated_ngrams[idx]
|
131 |
+
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
132 |
+
prev_ngram_tuple = tuple(ngram[:-1])
|
133 |
+
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
134 |
+
|
135 |
+
def _get_generated_ngrams(hypo_idx):
|
136 |
+
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
137 |
+
start_idx = cur_len + 1 - no_repeat_ngram_size
|
138 |
+
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
139 |
+
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
140 |
+
|
141 |
+
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
142 |
+
return banned_tokens
|
143 |
+
|
144 |
+
|
145 |
+
def _enforce_repetition_penalty_(
|
146 |
+
lprobs,
|
147 |
+
batch_size,
|
148 |
+
num_beams,
|
149 |
+
prev_output_tokens,
|
150 |
+
repetition_penalty
|
151 |
+
):
|
152 |
+
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
|
153 |
+
|
154 |
+
for i in range(batch_size * num_beams):
|
155 |
+
print('prev_output_tokens.shape', prev_output_tokens.shape)
|
156 |
+
print('prev_output_tokens[i].shape', prev_output_tokens[i].shape)
|
157 |
+
|
158 |
+
for previous_token in set(prev_output_tokens[i].tolist()):
|
159 |
+
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
160 |
+
if lprobs[i, previous_token] < 0:
|
161 |
+
lprobs[i, previous_token] *= repetition_penalty
|
162 |
+
else:
|
163 |
+
lprobs[i, previous_token] /= repetition_penalty
|
164 |
+
|
165 |
+
def _postprocess_next_token_scores(
|
166 |
+
scores,
|
167 |
+
history,
|
168 |
+
cur_len,
|
169 |
+
batch_size,
|
170 |
+
num_beams,
|
171 |
+
repetition_penalty=1.0,
|
172 |
+
no_repeat_ngram_size=4,
|
173 |
+
bad_words_ids=None,
|
174 |
+
min_length=0,
|
175 |
+
max_length=100,
|
176 |
+
eos_token_id=None,
|
177 |
+
):
|
178 |
+
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
179 |
+
if repetition_penalty != 1.0 and history is not None:
|
180 |
+
_enforce_repetition_penalty_(scores, batch_size, num_beams, history, repetition_penalty)
|
181 |
+
|
182 |
+
# score: batch_size * beam, vocab
|
183 |
+
# set eos token prob to zero if min_length is not reached
|
184 |
+
if eos_token_id is not None and cur_len < min_length:
|
185 |
+
for eos in eos_token_id:
|
186 |
+
scores[:, eos] = -float("inf")
|
187 |
+
|
188 |
+
if no_repeat_ngram_size > 0 and history is not None:
|
189 |
+
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
190 |
+
num_batch_hypotheses = batch_size * num_beams
|
191 |
+
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
192 |
+
banned_batch_tokens = _calc_banned_ngram_tokens(
|
193 |
+
history, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
194 |
+
)
|
195 |
+
|
196 |
+
for i, banned_tokens in enumerate(banned_batch_tokens):
|
197 |
+
scores[i, banned_tokens] = -float("inf")
|
198 |
+
|
199 |
+
return scores
|
200 |
+
|
201 |
+
|
202 |
+
def _add_beam_candidate(
|
203 |
+
best_score,
|
204 |
+
best_sequence,
|
205 |
+
batch_size,
|
206 |
+
num_beams,
|
207 |
+
beam_scores,
|
208 |
+
history,
|
209 |
+
eos_token_id=None
|
210 |
+
):
|
211 |
+
last_tokens = history[:, -1]
|
212 |
+
for _i in range(batch_size * num_beams):
|
213 |
+
if eos_token_id is None or last_tokens[_i] in eos_token_id:
|
214 |
+
cur_len = history.shape[-1]
|
215 |
+
_score = beam_scores.view(-1)[_i] / cur_len ** args.length_penalty
|
216 |
+
|
217 |
+
batch_id = _i // num_beams
|
218 |
+
|
219 |
+
if not batch_id in best_score or best_score[batch_id] < _score:
|
220 |
+
best_score[batch_id] = _score
|
221 |
+
best_sequence[batch_id][:cur_len] = history[_i]
|
222 |
+
|
223 |
+
beam_scores.view(-1)[_i] = -float("inf")
|
224 |
+
|
225 |
+
|
226 |
+
def beam(model, data_iter, args):
|
227 |
+
model.eval()
|
228 |
+
total_loss = 0.
|
229 |
+
start_time = time.time()
|
230 |
+
|
231 |
+
all_predictions = {}
|
232 |
+
with torch.no_grad():
|
233 |
+
for idx, data in enumerate(data_iter):
|
234 |
+
data = {key: value for key, value in data.items()}
|
235 |
+
|
236 |
+
_id = data['id'].to(args.device)
|
237 |
+
_query = data['query'].to(args.device)
|
238 |
+
_query_len = data['query_len'].to(args.device)
|
239 |
+
|
240 |
+
## local adaptation start.
|
241 |
+
|
242 |
+
## local adaptation end.
|
243 |
+
|
244 |
+
|
245 |
+
output = None
|
246 |
+
score = None
|
247 |
+
|
248 |
+
batch_size = _id.size(0)
|
249 |
+
num_beams = args.beam
|
250 |
+
length_penalty = args.length_penalty
|
251 |
+
|
252 |
+
_batch = torch.arange(0, _id.size(0), device=args.device, dtype=torch.long)
|
253 |
+
|
254 |
+
past = None
|
255 |
+
len_past = None
|
256 |
+
|
257 |
+
_query = _query.repeat(1, num_beams).view(batch_size * num_beams, -1)
|
258 |
+
_query_len = _query_len.unsqueeze(-1).repeat(1, num_beams).view(-1)
|
259 |
+
|
260 |
+
_bbatch = _batch.unsqueeze(-1).repeat(1, num_beams).view(-1)
|
261 |
+
|
262 |
+
# scores for each sentence in the beam
|
263 |
+
beam_scores = torch.zeros(
|
264 |
+
(batch_size, num_beams), dtype=torch.float, device=_query.device
|
265 |
+
)
|
266 |
+
|
267 |
+
best_sequence = torch.zeros(
|
268 |
+
(batch_size, args.eval_len), dtype=torch.long, device=_query.device
|
269 |
+
)
|
270 |
+
best_score = {}
|
271 |
+
|
272 |
+
history = None
|
273 |
+
with torch.no_grad():
|
274 |
+
for i in range(0, args.eval_len):
|
275 |
+
if i == 0:
|
276 |
+
logits, past = model(_query)
|
277 |
+
logits = logits[_bbatch, (_query_len-1).long(), :] # batch_size * beam, vocab
|
278 |
+
else:
|
279 |
+
#print('token_id.shape', token_id.shape, token_id)
|
280 |
+
#print('past.shape', past[0].shape)
|
281 |
+
#print('len_past.shape', len_past.shape, len_past)
|
282 |
+
|
283 |
+
logits, past = model(token_id, past=past, len_past=len_past)
|
284 |
+
logits = logits[:, -1, :] # batch_size * beam, vocab
|
285 |
+
|
286 |
+
logits = _postprocess_next_token_scores(
|
287 |
+
logits,
|
288 |
+
history,
|
289 |
+
i,
|
290 |
+
batch_size,
|
291 |
+
num_beams,
|
292 |
+
repetition_penalty=args.repetition_penalty,
|
293 |
+
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
294 |
+
min_length=args.min_length,
|
295 |
+
eos_token_id=args.eos_token_id,
|
296 |
+
)
|
297 |
+
|
298 |
+
softmax_probs = F.softmax(logits, dim=-1)
|
299 |
+
##_prob, _w_idx = torch.topk(softmax_probs, num_beams) # batch_size, beam
|
300 |
+
|
301 |
+
vocab_size = softmax_probs.shape[-1]
|
302 |
+
|
303 |
+
|
304 |
+
_logprob = torch.log(softmax_probs) # batch_size * beam, vocab
|
305 |
+
if i == 0:
|
306 |
+
next_scores = _logprob.view(batch_size, num_beams, -1)[:, 0, :] # batch_size, vocab
|
307 |
+
|
308 |
+
else:
|
309 |
+
next_scores = beam_scores.unsqueeze(-1) + _logprob.view(batch_size, num_beams, -1)
|
310 |
+
next_scores = next_scores.view(batch_size, -1) # batch_size, beam * vocab
|
311 |
+
|
312 |
+
next_scores, next_tokens = torch.topk(
|
313 |
+
next_scores, num_beams, dim=1, largest=True, sorted=True
|
314 |
+
) # batch_size, num_beams
|
315 |
+
|
316 |
+
beam_id = (next_tokens // vocab_size).view(-1) # batch_size * num_beams
|
317 |
+
token_id = (next_tokens % vocab_size).view(-1).unsqueeze(-1) # batch_size, num_beams
|
318 |
+
|
319 |
+
beam_idx = beam_id.view(batch_size, num_beams) + (_batch * num_beams).unsqueeze(-1)
|
320 |
+
past = _reorder_cache(past, beam_idx.view(-1))
|
321 |
+
beam_scores = next_scores # batch_size, num_beams
|
322 |
+
len_past = (_query_len + i).long()
|
323 |
+
|
324 |
+
if history is None:
|
325 |
+
history = token_id.detach()
|
326 |
+
else:
|
327 |
+
history = torch.cat((history[beam_idx.view(-1)], token_id.detach()), dim=1).detach()
|
328 |
+
|
329 |
+
_add_beam_candidate(
|
330 |
+
best_score, best_sequence, batch_size, num_beams, beam_scores, history,
|
331 |
+
eos_token_id=args.eos_token_id
|
332 |
+
)
|
333 |
+
|
334 |
+
_add_beam_candidate(
|
335 |
+
best_score, best_sequence, batch_size, num_beams, beam_scores, history
|
336 |
+
)
|
337 |
+
|
338 |
+
|
339 |
+
with torch.no_grad():
|
340 |
+
_id = distributed_gather(args, _id)
|
341 |
+
output = distributed_gather(args, best_sequence)
|
342 |
+
#score = distributed_gather(args, score)
|
343 |
+
distributed_sync(args)
|
344 |
+
|
345 |
+
if args.rank == 0:
|
346 |
+
_id = _id.view(-1).cpu()
|
347 |
+
output = output.view(-1, output.shape[-1]).cpu()
|
348 |
+
#score = score.view(-1, score.shape[-1]).cpu()
|
349 |
+
|
350 |
+
for _b in range(0, _id.shape[-1]):
|
351 |
+
_i = int(_id[_b].item())
|
352 |
+
all_predictions[_i] = {}
|
353 |
+
all_predictions[_i]['id'] = _i
|
354 |
+
all_predictions[_i]['predict'] = output[_b].tolist()
|
355 |
+
#all_predictions[_i]['score'] = score[_b].tolist()
|
356 |
+
|
357 |
+
if idx % 10 == 0:
|
358 |
+
print('inference samples', idx)
|
359 |
+
# pred_file = os.path.join(args.work_dir, args.output_file)
|
360 |
+
# print('saving prediction file', pred_file)
|
361 |
+
# with open(pred_file, 'w') as writer:
|
362 |
+
# for _i in all_predictions:
|
363 |
+
# writer.write(json.dumps(all_predictions[_i]) + '\n')
|
364 |
+
|
365 |
+
if args.rank == 0:
|
366 |
+
pred_file = os.path.join(args.work_dir, args.output_file)
|
367 |
+
print('saving prediction file', pred_file)
|
368 |
+
with open(pred_file, 'w') as writer:
|
369 |
+
for _i in all_predictions:
|
370 |
+
writer.write(json.dumps(all_predictions[_i]) + '\n')
|
371 |
+
|
372 |
+
|
373 |
+
if __name__ == '__main__':
|
374 |
+
args = parser.parse_args()
|
375 |
+
parse_gpu(args)
|
376 |
+
print_args(args)
|
377 |
+
|
378 |
+
if args.rank == 0:
|
379 |
+
args.logging = create_exp_dir(args.work_dir)
|
380 |
+
|
381 |
+
valid_data = FT_Dataset(
|
382 |
+
args.data, args.batch_size, args.seq_len, args.eval_len,
|
383 |
+
)
|
384 |
+
valid_data = valid_data.get_item_list(0, 1000)
|
385 |
+
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data)
|
386 |
+
valid_loader = DataLoader(
|
387 |
+
valid_data, batch_size=args.batch_size, num_workers=0, shuffle=False,
|
388 |
+
pin_memory=False, drop_last=False, sampler=valid_sampler
|
389 |
+
)
|
390 |
+
|
391 |
+
if args.model_card == 'gpt2.sm':
|
392 |
+
config = GPT2Config(
|
393 |
+
n_embd=768, n_layer=12, n_head=12,
|
394 |
+
lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
|
395 |
+
)
|
396 |
+
elif args.model_card == 'gpt2.md':
|
397 |
+
config = GPT2Config(
|
398 |
+
n_embd=1024, n_layer=24, n_head=16,
|
399 |
+
lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
|
400 |
+
)
|
401 |
+
elif args.model_card == 'gpt2.lg':
|
402 |
+
config = GPT2Config(
|
403 |
+
n_embd=1280, n_layer=36, n_head=20,
|
404 |
+
lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
|
405 |
+
)
|
406 |
+
|
407 |
+
lm_net = GPT2LMModel(config)
|
408 |
+
if args.init_checkpoint is not None:
|
409 |
+
print('loading model pretrained weight.')
|
410 |
+
cp = torch.load(args.init_checkpoint, map_location=torch.device('cpu'))
|
411 |
+
lm_net.load_weight(cp)
|
412 |
+
lm_net = lm_net.cuda()
|
413 |
+
print(lm_net.transformer.h[0].mlp)
|
414 |
+
|
415 |
+
print('model sampling ...')
|
416 |
+
beam(lm_net, valid_loader, args)
|
417 |
+
distributed_sync(args)
|
418 |
+
print('cleanup dist ...')
|
419 |
+
cleanup(args)
|
examples/NLG/src/gpt2_decode.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
# python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_beam.py \
|
7 |
+
# --data ./data/e2e/test.jsonl \
|
8 |
+
# --batch_size 1 \
|
9 |
+
# --seq_len 512 \
|
10 |
+
# --eval_len 64 \
|
11 |
+
# --model_card gpt2.md \
|
12 |
+
# --platform local \
|
13 |
+
# --beam 10 \
|
14 |
+
# --length_penalty 0.8 \
|
15 |
+
# --no_repeat_ngram_size 4 \
|
16 |
+
# --repetition_penalty 1.0 \
|
17 |
+
# --eos_token_id 628 \
|
18 |
+
# --lora_dim 4 \
|
19 |
+
# --lora_alpha 32 \
|
20 |
+
# --work_dir ./trained_models/GPT2_M/e2e \
|
21 |
+
# --output_file predict.26290.jsonl \
|
22 |
+
# --init_checkpoint ./trained_models/GPT2_M/e2e/model.26290.pt
|
23 |
+
|
24 |
+
|
25 |
+
import json
|
26 |
+
import numpy as np
|
27 |
+
import argparse
|
28 |
+
import os
|
29 |
+
import sys
|
30 |
+
import re
|
31 |
+
import json
|
32 |
+
|
33 |
+
import torch
|
34 |
+
import torch.nn as nn
|
35 |
+
import torch.nn.parallel
|
36 |
+
import torch.backends.cudnn as cudnn
|
37 |
+
import torch.optim as optim
|
38 |
+
import torch.utils.data
|
39 |
+
|
40 |
+
import encoder
|
41 |
+
|
42 |
+
|
43 |
+
parser = argparse.ArgumentParser()
|
44 |
+
|
45 |
+
parser.add_argument('--vocab', type=str, default=None, help='vocab path')
|
46 |
+
|
47 |
+
parser.add_argument('--sample_file', default=None, type=str, help='ft sample file')
|
48 |
+
parser.add_argument('--input_file', default=None, type=str, help='ft input file')
|
49 |
+
|
50 |
+
parser.add_argument('--output_ref_file', default=None, type=str, help='output reference file')
|
51 |
+
parser.add_argument('--output_pred_file', default=None, type=str, help='output predicion file')
|
52 |
+
|
53 |
+
parser.add_argument('--ref_unique_file', default=None, type=str, help='reference unique id file')
|
54 |
+
|
55 |
+
parser.add_argument('--ref_type', default='e2e', choices=['e2e', 'webnlg', 'dart'],
|
56 |
+
help='e2e style reference type; webnlg style reference type.')
|
57 |
+
parser.add_argument('--ref_num', default=4, type=int, help='number of references.')
|
58 |
+
|
59 |
+
|
60 |
+
parser.add_argument('--tokenize', action='store_true', help='')
|
61 |
+
parser.add_argument('--lower', action='store_true', help='')
|
62 |
+
|
63 |
+
parser.add_argument('--filter', default='all', choices=['all', 'seen', 'unseen'],
|
64 |
+
help='for webnlg only, filter categories that are seen during training, unseen, or all')
|
65 |
+
|
66 |
+
args = parser.parse_args()
|
67 |
+
|
68 |
+
|
69 |
+
def stardard_tokenize(sent):
|
70 |
+
sent = ' '.join(re.split('(\W)', sent))
|
71 |
+
sent = sent.split()
|
72 |
+
sent = ' '.join(sent)
|
73 |
+
return sent
|
74 |
+
|
75 |
+
|
76 |
+
def post_process(sent, is_tokenize, is_lower):
|
77 |
+
if is_lower:
|
78 |
+
sent = sent.lower()
|
79 |
+
if is_tokenize:
|
80 |
+
sent = stardard_tokenize(sent)
|
81 |
+
|
82 |
+
return sent
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
enc = encoder.get_encoder(args.vocab)
|
87 |
+
|
88 |
+
ref_unique = None
|
89 |
+
|
90 |
+
if args.ref_unique_file is not None:
|
91 |
+
print('reading ref_unique_file.')
|
92 |
+
ref_unique = []
|
93 |
+
uniques = {}
|
94 |
+
with open(args.ref_unique_file, 'r') as ref_unique_reader:
|
95 |
+
for line in ref_unique_reader:
|
96 |
+
_id = int(line.strip())
|
97 |
+
ref_unique.append(_id)
|
98 |
+
uniques[_id] = 1
|
99 |
+
print('len refer dict', len(ref_unique), 'unique', len(uniques))
|
100 |
+
|
101 |
+
with open(args.sample_file, 'r') as sample_reader, \
|
102 |
+
open(args.input_file, 'r', encoding='utf8') as input_reader, \
|
103 |
+
open(args.output_pred_file, 'w', encoding='utf8') as pred_writer:
|
104 |
+
|
105 |
+
refer_dict = {}
|
106 |
+
context_list = []
|
107 |
+
line_id = 0
|
108 |
+
for line in input_reader:
|
109 |
+
items = json.loads(line.strip())
|
110 |
+
context = items['context']
|
111 |
+
completion = items['completion']
|
112 |
+
|
113 |
+
context_list.append(context)
|
114 |
+
|
115 |
+
keep = False
|
116 |
+
|
117 |
+
if args.filter == 'all':
|
118 |
+
keep = True
|
119 |
+
if args.filter == 'seen' and items['cate']:
|
120 |
+
keep = True
|
121 |
+
if args.filter == 'unseen' and not items['cate']:
|
122 |
+
keep = True
|
123 |
+
|
124 |
+
if ref_unique is None:
|
125 |
+
_key = context
|
126 |
+
else:
|
127 |
+
_key = ref_unique[line_id]
|
128 |
+
|
129 |
+
if keep:
|
130 |
+
if not _key in refer_dict:
|
131 |
+
refer_dict[_key] = {}
|
132 |
+
refer_dict[_key]['references'] = []
|
133 |
+
refer_dict[_key]['references'].append(completion.split('<|endoftext|>')[0].split('\n\n')[0].strip())
|
134 |
+
|
135 |
+
line_id += 1
|
136 |
+
if line_id==1000:
|
137 |
+
break
|
138 |
+
|
139 |
+
print('unique refer dict', len(refer_dict))
|
140 |
+
|
141 |
+
for line in sample_reader:
|
142 |
+
items = json.loads(line.strip())
|
143 |
+
_id = items['id']
|
144 |
+
_pred_tokens = items['predict']
|
145 |
+
|
146 |
+
if ref_unique is None:
|
147 |
+
_key = context_list[_id]
|
148 |
+
else:
|
149 |
+
_key = ref_unique[_id]
|
150 |
+
|
151 |
+
#assert _key in refer_dict
|
152 |
+
# if _key in refer_dict:
|
153 |
+
if not _key in refer_dict:
|
154 |
+
refer_dict[_key] = {}
|
155 |
+
refer_dict[_key]['sample'] = []
|
156 |
+
refer_dict[_key]['sample'] = enc.decode(_pred_tokens).split('<|endoftext|>')[0].split('\n\n')[0].strip()
|
157 |
+
|
158 |
+
references = [refer_dict[s]['references'] for s in refer_dict]
|
159 |
+
hypothesis = [refer_dict[s]['sample'] for s in refer_dict]
|
160 |
+
|
161 |
+
if args.ref_type == 'e2e':
|
162 |
+
with open(args.output_ref_file, 'w', encoding='utf8') as ref_writer:
|
163 |
+
for ref, hyp in zip(references, hypothesis):
|
164 |
+
for r in ref:
|
165 |
+
ref_writer.write(post_process(r, args.tokenize, args.lower) + '\n')
|
166 |
+
ref_writer.write('\n')
|
167 |
+
pred_writer.write(post_process(hyp, args.tokenize, args.lower) + '\n')
|
168 |
+
|
169 |
+
elif args.ref_type in ['webnlg', 'dart']:
|
170 |
+
if not os.path.exists(args.output_ref_file):
|
171 |
+
os.makedirs(args.output_ref_file)
|
172 |
+
|
173 |
+
reference_writers = [
|
174 |
+
open(os.path.join(args.output_ref_file, f'reference{fid}'), 'w', encoding='utf8')
|
175 |
+
for fid in range(0, args.ref_num)
|
176 |
+
]
|
177 |
+
|
178 |
+
for ref, hyp in zip(references, hypothesis):
|
179 |
+
for fid in range(0, args.ref_num):
|
180 |
+
if len(ref) > fid:
|
181 |
+
reference_writers[fid].write(post_process(ref[fid], args.tokenize, args.lower) + '\n')
|
182 |
+
else:
|
183 |
+
reference_writers[fid].write(post_process(ref[0], args.tokenize, args.lower) + '\n')
|
184 |
+
pred_writer.write(post_process(hyp, args.tokenize, args.lower) + '\n')
|
185 |
+
|
186 |
+
for writer in reference_writers:
|
187 |
+
writer.close()
|
examples/NLG/src/gpt2_encode.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import encoder
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
import random
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.parallel
|
16 |
+
import torch.backends.cudnn as cudnn
|
17 |
+
import torch.optim as optim
|
18 |
+
import torch.utils.data
|
19 |
+
|
20 |
+
import numpy
|
21 |
+
import io
|
22 |
+
import sys
|
23 |
+
import threading
|
24 |
+
import math
|
25 |
+
import random
|
26 |
+
|
27 |
+
import json
|
28 |
+
import collections
|
29 |
+
from collections import Counter
|
30 |
+
from collections import OrderedDict
|
31 |
+
from progress.bar import Bar as Bar
|
32 |
+
|
33 |
+
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument('--input', default=None, type=str, help='ft input file')
|
36 |
+
parser.add_argument('--vocab', type=str, default=None, help='vocab path')
|
37 |
+
parser.add_argument('--output', default=None, type=str, help='ft output file')
|
38 |
+
parser.add_argument('--add_bos', action='store_true', help='')
|
39 |
+
parser.add_argument('--add_eos', action='store_true', help='')
|
40 |
+
args = parser.parse_args()
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
enc = encoder.get_encoder(args.vocab)
|
45 |
+
|
46 |
+
writer = open(args.output, 'w')
|
47 |
+
|
48 |
+
with open(args.input, 'r') as reader:
|
49 |
+
line_idx = 0
|
50 |
+
for line in reader:
|
51 |
+
items = json.loads(line.strip())
|
52 |
+
context = items['context']
|
53 |
+
completion = items['completion']
|
54 |
+
|
55 |
+
bos = 50256
|
56 |
+
eos = 50256
|
57 |
+
context_bpes, _ = enc.encode(context)
|
58 |
+
context_bpes += [bos] if args.add_bos else []
|
59 |
+
|
60 |
+
completion_bpes, _ = enc.encode(' ' + completion)
|
61 |
+
completion_bpes += [eos] if args.add_eos else []
|
62 |
+
|
63 |
+
ft_json = {}
|
64 |
+
ft_json['context'] = context_bpes
|
65 |
+
ft_json['completion'] = completion_bpes
|
66 |
+
writer.write(json.dumps(ft_json)+'\n')
|
67 |
+
|
68 |
+
line_idx += 1
|
69 |
+
|
70 |
+
writer.close()
|
examples/NLG/src/gpt2_ft.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import argparse
|
6 |
+
import time
|
7 |
+
import math
|
8 |
+
import os, sys
|
9 |
+
import numpy as np
|
10 |
+
import itertools
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import random
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
torch.set_printoptions(threshold=100000)
|
16 |
+
|
17 |
+
from gpu import (
|
18 |
+
add_gpu_params,
|
19 |
+
parse_gpu,
|
20 |
+
distributed_opt,
|
21 |
+
distributed_gather,
|
22 |
+
distributed_sync,
|
23 |
+
cleanup
|
24 |
+
)
|
25 |
+
from optimizer import (
|
26 |
+
create_adam_optimizer,
|
27 |
+
create_optimizer_scheduler,
|
28 |
+
add_optimizer_params,
|
29 |
+
create_adam_optimizer_from_args
|
30 |
+
)
|
31 |
+
|
32 |
+
from data_utils import FT_Dataset
|
33 |
+
from model import GPT2Config, GPT2LMModel
|
34 |
+
from exp_utils import create_exp_dir
|
35 |
+
|
36 |
+
import loralib as lora
|
37 |
+
|
38 |
+
parser = argparse.ArgumentParser(description='PyTorch GPT2 ft script')
|
39 |
+
|
40 |
+
add_gpu_params(parser)
|
41 |
+
add_optimizer_params(parser)
|
42 |
+
|
43 |
+
parser.add_argument('--train_data', required=True, help='location of training data corpus')
|
44 |
+
|
45 |
+
parser.add_argument('--valid_data', required=True, help='location of validation data corpus')
|
46 |
+
|
47 |
+
parser.add_argument('--train_batch_size', type=int, default=8, help='training batch size')
|
48 |
+
|
49 |
+
parser.add_argument('--valid_batch_size', type=int, default=4, help='validation batch size')
|
50 |
+
|
51 |
+
parser.add_argument('--grad_acc', type=int, default=1, help='gradient accumulation steps')
|
52 |
+
|
53 |
+
parser.add_argument('--clip', type=float, default=0.0, help='gradient clip')
|
54 |
+
|
55 |
+
parser.add_argument('--seq_len', type=int, default=512, help='number of tokens to predict.')
|
56 |
+
|
57 |
+
parser.add_argument('--model_card', default='gpt2.md', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
|
58 |
+
help='model names')
|
59 |
+
|
60 |
+
parser.add_argument('--init_checkpoint', default=None, help='pretrained checkpoint path')
|
61 |
+
|
62 |
+
parser.add_argument('--fp16', action='store_true', help='train model with fp16')
|
63 |
+
|
64 |
+
parser.add_argument('--log_interval', type=int, default=100, help='log interval')
|
65 |
+
|
66 |
+
parser.add_argument('--eval_interval', type=int, default=2000, help='eval interval')
|
67 |
+
|
68 |
+
parser.add_argument('--save_interval', type=int, default=500, help='save interval')
|
69 |
+
|
70 |
+
parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'),
|
71 |
+
help='working folder.')
|
72 |
+
|
73 |
+
parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')
|
74 |
+
|
75 |
+
parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')
|
76 |
+
|
77 |
+
parser.add_argument('--obj', default='clm', choices=['jlm', 'clm'],
|
78 |
+
help='language model training objective')
|
79 |
+
|
80 |
+
parser.add_argument('--lora_dropout', default=0.0, type=float,
|
81 |
+
help='dropout probability for lora layers')
|
82 |
+
|
83 |
+
parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing')
|
84 |
+
|
85 |
+
parser.add_argument('--roll_interval', type=int, default=-1, help='rolling interval')
|
86 |
+
|
87 |
+
parser.add_argument('--roll_lr', type=float, default=0.00001, help='rolling learning rate')
|
88 |
+
|
89 |
+
parser.add_argument('--roll_step', type=int, default=100, help='rolling step')
|
90 |
+
|
91 |
+
parser.add_argument('--eval_epoch', type=int, default=1, help='eval per number of epochs')
|
92 |
+
|
93 |
+
# influence model, calculate the influence score between two samples.
|
94 |
+
def print_args(args):
|
95 |
+
if args.rank == 0:
|
96 |
+
print('=' * 100)
|
97 |
+
for k, v in args.__dict__.items():
|
98 |
+
print(f' - {k} : {v}')
|
99 |
+
print('=' * 100)
|
100 |
+
|
101 |
+
|
102 |
+
class AverageMeter(object):
|
103 |
+
"""Computes and stores the average and current value
|
104 |
+
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
|
105 |
+
"""
|
106 |
+
def __init__(self):
|
107 |
+
self.reset()
|
108 |
+
|
109 |
+
def reset(self):
|
110 |
+
self.val = 0
|
111 |
+
self.avg = 0
|
112 |
+
self.sum = 0
|
113 |
+
self.count = 0
|
114 |
+
|
115 |
+
def update(self, val, n=1):
|
116 |
+
self.val = val
|
117 |
+
self.sum += val * n
|
118 |
+
self.count += n
|
119 |
+
self.avg = self.sum / self.count
|
120 |
+
|
121 |
+
|
122 |
+
def optimizer_step(_loss, _optimizer, _model, _schedule, args, is_update=True):
|
123 |
+
if args.fp16:
|
124 |
+
with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
|
125 |
+
_scaled_loss.backward()
|
126 |
+
else:
|
127 |
+
_loss.backward()
|
128 |
+
|
129 |
+
# for name, param in _model.named_parameters():
|
130 |
+
# if param.requires_grad and param.grad is not None:
|
131 |
+
# print(f"Parameter name: {name}")
|
132 |
+
# print(f"Gradient value: {param.grad}")
|
133 |
+
|
134 |
+
if is_update:
|
135 |
+
if args.clip > 0:
|
136 |
+
if args.fp16:
|
137 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), args.clip)
|
138 |
+
else:
|
139 |
+
torch.nn.utils.clip_grad_norm_(_model.parameters(), args.clip)
|
140 |
+
|
141 |
+
_optimizer.step()
|
142 |
+
_optimizer.zero_grad()
|
143 |
+
|
144 |
+
if _schedule is not None:
|
145 |
+
_schedule.step()
|
146 |
+
|
147 |
+
# print(f"query[0].lora_B = {_model.module.transformer.h[0].attn.c_attn.lora_B}")
|
148 |
+
|
149 |
+
|
150 |
+
def evaluate(model, valid_loader, args):
|
151 |
+
model.eval()
|
152 |
+
total_loss = 0.
|
153 |
+
start_time = time.time()
|
154 |
+
|
155 |
+
avg_lm_loss = AverageMeter()
|
156 |
+
|
157 |
+
with torch.no_grad():
|
158 |
+
for idx, data in enumerate(valid_loader):
|
159 |
+
data = {key: value for key, value in data.items()}
|
160 |
+
|
161 |
+
_input = data['input'].to(args.device)
|
162 |
+
_target = data['target'].to(args.device)
|
163 |
+
_msk = data['mask'].to(args.device)
|
164 |
+
|
165 |
+
_lm_logits, _loss = model(_input, lm_labels=_target, lm_mask=_msk)
|
166 |
+
loss = _loss.mean()
|
167 |
+
# print(f"logits={_lm_logits}, _loss={_loss}")
|
168 |
+
|
169 |
+
avg_lm_loss.update(loss.item())
|
170 |
+
|
171 |
+
if idx % 100 == 0:
|
172 |
+
print('eval samples:', idx, 'loss:', loss.float())
|
173 |
+
|
174 |
+
total_time = time.time() - start_time
|
175 |
+
print('average loss', avg_lm_loss.avg)
|
176 |
+
return avg_lm_loss.avg, math.exp(avg_lm_loss.avg)
|
177 |
+
|
178 |
+
|
179 |
+
def train_validate(
|
180 |
+
model,
|
181 |
+
optimizer,
|
182 |
+
scheduler,
|
183 |
+
train_loader,
|
184 |
+
valid_loader,
|
185 |
+
args,
|
186 |
+
train_step=0,
|
187 |
+
epoch=0
|
188 |
+
):
|
189 |
+
model.train()
|
190 |
+
avg_lm_loss = AverageMeter()
|
191 |
+
print('start to train the model................', epoch)
|
192 |
+
log_start_time = time.time()
|
193 |
+
best_val_ppl = None
|
194 |
+
|
195 |
+
# train_loader.sampler.set_epoch(epoch)
|
196 |
+
|
197 |
+
for idx, data in enumerate(train_loader):
|
198 |
+
data = {key: value for key, value in data.items()}
|
199 |
+
|
200 |
+
_input = data['input'].to(args.device)
|
201 |
+
_target = data['target'].to(args.device)
|
202 |
+
_msk = data['mask'].to(args.device)
|
203 |
+
|
204 |
+
_lm_logits, _lm_loss = model(
|
205 |
+
_input, lm_labels=_target, lm_mask=_msk, label_smooth=args.label_smooth
|
206 |
+
)
|
207 |
+
# print(_input[0])
|
208 |
+
|
209 |
+
_lm_loss = _lm_loss.mean()
|
210 |
+
|
211 |
+
train_step += 1
|
212 |
+
is_update = True if train_step % args.grad_acc == 0 else False
|
213 |
+
avg_lm_loss.update(_lm_loss.item())
|
214 |
+
optimizer_step(
|
215 |
+
_lm_loss/(args.grad_acc), optimizer, model, scheduler, args, is_update=is_update
|
216 |
+
)
|
217 |
+
|
218 |
+
if train_step % args.log_interval == 0:
|
219 |
+
print(f"_lm_loss = {_lm_loss}")
|
220 |
+
print(f"layer[0].lora_A = {model.module.transformer.h[0].attn.c_attn.lora_A[0,:100]}")
|
221 |
+
elapsed = time.time() - log_start_time
|
222 |
+
lr = optimizer.param_groups[0]['lr']
|
223 |
+
log_str = f'| epoch {epoch:3d} step {train_step:>8d} | { idx + 1:>6d} batches | ' \
|
224 |
+
f'lr {lr:.3g} | ms/batch {elapsed * 1000 / args.log_interval:5.2f} | ' \
|
225 |
+
f'loss {avg_lm_loss.val:5.2f} | avg loss {avg_lm_loss.avg:5.2f} | ' \
|
226 |
+
f'ppl {math.exp(avg_lm_loss.avg):5.2f}'
|
227 |
+
|
228 |
+
if args.rank == 0:
|
229 |
+
print(log_str)
|
230 |
+
log_start_time = time.time()
|
231 |
+
avg_lm_loss.reset()
|
232 |
+
|
233 |
+
if train_step % args.save_interval == 0:
|
234 |
+
if args.rank == 0:
|
235 |
+
model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
|
236 |
+
print('saving checkpoint', model_path)
|
237 |
+
torch.save({'model_state_dict': lora.lora_state_dict(model)}, model_path)
|
238 |
+
distributed_sync(args)
|
239 |
+
|
240 |
+
# evaluation interval
|
241 |
+
if train_step % args.eval_interval == 0:
|
242 |
+
eval_start_time = time.time()
|
243 |
+
|
244 |
+
valid_loss, valid_ppl = evaluate(model, valid_loader, args)
|
245 |
+
|
246 |
+
if best_val_ppl is None or valid_ppl < best_val_ppl:
|
247 |
+
best_val_ppl = valid_ppl
|
248 |
+
|
249 |
+
log_str = f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' \
|
250 |
+
f'time: {time.time() - eval_start_time:5.2f}s | valid loss {valid_loss:5.2f} | ' \
|
251 |
+
f'valid ppl {valid_ppl:5.2f} | best ppl {best_val_ppl:5.2f} '
|
252 |
+
|
253 |
+
if args.rank == 0:
|
254 |
+
print('-' * 100)
|
255 |
+
print(log_str)
|
256 |
+
print('-' * 100)
|
257 |
+
|
258 |
+
model.train()
|
259 |
+
distributed_sync(args)
|
260 |
+
|
261 |
+
if train_step == args.max_step:
|
262 |
+
break
|
263 |
+
|
264 |
+
if args.rank == 0:
|
265 |
+
model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
|
266 |
+
print('saving checkpoint', model_path)
|
267 |
+
torch.save({'model_state_dict': model.state_dict()}, model_path)
|
268 |
+
distributed_sync(args)
|
269 |
+
return train_step
|
270 |
+
|
271 |
+
|
272 |
+
if __name__ == '__main__':
|
273 |
+
args = parser.parse_args()
|
274 |
+
parse_gpu(args)
|
275 |
+
print_args(args)
|
276 |
+
|
277 |
+
if args.fp16:
|
278 |
+
try:
|
279 |
+
from apex import amp
|
280 |
+
except Exception as e:
|
281 |
+
warnings.warn('Could not import amp, apex may not be installed')
|
282 |
+
|
283 |
+
torch.manual_seed(args.random_seed)
|
284 |
+
random.seed(args.random_seed)
|
285 |
+
|
286 |
+
if args.rank == 0:
|
287 |
+
args.logging = create_exp_dir(args.work_dir)
|
288 |
+
|
289 |
+
train_data = FT_Dataset(
|
290 |
+
args.train_data, args.train_batch_size, args.seq_len,
|
291 |
+
joint_lm=args.obj=='jlm'
|
292 |
+
)
|
293 |
+
|
294 |
+
valid_data = FT_Dataset(
|
295 |
+
args.valid_data, args.valid_batch_size, args.seq_len,
|
296 |
+
)
|
297 |
+
|
298 |
+
train_loader = DataLoader(
|
299 |
+
train_data, batch_size=args.train_batch_size, num_workers=0,
|
300 |
+
shuffle=False, pin_memory=False, drop_last=True,
|
301 |
+
# sampler=torch.utils.data.distributed.DistributedSampler(train_data, seed=args.random_seed)
|
302 |
+
)
|
303 |
+
|
304 |
+
valid_loader = DataLoader(
|
305 |
+
valid_data, batch_size=args.valid_batch_size, num_workers=0,
|
306 |
+
shuffle=False, pin_memory=False, drop_last=False,
|
307 |
+
# sampler=torch.utils.data.distributed.DistributedSampler(valid_data, seed=args.random_seed)
|
308 |
+
)
|
309 |
+
print(f"train_loader={len(train_loader)}, train_data={len(train_data)}")
|
310 |
+
print(f"valid_loader={len(valid_loader)}, valid_data={len(valid_data)}")
|
311 |
+
|
312 |
+
if args.model_card == 'gpt2.sm':
|
313 |
+
config = GPT2Config(
|
314 |
+
n_embd=768, n_layer=12, n_head=12,
|
315 |
+
lora_attn_dim=args.lora_dim,
|
316 |
+
lora_attn_alpha=args.lora_alpha,
|
317 |
+
lora_dropout=args.lora_dropout,
|
318 |
+
)
|
319 |
+
elif args.model_card == 'gpt2.md':
|
320 |
+
config = GPT2Config(
|
321 |
+
n_embd=1024, n_layer=24, n_head=16,
|
322 |
+
lora_attn_dim=args.lora_dim,
|
323 |
+
lora_attn_alpha=args.lora_alpha,
|
324 |
+
lora_dropout=args.lora_dropout,
|
325 |
+
)
|
326 |
+
elif args.model_card == 'gpt2.lg':
|
327 |
+
config = GPT2Config(
|
328 |
+
n_embd=1280, n_layer=36, n_head=20,
|
329 |
+
lora_attn_dim=args.lora_dim,
|
330 |
+
lora_attn_alpha=args.lora_alpha,
|
331 |
+
lora_dropout=args.lora_dropout,
|
332 |
+
)
|
333 |
+
|
334 |
+
lm_net = GPT2LMModel(config)
|
335 |
+
if args.init_checkpoint is not None:
|
336 |
+
print('loading model pretrained weight.')
|
337 |
+
lm_net.load_weight(torch.load(args.init_checkpoint))
|
338 |
+
|
339 |
+
lm_net = lm_net.cuda()
|
340 |
+
|
341 |
+
if args.lora_dim > 0:
|
342 |
+
lora.mark_only_lora_as_trainable(lm_net)
|
343 |
+
|
344 |
+
print(lm_net)
|
345 |
+
print(lm_net.transformer.h[0].attn.c_attn.weight.shape)
|
346 |
+
print(lm_net.transformer.h[0].attn.c_attn.lora_A.shape)
|
347 |
+
print(lm_net.transformer.h[0].attn.c_attn.lora_B.shape)
|
348 |
+
config_dict = vars(config)
|
349 |
+
for param, value in config_dict.items():
|
350 |
+
print(f"{param}: {value}")
|
351 |
+
print(args)
|
352 |
+
optimizer = create_adam_optimizer_from_args(lm_net, args)
|
353 |
+
print("optimizer: " + str(optimizer))
|
354 |
+
|
355 |
+
if args.max_step is None:
|
356 |
+
args.max_step = (args.max_epoch * train_data.num_batches + args.world_size - 1) // args.world_size
|
357 |
+
print('set max_step:', args.max_step)
|
358 |
+
print('train_data.num_batches:', train_data.num_batches)
|
359 |
+
|
360 |
+
scheduler = create_optimizer_scheduler(optimizer, args)
|
361 |
+
if args.fp16:
|
362 |
+
lm_net, optimizer = amp.initialize(lm_net, optimizer, opt_level="O1")
|
363 |
+
lm_net, optimizer = distributed_opt(args, lm_net, optimizer, grad_acc=args.grad_acc)
|
364 |
+
|
365 |
+
try:
|
366 |
+
train_step = 0
|
367 |
+
for epoch in itertools.count(start=1):
|
368 |
+
train_step = train_validate(
|
369 |
+
lm_net, optimizer, scheduler, train_loader, valid_loader, args,
|
370 |
+
train_step=train_step, epoch=epoch
|
371 |
+
)
|
372 |
+
|
373 |
+
if train_step >= args.max_step or (args.max_epoch is not None and epoch >= args.max_epoch):
|
374 |
+
if args.rank == 0:
|
375 |
+
print('-' * 100)
|
376 |
+
print('End of training')
|
377 |
+
break
|
378 |
+
except KeyboardInterrupt:
|
379 |
+
if args.rank == 0:
|
380 |
+
print('-' * 100)
|
381 |
+
print('Exiting from training early')
|
382 |
+
|
383 |
+
distributed_sync(args)
|
384 |
+
print('cleanup dist ...')
|
385 |
+
cleanup(args)
|
examples/NLG/src/gpu.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import argparse
|
6 |
+
import time
|
7 |
+
import math
|
8 |
+
import os, sys
|
9 |
+
import itertools
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.optim as optim
|
16 |
+
import torch.distributed as dist
|
17 |
+
|
18 |
+
|
19 |
+
gpu_offset = 4 # 0
|
20 |
+
|
21 |
+
def add_gpu_params(parser: argparse.ArgumentParser):
|
22 |
+
parser.add_argument("--platform", default='k8s', type=str, help='platform cloud')
|
23 |
+
parser.add_argument("--local_rank", default=0, type=int, help='local rank')
|
24 |
+
parser.add_argument("--rank", default=0, type=int, help='rank')
|
25 |
+
parser.add_argument("--device", default=0, type=int, help='device')
|
26 |
+
parser.add_argument("--world_size", default=0, type=int, help='world size')
|
27 |
+
parser.add_argument("--random_seed", default=10, type=int, help='random seed')
|
28 |
+
|
29 |
+
|
30 |
+
def distributed_opt(args, model, opt, grad_acc=1):
|
31 |
+
if args.platform == 'azure':
|
32 |
+
args.hvd.broadcast_parameters(model.state_dict(), root_rank=0)
|
33 |
+
opt = args.hvd.DistributedOptimizer(
|
34 |
+
opt, named_parameters=model.named_parameters(), backward_passes_per_step=grad_acc
|
35 |
+
)
|
36 |
+
elif args.platform == 'philly' or args.platform == 'k8s' or args.platform == 'local':
|
37 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
38 |
+
model, device_ids=[args.local_rank+gpu_offset], output_device=args.local_rank+gpu_offset, # change
|
39 |
+
find_unused_parameters=False, broadcast_buffers=False
|
40 |
+
)
|
41 |
+
return model, opt
|
42 |
+
|
43 |
+
|
44 |
+
def distributed_gather(args, tensor):
|
45 |
+
g_y = [torch.zeros_like(tensor) for _ in range(args.world_size)]
|
46 |
+
torch.distributed.all_gather(g_y, tensor, async_op=False)
|
47 |
+
return torch.stack(g_y)
|
48 |
+
|
49 |
+
|
50 |
+
def distributed_sync(args):
|
51 |
+
if args.platform == 'azure':
|
52 |
+
args.hvd.allreduce(torch.tensor(0), name='barrier')
|
53 |
+
else:
|
54 |
+
args.dist.barrier()
|
55 |
+
|
56 |
+
|
57 |
+
def parse_gpu(args):
|
58 |
+
torch.manual_seed(args.random_seed)
|
59 |
+
|
60 |
+
if args.platform == 'local':
|
61 |
+
dist.init_process_group(backend='nccl')
|
62 |
+
local_rank = torch.distributed.get_rank()
|
63 |
+
torch.cuda.set_device(local_rank+gpu_offset) # change
|
64 |
+
device = torch.device('cuda', local_rank+gpu_offset) # change
|
65 |
+
args.rank = local_rank
|
66 |
+
args.device = device
|
67 |
+
args.world_size = torch.distributed.get_world_size()
|
68 |
+
args.dist = dist
|
69 |
+
|
70 |
+
elif args.platform == 'azure':
|
71 |
+
import horovod.torch as hvd
|
72 |
+
hvd.init()
|
73 |
+
print('azure hvd rank', hvd.rank(), 'local rank', hvd.local_rank())
|
74 |
+
local_rank = hvd.local_rank()
|
75 |
+
torch.cuda.set_device(local_rank)
|
76 |
+
device = torch.device('cuda', local_rank)
|
77 |
+
rank = hvd.rank()
|
78 |
+
world_size = hvd.size()
|
79 |
+
|
80 |
+
args.local_rank = local_rank
|
81 |
+
args.rank = rank
|
82 |
+
args.device = device
|
83 |
+
args.world_size = world_size
|
84 |
+
args.hvd = hvd
|
85 |
+
|
86 |
+
elif args.platform == 'philly':
|
87 |
+
local_rank = args.local_rank
|
88 |
+
torch.cuda.set_device(local_rank)
|
89 |
+
dist.init_process_group(backend='nccl')
|
90 |
+
rank = dist.get_rank()
|
91 |
+
world_size = torch.distributed.get_world_size()
|
92 |
+
device = torch.device('cuda', local_rank)
|
93 |
+
|
94 |
+
args.rank = rank
|
95 |
+
args.device = device
|
96 |
+
args.world_size = world_size
|
97 |
+
args.dist = dist
|
98 |
+
elif args.platform == 'k8s':
|
99 |
+
master_uri = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
|
100 |
+
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
101 |
+
args.local_rank = local_rank
|
102 |
+
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
103 |
+
world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
104 |
+
rank = world_rank
|
105 |
+
torch.cuda.set_device(local_rank)
|
106 |
+
|
107 |
+
dist.init_process_group(
|
108 |
+
backend='nccl',
|
109 |
+
init_method=master_uri,
|
110 |
+
world_size=world_size,
|
111 |
+
rank=world_rank,
|
112 |
+
)
|
113 |
+
device = torch.device("cuda", local_rank)
|
114 |
+
args.rank = rank
|
115 |
+
args.device = device
|
116 |
+
args.world_size = world_size
|
117 |
+
args.dist = dist
|
118 |
+
print(
|
119 |
+
'myrank:', args.rank,
|
120 |
+
'local_rank:', args.local_rank,
|
121 |
+
'device_count:', torch.cuda.device_count(),
|
122 |
+
'world_size:', args.world_size,
|
123 |
+
'device:', device
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
def cleanup(args):
|
128 |
+
if args.platform == 'k8s' or args.platform == 'philly':
|
129 |
+
args.dist.destroy_process_group()
|
examples/NLG/src/model.log
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
myrank: 0 local_rank: 0 device_count: 8 world_size: 1 device: cuda:4
|
2 |
+
====================================================================================================
|
3 |
+
- platform : local
|
4 |
+
- local_rank : 0
|
5 |
+
- rank : 0
|
6 |
+
- device : cuda:4
|
7 |
+
- world_size : 1
|
8 |
+
- random_seed : 110
|
9 |
+
- lr : 0.0002
|
10 |
+
- weight_decay : 0.01
|
11 |
+
- correct_bias : True
|
12 |
+
- adam_epislon : 1e-06
|
13 |
+
- no_decay_bias : False
|
14 |
+
- adam_beta1 : 0.9
|
15 |
+
- adam_beta2 : 0.999
|
16 |
+
- scheduler : linear
|
17 |
+
- max_step : None
|
18 |
+
- max_epoch : 5
|
19 |
+
- warmup_step : 500
|
20 |
+
- i_steps : 0
|
21 |
+
- i_lrs : 0.00025
|
22 |
+
- train_data : ./data/e2e/train.jsonl
|
23 |
+
- valid_data : ./data/e2e/valid.jsonl
|
24 |
+
- train_batch_size : 8
|
25 |
+
- valid_batch_size : 4
|
26 |
+
- grad_acc : 1
|
27 |
+
- clip : 0.0
|
28 |
+
- seq_len : 512
|
29 |
+
- model_card : gpt2.md
|
30 |
+
- init_checkpoint : ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin
|
31 |
+
- fp16 : False
|
32 |
+
- log_interval : 100
|
33 |
+
- eval_interval : 2000
|
34 |
+
- save_interval : 1000
|
35 |
+
- work_dir : ./trained_models/GPT2_M/e2e
|
36 |
+
- lora_dim : 4
|
37 |
+
- lora_alpha : 32
|
38 |
+
- obj : clm
|
39 |
+
- lora_dropout : 0.1
|
40 |
+
- label_smooth : 0.1
|
41 |
+
- roll_interval : -1
|
42 |
+
- roll_lr : 1e-05
|
43 |
+
- roll_step : 100
|
44 |
+
- eval_epoch : 1
|
45 |
+
- dist : <module 'torch.distributed' from '/home/inc/miniconda3/envs/fedadp-new/lib/python3.7/site-packages/torch/distributed/__init__.py'>
|
46 |
+
====================================================================================================
|
47 |
+
Experiment dir : ./trained_models/GPT2_M/e2e
|
48 |
+
train_loader=5258, train_data=42064
|
49 |
+
valid_loader=1168, valid_data=4672
|
50 |
+
scaling = 8.0
|
51 |
+
loading model pretrained weight.
|
52 |
+
GPT2LMModel(
|
53 |
+
(transformer): GPT2Model(
|
54 |
+
(wte): Embedding(50257, 1024)
|
55 |
+
(wpe): Embedding(1024, 1024)
|
56 |
+
(h): ModuleList(
|
57 |
+
(0): Block(
|
58 |
+
(ln_1): LayerNorm()
|
59 |
+
(attn): Attention(
|
60 |
+
(c_attn): MergedLinear(
|
61 |
+
in_features=1024, out_features=3072, bias=True
|
62 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
63 |
+
)
|
64 |
+
(c_proj): Conv1D()
|
65 |
+
)
|
66 |
+
(ln_2): LayerNorm()
|
67 |
+
(mlp): MLP(
|
68 |
+
(c_fc): Conv1D()
|
69 |
+
(c_proj): Conv1D()
|
70 |
+
)
|
71 |
+
)
|
72 |
+
(1): Block(
|
73 |
+
(ln_1): LayerNorm()
|
74 |
+
(attn): Attention(
|
75 |
+
(c_attn): MergedLinear(
|
76 |
+
in_features=1024, out_features=3072, bias=True
|
77 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
78 |
+
)
|
79 |
+
(c_proj): Conv1D()
|
80 |
+
)
|
81 |
+
(ln_2): LayerNorm()
|
82 |
+
(mlp): MLP(
|
83 |
+
(c_fc): Conv1D()
|
84 |
+
(c_proj): Conv1D()
|
85 |
+
)
|
86 |
+
)
|
87 |
+
(2): Block(
|
88 |
+
(ln_1): LayerNorm()
|
89 |
+
(attn): Attention(
|
90 |
+
(c_attn): MergedLinear(
|
91 |
+
in_features=1024, out_features=3072, bias=True
|
92 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
93 |
+
)
|
94 |
+
(c_proj): Conv1D()
|
95 |
+
)
|
96 |
+
(ln_2): LayerNorm()
|
97 |
+
(mlp): MLP(
|
98 |
+
(c_fc): Conv1D()
|
99 |
+
(c_proj): Conv1D()
|
100 |
+
)
|
101 |
+
)
|
102 |
+
(3): Block(
|
103 |
+
(ln_1): LayerNorm()
|
104 |
+
(attn): Attention(
|
105 |
+
(c_attn): MergedLinear(
|
106 |
+
in_features=1024, out_features=3072, bias=True
|
107 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
108 |
+
)
|
109 |
+
(c_proj): Conv1D()
|
110 |
+
)
|
111 |
+
(ln_2): LayerNorm()
|
112 |
+
(mlp): MLP(
|
113 |
+
(c_fc): Conv1D()
|
114 |
+
(c_proj): Conv1D()
|
115 |
+
)
|
116 |
+
)
|
117 |
+
(4): Block(
|
118 |
+
(ln_1): LayerNorm()
|
119 |
+
(attn): Attention(
|
120 |
+
(c_attn): MergedLinear(
|
121 |
+
in_features=1024, out_features=3072, bias=True
|
122 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
123 |
+
)
|
124 |
+
(c_proj): Conv1D()
|
125 |
+
)
|
126 |
+
(ln_2): LayerNorm()
|
127 |
+
(mlp): MLP(
|
128 |
+
(c_fc): Conv1D()
|
129 |
+
(c_proj): Conv1D()
|
130 |
+
)
|
131 |
+
)
|
132 |
+
(5): Block(
|
133 |
+
(ln_1): LayerNorm()
|
134 |
+
(attn): Attention(
|
135 |
+
(c_attn): MergedLinear(
|
136 |
+
in_features=1024, out_features=3072, bias=True
|
137 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
138 |
+
)
|
139 |
+
(c_proj): Conv1D()
|
140 |
+
)
|
141 |
+
(ln_2): LayerNorm()
|
142 |
+
(mlp): MLP(
|
143 |
+
(c_fc): Conv1D()
|
144 |
+
(c_proj): Conv1D()
|
145 |
+
)
|
146 |
+
)
|
147 |
+
(6): Block(
|
148 |
+
(ln_1): LayerNorm()
|
149 |
+
(attn): Attention(
|
150 |
+
(c_attn): MergedLinear(
|
151 |
+
in_features=1024, out_features=3072, bias=True
|
152 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
153 |
+
)
|
154 |
+
(c_proj): Conv1D()
|
155 |
+
)
|
156 |
+
(ln_2): LayerNorm()
|
157 |
+
(mlp): MLP(
|
158 |
+
(c_fc): Conv1D()
|
159 |
+
(c_proj): Conv1D()
|
160 |
+
)
|
161 |
+
)
|
162 |
+
(7): Block(
|
163 |
+
(ln_1): LayerNorm()
|
164 |
+
(attn): Attention(
|
165 |
+
(c_attn): MergedLinear(
|
166 |
+
in_features=1024, out_features=3072, bias=True
|
167 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
168 |
+
)
|
169 |
+
(c_proj): Conv1D()
|
170 |
+
)
|
171 |
+
(ln_2): LayerNorm()
|
172 |
+
(mlp): MLP(
|
173 |
+
(c_fc): Conv1D()
|
174 |
+
(c_proj): Conv1D()
|
175 |
+
)
|
176 |
+
)
|
177 |
+
(8): Block(
|
178 |
+
(ln_1): LayerNorm()
|
179 |
+
(attn): Attention(
|
180 |
+
(c_attn): MergedLinear(
|
181 |
+
in_features=1024, out_features=3072, bias=True
|
182 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
183 |
+
)
|
184 |
+
(c_proj): Conv1D()
|
185 |
+
)
|
186 |
+
(ln_2): LayerNorm()
|
187 |
+
(mlp): MLP(
|
188 |
+
(c_fc): Conv1D()
|
189 |
+
(c_proj): Conv1D()
|
190 |
+
)
|
191 |
+
)
|
192 |
+
(9): Block(
|
193 |
+
(ln_1): LayerNorm()
|
194 |
+
(attn): Attention(
|
195 |
+
(c_attn): MergedLinear(
|
196 |
+
in_features=1024, out_features=3072, bias=True
|
197 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
198 |
+
)
|
199 |
+
(c_proj): Conv1D()
|
200 |
+
)
|
201 |
+
(ln_2): LayerNorm()
|
202 |
+
(mlp): MLP(
|
203 |
+
(c_fc): Conv1D()
|
204 |
+
(c_proj): Conv1D()
|
205 |
+
)
|
206 |
+
)
|
207 |
+
(10): Block(
|
208 |
+
(ln_1): LayerNorm()
|
209 |
+
(attn): Attention(
|
210 |
+
(c_attn): MergedLinear(
|
211 |
+
in_features=1024, out_features=3072, bias=True
|
212 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
213 |
+
)
|
214 |
+
(c_proj): Conv1D()
|
215 |
+
)
|
216 |
+
(ln_2): LayerNorm()
|
217 |
+
(mlp): MLP(
|
218 |
+
(c_fc): Conv1D()
|
219 |
+
(c_proj): Conv1D()
|
220 |
+
)
|
221 |
+
)
|
222 |
+
(11): Block(
|
223 |
+
(ln_1): LayerNorm()
|
224 |
+
(attn): Attention(
|
225 |
+
(c_attn): MergedLinear(
|
226 |
+
in_features=1024, out_features=3072, bias=True
|
227 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
228 |
+
)
|
229 |
+
(c_proj): Conv1D()
|
230 |
+
)
|
231 |
+
(ln_2): LayerNorm()
|
232 |
+
(mlp): MLP(
|
233 |
+
(c_fc): Conv1D()
|
234 |
+
(c_proj): Conv1D()
|
235 |
+
)
|
236 |
+
)
|
237 |
+
(12): Block(
|
238 |
+
(ln_1): LayerNorm()
|
239 |
+
(attn): Attention(
|
240 |
+
(c_attn): MergedLinear(
|
241 |
+
in_features=1024, out_features=3072, bias=True
|
242 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
243 |
+
)
|
244 |
+
(c_proj): Conv1D()
|
245 |
+
)
|
246 |
+
(ln_2): LayerNorm()
|
247 |
+
(mlp): MLP(
|
248 |
+
(c_fc): Conv1D()
|
249 |
+
(c_proj): Conv1D()
|
250 |
+
)
|
251 |
+
)
|
252 |
+
(13): Block(
|
253 |
+
(ln_1): LayerNorm()
|
254 |
+
(attn): Attention(
|
255 |
+
(c_attn): MergedLinear(
|
256 |
+
in_features=1024, out_features=3072, bias=True
|
257 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
258 |
+
)
|
259 |
+
(c_proj): Conv1D()
|
260 |
+
)
|
261 |
+
(ln_2): LayerNorm()
|
262 |
+
(mlp): MLP(
|
263 |
+
(c_fc): Conv1D()
|
264 |
+
(c_proj): Conv1D()
|
265 |
+
)
|
266 |
+
)
|
267 |
+
(14): Block(
|
268 |
+
(ln_1): LayerNorm()
|
269 |
+
(attn): Attention(
|
270 |
+
(c_attn): MergedLinear(
|
271 |
+
in_features=1024, out_features=3072, bias=True
|
272 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
273 |
+
)
|
274 |
+
(c_proj): Conv1D()
|
275 |
+
)
|
276 |
+
(ln_2): LayerNorm()
|
277 |
+
(mlp): MLP(
|
278 |
+
(c_fc): Conv1D()
|
279 |
+
(c_proj): Conv1D()
|
280 |
+
)
|
281 |
+
)
|
282 |
+
(15): Block(
|
283 |
+
(ln_1): LayerNorm()
|
284 |
+
(attn): Attention(
|
285 |
+
(c_attn): MergedLinear(
|
286 |
+
in_features=1024, out_features=3072, bias=True
|
287 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
288 |
+
)
|
289 |
+
(c_proj): Conv1D()
|
290 |
+
)
|
291 |
+
(ln_2): LayerNorm()
|
292 |
+
(mlp): MLP(
|
293 |
+
(c_fc): Conv1D()
|
294 |
+
(c_proj): Conv1D()
|
295 |
+
)
|
296 |
+
)
|
297 |
+
(16): Block(
|
298 |
+
(ln_1): LayerNorm()
|
299 |
+
(attn): Attention(
|
300 |
+
(c_attn): MergedLinear(
|
301 |
+
in_features=1024, out_features=3072, bias=True
|
302 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
303 |
+
)
|
304 |
+
(c_proj): Conv1D()
|
305 |
+
)
|
306 |
+
(ln_2): LayerNorm()
|
307 |
+
(mlp): MLP(
|
308 |
+
(c_fc): Conv1D()
|
309 |
+
(c_proj): Conv1D()
|
310 |
+
)
|
311 |
+
)
|
312 |
+
(17): Block(
|
313 |
+
(ln_1): LayerNorm()
|
314 |
+
(attn): Attention(
|
315 |
+
(c_attn): MergedLinear(
|
316 |
+
in_features=1024, out_features=3072, bias=True
|
317 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
318 |
+
)
|
319 |
+
(c_proj): Conv1D()
|
320 |
+
)
|
321 |
+
(ln_2): LayerNorm()
|
322 |
+
(mlp): MLP(
|
323 |
+
(c_fc): Conv1D()
|
324 |
+
(c_proj): Conv1D()
|
325 |
+
)
|
326 |
+
)
|
327 |
+
(18): Block(
|
328 |
+
(ln_1): LayerNorm()
|
329 |
+
(attn): Attention(
|
330 |
+
(c_attn): MergedLinear(
|
331 |
+
in_features=1024, out_features=3072, bias=True
|
332 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
333 |
+
)
|
334 |
+
(c_proj): Conv1D()
|
335 |
+
)
|
336 |
+
(ln_2): LayerNorm()
|
337 |
+
(mlp): MLP(
|
338 |
+
(c_fc): Conv1D()
|
339 |
+
(c_proj): Conv1D()
|
340 |
+
)
|
341 |
+
)
|
342 |
+
(19): Block(
|
343 |
+
(ln_1): LayerNorm()
|
344 |
+
(attn): Attention(
|
345 |
+
(c_attn): MergedLinear(
|
346 |
+
in_features=1024, out_features=3072, bias=True
|
347 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
348 |
+
)
|
349 |
+
(c_proj): Conv1D()
|
350 |
+
)
|
351 |
+
(ln_2): LayerNorm()
|
352 |
+
(mlp): MLP(
|
353 |
+
(c_fc): Conv1D()
|
354 |
+
(c_proj): Conv1D()
|
355 |
+
)
|
356 |
+
)
|
357 |
+
(20): Block(
|
358 |
+
(ln_1): LayerNorm()
|
359 |
+
(attn): Attention(
|
360 |
+
(c_attn): MergedLinear(
|
361 |
+
in_features=1024, out_features=3072, bias=True
|
362 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
363 |
+
)
|
364 |
+
(c_proj): Conv1D()
|
365 |
+
)
|
366 |
+
(ln_2): LayerNorm()
|
367 |
+
(mlp): MLP(
|
368 |
+
(c_fc): Conv1D()
|
369 |
+
(c_proj): Conv1D()
|
370 |
+
)
|
371 |
+
)
|
372 |
+
(21): Block(
|
373 |
+
(ln_1): LayerNorm()
|
374 |
+
(attn): Attention(
|
375 |
+
(c_attn): MergedLinear(
|
376 |
+
in_features=1024, out_features=3072, bias=True
|
377 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
378 |
+
)
|
379 |
+
(c_proj): Conv1D()
|
380 |
+
)
|
381 |
+
(ln_2): LayerNorm()
|
382 |
+
(mlp): MLP(
|
383 |
+
(c_fc): Conv1D()
|
384 |
+
(c_proj): Conv1D()
|
385 |
+
)
|
386 |
+
)
|
387 |
+
(22): Block(
|
388 |
+
(ln_1): LayerNorm()
|
389 |
+
(attn): Attention(
|
390 |
+
(c_attn): MergedLinear(
|
391 |
+
in_features=1024, out_features=3072, bias=True
|
392 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
393 |
+
)
|
394 |
+
(c_proj): Conv1D()
|
395 |
+
)
|
396 |
+
(ln_2): LayerNorm()
|
397 |
+
(mlp): MLP(
|
398 |
+
(c_fc): Conv1D()
|
399 |
+
(c_proj): Conv1D()
|
400 |
+
)
|
401 |
+
)
|
402 |
+
(23): Block(
|
403 |
+
(ln_1): LayerNorm()
|
404 |
+
(attn): Attention(
|
405 |
+
(c_attn): MergedLinear(
|
406 |
+
in_features=1024, out_features=3072, bias=True
|
407 |
+
(lora_dropout): Dropout(p=0.1, inplace=False)
|
408 |
+
)
|
409 |
+
(c_proj): Conv1D()
|
410 |
+
)
|
411 |
+
(ln_2): LayerNorm()
|
412 |
+
(mlp): MLP(
|
413 |
+
(c_fc): Conv1D()
|
414 |
+
(c_proj): Conv1D()
|
415 |
+
)
|
416 |
+
)
|
417 |
+
)
|
418 |
+
(ln_f): LayerNorm()
|
419 |
+
)
|
420 |
+
(lm_head): GPT2LMHead(
|
421 |
+
(decoder): Linear(in_features=1024, out_features=50257, bias=False)
|
422 |
+
)
|
423 |
+
)
|
424 |
+
vocab_size: 50257
|
425 |
+
n_ctx: 1024
|
426 |
+
n_positions: 1024
|
427 |
+
n_embd: 1024
|
428 |
+
n_layer: 24
|
429 |
+
n_head: 16
|
430 |
+
layer_norm_epsilon: 1e-05
|
431 |
+
initializer_range: 0.02
|
432 |
+
lora_attn_dim: 4
|
433 |
+
lora_attn_alpha: 32
|
434 |
+
lora_dropout: 0.1
|
435 |
+
lora_r_dropout: 0.0
|
436 |
+
fix_dropout: 0.0
|
437 |
+
Namespace(adam_beta1=0.9, adam_beta2=0.999, adam_epislon=1e-06, clip=0.0, correct_bias=True, device=device(type='cuda', index=4), dist=<module 'torch.distributed' from '/home/inc/miniconda3/envs/fedadp-new/lib/python3.7/site-packages/torch/distributed/__init__.py'>, eval_epoch=1, eval_interval=2000, fp16=False, grad_acc=1, i_lrs='0.00025', i_steps='0', init_checkpoint='./pretrained_checkpoints/gpt2-medium-pytorch_model.bin', label_smooth=0.1, local_rank=0, log_interval=100, logging=functools.partial(<function logging at 0x7f90cac2ae60>, log_path='./trained_models/GPT2_M/e2e/log.txt'), lora_alpha=32, lora_dim=4, lora_dropout=0.1, lr=0.0002, max_epoch=5, max_step=None, model_card='gpt2.md', no_decay_bias=False, obj='clm', platform='local', random_seed=110, rank=0, roll_interval=-1, roll_lr=1e-05, roll_step=100, save_interval=1000, scheduler='linear', seq_len=512, train_batch_size=8, train_data='./data/e2e/train.jsonl', valid_batch_size=4, valid_data='./data/e2e/valid.jsonl', warmup_step=500, weight_decay=0.01, work_dir='./trained_models/GPT2_M/e2e', world_size=1)
|
438 |
+
optimizer: AdamW (
|
439 |
+
Parameter Group 0
|
440 |
+
betas: (0.9, 0.999)
|
441 |
+
correct_bias: True
|
442 |
+
eps: 1e-06
|
443 |
+
lr: 0.0002
|
444 |
+
weight_decay: 0.01
|
445 |
+
)
|
446 |
+
set max_step: 26290
|
447 |
+
train_data.num_batches: 5258
|
448 |
+
start to train the model................ 1
|
449 |
+
/home/inc/Documents/fzh/python/LoRA-main/examples/NLG/src/optimizer.py:117: UserWarning: This overload of addcdiv_ is deprecated:
|
450 |
+
addcdiv_(Number value, Tensor tensor1, Tensor tensor2)
|
451 |
+
Consider using one of the following signatures instead:
|
452 |
+
addcdiv_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1050.)
|
453 |
+
p.data.addcdiv_(-step_size, exp_avg, denom)
|
454 |
+
|
455 |
+
|
456 |
+
| epoch 1 step 100 | 100 batches | lr 4e-05 | ms/batch 612.69 | loss 5.06 | avg loss 5.52 | ppl 250.72
|
457 |
+
| epoch 1 step 200 | 200 batches | lr 8e-05 | ms/batch 608.52 | loss 3.21 | avg loss 3.70 | ppl 40.58
|
458 |
+
| epoch 1 step 300 | 300 batches | lr 0.00012 | ms/batch 609.77 | loss 2.98 | avg loss 3.08 | ppl 21.74
|
459 |
+
| epoch 1 step 400 | 400 batches | lr 0.00016 | ms/batch 610.18 | loss 3.11 | avg loss 2.98 | ppl 19.63
|
460 |
+
| epoch 1 step 500 | 500 batches | lr 0.0002 | ms/batch 610.03 | loss 2.84 | avg loss 2.89 | ppl 18.03
|
461 |
+
| epoch 1 step 600 | 600 batches | lr 0.000199 | ms/batch 608.84 | loss 2.77 | avg loss 2.83 | ppl 16.93
|
462 |
+
| epoch 1 step 700 | 700 batches | lr 0.000198 | ms/batch 611.37 | loss 2.88 | avg loss 2.80 | ppl 16.37
|
463 |
+
| epoch 1 step 800 | 800 batches | lr 0.000198 | ms/batch 611.10 | loss 2.48 | avg loss 2.76 | ppl 15.76
|
464 |
+
| epoch 1 step 900 | 900 batches | lr 0.000197 | ms/batch 610.61 | loss 2.50 | avg loss 2.75 | ppl 15.59
|
465 |
+
| epoch 1 step 1000 | 1000 batches | lr 0.000196 | ms/batch 610.44 | loss 3.19 | avg loss 2.77 | ppl 15.95
|
466 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.1000.pt
|
467 |
+
| epoch 1 step 1100 | 1100 batches | lr 0.000195 | ms/batch 612.14 | loss 2.76 | avg loss 2.73 | ppl 15.41
|
468 |
+
| epoch 1 step 1200 | 1200 batches | lr 0.000195 | ms/batch 608.16 | loss 3.02 | avg loss 2.76 | ppl 15.84
|
469 |
+
| epoch 1 step 1300 | 1300 batches | lr 0.000194 | ms/batch 610.06 | loss 2.55 | avg loss 2.75 | ppl 15.62
|
470 |
+
| epoch 1 step 1400 | 1400 batches | lr 0.000193 | ms/batch 609.24 | loss 2.35 | avg loss 2.70 | ppl 14.93
|
471 |
+
| epoch 1 step 1500 | 1500 batches | lr 0.000192 | ms/batch 607.91 | loss 2.53 | avg loss 2.72 | ppl 15.24
|
472 |
+
| epoch 1 step 1600 | 1600 batches | lr 0.000191 | ms/batch 608.62 | loss 2.53 | avg loss 2.67 | ppl 14.50
|
473 |
+
| epoch 1 step 1700 | 1700 batches | lr 0.000191 | ms/batch 608.92 | loss 2.66 | avg loss 2.71 | ppl 14.99
|
474 |
+
| epoch 1 step 1800 | 1800 batches | lr 0.00019 | ms/batch 608.44 | loss 2.55 | avg loss 2.69 | ppl 14.75
|
475 |
+
| epoch 1 step 1900 | 1900 batches | lr 0.000189 | ms/batch 609.27 | loss 2.43 | avg loss 2.66 | ppl 14.31
|
476 |
+
| epoch 1 step 2000 | 2000 batches | lr 0.000188 | ms/batch 607.05 | loss 2.71 | avg loss 2.66 | ppl 14.36
|
477 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.2000.pt
|
478 |
+
/home/inc/miniconda3/envs/fedadp-new/lib/python3.7/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
|
479 |
+
warnings.warn(warning.format(ret))
|
480 |
+
eval samples: 0 loss: tensor(1.1374, device='cuda:4')
|
481 |
+
eval samples: 100 loss: tensor(1.0985, device='cuda:4')
|
482 |
+
eval samples: 200 loss: tensor(1.2215, device='cuda:4')
|
483 |
+
eval samples: 300 loss: tensor(1.2918, device='cuda:4')
|
484 |
+
eval samples: 400 loss: tensor(1.6716, device='cuda:4')
|
485 |
+
eval samples: 500 loss: tensor(1.9854, device='cuda:4')
|
486 |
+
eval samples: 600 loss: tensor(1.2216, device='cuda:4')
|
487 |
+
eval samples: 700 loss: tensor(1.0347, device='cuda:4')
|
488 |
+
eval samples: 800 loss: tensor(1.5289, device='cuda:4')
|
489 |
+
eval samples: 900 loss: tensor(1.5743, device='cuda:4')
|
490 |
+
eval samples: 1000 loss: tensor(1.3339, device='cuda:4')
|
491 |
+
eval samples: 1100 loss: tensor(1.3198, device='cuda:4')
|
492 |
+
average loss 1.3344345796496084
|
493 |
+
----------------------------------------------------------------------------------------------------
|
494 |
+
| Eval 1 at step 2000 | time: 137.89s | valid loss 1.33 | valid ppl 3.80 | best ppl 3.80
|
495 |
+
----------------------------------------------------------------------------------------------------
|
496 |
+
| epoch 1 step 2100 | 2100 batches | lr 0.000188 | ms/batch 1988.14 | loss 2.64 | avg loss 2.68 | ppl 14.57
|
497 |
+
| epoch 1 step 2200 | 2200 batches | lr 0.000187 | ms/batch 608.77 | loss 2.45 | avg loss 2.66 | ppl 14.34
|
498 |
+
| epoch 1 step 2300 | 2300 batches | lr 0.000186 | ms/batch 610.52 | loss 2.60 | avg loss 2.67 | ppl 14.38
|
499 |
+
| epoch 1 step 2400 | 2400 batches | lr 0.000185 | ms/batch 608.14 | loss 2.70 | avg loss 2.67 | ppl 14.49
|
500 |
+
| epoch 1 step 2500 | 2500 batches | lr 0.000184 | ms/batch 607.87 | loss 2.52 | avg loss 2.64 | ppl 14.05
|
501 |
+
| epoch 1 step 2600 | 2600 batches | lr 0.000184 | ms/batch 608.44 | loss 2.54 | avg loss 2.70 | ppl 14.85
|
502 |
+
| epoch 1 step 2700 | 2700 batches | lr 0.000183 | ms/batch 608.49 | loss 2.87 | avg loss 2.69 | ppl 14.72
|
503 |
+
| epoch 1 step 2800 | 2800 batches | lr 0.000182 | ms/batch 608.82 | loss 2.44 | avg loss 2.66 | ppl 14.26
|
504 |
+
| epoch 1 step 2900 | 2900 batches | lr 0.000181 | ms/batch 609.19 | loss 2.69 | avg loss 2.68 | ppl 14.52
|
505 |
+
| epoch 1 step 3000 | 3000 batches | lr 0.000181 | ms/batch 609.05 | loss 2.73 | avg loss 2.64 | ppl 13.99
|
506 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.3000.pt
|
507 |
+
| epoch 1 step 3100 | 3100 batches | lr 0.00018 | ms/batch 609.17 | loss 2.63 | avg loss 2.64 | ppl 14.04
|
508 |
+
| epoch 1 step 3200 | 3200 batches | lr 0.000179 | ms/batch 609.50 | loss 2.57 | avg loss 2.66 | ppl 14.28
|
509 |
+
| epoch 1 step 3300 | 3300 batches | lr 0.000178 | ms/batch 607.31 | loss 2.47 | avg loss 2.62 | ppl 13.76
|
510 |
+
| epoch 1 step 3400 | 3400 batches | lr 0.000178 | ms/batch 604.83 | loss 2.54 | avg loss 2.60 | ppl 13.49
|
511 |
+
| epoch 1 step 3500 | 3500 batches | lr 0.000177 | ms/batch 607.92 | loss 2.62 | avg loss 2.63 | ppl 13.90
|
512 |
+
| epoch 1 step 3600 | 3600 batches | lr 0.000176 | ms/batch 608.49 | loss 2.41 | avg loss 2.62 | ppl 13.78
|
513 |
+
| epoch 1 step 3700 | 3700 batches | lr 0.000175 | ms/batch 605.91 | loss 2.58 | avg loss 2.59 | ppl 13.36
|
514 |
+
| epoch 1 step 3800 | 3800 batches | lr 0.000174 | ms/batch 607.54 | loss 2.46 | avg loss 2.64 | ppl 13.97
|
515 |
+
| epoch 1 step 3900 | 3900 batches | lr 0.000174 | ms/batch 610.01 | loss 2.68 | avg loss 2.66 | ppl 14.24
|
516 |
+
| epoch 1 step 4000 | 4000 batches | lr 0.000173 | ms/batch 607.98 | loss 2.78 | avg loss 2.64 | ppl 14.04
|
517 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.4000.pt
|
518 |
+
eval samples: 0 loss: tensor(1.1133, device='cuda:4')
|
519 |
+
eval samples: 100 loss: tensor(1.0210, device='cuda:4')
|
520 |
+
eval samples: 200 loss: tensor(1.1742, device='cuda:4')
|
521 |
+
eval samples: 300 loss: tensor(1.2072, device='cuda:4')
|
522 |
+
eval samples: 400 loss: tensor(1.6256, device='cuda:4')
|
523 |
+
eval samples: 500 loss: tensor(1.9378, device='cuda:4')
|
524 |
+
eval samples: 600 loss: tensor(1.0971, device='cuda:4')
|
525 |
+
eval samples: 700 loss: tensor(1.0210, device='cuda:4')
|
526 |
+
eval samples: 800 loss: tensor(1.4538, device='cuda:4')
|
527 |
+
eval samples: 900 loss: tensor(1.5298, device='cuda:4')
|
528 |
+
eval samples: 1000 loss: tensor(1.2354, device='cuda:4')
|
529 |
+
eval samples: 1100 loss: tensor(1.2567, device='cuda:4')
|
530 |
+
average loss 1.2714025441506138
|
531 |
+
----------------------------------------------------------------------------------------------------
|
532 |
+
| Eval 2 at step 4000 | time: 138.19s | valid loss 1.27 | valid ppl 3.57 | best ppl 3.57
|
533 |
+
----------------------------------------------------------------------------------------------------
|
534 |
+
| epoch 1 step 4100 | 4100 batches | lr 0.000172 | ms/batch 1990.32 | loss 2.81 | avg loss 2.62 | ppl 13.78
|
535 |
+
| epoch 1 step 4200 | 4200 batches | lr 0.000171 | ms/batch 608.76 | loss 3.11 | avg loss 2.61 | ppl 13.57
|
536 |
+
| epoch 1 step 4300 | 4300 batches | lr 0.000171 | ms/batch 610.45 | loss 2.46 | avg loss 2.61 | ppl 13.63
|
537 |
+
| epoch 1 step 4400 | 4400 batches | lr 0.00017 | ms/batch 610.84 | loss 2.96 | avg loss 2.62 | ppl 13.74
|
538 |
+
| epoch 1 step 4500 | 4500 batches | lr 0.000169 | ms/batch 611.36 | loss 2.78 | avg loss 2.61 | ppl 13.58
|
539 |
+
| epoch 1 step 4600 | 4600 batches | lr 0.000168 | ms/batch 612.08 | loss 2.81 | avg loss 2.57 | ppl 13.07
|
540 |
+
| epoch 1 step 4700 | 4700 batches | lr 0.000167 | ms/batch 615.36 | loss 2.90 | avg loss 2.63 | ppl 13.91
|
541 |
+
| epoch 1 step 4800 | 4800 batches | lr 0.000167 | ms/batch 611.17 | loss 2.99 | avg loss 2.61 | ppl 13.55
|
542 |
+
| epoch 1 step 4900 | 4900 batches | lr 0.000166 | ms/batch 608.81 | loss 2.73 | avg loss 2.60 | ppl 13.47
|
543 |
+
| epoch 1 step 5000 | 5000 batches | lr 0.000165 | ms/batch 609.73 | loss 2.50 | avg loss 2.58 | ppl 13.26
|
544 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.5000.pt
|
545 |
+
| epoch 1 step 5100 | 5100 batches | lr 0.000164 | ms/batch 609.36 | loss 2.27 | avg loss 2.59 | ppl 13.33
|
546 |
+
| epoch 1 step 5200 | 5200 batches | lr 0.000164 | ms/batch 611.66 | loss 2.39 | avg loss 2.62 | ppl 13.78
|
547 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.5258.pt
|
548 |
+
start to train the model................ 2
|
549 |
+
| epoch 2 step 5300 | 42 batches | lr 0.000163 | ms/batch 256.06 | loss 2.41 | avg loss 2.61 | ppl 13.53
|
550 |
+
| epoch 2 step 5400 | 142 batches | lr 0.000162 | ms/batch 609.01 | loss 2.63 | avg loss 2.61 | ppl 13.58
|
551 |
+
| epoch 2 step 5500 | 242 batches | lr 0.000161 | ms/batch 612.10 | loss 2.45 | avg loss 2.59 | ppl 13.30
|
552 |
+
| epoch 2 step 5600 | 342 batches | lr 0.00016 | ms/batch 611.07 | loss 2.67 | avg loss 2.59 | ppl 13.27
|
553 |
+
| epoch 2 step 5700 | 442 batches | lr 0.00016 | ms/batch 611.19 | loss 2.52 | avg loss 2.64 | ppl 13.95
|
554 |
+
| epoch 2 step 5800 | 542 batches | lr 0.000159 | ms/batch 611.61 | loss 2.87 | avg loss 2.57 | ppl 13.10
|
555 |
+
| epoch 2 step 5900 | 642 batches | lr 0.000158 | ms/batch 612.67 | loss 3.17 | avg loss 2.58 | ppl 13.25
|
556 |
+
| epoch 2 step 6000 | 742 batches | lr 0.000157 | ms/batch 610.88 | loss 2.45 | avg loss 2.59 | ppl 13.32
|
557 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.6000.pt
|
558 |
+
eval samples: 0 loss: tensor(1.0454, device='cuda:4')
|
559 |
+
eval samples: 100 loss: tensor(0.9909, device='cuda:4')
|
560 |
+
eval samples: 200 loss: tensor(1.1352, device='cuda:4')
|
561 |
+
eval samples: 300 loss: tensor(1.1335, device='cuda:4')
|
562 |
+
eval samples: 400 loss: tensor(1.5766, device='cuda:4')
|
563 |
+
eval samples: 500 loss: tensor(2.0034, device='cuda:4')
|
564 |
+
eval samples: 600 loss: tensor(1.1043, device='cuda:4')
|
565 |
+
eval samples: 700 loss: tensor(0.9965, device='cuda:4')
|
566 |
+
eval samples: 800 loss: tensor(1.4912, device='cuda:4')
|
567 |
+
eval samples: 900 loss: tensor(1.5128, device='cuda:4')
|
568 |
+
eval samples: 1000 loss: tensor(1.1385, device='cuda:4')
|
569 |
+
eval samples: 1100 loss: tensor(1.2201, device='cuda:4')
|
570 |
+
average loss 1.239899498908079
|
571 |
+
----------------------------------------------------------------------------------------------------
|
572 |
+
| Eval 3 at step 6000 | time: 138.83s | valid loss 1.24 | valid ppl 3.46 | best ppl 3.46
|
573 |
+
----------------------------------------------------------------------------------------------------
|
574 |
+
| epoch 2 step 6100 | 842 batches | lr 0.000157 | ms/batch 1999.78 | loss 2.55 | avg loss 2.61 | ppl 13.54
|
575 |
+
| epoch 2 step 6200 | 942 batches | lr 0.000156 | ms/batch 612.01 | loss 2.72 | avg loss 2.60 | ppl 13.48
|
576 |
+
| epoch 2 step 6300 | 1042 batches | lr 0.000155 | ms/batch 611.75 | loss 2.61 | avg loss 2.58 | ppl 13.26
|
577 |
+
| epoch 2 step 6400 | 1142 batches | lr 0.000154 | ms/batch 612.29 | loss 2.48 | avg loss 2.58 | ppl 13.15
|
578 |
+
| epoch 2 step 6500 | 1242 batches | lr 0.000153 | ms/batch 613.03 | loss 2.90 | avg loss 2.62 | ppl 13.67
|
579 |
+
| epoch 2 step 6600 | 1342 batches | lr 0.000153 | ms/batch 611.04 | loss 3.07 | avg loss 2.58 | ppl 13.16
|
580 |
+
| epoch 2 step 6700 | 1442 batches | lr 0.000152 | ms/batch 611.17 | loss 2.79 | avg loss 2.56 | ppl 12.96
|
581 |
+
| epoch 2 step 6800 | 1542 batches | lr 0.000151 | ms/batch 614.47 | loss 2.50 | avg loss 2.56 | ppl 12.95
|
582 |
+
| epoch 2 step 6900 | 1642 batches | lr 0.00015 | ms/batch 610.47 | loss 2.71 | avg loss 2.56 | ppl 12.99
|
583 |
+
| epoch 2 step 7000 | 1742 batches | lr 0.00015 | ms/batch 608.59 | loss 2.56 | avg loss 2.59 | ppl 13.37
|
584 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.7000.pt
|
585 |
+
| epoch 2 step 7100 | 1842 batches | lr 0.000149 | ms/batch 610.96 | loss 2.32 | avg loss 2.57 | ppl 13.01
|
586 |
+
| epoch 2 step 7200 | 1942 batches | lr 0.000148 | ms/batch 610.97 | loss 2.41 | avg loss 2.53 | ppl 12.50
|
587 |
+
| epoch 2 step 7300 | 2042 batches | lr 0.000147 | ms/batch 611.57 | loss 2.48 | avg loss 2.57 | ppl 13.10
|
588 |
+
| epoch 2 step 7400 | 2142 batches | lr 0.000146 | ms/batch 610.40 | loss 2.39 | avg loss 2.56 | ppl 12.89
|
589 |
+
| epoch 2 step 7500 | 2242 batches | lr 0.000146 | ms/batch 610.66 | loss 2.63 | avg loss 2.57 | ppl 13.04
|
590 |
+
| epoch 2 step 7600 | 2342 batches | lr 0.000145 | ms/batch 610.52 | loss 2.63 | avg loss 2.58 | ppl 13.26
|
591 |
+
| epoch 2 step 7700 | 2442 batches | lr 0.000144 | ms/batch 608.69 | loss 2.22 | avg loss 2.54 | ppl 12.73
|
592 |
+
| epoch 2 step 7800 | 2542 batches | lr 0.000143 | ms/batch 609.99 | loss 2.35 | avg loss 2.57 | ppl 13.07
|
593 |
+
| epoch 2 step 7900 | 2642 batches | lr 0.000143 | ms/batch 609.05 | loss 2.72 | avg loss 2.60 | ppl 13.47
|
594 |
+
| epoch 2 step 8000 | 2742 batches | lr 0.000142 | ms/batch 609.02 | loss 2.57 | avg loss 2.59 | ppl 13.30
|
595 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.8000.pt
|
596 |
+
eval samples: 0 loss: tensor(1.0535, device='cuda:4')
|
597 |
+
eval samples: 100 loss: tensor(0.9691, device='cuda:4')
|
598 |
+
eval samples: 200 loss: tensor(1.1137, device='cuda:4')
|
599 |
+
eval samples: 300 loss: tensor(1.1214, device='cuda:4')
|
600 |
+
eval samples: 400 loss: tensor(1.5688, device='cuda:4')
|
601 |
+
eval samples: 500 loss: tensor(1.9425, device='cuda:4')
|
602 |
+
eval samples: 600 loss: tensor(1.0476, device='cuda:4')
|
603 |
+
eval samples: 700 loss: tensor(0.9898, device='cuda:4')
|
604 |
+
eval samples: 800 loss: tensor(1.4776, device='cuda:4')
|
605 |
+
eval samples: 900 loss: tensor(1.5046, device='cuda:4')
|
606 |
+
eval samples: 1000 loss: tensor(1.1689, device='cuda:4')
|
607 |
+
eval samples: 1100 loss: tensor(1.1641, device='cuda:4')
|
608 |
+
average loss 1.2270236368456933
|
609 |
+
----------------------------------------------------------------------------------------------------
|
610 |
+
| Eval 4 at step 8000 | time: 138.04s | valid loss 1.23 | valid ppl 3.41 | best ppl 3.41
|
611 |
+
----------------------------------------------------------------------------------------------------
|
612 |
+
| epoch 2 step 8100 | 2842 batches | lr 0.000141 | ms/batch 1991.53 | loss 2.46 | avg loss 2.56 | ppl 12.98
|
613 |
+
| epoch 2 step 8200 | 2942 batches | lr 0.00014 | ms/batch 609.84 | loss 2.50 | avg loss 2.60 | ppl 13.49
|
614 |
+
| epoch 2 step 8300 | 3042 batches | lr 0.00014 | ms/batch 610.87 | loss 2.47 | avg loss 2.54 | ppl 12.72
|
615 |
+
| epoch 2 step 8400 | 3142 batches | lr 0.000139 | ms/batch 610.92 | loss 2.41 | avg loss 2.57 | ppl 13.03
|
616 |
+
| epoch 2 step 8500 | 3242 batches | lr 0.000138 | ms/batch 611.04 | loss 2.81 | avg loss 2.56 | ppl 12.89
|
617 |
+
| epoch 2 step 8600 | 3342 batches | lr 0.000137 | ms/batch 612.82 | loss 2.40 | avg loss 2.55 | ppl 12.87
|
618 |
+
| epoch 2 step 8700 | 3442 batches | lr 0.000136 | ms/batch 611.25 | loss 2.47 | avg loss 2.52 | ppl 12.43
|
619 |
+
| epoch 2 step 8800 | 3542 batches | lr 0.000136 | ms/batch 611.59 | loss 2.57 | avg loss 2.55 | ppl 12.86
|
620 |
+
| epoch 2 step 8900 | 3642 batches | lr 0.000135 | ms/batch 611.43 | loss 2.33 | avg loss 2.54 | ppl 12.62
|
621 |
+
| epoch 2 step 9000 | 3742 batches | lr 0.000134 | ms/batch 610.78 | loss 2.96 | avg loss 2.55 | ppl 12.78
|
622 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.9000.pt
|
623 |
+
| epoch 2 step 9100 | 3842 batches | lr 0.000133 | ms/batch 608.39 | loss 2.67 | avg loss 2.55 | ppl 12.81
|
624 |
+
| epoch 2 step 9200 | 3942 batches | lr 0.000133 | ms/batch 611.72 | loss 2.65 | avg loss 2.58 | ppl 13.17
|
625 |
+
| epoch 2 step 9300 | 4042 batches | lr 0.000132 | ms/batch 611.24 | loss 2.60 | avg loss 2.58 | ppl 13.15
|
626 |
+
| epoch 2 step 9400 | 4142 batches | lr 0.000131 | ms/batch 613.45 | loss 2.58 | avg loss 2.56 | ppl 12.95
|
627 |
+
| epoch 2 step 9500 | 4242 batches | lr 0.00013 | ms/batch 611.51 | loss 2.40 | avg loss 2.54 | ppl 12.71
|
628 |
+
| epoch 2 step 9600 | 4342 batches | lr 0.000129 | ms/batch 613.03 | loss 2.62 | avg loss 2.53 | ppl 12.55
|
629 |
+
| epoch 2 step 9700 | 4442 batches | lr 0.000129 | ms/batch 612.45 | loss 2.26 | avg loss 2.54 | ppl 12.74
|
630 |
+
| epoch 2 step 9800 | 4542 batches | lr 0.000128 | ms/batch 610.95 | loss 2.78 | avg loss 2.55 | ppl 12.82
|
631 |
+
| epoch 2 step 9900 | 4642 batches | lr 0.000127 | ms/batch 608.32 | loss 2.61 | avg loss 2.52 | ppl 12.37
|
632 |
+
| epoch 2 step 10000 | 4742 batches | lr 0.000126 | ms/batch 610.72 | loss 2.45 | avg loss 2.54 | ppl 12.73
|
633 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.10000.pt
|
634 |
+
eval samples: 0 loss: tensor(1.0123, device='cuda:4')
|
635 |
+
eval samples: 100 loss: tensor(1.0022, device='cuda:4')
|
636 |
+
eval samples: 200 loss: tensor(1.0972, device='cuda:4')
|
637 |
+
eval samples: 300 loss: tensor(1.1317, device='cuda:4')
|
638 |
+
eval samples: 400 loss: tensor(1.5788, device='cuda:4')
|
639 |
+
eval samples: 500 loss: tensor(1.9430, device='cuda:4')
|
640 |
+
eval samples: 600 loss: tensor(1.0426, device='cuda:4')
|
641 |
+
eval samples: 700 loss: tensor(0.9720, device='cuda:4')
|
642 |
+
eval samples: 800 loss: tensor(1.4556, device='cuda:4')
|
643 |
+
eval samples: 900 loss: tensor(1.4790, device='cuda:4')
|
644 |
+
eval samples: 1000 loss: tensor(1.1323, device='cuda:4')
|
645 |
+
eval samples: 1100 loss: tensor(1.1691, device='cuda:4')
|
646 |
+
average loss 1.2222425683006033
|
647 |
+
----------------------------------------------------------------------------------------------------
|
648 |
+
| Eval 5 at step 10000 | time: 139.05s | valid loss 1.22 | valid ppl 3.39 | best ppl 3.39
|
649 |
+
----------------------------------------------------------------------------------------------------
|
650 |
+
| epoch 2 step 10100 | 4842 batches | lr 0.000126 | ms/batch 2003.85 | loss 2.46 | avg loss 2.55 | ppl 12.79
|
651 |
+
| epoch 2 step 10200 | 4942 batches | lr 0.000125 | ms/batch 609.56 | loss 2.62 | avg loss 2.56 | ppl 12.88
|
652 |
+
| epoch 2 step 10300 | 5042 batches | lr 0.000124 | ms/batch 610.36 | loss 2.85 | avg loss 2.51 | ppl 12.28
|
653 |
+
| epoch 2 step 10400 | 5142 batches | lr 0.000123 | ms/batch 610.63 | loss 2.40 | avg loss 2.57 | ppl 13.05
|
654 |
+
| epoch 2 step 10500 | 5242 batches | lr 0.000122 | ms/batch 613.64 | loss 2.43 | avg loss 2.52 | ppl 12.45
|
655 |
+
saving checkpoint ./trained_models/GPT2_M/e2e/model.10516.pt
|
656 |
+
start to train the model................ 3
|
657 |
+
| epoch 3 step 10600 | 84 batches | lr 0.000122 | ms/batch 510.61 | loss 2.63 | avg loss 2.53 | ppl 12.61
|
658 |
+
| epoch 3 step 10700 | 184 batches | lr 0.000121 | ms/batch 613.48 | loss 2.67 | avg loss 2.56 | ppl 13.00
|
659 |
+
| epoch 3 step 10800 | 284 batches | lr 0.00012 | ms/batch 608.43 | loss 2.48 | avg loss 2.52 | ppl 12.39
|
660 |
+
| epoch 3 step 10900 | 384 batches | lr 0.000119 | ms/batch 611.59 | loss 2.69 | avg loss 2.56 | ppl 12.91
|
661 |
+
|
662 |
+
|
663 |
+
|
664 |
+
|
665 |
+
|
666 |
+
Running MS-COCO evaluator...
|
667 |
+
creating index...
|
668 |
+
index created!
|
669 |
+
Loading and preparing results...
|
670 |
+
DONE (t=0.00s)
|
671 |
+
creating index...
|
672 |
+
index created!
|
673 |
+
tokenization...
|
674 |
+
PTBTokenizer tokenized 22530 tokens at 184928.37 tokens per second.
|
675 |
+
PTBTokenizer tokenized 2122 tokens at 21442.98 tokens per second.
|
676 |
+
setting up scorers...
|
677 |
+
computing METEOR score...
|
678 |
+
METEOR: 0.485
|
679 |
+
computing Rouge score...
|
680 |
+
ROUGE_L: 0.761
|
681 |
+
computing CIDEr score...
|
682 |
+
CIDEr: 3.314
|
683 |
+
Running Py-MTEval metrics...
|
684 |
+
SCORES:
|
685 |
+
==============
|
686 |
+
BLEU: 0.7401
|
687 |
+
NIST: 8.6766
|
688 |
+
METEOR: 0.4851
|
689 |
+
ROUGE_L: 0.7614
|
690 |
+
CIDEr: 3.3144
|
691 |
+
=== lora.Linear, model.5258.pt ===
|
692 |
+
|
693 |
+
BLEU: 0.7905
|
694 |
+
NIST: 9.1684
|
695 |
+
METEOR: 0.5016
|
696 |
+
ROUGE_L: 0.7865
|
697 |
+
CIDEr: 3.4686
|
698 |
+
=== lora.MergedLinear, model.26290.pt ===
|
examples/NLG/src/model.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
+
# ------------------------------------------------------------------------------------------
|
5 |
+
import logging
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
from collections import OrderedDict
|
9 |
+
import copy
|
10 |
+
import math
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.optim import Optimizer
|
17 |
+
from torch.optim.lr_scheduler import LambdaLR
|
18 |
+
from torch.nn.parameter import Parameter
|
19 |
+
|
20 |
+
import loralib as lora
|
21 |
+
|
22 |
+
|
23 |
+
def gelu(x):
|
24 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
25 |
+
|
26 |
+
|
27 |
+
def gelu_fast(x):
|
28 |
+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
29 |
+
|
30 |
+
|
31 |
+
def gelu_new(x):
|
32 |
+
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
33 |
+
Also see https://arxiv.org/abs/1606.08415
|
34 |
+
"""
|
35 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
36 |
+
|
37 |
+
|
38 |
+
def swish(x):
|
39 |
+
return x * torch.sigmoid(x)
|
40 |
+
|
41 |
+
|
42 |
+
def _gelu_python(x):
|
43 |
+
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
44 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
45 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
46 |
+
This is now written in C in torch.nn.functional
|
47 |
+
Also see https://arxiv.org/abs/1606.08415
|
48 |
+
"""
|
49 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
50 |
+
|
51 |
+
|
52 |
+
class LayerNorm(nn.Module):
|
53 |
+
def __init__(self, hidden_size, eps=1e-12):
|
54 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root)."""
|
55 |
+
super(LayerNorm, self).__init__()
|
56 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
57 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
58 |
+
self.variance_epsilon = eps
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
u = x.mean(-1, keepdim=True)
|
62 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
63 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
64 |
+
return self.weight * x + self.bias
|
65 |
+
|
66 |
+
|
67 |
+
class Conv1D(nn.Module):
|
68 |
+
def __init__(self, nf, nx):
|
69 |
+
super(Conv1D, self).__init__()
|
70 |
+
self.nf = nf
|
71 |
+
w = torch.empty(nx, nf)
|
72 |
+
nn.init.normal_(w, std=0.02)
|
73 |
+
self.weight = Parameter(w)
|
74 |
+
self.bias = Parameter(torch.zeros(nf))
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
size_out = x.size()[:-1] + (self.nf,)
|
78 |
+
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
79 |
+
x = x.view(*size_out)
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class Attention(nn.Module):
|
84 |
+
def __init__(self, nx, n_ctx, config, scale=False):
|
85 |
+
super(Attention, self).__init__()
|
86 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
87 |
+
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
88 |
+
|
89 |
+
assert n_state % config.n_head == 0
|
90 |
+
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
91 |
+
self.n_head = config.n_head
|
92 |
+
self.split_size = n_state
|
93 |
+
self.scale = scale
|
94 |
+
self.c_attn = Conv1D(n_state * 3, nx)
|
95 |
+
self.c_attn = lora.MergedLinear(
|
96 |
+
nx, n_state * 3,
|
97 |
+
r=config.lora_attn_dim,
|
98 |
+
lora_alpha=config.lora_attn_alpha,
|
99 |
+
lora_dropout=config.lora_dropout,
|
100 |
+
enable_lora=[True, False, True],
|
101 |
+
fan_in_fan_out=True,
|
102 |
+
merge_weights=False
|
103 |
+
)
|
104 |
+
# self.c_attn = lora.Linear(
|
105 |
+
# nx, n_state * 3,
|
106 |
+
# r=config.lora_attn_dim,
|
107 |
+
# lora_alpha=config.lora_attn_alpha,
|
108 |
+
# lora_dropout=config.lora_dropout,
|
109 |
+
# fan_in_fan_out=True,
|
110 |
+
# merge_weights=False
|
111 |
+
# )
|
112 |
+
print(f"scaling = {config.lora_attn_alpha / config.lora_attn_dim}")
|
113 |
+
self.c_proj = Conv1D(n_state, nx)
|
114 |
+
|
115 |
+
self.config = config
|
116 |
+
|
117 |
+
def _attn(self, q, k, v, len_kv=None):
|
118 |
+
w = torch.matmul(q, k)
|
119 |
+
if self.scale:
|
120 |
+
w = w / math.sqrt(v.size(-1))
|
121 |
+
nd, ns = w.size(-2), w.size(-1)
|
122 |
+
b = self.bias[:, :, ns-nd:ns, :ns]
|
123 |
+
w = w * b - 1e10 * (1 - b)
|
124 |
+
|
125 |
+
# q : (batch, head, q_seq_length, head_features)
|
126 |
+
# k : (batch, head, head_features, kv_seq_length)
|
127 |
+
# w : (batch, head, q_seq_length, kv_seq_length)
|
128 |
+
# v : (batch, head, kv_seq_length, head_features)
|
129 |
+
if len_kv is not None:
|
130 |
+
_len = torch.arange(k.size(-1), device=k.device)
|
131 |
+
_input_msk = _len[None, :] >= (len_kv)[:, None]
|
132 |
+
w = w.masked_fill(_input_msk.unsqueeze(1).unsqueeze(2), -1.0e10)
|
133 |
+
|
134 |
+
w = nn.Softmax(dim=-1)(w)
|
135 |
+
return torch.matmul(w, v)
|
136 |
+
|
137 |
+
def merge_heads(self, x):
|
138 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
139 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
140 |
+
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
141 |
+
|
142 |
+
def split_heads(self, x, k=False):
|
143 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
144 |
+
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
145 |
+
if k:
|
146 |
+
return x.permute(0, 2, 3, 1).contiguous() # (batch, head, head_features, seq_length)
|
147 |
+
else:
|
148 |
+
return x.permute(0, 2, 1, 3).contiguous() # (batch, head, seq_length, head_features)
|
149 |
+
|
150 |
+
def forward(self, x, history=None, layer_past=None, len_past=None):
|
151 |
+
hidden_states = x
|
152 |
+
|
153 |
+
x = self.c_attn(x)
|
154 |
+
query, key, value = x.split(self.split_size, dim=2)
|
155 |
+
|
156 |
+
query = self.split_heads(query)
|
157 |
+
key = self.split_heads(key, k=True)
|
158 |
+
value = self.split_heads(value)
|
159 |
+
|
160 |
+
#_input_msk = None
|
161 |
+
|
162 |
+
len_kv = None
|
163 |
+
|
164 |
+
if layer_past is not None:
|
165 |
+
# key : (batch, head, head_features, seq_length)
|
166 |
+
# value : (batch, head, seq_length, head_features)
|
167 |
+
# layer_past, key : (batch, head, seq_length, head_features)
|
168 |
+
if len_past is None:
|
169 |
+
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
170 |
+
key = torch.cat((past_key, key), dim=-1)
|
171 |
+
value = torch.cat((past_value, value), dim=-2)
|
172 |
+
else:
|
173 |
+
key_seq = key.shape[-1]
|
174 |
+
assert key_seq == 1
|
175 |
+
|
176 |
+
_batch = torch.arange(0, key.shape[0], dtype=torch.long, device=key.device)
|
177 |
+
|
178 |
+
past_key, past_value = layer_past[0], layer_past[1]
|
179 |
+
|
180 |
+
past_key[_batch,:,len_past,:] = key.squeeze(-1)
|
181 |
+
past_value[_batch,:,len_past,:] = value.squeeze(-2)
|
182 |
+
|
183 |
+
key = past_key.transpose(-2, -1)
|
184 |
+
value = past_value
|
185 |
+
|
186 |
+
len_kv = len_past + 1
|
187 |
+
|
188 |
+
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
189 |
+
a = self._attn(query, key, value, len_kv = len_kv)
|
190 |
+
a = self.merge_heads(a)
|
191 |
+
a = self.c_proj(a)
|
192 |
+
# logging.info(f"attention forward: {a[0,0,:100]}, present: {present[0,0,0,:]}")
|
193 |
+
return a, present
|
194 |
+
|
195 |
+
|
196 |
+
class MLP(nn.Module):
|
197 |
+
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
198 |
+
super(MLP, self).__init__()
|
199 |
+
nx = config.n_embd
|
200 |
+
self.c_fc = Conv1D(n_state, nx)
|
201 |
+
self.c_proj = Conv1D(nx, n_state)
|
202 |
+
self.act = gelu
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
h = self.act(self.c_fc(x))
|
206 |
+
h2 = self.c_proj(h)
|
207 |
+
return h2
|
208 |
+
|
209 |
+
|
210 |
+
class Block(nn.Module):
|
211 |
+
def __init__(self, n_ctx, config, scale=False):
|
212 |
+
super(Block, self).__init__()
|
213 |
+
nx = config.n_embd
|
214 |
+
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
215 |
+
self.attn = Attention(nx, n_ctx, config, scale)
|
216 |
+
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
217 |
+
self.mlp = MLP(4 * nx, config)
|
218 |
+
|
219 |
+
def forward(self, x, layer_past=None, len_past=None):
|
220 |
+
a, present = self.attn(self.ln_1(x), layer_past=layer_past, len_past=len_past)
|
221 |
+
x = x + a
|
222 |
+
m = self.mlp(self.ln_2(x))
|
223 |
+
x = x + m
|
224 |
+
return x, present
|
225 |
+
|
226 |
+
|
227 |
+
class GPT2Model(nn.Module):
|
228 |
+
def __init__(self, config):
|
229 |
+
super(GPT2Model, self).__init__()
|
230 |
+
self.n_layer = config.n_layer
|
231 |
+
self.n_embd = config.n_embd
|
232 |
+
self.n_vocab = config.vocab_size
|
233 |
+
|
234 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
235 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
236 |
+
block = Block(config.n_ctx, config, scale=True)
|
237 |
+
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
238 |
+
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
239 |
+
|
240 |
+
self.config = config
|
241 |
+
|
242 |
+
|
243 |
+
def forward(
|
244 |
+
self,
|
245 |
+
input_ids,
|
246 |
+
position_ids=None,
|
247 |
+
token_type_ids=None,
|
248 |
+
past=None,
|
249 |
+
len_past=None
|
250 |
+
):
|
251 |
+
if past is None:
|
252 |
+
past_length = 0
|
253 |
+
past = [None] * len(self.h)
|
254 |
+
elif len_past is None:
|
255 |
+
# equal size for past. []
|
256 |
+
past_length = past[0][0].size(-2)
|
257 |
+
|
258 |
+
if position_ids is None and len_past is None:
|
259 |
+
position_ids = torch.arange(
|
260 |
+
past_length, input_ids.size(-1) + past_length,
|
261 |
+
dtype=torch.long, device=input_ids.device
|
262 |
+
)
|
263 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
264 |
+
elif len_past is not None:
|
265 |
+
position_ids = (len_past).unsqueeze(1) #.long()
|
266 |
+
|
267 |
+
input_shape = input_ids.size()
|
268 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
269 |
+
position_ids = position_ids.view(-1, position_ids.size(-1))
|
270 |
+
|
271 |
+
inputs_embeds = self.wte(input_ids)
|
272 |
+
|
273 |
+
position_embeds = self.wpe(position_ids)
|
274 |
+
|
275 |
+
if token_type_ids is not None:
|
276 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
277 |
+
token_type_embeds = self.wte(token_type_ids)
|
278 |
+
else:
|
279 |
+
token_type_embeds = 0
|
280 |
+
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
281 |
+
presents = []
|
282 |
+
for block, layer_past in zip(self.h, past):
|
283 |
+
hidden_states, present = block(hidden_states, layer_past = layer_past, len_past=len_past)
|
284 |
+
presents.append(present)
|
285 |
+
hidden_states = self.ln_f(hidden_states)
|
286 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
287 |
+
return hidden_states.view(*output_shape), presents
|
288 |
+
|
289 |
+
|
290 |
+
class GPT2LMHead(nn.Module):
|
291 |
+
def __init__(self, model_embeddings_weights, config):
|
292 |
+
super(GPT2LMHead, self).__init__()
|
293 |
+
self.n_embd = config.n_embd
|
294 |
+
self.set_embeddings_weights(model_embeddings_weights)
|
295 |
+
|
296 |
+
def set_embeddings_weights(self, model_embeddings_weights):
|
297 |
+
embed_shape = model_embeddings_weights.shape
|
298 |
+
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
299 |
+
self.decoder.weight = model_embeddings_weights # Tied weights
|
300 |
+
|
301 |
+
def forward(self, hidden_state):
|
302 |
+
# Truncated Language modeling logits (we remove the last token)
|
303 |
+
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
|
304 |
+
lm_logits = self.decoder(hidden_state)
|
305 |
+
return lm_logits
|
306 |
+
|
307 |
+
|
308 |
+
class GPT2Config(object):
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
vocab_size_or_config_json_file=50257,
|
312 |
+
n_positions=1024,
|
313 |
+
n_ctx=1024,
|
314 |
+
n_embd=768,
|
315 |
+
n_layer=12,
|
316 |
+
n_head=12,
|
317 |
+
layer_norm_epsilon=1e-5,
|
318 |
+
initializer_range=0.02,
|
319 |
+
lora_attn_dim=0,
|
320 |
+
lora_attn_alpha=128,
|
321 |
+
lora_dropout=0.0,
|
322 |
+
lora_r_dropout=0.0,
|
323 |
+
fix_dropout=0.0,
|
324 |
+
):
|
325 |
+
self.vocab_size = vocab_size_or_config_json_file
|
326 |
+
self.n_ctx = n_ctx
|
327 |
+
self.n_positions = n_positions
|
328 |
+
self.n_embd = n_embd
|
329 |
+
self.n_layer = n_layer
|
330 |
+
self.n_head = n_head
|
331 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
332 |
+
self.initializer_range = initializer_range
|
333 |
+
self.lora_attn_dim = lora_attn_dim
|
334 |
+
self.lora_attn_alpha = lora_attn_alpha
|
335 |
+
self.lora_dropout = lora_dropout
|
336 |
+
self.lora_r_dropout = lora_r_dropout
|
337 |
+
|
338 |
+
self.fix_dropout = fix_dropout
|
339 |
+
|
340 |
+
|
341 |
+
class GPT2LMModel(nn.Module):
|
342 |
+
def __init__(self, config):
|
343 |
+
super(GPT2LMModel, self).__init__()
|
344 |
+
self.transformer = GPT2Model(config)
|
345 |
+
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
346 |
+
self.apply(self._init_weights)
|
347 |
+
|
348 |
+
def set_tied(self):
|
349 |
+
""" Make sure we are sharing the embeddings"""
|
350 |
+
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
351 |
+
|
352 |
+
def forward(
|
353 |
+
self,
|
354 |
+
input_ids,
|
355 |
+
lm_labels=None,
|
356 |
+
lm_mask=None,
|
357 |
+
past=None,
|
358 |
+
len_past=None,
|
359 |
+
label_smooth=0.0,
|
360 |
+
is_report_accuracy=False
|
361 |
+
):
|
362 |
+
_batch, _len = input_ids.shape
|
363 |
+
hidden_states, presents = self.transformer(input_ids, past=past, len_past=len_past)
|
364 |
+
|
365 |
+
# batch, seq, vocab
|
366 |
+
lm_logits = self.lm_head(hidden_states)
|
367 |
+
|
368 |
+
if lm_labels is not None:
|
369 |
+
|
370 |
+
if is_report_accuracy:
|
371 |
+
_pred_token = torch.argmax(lm_logits, dim=-1)
|
372 |
+
_hit = (_pred_token == lm_labels) * lm_mask
|
373 |
+
|
374 |
+
_t1_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
|
375 |
+
_all_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
|
376 |
+
|
377 |
+
for _b in range(0, _batch):
|
378 |
+
for _i in range(0, _len):
|
379 |
+
if lm_mask[_b, _i] >= 1.0:
|
380 |
+
if _hit[_b, _i] > 0:
|
381 |
+
_t1_acc[_b] = 1.0
|
382 |
+
break
|
383 |
+
|
384 |
+
_is_succ = True
|
385 |
+
for _i in range(0, _len):
|
386 |
+
if lm_mask[_b, _i] >= 1.0:
|
387 |
+
if _hit[_b, _i] <= 0:
|
388 |
+
_is_succ = False
|
389 |
+
break
|
390 |
+
|
391 |
+
if _is_succ:
|
392 |
+
_all_acc[_b] = 1.0
|
393 |
+
|
394 |
+
#_t1_acc = _t1_acc * 1.0 / _batch
|
395 |
+
#_all_acc = _all_acc * 1.0 / _batch
|
396 |
+
|
397 |
+
if label_smooth > 0.0001:
|
398 |
+
logprobs = torch.nn.functional.log_softmax(lm_logits.view(-1, lm_logits.size(-1)), dim=-1)
|
399 |
+
nll_loss = -logprobs.gather(dim=-1, index=lm_labels.view(-1).unsqueeze(1))
|
400 |
+
nll_loss = nll_loss.squeeze(1)
|
401 |
+
smooth_loss = -logprobs.mean(dim=-1)
|
402 |
+
loss = (1.0 - label_smooth) * nll_loss + label_smooth * smooth_loss
|
403 |
+
loss = loss.view(_batch, _len)
|
404 |
+
else:
|
405 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduce=False)
|
406 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)).view(_batch, _len)
|
407 |
+
|
408 |
+
if lm_mask is None:
|
409 |
+
lm_mask = torch.ones(loss.shape, dtype=loss.dtype, device=loss.device)
|
410 |
+
loss = loss * lm_mask
|
411 |
+
|
412 |
+
loss = loss.sum() / (lm_mask.sum() + 0.0001)
|
413 |
+
|
414 |
+
if is_report_accuracy:
|
415 |
+
return lm_logits, loss, _t1_acc, _all_acc
|
416 |
+
else:
|
417 |
+
return lm_logits, loss
|
418 |
+
return lm_logits, presents
|
419 |
+
|
420 |
+
def _init_weights(self, module):
|
421 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
422 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
423 |
+
elif isinstance(module, nn.LayerNorm):
|
424 |
+
module.bias.data.zero_()
|
425 |
+
module.weight.data.fill_(1.0)
|
426 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
427 |
+
module.bias.data.zero_()
|
428 |
+
|
429 |
+
def load_weight(self, state_dict):
|
430 |
+
if 'model_state_dict' in state_dict:
|
431 |
+
state_dict = state_dict['model_state_dict']
|
432 |
+
|
433 |
+
state_dict_tmp = copy.deepcopy(state_dict)
|
434 |
+
old_keys = []
|
435 |
+
new_keys = []
|
436 |
+
for key in state_dict_tmp:
|
437 |
+
new_key = None
|
438 |
+
if key.endswith(".g"):
|
439 |
+
new_key = key[:-2] + ".weight"
|
440 |
+
elif key.endswith(".b"):
|
441 |
+
new_key = key[:-2] + ".bias"
|
442 |
+
elif key.endswith(".w"):
|
443 |
+
new_key = key[:-2] + ".weight"
|
444 |
+
|
445 |
+
if key.startswith("module.transformer."):
|
446 |
+
new_key = key[len("module.transformer."):]
|
447 |
+
|
448 |
+
if new_key:
|
449 |
+
old_keys.append(key)
|
450 |
+
new_keys.append(new_key)
|
451 |
+
|
452 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
453 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
454 |
+
|
455 |
+
for n, p in self.transformer.named_parameters():
|
456 |
+
if n not in state_dict:
|
457 |
+
state_dict[n] = p
|
458 |
+
|
459 |
+
self.transformer.load_state_dict(state_dict, strict=False)
|
460 |
+
self.set_tied()
|