Shilpaj commited on
Commit
a1fcadc
·
verified ·
1 Parent(s): ad18062

Upload cosmopedia_datamodule.py

Browse files
Files changed (1) hide show
  1. cosmopedia_datamodule.py +120 -0
cosmopedia_datamodule.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Data module for Cosmopedia dataset
4
+ Author: Shilpaj Bhalerao
5
+ Date: 2025-01-20
6
+ """
7
+ # Standard Library Imports
8
+ from typing import Optional
9
+
10
+ # Third-Party Imports
11
+ import pytorch_lightning as pl
12
+ from torch.utils.data import DataLoader
13
+ from datasets import load_dataset
14
+ from transformers import GPT2Tokenizer
15
+
16
+ # Local Imports
17
+ from config import DataConfig
18
+
19
+
20
+ class CosmopediaDataModule(pl.LightningDataModule):
21
+ """
22
+ Data module for Cosmopedia dataset
23
+ """
24
+ def __init__(
25
+ self,
26
+ batch_size: int = DataConfig.batch_size,
27
+ num_workers: int = DataConfig.num_workers,
28
+ shuffle_buffer_size: int = DataConfig.shuffle_buffer_size,
29
+ max_length: int = DataConfig.max_length,
30
+ ):
31
+ """
32
+ Constructor
33
+ :param batch_size: Batch size for dataloaders
34
+ :param num_workers: Number of workers for dataloaders
35
+ :param shuffle_buffer_size: Size of buffer for shuffling streaming data
36
+ :param max_length: Maximum sequence length for tokenized text
37
+ """
38
+ super().__init__()
39
+ self.batch_size = batch_size
40
+ self.num_workers = num_workers
41
+ self.shuffle_buffer_size = shuffle_buffer_size
42
+ self.max_length = max_length
43
+
44
+ # Dataset path on HuggingFace
45
+ self.dataset_path = DataConfig.dataset_path
46
+ self.dataset_name = DataConfig.dataset_name
47
+
48
+ # Initialize tokenizer
49
+ self.tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
50
+ self.tokenizer.pad_token = self.tokenizer.eos_token
51
+
52
+ def setup(self, stage: Optional[str] = None):
53
+ """
54
+ Setup datasets for training and validation
55
+ """
56
+ # Load dataset in streaming mode
57
+ self.dataset = load_dataset(
58
+ self.dataset_path,
59
+ self.dataset_name,
60
+ split="train", # Only train split is available
61
+ streaming=DataConfig.streaming
62
+ )
63
+
64
+ # Shuffle the streaming dataset
65
+ self.dataset = self.dataset.shuffle(buffer_size=self.shuffle_buffer_size)
66
+
67
+ # Create train/val split using configured validation split
68
+ val_size = int(DataConfig.validation_split * self.shuffle_buffer_size)
69
+ self.train_dataset = self.dataset.skip(val_size)
70
+ self.val_dataset = self.dataset.take(val_size)
71
+
72
+ def collate_fn(self, batch):
73
+ """
74
+ Tokenize and pad the texts in the batch
75
+ """
76
+ texts = [item['text'] for item in batch]
77
+
78
+ # Tokenize all texts in the batch
79
+ encodings = self.tokenizer(
80
+ texts,
81
+ padding=True,
82
+ truncation=True,
83
+ max_length=self.max_length,
84
+ return_tensors='pt'
85
+ )
86
+
87
+ # Prepare inputs and labels for language modeling
88
+ input_ids = encodings['input_ids'][:, :-1]
89
+ labels = encodings['input_ids'][:, 1:]
90
+ attention_mask = encodings['attention_mask'][:, :-1]
91
+
92
+ return {
93
+ 'input_ids': input_ids,
94
+ 'labels': labels,
95
+ 'attention_mask': attention_mask
96
+ }
97
+
98
+ def train_dataloader(self):
99
+ """
100
+ Return train dataloader
101
+ """
102
+ return DataLoader(
103
+ self.train_dataset,
104
+ batch_size=self.batch_size,
105
+ num_workers=self.num_workers,
106
+ pin_memory=DataConfig.pin_memory,
107
+ collate_fn=self.collate_fn
108
+ )
109
+
110
+ def val_dataloader(self):
111
+ """
112
+ Return validation dataloader
113
+ """
114
+ return DataLoader(
115
+ self.val_dataset,
116
+ batch_size=self.batch_size,
117
+ num_workers=self.num_workers,
118
+ pin_memory=DataConfig.pin_memory,
119
+ collate_fn=self.collate_fn
120
+ )