Spaces:
Sleeping
Sleeping
from datasets import load_dataset | |
from transformers import AutoTokenizer, DataCollatorForLanguageModeling | |
from torch.utils.data import DataLoader | |
from typing import Tuple | |
def build_dataloaders(dataset_name: str, tokenizer_name: str, batch_size: int, val_split: float = 0.05, block_size: int = 512, num_workers: int = 2) -> Tuple: | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
raw = load_dataset(dataset_name) | |
if 'train' not in raw: | |
raw = {'train': raw} | |
if isinstance(raw, dict) and 'train' in raw: | |
ds = raw['train'] | |
else: | |
ds = raw | |
split = ds.train_test_split(test_size=val_split, seed=42) if hasattr(ds, 'train_test_split') else {'train': ds, 'test': ds} | |
train_ds, val_ds = split['train'], split['test'] | |
def text_key(example): | |
for k in example.keys(): | |
if example[k] is not None and isinstance(example[k], str): | |
return k | |
return None | |
sample = train_ds[0] | |
tkey = text_key(sample) or 'text' | |
train_tok = train_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=train_ds.column_names) | |
val_tok = val_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=val_ds.column_names) | |
def labelize(batch): | |
input_ids = batch['input_ids'] | |
labels = [ids[:] for ids in input_ids] | |
for i, ids in enumerate(labels): | |
labels[i] = [(-100 if token == tokenizer.pad_token_id else token) for token in ids] | |
batch['labels'] = labels | |
return batch | |
train_tok = train_tok.map(labelize, batched=True) | |
val_tok = val_tok.map(labelize, batched=True) | |
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
train_loader = DataLoader(train_tok, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collator) | |
val_loader = DataLoader(val_tok, batch_size=max(2, batch_size), shuffle=False, num_workers=num_workers, collate_fn=collator) | |
return tokenizer, train_loader, val_loader | |