import torch from datasets import load_dataset from transformers import ( Trainer, T5Config, T5TokenizerFast, TrainingArguments, DataCollatorForSeq2Seq, T5ForConditionalGeneration ) # Path config base_model = "t5-small" data_path = "src/data/clean_corpus.jsonl" tokeniser_path = "src/tokeniser/" output_dir = "checkpoints/" # Load tokeniser tokeniser = T5TokenizerFast.from_pretrained(tokeniser_path) vocab_size = tokeniser.vocab_size pad_token_id = tokeniser.pad_token_id # Use custom vocab size for the model config = T5Config( vocab_size = vocab_size, d_model = 512, d_ff = 2048, num_layers = 6, num_heads = 8, pad_token_id = pad_token_id, decoder_start_token_id = pad_token_id ) model = T5ForConditionalGeneration(config) def tokenise_function(example: dict) -> T5TokenizerFast: """ Simple function to tokenise input data. """ inputs = [f"Cyrillic2Latin: {item['src']}" for item in example["transliteration"]] targets = [item["tgt"] for item in example["transliteration"]] model_inputs = tokeniser( inputs, max_length = 128, truncation = True, padding = "max_length" ) labels = tokeniser( targets, max_length = 128, truncation = True, padding = "max_length" )["input_ids"] model_inputs["labels"] = labels return model_inputs # Load dataset dataset = load_dataset("json", data_files = data_path, split = "train") # Split dataset into train and validation sets (75/25 split) dataset_split = dataset.train_test_split(test_size = 0.25) train_dataset = dataset_split["train"] val_dataset = dataset_split["test"] # Tokenise datasets tokenised_train = train_dataset.map(tokenise_function, batched = True, remove_columns = ["transliteration"]) tokenised_eval = val_dataset.map(tokenise_function, batched = True, remove_columns = ["transliteration"]) # Data collator data_collator = DataCollatorForSeq2Seq(tokenizer = tokeniser, model = model) # Training args training_args = TrainingArguments( output_dir = output_dir, overwrite_output_dir = True, num_train_epochs = 2, per_device_train_batch_size = 32, gradient_accumulation_steps = 2, save_strategy = "steps", save_steps = 500, save_total_limit = 3, eval_strategy = "epoch", logging_dir = "logs", fp16 = torch.cuda.is_available() ) # Trainer trainer = Trainer( model = model, args = training_args, train_dataset = tokenised_train, eval_dataset = tokenised_eval, data_collator = data_collator, processing_class = tokeniser ) # Train trainer.train() model.save_pretrained(output_dir) tokeniser.save_pretrained(output_dir) print("DalaT5 training complete.")