training_bench / data.py
rider-provider-777's picture
Upload 7 files
cd221f8 verified
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
from typing import Tuple
def build_dataloaders(dataset_name: str, tokenizer_name: str, batch_size: int, val_split: float = 0.05, block_size: int = 512, num_workers: int = 2) -> Tuple:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
raw = load_dataset(dataset_name)
if 'train' not in raw:
raw = {'train': raw}
if isinstance(raw, dict) and 'train' in raw:
ds = raw['train']
else:
ds = raw
split = ds.train_test_split(test_size=val_split, seed=42) if hasattr(ds, 'train_test_split') else {'train': ds, 'test': ds}
train_ds, val_ds = split['train'], split['test']
def text_key(example):
for k in example.keys():
if example[k] is not None and isinstance(example[k], str):
return k
return None
sample = train_ds[0]
tkey = text_key(sample) or 'text'
train_tok = train_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=train_ds.column_names)
val_tok = val_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=val_ds.column_names)
def labelize(batch):
input_ids = batch['input_ids']
labels = [ids[:] for ids in input_ids]
for i, ids in enumerate(labels):
labels[i] = [(-100 if token == tokenizer.pad_token_id else token) for token in ids]
batch['labels'] = labels
return batch
train_tok = train_tok.map(labelize, batched=True)
val_tok = val_tok.map(labelize, batched=True)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
train_loader = DataLoader(train_tok, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collator)
val_loader = DataLoader(val_tok, batch_size=max(2, batch_size), shuffle=False, num_workers=num_workers, collate_fn=collator)
return tokenizer, train_loader, val_loader