File size: 2,461 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from typing import Optional
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from fengshen.data.mmap_index_dataset import MMapIndexDataset
class MMapDataModule(LightningDataModule):
@ staticmethod
def add_data_specific_args(parent_args):
parser = parent_args.add_argument_group('MMAP DataModule')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--train_batchsize', default=32, type=int)
parser.add_argument('--eval_batchsize', default=32, type=int)
parser.add_argument('--test_batchsize', default=32, type=int)
parser.add_argument('--train_datas', default=[
'./train_datas'
], type=str, nargs='+')
parser.add_argument('--valid_datas', default=[
'./valid_datas'
], type=str, nargs='+')
parser.add_argument('--test_datas', default=[
'./test_datas'],
type=str, nargs='+')
parser.add_argument('--input_tensor_name', default=['input_ids'], type=str, nargs='+')
return parent_args
def __init__(
self,
collate_fn,
args,
**kwargs,
):
super().__init__()
self.collate_fn = collate_fn
self.train_dataset = MMapIndexDataset(args.train_datas, args.input_tensor_name)
self.valid_dataset = MMapIndexDataset(args.valid_datas, args.input_tensor_name)
self.test_dataset = MMapIndexDataset(args.test_datas, args.input_tensor_name)
self.save_hyperparameters(args)
def setup(self, stage: Optional[str] = None) -> None:
return super().setup(stage)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.hparams.train_batchsize,
shuffle=True,
num_workers=self.hparams.num_workers,
collate_fn=self.collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.valid_dataset,
batch_size=self.hparams.eval_batchsize,
shuffle=True,
num_workers=self.hparams.num_workers,
collate_fn=self.collate_fn,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.hparams.test_batchsize,
shuffle=True,
num_workers=self.hparams.num_workers,
collate_fn=self.collate_fn,
)
|