|
|
|
''' |
|
@File : t5_gen_datasets.py |
|
@Time : 2022/10/24 19:29 |
|
@Author : He Junqing |
|
@Version : 1.0 |
|
@Contact : [email protected] |
|
@License : (C)Copyright 2022-2023, CCNL-IDEA |
|
''' |
|
|
|
from logging import exception |
|
from transformers import ( |
|
BertTokenizer, |
|
MT5Config, |
|
MT5Tokenizer, |
|
MT5ForConditionalGeneration, |
|
) |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn.utils.rnn import pad_sequence |
|
import pytorch_lightning as pl |
|
import numpy as np |
|
import sys |
|
|
|
sys.path.append("../../") |
|
|
|
special_token_dict = { |
|
"additional_special_tokens": [ |
|
"[CTSTART]", |
|
"[CTEND]", |
|
"[SEP]", |
|
"[KNSTART]", |
|
"[KNEND]", |
|
] |
|
} |
|
|
|
|
|
class DialogDataset(Dataset): |
|
def __init__(self, data_path, args, data, load_data_type=1) -> None: |
|
super().__init__() |
|
|
|
if args.tokenizer_type == "t5_tokenizer": |
|
self.tokenizer = MT5Tokenizer.from_pretrained( |
|
args.pretrained_model_path) |
|
if len(self.tokenizer) == 32596: |
|
self.tokenizer.add_special_tokens(special_token_dict) |
|
print( |
|
"add special tokens to tokenizer,vocab size:", |
|
len(self.tokenizer) |
|
) |
|
self.model = MT5ForConditionalGeneration.from_pretrained( |
|
args.pretrained_model_path |
|
) |
|
self.model.resize_token_embeddings(len(self.tokenizer)) |
|
self.model.save_pretrained(args.new_vocab_path) |
|
self.tokenizer.save_pretrained( |
|
args.new_vocab_path) |
|
else: |
|
self.tokenizer = BertTokenizer.from_pretrained( |
|
args.pretrained_model_path) |
|
|
|
self.load_data_type = load_data_type |
|
self.data_split = data |
|
self.num_workers = args.preprocessing_num_workers |
|
self.max_seq_length = args.max_seq_length |
|
self.max_knowledge_length = args.max_knowledge_length |
|
self.max_target_length = args.max_target_length |
|
|
|
|
|
self.config = MT5Config.from_pretrained(args.pretrained_model_path) |
|
self.decoder_start_token_id = self.config.decoder_start_token_id |
|
self.eos_token_id = self.config.eos_token_id |
|
self.vocab_size = self.config.vocab_size |
|
|
|
|
|
|
|
|
|
if self.load_data_type == 0: |
|
self.data = self.load_data(data_path) |
|
elif self.load_data_type == 1: |
|
self.data = self.load_packed_data(data_path) |
|
else: |
|
self.data = data_path |
|
|
|
def load_packed_data(self, data_path): |
|
from fengshen.data.fs_datasets import load_dataset |
|
|
|
samples = load_dataset(data_path, |
|
num_proc=self.num_workers)[self.data_split] |
|
tokenized_samples = samples.map( |
|
self.regular_tokenize, batched=False, |
|
num_proc=self.num_workers |
|
) |
|
|
|
return tokenized_samples |
|
|
|
def load_data(self, data_path): |
|
""" |
|
load data from raw data |
|
return untokoenized data |
|
""" |
|
from datasets import load_dataset |
|
|
|
ds = load_dataset("json", data_files=data_path)['train'] |
|
samples = ds.map(self.regular_tokenize, batched=False, num_proc=self.num_workers |
|
) |
|
return samples |
|
|
|
def __getitem__(self, index): |
|
return self.data[index] |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def regular_tokenize(self, sample): |
|
|
|
context_ids = self.tokenizer( |
|
sample["context"], |
|
add_special_tokens=True, |
|
return_attention_mask=False, |
|
return_token_type_ids=True, |
|
) |
|
|
|
context_types = self.get_token_type( |
|
sample["context"], context_ids["token_type_ids"] |
|
) |
|
|
|
|
|
knowledge_ids = self.tokenizer.encode( |
|
sample["knowledge"], add_special_tokens=False |
|
) |
|
|
|
if isinstance(knowledge_ids, int): |
|
knowledge_ids = [knowledge_ids] |
|
target_ids = self.tokenizer.encode( |
|
sample["target"], |
|
add_special_tokens=False, |
|
max_length=self.max_target_length - 1, |
|
truncation=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
knowledge_ids = ( |
|
[self.tokenizer.convert_tokens_to_ids("[KNSTART]")] |
|
+ knowledge_ids[: self.max_knowledge_length - 2] |
|
+ [self.tokenizer.convert_tokens_to_ids("[KNEND]")] |
|
) |
|
l_kn = len(knowledge_ids) |
|
knowledge_types = [2] * l_kn |
|
|
|
flatten_context = [] |
|
for line in context_ids["input_ids"]: |
|
flatten_context.extend(line) |
|
l_ct = min(len(flatten_context), self.max_seq_length - l_kn - 2) |
|
context_ids = ( |
|
[self.tokenizer.convert_tokens_to_ids("[CTSTART]")] |
|
+ flatten_context[-l_ct:] |
|
+ [self.tokenizer.convert_tokens_to_ids("[CTEND]")] |
|
) |
|
|
|
context_types = context_types[-l_ct:] + [0] |
|
context_types.insert(0, context_types[0]) |
|
assert len(context_ids) == len( |
|
context_types |
|
), "len of context ids and token types unmatch, context:{},ids:{} types:{},len {}:{}".format( |
|
sample["context"], |
|
context_ids, |
|
context_types, |
|
len(context_ids), |
|
len(context_types), |
|
) |
|
|
|
try: |
|
target_ids = target_ids + [self.eos_token_id] |
|
except exception: |
|
print(sample["target"], target_ids, self.eos_token_id) |
|
|
|
tokenized = {} |
|
tokenized["input_ids"] = np.array(context_ids + knowledge_ids, dtype=np.int32) |
|
tokenized["token_types"] = np.array( |
|
context_types + knowledge_types, dtype=np.int32 |
|
) |
|
tokenized["attention_mask"] = np.ones( |
|
len(context_types + knowledge_types), dtype=np.int8 |
|
) |
|
tokenized["labels"] = np.array(target_ids, dtype=np.int32) |
|
|
|
return tokenized |
|
|
|
def get_token_type(self, context, tokentypes=None): |
|
|
|
context_token_types = [] |
|
for i, line in enumerate(context): |
|
if tokentypes: |
|
if i % 2 == 0: |
|
token_type = [0] * len(tokentypes[i]) |
|
else: |
|
token_type = [1] * len(tokentypes[i]) |
|
else: |
|
if i % 2 == 0: |
|
token_type = [0] * (1 + len(line)) |
|
else: |
|
token_type = [1] * (1 + len(line)) |
|
|
|
context_token_types.extend(token_type) |
|
|
|
return context_token_types |
|
|
|
|
|
class DialogDataModel(pl.LightningDataModule): |
|
@staticmethod |
|
def add_data_specific_args(parent_args): |
|
parser = parent_args.add_argument_group("SuperviseT5DataModel") |
|
parser.add_argument("--dataset_num_workers", default=8, type=int) |
|
parser.add_argument("--dataloader_num_workers", default=4, type=int) |
|
parser.add_argument("--train_data_path", default="dialog_4g_test", type=str) |
|
parser.add_argument( |
|
"--valid_data_path", default="wudao_180g_mt5_tokenized", type=str |
|
) |
|
parser.add_argument("--train_batchsize", default=2, type=int) |
|
parser.add_argument("--valid_batchsize", default=2, type=int) |
|
parser.add_argument("--max_seq_length", default=512, type=int) |
|
parser.add_argument("--max_knowledge_length", default=128, type=int) |
|
parser.add_argument("--max_target_length", default=128, type=int) |
|
|
|
return parent_args |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.save_hyperparameters(args) |
|
self.load_data(args) |
|
self.epochs = args.max_epochs |
|
|
|
def load_data(self, args): |
|
if args.train_split_size is not None: |
|
from fengshen.data.fs_datasets import load_dataset |
|
|
|
data_splits = load_dataset( |
|
args.train_data_path, num_proc=args.dataset_num_workers |
|
) |
|
train_split = data_splits['train'] |
|
test_split = data_splits['test'] |
|
print('train:', train_split, '\ntest_data:', test_split) |
|
self.train_dataset = DialogDataset( |
|
args.train_data_path, args, load_data_type=1, data="train" |
|
) |
|
self.test_dataset = DialogDataset( |
|
args.train_data_path, args, load_data_type=1, data="test" |
|
) |
|
else: |
|
self.train_data = DialogDataset( |
|
args.train_data_path, args, load_data_type=1 |
|
) |
|
|
|
self.config = MT5Config.from_pretrained(args.pretrained_model_path) |
|
self.pad_token_id = self.config.pad_token_id |
|
self.decoder_start_token_id = self.config.decoder_start_token_id |
|
print("bos id:", self.decoder_start_token_id) |
|
|
|
def collate_fn(self, samples): |
|
batch = { |
|
k: [ |
|
torch.tensor(samples[i][k], dtype=torch.int64) |
|
for i in range(len(samples)) |
|
] |
|
for k in ["input_ids", "token_types", "attention_mask", "labels"] |
|
} |
|
|
|
|
|
for k, v in batch.items(): |
|
if k != "labels": |
|
batch[k] = pad_sequence( |
|
v, batch_first=True, padding_value=self.pad_token_id |
|
) |
|
else: |
|
batch[k] = pad_sequence(v, batch_first=True, padding_value=-100) |
|
batch["decoder_input_ids"] = torch.tensor( |
|
self.shift_tokens_right( |
|
batch["labels"], self.pad_token_id, self.decoder_start_token_id |
|
), |
|
dtype=torch.long, |
|
) |
|
return batch |
|
|
|
def shift_tokens_right( |
|
self, input_ids: np.array, pad_token_id: int, decoder_start_token_id: int |
|
) -> np.ndarray: |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = np.zeros_like(input_ids) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1] |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
shifted_input_ids = np.where( |
|
shifted_input_ids == -100, pad_token_id, shifted_input_ids |
|
) |
|
return shifted_input_ids |
|
|
|
def train_dataloader(self): |
|
from fengshen.data.universal_datamodule.universal_sampler import ( |
|
PretrainingRandomSampler, |
|
) |
|
from fengshen.data.universal_datamodule.universal_datamodule import ( |
|
get_consume_samples, |
|
) |
|
|
|
|
|
consumed_samples = get_consume_samples(self) |
|
batch_sampler = PretrainingRandomSampler( |
|
epoch=self.epochs, |
|
total_samples=len(self.train_dataset), |
|
consumed_samples=consumed_samples, |
|
micro_batch_size=self.hparams.train_batchsize, |
|
data_parallel_rank=self.trainer.global_rank, |
|
data_parallel_size=self.trainer.world_size, |
|
) |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_sampler=batch_sampler, |
|
pin_memory=True, |
|
num_workers=self.hparams.dataloader_num_workers, |
|
collate_fn=self.collate_fn, |
|
) |
|
|
|
def val_dataloader(self): |
|
sampler = torch.utils.data.distributed.DistributedSampler( |
|
self.test_dataset, shuffle=False |
|
) |
|
return DataLoader( |
|
self.test_dataset, |
|
sampler=sampler, |
|
shuffle=False, |
|
batch_size=self.hparams.valid_batchsize, |
|
pin_memory=True, |
|
num_workers=self.hparams.dataloader_num_workers, |
|
collate_fn=self.collate_fn, |
|
) |
|
|
|
def predict_dataloader(self): |
|
sampler = torch.utils.data.distributed.DistributedSampler( |
|
self.test_dataset, shuffle=False |
|
) |
|
return DataLoader( |
|
self.test_dataset, |
|
sampler=sampler, |
|
shuffle=False, |
|
batch_size=self.hparams.valid_batchsize, |
|
pin_memory=True, |
|
num_workers=self.hparams.dataloader_num_workers, |
|
collate_fn=self.collate_fn, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
|
|
total_parser = argparse.ArgumentParser("DATASET parser") |
|
total_parser.add_argument( |
|
"--tokenizer_type", |
|
default="t5_tokenizer", |
|
choices=["bert_tokenizer", "t5_tokenizer"], |
|
) |
|
total_parser.add_argument("--preprocessing_num_workers", default="10", type=int) |
|
total_parser.add_argument( |
|
"--new_vocab_path", |
|
default="/cognitive_comp/hejunqing/projects/Dialog_pretrain/randeng_t5_newvocab_784M", |
|
type=str, |
|
) |
|
total_parser.add_argument("--train_split_size", default=0.995, type=int) |
|
total_parser.add_argument( |
|
"--pretrained_model_path", |
|
default="/cognitive_comp/hejunqing/projects/Dialog_pretrain/randeng_t5_newvocab_784M", |
|
) |
|
total_parser = DialogDataModel.add_data_specific_args(total_parser) |
|
args = total_parser.parse_args() |
|
dl = DialogDataModel(args) |
|
|
|
for i in range(5): |
|
for batch in dl.train_dataloader(): |
|
print(batch) |
|
print(batch["input_ids"]) |
|
print(batch["token_types"]) |
|
print(batch["decoder_input_ids"]) |
|
print(batch["labels"]) |
|
|
|
print("test finish") |
|
|