Spaces:
Sleeping
Sleeping
| """ | |
| TinyLlama Training Script | |
| This script provides functionality to fine-tune a TinyLlama model on custom data. | |
| """ | |
| import os | |
| import argparse | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| Trainer, | |
| TrainingArguments, | |
| DataCollatorForLanguageModeling, | |
| ) | |
| from datasets import load_dataset | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Fine-tune a TinyLlama model") | |
| parser.add_argument( | |
| "--model_name", | |
| type=str, | |
| default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| help="Model identifier from HuggingFace" | |
| ) | |
| parser.add_argument( | |
| "--data_path", | |
| type=str, | |
| required=True, | |
| help="Path to the dataset file (JSON or CSV)" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./results", | |
| help="Directory to save model checkpoints" | |
| ) | |
| parser.add_argument( | |
| "--epochs", | |
| type=int, | |
| default=3, | |
| help="Number of training epochs" | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=4, | |
| help="Training batch size" | |
| ) | |
| parser.add_argument( | |
| "--learning_rate", | |
| type=float, | |
| default=5e-5, | |
| help="Learning rate" | |
| ) | |
| parser.add_argument( | |
| "--text_column", | |
| type=str, | |
| default="text", | |
| help="Column name containing the text data" | |
| ) | |
| parser.add_argument( | |
| "--gradient_accumulation_steps", | |
| type=int, | |
| default=4, | |
| help="Number of update steps to accumulate before updating weights" | |
| ) | |
| parser.add_argument( | |
| "--max_length", | |
| type=int, | |
| default=512, | |
| help="Maximum sequence length" | |
| ) | |
| parser.add_argument( | |
| "--fp16", | |
| action="store_true", | |
| help="Use mixed precision training" | |
| ) | |
| return parser.parse_args() | |
| def preprocess_function(examples, tokenizer, max_length): | |
| """Tokenize text examples""" | |
| result = tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| max_length=max_length, | |
| padding="max_length", | |
| ) | |
| result["labels"] = result["input_ids"].copy() | |
| return result | |
| def main(): | |
| args = parse_args() | |
| logger.info(f"Loading model: {args.model_name}") | |
| model = AutoModelForCausalLM.from_pretrained(args.model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load and prepare the dataset | |
| logger.info(f"Loading dataset from: {args.data_path}") | |
| data_extension = os.path.splitext(args.data_path)[1].replace(".", "") | |
| dataset = load_dataset(data_extension, data_files=args.data_path) | |
| # Preprocess the dataset | |
| logger.info("Preprocessing dataset") | |
| tokenized_dataset = dataset.map( | |
| lambda examples: preprocess_function(examples, tokenizer, args.max_length), | |
| batched=True, | |
| remove_columns=dataset["train"].column_names, | |
| ) | |
| # Setup training arguments | |
| training_args = TrainingArguments( | |
| output_dir=args.output_dir, | |
| overwrite_output_dir=True, | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| learning_rate=args.learning_rate, | |
| weight_decay=0.01, | |
| logging_dir=os.path.join(args.output_dir, "logs"), | |
| logging_steps=10, | |
| save_steps=500, | |
| save_total_limit=2, | |
| fp16=args.fp16, | |
| ) | |
| # Setup trainer | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, | |
| mlm=False, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_dataset["train"], | |
| data_collator=data_collator, | |
| ) | |
| # Start training | |
| logger.info("Starting training...") | |
| trainer.train() | |
| # Save trained model | |
| logger.info(f"Saving model to {args.output_dir}") | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| logger.info("Training complete!") | |
| if __name__ == "__main__": | |
| main() |