Text Generation
English
instruction-following
reasoning
gem-1o / utils /data_preprocessing.py
comethrusws's picture
Commit #1: GEM_1o_Aug trained
d18eb09 verified
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
def train_tokenizer(texts, vocab_size=50000, min_frequency=2):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = tokenizer.train_new_from_iterator(texts, vocab_size=vocab_size, min_frequency=min_frequency)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.save_pretrained("./tokenizer")
return tokenizer
def load_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("./tokenizer")
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
return tokenizer
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length)
return torch.tensor(encodings['input_ids'])
def get_dataloader(dataset_name, config_name, tokenizer, max_length, batch_size):
dataset = load_dataset(dataset_name, config_name)
texts = dataset['train']['text'][:50] #delete [:500 for actual training set w/ full voxabsize]
dataset = TextDataset(texts, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataloader