Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, DataCollatorForLanguageModeling | |
| from torch.utils.data import DataLoader | |
| from typing import Tuple | |
| def build_dataloaders(dataset_name: str, tokenizer_name: str, batch_size: int, val_split: float = 0.05, block_size: int = 512, num_workers: int = 2) -> Tuple: | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| raw = load_dataset(dataset_name) | |
| if 'train' not in raw: | |
| raw = {'train': raw} | |
| if isinstance(raw, dict) and 'train' in raw: | |
| ds = raw['train'] | |
| else: | |
| ds = raw | |
| split = ds.train_test_split(test_size=val_split, seed=42) if hasattr(ds, 'train_test_split') else {'train': ds, 'test': ds} | |
| train_ds, val_ds = split['train'], split['test'] | |
| def text_key(example): | |
| for k in example.keys(): | |
| if example[k] is not None and isinstance(example[k], str): | |
| return k | |
| return None | |
| sample = train_ds[0] | |
| tkey = text_key(sample) or 'text' | |
| train_tok = train_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=train_ds.column_names) | |
| val_tok = val_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=val_ds.column_names) | |
| def labelize(batch): | |
| input_ids = batch['input_ids'] | |
| labels = [ids[:] for ids in input_ids] | |
| for i, ids in enumerate(labels): | |
| labels[i] = [(-100 if token == tokenizer.pad_token_id else token) for token in ids] | |
| batch['labels'] = labels | |
| return batch | |
| train_tok = train_tok.map(labelize, batched=True) | |
| val_tok = val_tok.map(labelize, batched=True) | |
| collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| train_loader = DataLoader(train_tok, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collator) | |
| val_loader = DataLoader(val_tok, batch_size=max(2, batch_size), shuffle=False, num_workers=num_workers, collate_fn=collator) | |
| return tokenizer, train_loader, val_loader | |