|
|
|
import torch |
|
from torch.utils.data import DataLoader, Dataset |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer |
|
|
|
def train_tokenizer(texts, vocab_size=50000, min_frequency=2): |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer = tokenizer.train_new_from_iterator(texts, vocab_size=vocab_size, min_frequency=min_frequency) |
|
if tokenizer.pad_token is None: |
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
tokenizer.save_pretrained("./tokenizer") |
|
return tokenizer |
|
|
|
def load_tokenizer(): |
|
tokenizer = AutoTokenizer.from_pretrained("./tokenizer") |
|
if tokenizer.pad_token is None: |
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
return tokenizer |
|
|
|
class TextDataset(Dataset): |
|
def __init__(self, texts, tokenizer, max_length): |
|
self.texts = texts |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.texts) |
|
|
|
def __getitem__(self, idx): |
|
text = self.texts[idx] |
|
encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length) |
|
return torch.tensor(encodings['input_ids']) |
|
|
|
def get_dataloader(dataset_name, config_name, tokenizer, max_length, batch_size): |
|
dataset = load_dataset(dataset_name, config_name) |
|
texts = dataset['train']['text'][:50] |
|
dataset = TextDataset(texts, tokenizer, max_length) |
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
return dataloader |
|
|