from datasets import load_dataset from transformers import AutoTokenizer, DataCollatorForLanguageModeling from torch.utils.data import DataLoader from typing import Tuple def build_dataloaders(dataset_name: str, tokenizer_name: str, batch_size: int, val_split: float = 0.05, block_size: int = 512, num_workers: int = 2) -> Tuple: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token raw = load_dataset(dataset_name) if 'train' not in raw: raw = {'train': raw} if isinstance(raw, dict) and 'train' in raw: ds = raw['train'] else: ds = raw split = ds.train_test_split(test_size=val_split, seed=42) if hasattr(ds, 'train_test_split') else {'train': ds, 'test': ds} train_ds, val_ds = split['train'], split['test'] def text_key(example): for k in example.keys(): if example[k] is not None and isinstance(example[k], str): return k return None sample = train_ds[0] tkey = text_key(sample) or 'text' train_tok = train_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=train_ds.column_names) val_tok = val_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=val_ds.column_names) def labelize(batch): input_ids = batch['input_ids'] labels = [ids[:] for ids in input_ids] for i, ids in enumerate(labels): labels[i] = [(-100 if token == tokenizer.pad_token_id else token) for token in ids] batch['labels'] = labels return batch train_tok = train_tok.map(labelize, batched=True) val_tok = val_tok.map(labelize, batched=True) collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) train_loader = DataLoader(train_tok, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collator) val_loader = DataLoader(val_tok, batch_size=max(2, batch_size), shuffle=False, num_workers=num_workers, collate_fn=collator) return tokenizer, train_loader, val_loader